From 6fce648b0181cd0b1967cce431f0a05be894fb38 Mon Sep 17 00:00:00 2001 From: cdemeyer-teachx Date: Sun, 24 Aug 2025 20:03:11 +0900 Subject: [PATCH] sudoku class integration --- demos/sudoku/CMakeLists.txt | 1 + demos/sudoku/sudoku.h | 16 +++++++++++++ demos/sudoku/sudoku_wfc.cpp | 13 +++++++---- include/nd-wfc/wfc.hpp | 46 +++++++++++++++++++++++++++++-------- include/nd-wfc/worlds.hpp | 14 +++++++++++ 5 files changed, 75 insertions(+), 15 deletions(-) diff --git a/demos/sudoku/CMakeLists.txt b/demos/sudoku/CMakeLists.txt index 11f07e3..67f8301 100644 --- a/demos/sudoku/CMakeLists.txt +++ b/demos/sudoku/CMakeLists.txt @@ -42,6 +42,7 @@ add_executable(sudoku_demo # Create WFC demo executable add_executable(sudoku_wfc_demo sudoku_wfc.cpp + sudoku.cpp ) # Set output directory for sudoku_demo diff --git a/demos/sudoku/sudoku.h b/demos/sudoku/sudoku.h index 867c080..20552ed 100644 --- a/demos/sudoku/sudoku.h +++ b/demos/sudoku/sudoku.h @@ -193,6 +193,22 @@ private: } return false; } + +public: // WFC Support + using ValueType = uint8_t; + + ValueType getValue(size_t index) const { + return board_.get(index); + } + + void setValue(size_t index, ValueType value) { + board_.set(index, value); + } + + constexpr size_t size() const { + return 81; + } + }; // Static assert to ensure exactly 41 bytes diff --git a/demos/sudoku/sudoku_wfc.cpp b/demos/sudoku/sudoku_wfc.cpp index 9b70bb3..dac5f39 100644 --- a/demos/sudoku/sudoku_wfc.cpp +++ b/demos/sudoku/sudoku_wfc.cpp @@ -1,14 +1,15 @@ #include #include +#include "sudoku.h" #include int main() { std::cout << "Running Sudoku WFC" << std::endl; - auto sudokuSolver = WFC::Builder, uint8_t>() + auto sudokuSolver = WFC::Builder() .DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>() - .Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](WFC::Array2D&, size_t index, WFC::WorldValue val, auto& constrainer) { + .Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue val, auto& constrainer) { size_t x = index % 9; size_t y = index / 9; @@ -38,15 +39,17 @@ int main() }) .build(); - WFC::Array2D sudokuWorld; - bool success = sudokuSolver.Run(sudokuWorld); + Sudoku sudokuWorld; + sudokuWorld.setValue(0, 5); + sudokuWorld.setValue(80, 1); + bool success = sudokuSolver.Run(sudokuWorld, true); if (success) { std::cout << "Sudoku solved successfully!" << std::endl; // Print the solved sudoku for (size_t y = 0; y < 9; ++y) { for (size_t x = 0; x < 9; ++x) { - std::cout << static_cast(sudokuWorld.at(static_cast(x), static_cast(y))) << " "; + std::cout << static_cast(sudokuWorld.getValue(x + y * 9)) << " "; if (x == 2 || x == 5) std::cout << "| "; } std::cout << std::endl; diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index bb34d3b..a54eabf 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -3,8 +3,6 @@ #include #include #include -#include -#include #include #include #include @@ -17,7 +15,6 @@ namespace WFC { inline int FindNthSetBit(size_t num, int n) { - auto popCount = std::popcount(num); assert(n < std::popcount(num) && "index is out of range"); int bitCount = 0; while (num) { @@ -27,7 +24,6 @@ inline int FindNthSetBit(size_t num, int n) { bitCount++; num &= (num - 1); // turn of lowest set bit } - assert(bitCount < popCount && "out of bounds"); return bitCount; } @@ -35,6 +31,7 @@ template concept WorldType = requires(T world, size_t id, typename T::ValueType value) { { world.size() } -> std::convertible_to; { world.setValue(id, value) }; + { world.getValue(id) } -> std::convertible_to; typename T::ValueType; }; @@ -96,7 +93,7 @@ public: { static_assert(HasValue(), "Value was not defined"); constexpr VarT arr[] = {Values...}; - constexpr size_t size = sizeof...(Values); + constexpr size_t size = ValuesRegisteredAmount; for (size_t i = 0; i < size; ++i) if (arr[i] == Value) @@ -106,7 +103,7 @@ public: } static VarT GetValue(size_t index) { - assert(index < sizeof...(Values)); + assert(index < ValuesRegisteredAmount); constexpr VarT arr[] = {Values...}; return arr[index]; } @@ -117,7 +114,12 @@ public: return (0 | ... | (1 << GetIndex())); } - static consteval size_t size() { return sizeof...(Values); } + static consteval std::array GetAllValues() + { + return {Values...}; + } + + static consteval size_t size() { return ValuesRegisteredAmount; } }; template @@ -267,18 +269,23 @@ public: {} public: - bool Run(WorldT& world) + bool Run(WorldT& world, bool propagateInitialValues = false) { WorldSolver worldSolver(world, m_variables); - return Run(worldSolver); + return Run(worldSolver, propagateInitialValues); } /** * @brief Run the WFC algorithm to generate a solution * @return true if a solution was found, false if contradiction occurred */ - bool Run(WorldSolver& worldSolver) + bool Run(WorldSolver& worldSolver, bool propagateInitialValues = false) { + if (propagateInitialValues) + { + PropogateInitialValues(worldSolver); + } + for (size_t i = 0; i < 1024; ++i) { Propagate(worldSolver); @@ -379,6 +386,25 @@ private: } } +private: + void PropogateInitialValues(WorldSolver& worldSolver) + { + auto allValues = VariableIDMapT::GetAllValues(); + for (size_t i = 0; i < worldSolver.wave.size(); ++i) + { + for (size_t j = 0; j < allValues.size(); ++j) + { + if (worldSolver.world.getValue(i) == allValues[j]) + { + worldSolver.wave.Collapse(i, 1 << j); + worldSolver.propagationQueue.push(i); + break; + } + } + } + } + +private: std::vector> m_variables {}; }; diff --git a/include/nd-wfc/worlds.hpp b/include/nd-wfc/worlds.hpp index 30cc416..c6ec047 100644 --- a/include/nd-wfc/worlds.hpp +++ b/include/nd-wfc/worlds.hpp @@ -86,6 +86,13 @@ public: data_[index] = value; } + /** + * @brief Get value at specific index + */ + T getValue(size_t index) const { + return data_[index]; + } + private: std::array data_; }; @@ -161,6 +168,13 @@ public: data_[index] = value; } + /** + * @brief Get value at specific index + */ + T getValue(size_t index) const { + return data_[index]; + } + private: std::array data_; };