From 414ded7e09debc5d25750fd1bde0ffbf5e66a446 Mon Sep 17 00:00:00 2001 From: Connor Date: Fri, 6 Feb 2026 20:05:16 +0900 Subject: [PATCH] WFC class refactor --- demos/sudoku/sudoku_wfc.cpp | 2 +- demos/sudoku/test_sudoku.cpp | 4 +- include/nd-wfc/wfc.hpp | 532 +++++++++++++++------------- include/nd-wfc/wfc_builder.hpp | 2 +- include/nd-wfc/wfc_variable_map.hpp | 1 + 5 files changed, 282 insertions(+), 259 deletions(-) diff --git a/demos/sudoku/sudoku_wfc.cpp b/demos/sudoku/sudoku_wfc.cpp index 169dae5..6dc29c6 100644 --- a/demos/sudoku/sudoku_wfc.cpp +++ b/demos/sudoku/sudoku_wfc.cpp @@ -64,7 +64,7 @@ int main() Sudoku sudokuWorld = Sudoku{ "6......3.......7....7463....7.8...2.4...9...1.9...7.8....9851....6.......1......9" }; - bool success = SudokuSolverCallback::Run(sudokuWorld, true); + bool success = WFC::Run(sudokuWorld, true); bool solved = sudokuWorld.isSolved(); diff --git a/demos/sudoku/test_sudoku.cpp b/demos/sudoku/test_sudoku.cpp index 8080fd3..f4a4f06 100644 --- a/demos/sudoku/test_sudoku.cpp +++ b/demos/sudoku/test_sudoku.cpp @@ -33,7 +33,7 @@ protected: // Helper function to solve a puzzle using WFC void solvePuzzle(Sudoku& sudoku) { - SudokuSolver::Run(sudoku, true); + WFC::Run(sudoku, true); } }; @@ -286,7 +286,7 @@ void testPuzzleSolving(const std::string& difficulty, const std::string& filenam Sudoku& sudoku = puzzles[i]; EXPECT_TRUE(sudoku.isValid()) << difficulty << " puzzle " << i << " is not valid"; - SudokuSolver::Run(sudoku); + WFC::Run(sudoku); EXPECT_TRUE(sudoku.isSolved()) << difficulty << " puzzle " << i << " was not solved. Puzzle string: " << sudoku.toString(); diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index cd3c80d..4e86edf 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -44,265 +44,287 @@ concept HasConstexprSize = requires { { []() constexpr -> std::size_t { return WorldT{}.size(); }() }; }; -template, - typename ConstrainerFunctionMapT = ConstrainerFunctionMap, - typename CallbacksT = Callbacks, - typename RandomSelectorT = DefaultRandomSelector -> -class WFC { -public: +// Standalone SolverState struct +template > +struct SolverState { + using WorldType = WorldT; + using WorldSizeT = decltype(WorldT{}.size()); + static constexpr WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + using PropagationQueueType = WFCQueue; + + WorldT& m_world; + PropagationQueueType m_propagationQueue{}; + RandomSelectorT m_randomSelector{}; + WFCStackAllocator m_allocator{}; + size_t m_iterations{}; + + SolverState(WorldT& world, uint32_t seed) + : m_world(world) + , m_propagationQueue{ WorldSize ? WorldSize : static_cast(world.size()) } + , m_randomSelector(seed) + {} + + SolverState(const SolverState& other) = default; +}; + +// Types-only config struct produced by Builder +template +struct WFCConfig { static_assert(WorldType, "WorldT must satisfy World type requirements"); using WorldSizeT = decltype(WorldT{}.size()); - - // Try getting the world size, which is only available if the world type has a constexpr size() method - constexpr static WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; - + static constexpr WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + using SolverStateType = SolverState; using WaveType = Wave; - using PropagationQueueType = WFCQueue; - using ConstrainerType = Constrainer; - using MaskType = typename WaveType::ElementT; - using VariableIDT = typename WaveType::VariableIDT; - -public: - struct SolverState - { - WorldT& m_world; - PropagationQueueType m_propagationQueue{}; - RandomSelectorT m_randomSelector{}; - WFCStackAllocator m_allocator{}; - size_t m_iterations{}; - - SolverState(WorldT& world, uint32_t seed) - : m_world(world) - , m_propagationQueue{ WorldSize ? WorldSize : static_cast(world.size()) } - , m_randomSelector(seed) - {} - - SolverState(const SolverState& other) = default; - }; - -public: - WFC() = delete; // dont make an instance of this class, only use the static methods. - -public: - - static bool Run(WorldT& world, uint32_t seed = std::random_device{}()) - { - SolverState state{ world, seed }; - bool result = Run(state); - - return result; - } - - /** - * @brief Run the WFC algorithm to generate a solution - * @return true if a solution was found, false if contradiction occurred - */ - static bool Run(SolverState& state) - { - WaveType wave{ WorldSize, VariableIDMapT::size(), state.m_allocator }; - - PropogateInitialValues(state, wave); - - if (RunLoop(state, wave)) { - - PopulateWorld(state, wave); - return true; - } - return false; - } - - static bool RunLoop(SolverState& state, WaveType& wave) - { - static constexpr size_t MaxIterations = 1024 * 8; - for (; state.m_iterations < MaxIterations; ++state.m_iterations) - { - if (!Propagate(state, wave)) - return false; - - if (wave.HasContradiction()) - { - if constexpr (CallbacksT::HasContradictionCallback()) - { - PopulateWorld(state, wave); - typename CallbacksT::ContradictionCallback{}(state.m_world); - } - return false; - } - - if (wave.IsFullyCollapsed()) - return true; - - if constexpr (CallbacksT::HasBranchCallback()) - { - PopulateWorld(state, wave); - typename CallbacksT::BranchCallback{}(state.m_world); - } - - if (Branch(state, wave)) - return true; - } - return false; - } - - /** - * @brief Get the value at a specific cell - * @param cellId The cell ID - * @return The value if collapsed, std::nullopt otherwise - */ - static std::optional GetValue(WaveType& wave, int cellId) { - if (wave.IsCollapsed(cellId)) { - auto variableId = wave.GetVariableID(cellId); - return VariableIDMapT::GetValue(variableId); - } - return std::nullopt; - } - - /** - * @brief Get all possible values for a cell - * @param cellId The cell ID - * @return Set of possible values - */ - static const std::vector GetPossibleValues(WaveType& wave, int cellId) - { - std::vector possibleValues; - MaskType mask = wave.GetMask(cellId); - for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { - if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); - } - return possibleValues; - } - -private: - static void CollapseCell(SolverState& state, WaveType& wave, WorldSizeT cellId, VariableIDT value) - { - constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value)); - wave.Collapse(cellId, 1 << value); - constexpr_assert(wave.IsCollapsed(cellId)); - - if constexpr (CallbacksT::HasCellCollapsedCallback()) - { - PopulateWorld(state, wave); - typename CallbacksT::CellCollapsedCallback{}(state.m_world); - } - } - - static bool Branch(SolverState& state, WaveType& wave) - { - constexpr_assert(state.m_propagationQueue.empty()); - - // Find cell with minimum entropy > 1 - WorldSizeT minEntropyCell = static_cast(-1); - size_t minEntropy = static_cast(-1); - - for (WorldSizeT i = 0; i < wave.size(); ++i) { - size_t entropy = wave.Entropy(i); - if (entropy > 1 && entropy < minEntropy) { - minEntropy = entropy; - minEntropyCell = i; - } - } - if (minEntropyCell == static_cast(-1)) return false; - - constexpr_assert(!wave.IsCollapsed(minEntropyCell)); - - // create a list of possible values - VariableIDT availableValues = static_cast(wave.Entropy(minEntropyCell)); - std::array possibleValues{}; - MaskType mask = wave.GetMask(minEntropyCell); - for (size_t i = 0; i < availableValues; ++i) - { - VariableIDT index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit - constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds"); - - possibleValues[i] = index; - constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); - - mask = mask & (mask - 1); // turn off lowest set bit - } - - // randomly select a value from possible values - while (availableValues) - { - size_t randomIndex = state.m_randomSelector.rng(availableValues); - VariableIDT selectedValue = possibleValues[randomIndex]; - - { - // copy the state and branch out - auto stackFrame = state.m_allocator.createFrame(); - auto queueFrame = state.m_propagationQueue.createBranchPoint(); - - auto newWave = wave; - CollapseCell(state, newWave, minEntropyCell, selectedValue); - state.m_propagationQueue.push(minEntropyCell); - - if (RunLoop(state, newWave)) - { - // move the solution to the original state - wave = newWave; - - return true; - } - } - - // remove the failure state from the wave - constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) != 0, "Possible value was not set"); - wave.Collapse(minEntropyCell, ~(1 << selectedValue)); - constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); - - // swap replacement value with the last value - std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); - } - - return false; - } - - static bool Propagate(SolverState& state, WaveType& wave) - { - while (!state.m_propagationQueue.empty()) - { - WorldSizeT cellId = state.m_propagationQueue.pop(); - - if (wave.IsContradicted(cellId)) return false; - - constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed"); - - VariableIDT variableID = wave.GetVariableID(cellId); - ConstrainerType constrainer(wave, state.m_propagationQueue); - - using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, ConstrainerType&); - - ConstrainerFunctionMapT::template GetFunction(variableID)(state.m_world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); - } - return true; - } - - static void PopulateWorld(SolverState& state, WaveType& wave) - { - for (size_t i = 0; i < wave.size(); ++i) - { - if (wave.IsCollapsed(i)) - state.m_world.setValue(i, VariableIDMapT::GetValue(wave.GetVariableID(i))); - } - } - - static void PropogateInitialValues(SolverState& state, WaveType& wave) - { - for (size_t i = 0; i < wave.size(); ++i) - { - for (size_t j = 0; j < VariableIDMapT::size(); ++j) - { - if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j)) - { - CollapseCell(state, wave, static_cast(i), static_cast(j)); - state.m_propagationQueue.push(i); - break; - } - } - } - } + using CallbacksType = CallbacksT; + using ConstrainerFunctionMapType = ConstrainerFunctionMapT; }; +// Forward declarations for mutually recursive functions +template +bool RunLoop(StateT& state, WaveT& wave); + +template +bool Branch(StateT& state, WaveT& wave); + +namespace detail { + +template +void PopulateWorld(StateT& state, WaveT& wave) +{ + using VariableIDMapT = typename WaveT::IDMapT; + for (size_t i = 0; i < wave.size(); ++i) + { + if (wave.IsCollapsed(i)) + state.m_world.setValue(i, VariableIDMapT::GetValue(wave.GetVariableID(i))); + } +} + +template +void CollapseCell(StateT& state, WaveT& wave, typename StateT::WorldSizeT cellId, typename WaveT::VariableIDT value) +{ + using MaskType = typename WaveT::ElementT; + constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value)); + wave.Collapse(cellId, 1 << value); + constexpr_assert(wave.IsCollapsed(cellId)); + + if constexpr (CallbacksT::HasCellCollapsedCallback()) + { + PopulateWorld(state, wave); + typename CallbacksT::CellCollapsedCallback{}(state.m_world); + } +} + +template +void PropogateInitialValues(StateT& state, WaveT& wave) +{ + using VariableIDMapT = typename WaveT::IDMapT; + using WorldSizeT = typename StateT::WorldSizeT; + using VariableIDT = typename WaveT::VariableIDT; + for (size_t i = 0; i < wave.size(); ++i) + { + for (size_t j = 0; j < VariableIDMapT::size(); ++j) + { + if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j)) + { + CollapseCell(state, wave, static_cast(i), static_cast(j)); + state.m_propagationQueue.push(i); + break; + } + } + } +} + +template +bool Propagate(StateT& state, WaveT& wave) +{ + using VariableIDMapT = typename WaveT::IDMapT; + using VarT = typename VariableIDMapT::Type; + using WorldSizeT = typename StateT::WorldSizeT; + using VariableIDT = typename WaveT::VariableIDT; + using PropagationQueueType = typename StateT::PropagationQueueType; + using ConstrainerType = Constrainer; + + while (!state.m_propagationQueue.empty()) + { + WorldSizeT cellId = state.m_propagationQueue.pop(); + + if (wave.IsContradicted(cellId)) return false; + + constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed"); + + VariableIDT variableID = wave.GetVariableID(cellId); + ConstrainerType constrainer(wave, state.m_propagationQueue); + + using WorldT = typename StateT::WorldType; + using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, ConstrainerType&); + + ConstrainerFunctionMapT::template GetFunction(variableID)(state.m_world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); + } + return true; +} + +} // namespace detail + +template +bool Branch(StateT& state, WaveT& wave) +{ + using VariableIDMapT = typename WaveT::IDMapT; + using MaskType = typename WaveT::ElementT; + using WorldSizeT = typename StateT::WorldSizeT; + using VariableIDT = typename WaveT::VariableIDT; + + constexpr_assert(state.m_propagationQueue.empty()); + + // Find cell with minimum entropy > 1 + WorldSizeT minEntropyCell = static_cast(-1); + size_t minEntropy = static_cast(-1); + + for (WorldSizeT i = 0; i < wave.size(); ++i) { + size_t entropy = wave.Entropy(i); + if (entropy > 1 && entropy < minEntropy) { + minEntropy = entropy; + minEntropyCell = i; + } + } + if (minEntropyCell == static_cast(-1)) return false; + + constexpr_assert(!wave.IsCollapsed(minEntropyCell)); + + // create a list of possible values + VariableIDT availableValues = static_cast(wave.Entropy(minEntropyCell)); + std::array possibleValues{}; + MaskType mask = wave.GetMask(minEntropyCell); + for (size_t i = 0; i < availableValues; ++i) + { + VariableIDT index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit + constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds"); + + possibleValues[i] = index; + constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); + + mask = mask & (mask - 1); // turn off lowest set bit + } + + // randomly select a value from possible values + while (availableValues) + { + size_t randomIndex = state.m_randomSelector.rng(availableValues); + VariableIDT selectedValue = possibleValues[randomIndex]; + + { + // copy the state and branch out + auto stackFrame = state.m_allocator.createFrame(); + auto queueFrame = state.m_propagationQueue.createBranchPoint(); + + auto newWave = wave; + detail::CollapseCell(state, newWave, minEntropyCell, selectedValue); + state.m_propagationQueue.push(minEntropyCell); + + if (RunLoop(state, newWave)) + { + // move the solution to the original state + wave = newWave; + + return true; + } + } + + // remove the failure state from the wave + constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) != 0, "Possible value was not set"); + wave.Collapse(minEntropyCell, ~(1 << selectedValue)); + constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); + + // swap replacement value with the last value + std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); + } + + return false; +} + +template +bool RunLoop(StateT& state, WaveT& wave) +{ + static constexpr size_t MaxIterations = 1024 * 8; + for (; state.m_iterations < MaxIterations; ++state.m_iterations) + { + if (!detail::Propagate(state, wave)) + return false; + + if (wave.HasContradiction()) + { + if constexpr (CallbacksT::HasContradictionCallback()) + { + detail::PopulateWorld(state, wave); + typename CallbacksT::ContradictionCallback{}(state.m_world); + } + return false; + } + + if (wave.IsFullyCollapsed()) + return true; + + if constexpr (CallbacksT::HasBranchCallback()) + { + detail::PopulateWorld(state, wave); + typename CallbacksT::BranchCallback{}(state.m_world); + } + + if (Branch(state, wave)) + return true; + } + return false; +} + +template +bool Run(typename ConfigT::SolverStateType& state) +{ + using CallbacksT = typename ConfigT::CallbacksType; + using ConstrainerFunctionMapT = typename ConfigT::ConstrainerFunctionMapType; + using WaveType = typename ConfigT::WaveType; + using VariableIDMapT = typename WaveType::IDMapT; + + WaveType wave{ ConfigT::WorldSize, VariableIDMapT::size(), state.m_allocator }; + + detail::PropogateInitialValues(state, wave); + + if (RunLoop(state, wave)) { + detail::PopulateWorld(state, wave); + return true; + } + return false; +} + +template +bool Run(WorldT& world, uint32_t seed = std::random_device{}()) +{ + typename ConfigT::SolverStateType state{ world, seed }; + return Run(state); +} + +template +std::optional GetValue(WaveT& wave, int cellId) { + using VariableIDMapT = typename WaveT::IDMapT; + if (wave.IsCollapsed(cellId)) { + auto variableId = wave.GetVariableID(cellId); + return VariableIDMapT::GetValue(variableId); + } + return std::nullopt; +} + +template +const std::vector GetPossibleValues(WaveT& wave, int cellId) +{ + using VariableIDMapT = typename WaveT::IDMapT; + using VarT = typename VariableIDMapT::Type; + using MaskType = typename WaveT::ElementT; + std::vector possibleValues; + MaskType mask = wave.GetMask(cellId); + for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { + if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); + } + return possibleValues; +} + } // namespace WFC diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp index 5908641..a4309cc 100644 --- a/include/nd-wfc/wfc_builder.hpp +++ b/include/nd-wfc/wfc_builder.hpp @@ -85,7 +85,7 @@ public: template using SetRandomSelector = Builder; - using Build = WFC; + using Build = WFCConfig; }; } \ No newline at end of file diff --git a/include/nd-wfc/wfc_variable_map.hpp b/include/nd-wfc/wfc_variable_map.hpp index cb333fd..267e95e 100644 --- a/include/nd-wfc/wfc_variable_map.hpp +++ b/include/nd-wfc/wfc_variable_map.hpp @@ -19,6 +19,7 @@ using VariableIDType = std::conditional_t class VariableIDMap { public: + using Type = VarT; template using Merge = VariableIDMap;