diff --git a/demos/sudoku/test_sudoku.cpp b/demos/sudoku/test_sudoku.cpp index 7713813..8080fd3 100644 --- a/demos/sudoku/test_sudoku.cpp +++ b/demos/sudoku/test_sudoku.cpp @@ -282,13 +282,11 @@ void testPuzzleSolving(const std::string& difficulty, const std::string& filenam int solvedCount = 0; auto start = std::chrono::high_resolution_clock::now(); - WFC::WFCStackAllocator allocator{}; - for (size_t i = 0; i < puzzles.size(); ++i) { Sudoku& sudoku = puzzles[i]; EXPECT_TRUE(sudoku.isValid()) << difficulty << " puzzle " << i << " is not valid"; - SudokuSolver::Run(sudoku, allocator); + SudokuSolver::Run(sudoku); EXPECT_TRUE(sudoku.isSolved()) << difficulty << " puzzle " << i << " was not solved. Puzzle string: " << sudoku.toString(); diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index daed5e3..d31ec67 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -22,6 +21,7 @@ #include "wfc_constrainer.hpp" #include "wfc_callbacks.hpp" #include "wfc_random.hpp" +#include "wfc_queue.hpp" namespace WFC { @@ -37,8 +37,8 @@ concept WorldType = requires(T world, size_t id, typename T::ValueType value) { * @brief Concept to validate constrainer function signature * The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) */ -template -concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { +template +concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { func(world, index, value, constrainer); }; @@ -73,28 +73,23 @@ public: constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; using WaveType = Wave; - using ConstrainerType = Constrainer; + using PropagationQueueType = WFCQueue; + using ConstrainerType = Constrainer; using MaskType = typename WaveType::ElementT; public: struct SolverState { - WorldT& world; - WFCQueue propagationQueue; - WaveType wave; - std::mt19937& rng; - RandomSelectorT& randomSelector; - WFCStackAllocator& allocator; - size_t& iterations; + WorldT& m_world; + PropagationQueueType m_propagationQueue{}; + RandomSelectorT m_randomSelector{}; + WFCStackAllocator m_allocator{}; + size_t m_iterations{}; - SolverState(WorldT& world, size_t variableAmount, std::mt19937& rng, RandomSelectorT& randomSelector, WFCStackAllocator& allocator, size_t& iterations) - : world(world) - , propagationQueue{ WFCStackAllocatorAdapter(allocator) } - , wave{ WorldSize, variableAmount, allocator } - , rng(rng) - , randomSelector(randomSelector) - , allocator(allocator) - , iterations(iterations) + SolverState(WorldT& world, uint32_t seed) + : m_world(world) + , m_propagationQueue{ WorldSize ? WorldSize : static_cast(world.size()) } + , m_randomSelector(seed) {} SolverState(const SolverState& other) = default; @@ -104,35 +99,12 @@ public: WFC() = delete; // dont make an instance of this class, only use the static methods. public: - + static bool Run(WorldT& world, uint32_t seed = std::random_device{}()) { - WFCStackAllocator allocator{}; - return Run(world, allocator, seed); - } - - static bool Run(WorldT& world, WFCStackAllocator& allocator, uint32_t seed = std::random_device{}()) - { - allocator.reset(); - constexpr_assert(allocator.getUsed() == 0, "Allocator must be empty"); - - size_t iterations = 0; - auto random = std::mt19937{ seed }; - RandomSelectorT randomSelector{ seed }; - SolverState state - { - world, - ConstrainerFunctionMapT::size(), - random, - randomSelector, - allocator, - iterations - }; + SolverState state{ world, seed }; bool result = Run(state); - allocator.reset(); - constexpr_assert(allocator.getUsed() == 0, "Allocator must be empty"); - return result; } @@ -142,43 +114,46 @@ public: */ static bool Run(SolverState& state) { - PropogateInitialValues(state); + WaveType wave{ WorldSize, VariableIDMapT::ValuesRegisteredAmount, state.m_allocator }; - if (RunLoop(state)) { + PropogateInitialValues(state, wave); - PopulateWorld(state); + if (RunLoop(state, wave)) { + + PopulateWorld(state, wave); return true; } return false; } - static bool RunLoop(SolverState& state) + static bool RunLoop(SolverState& state, WaveType& wave) { - for (; state.iterations < 1024 * 8; ++state.iterations) + static constexpr size_t MaxIterations = 1024 * 8; + for (; state.m_iterations < MaxIterations; ++state.m_iterations) { - if (!Propagate(state)) + if (!Propagate(state, wave)) return false; - if (state.wave.HasContradiction()) + if (wave.HasContradiction()) { if constexpr (CallbacksT::HasContradictionCallback()) { - PopulateWorld(state); - typename CallbacksT::ContradictionCallback{}(state.world); + PopulateWorld(state, wave); + typename CallbacksT::ContradictionCallback{}(state.m_world); } return false; } - if (state.wave.IsFullyCollapsed()) + if (wave.IsFullyCollapsed()) return true; if constexpr (CallbacksT::HasBranchCallback()) { - PopulateWorld(state); - typename CallbacksT::BranchCallback{}(state.world); + PopulateWorld(state, wave); + typename CallbacksT::BranchCallback{}(state.m_world); } - if (Branch(state)) + if (Branch(state, wave)) return true; } return false; @@ -189,9 +164,9 @@ public: * @param cellId The cell ID * @return The value if collapsed, std::nullopt otherwise */ - static std::optional GetValue(SolverState& state, int cellId) { - if (state.wave.IsCollapsed(cellId)) { - auto variableId = state.wave.GetVariableID(cellId); + static std::optional GetValue(WaveType& wave, int cellId) { + if (wave.IsCollapsed(cellId)) { + auto variableId = wave.GetVariableID(cellId); return VariableIDMapT::GetValue(variableId); } return std::nullopt; @@ -202,10 +177,10 @@ public: * @param cellId The cell ID * @return Set of possible values */ - static const std::vector GetPossibleValues(SolverState& state, int cellId) + static const std::vector GetPossibleValues(WaveType& wave, int cellId) { std::vector possibleValues; - MaskType mask = state.wave.GetMask(cellId); + MaskType mask = wave.GetMask(cellId); for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); } @@ -213,29 +188,29 @@ public: } private: - static void CollapseCell(SolverState& state, size_t cellId, uint16_t value) + static void CollapseCell(SolverState& state, WaveType& wave, size_t cellId, uint16_t value) { - constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (MaskType(1) << value)); - state.wave.Collapse(cellId, 1 << value); - constexpr_assert(state.wave.IsCollapsed(cellId)); + constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value)); + wave.Collapse(cellId, 1 << value); + constexpr_assert(wave.IsCollapsed(cellId)); if constexpr (CallbacksT::HasCellCollapsedCallback()) { - PopulateWorld(state); - typename CallbacksT::CellCollapsedCallback{}(state.world); + PopulateWorld(state, wave); + typename CallbacksT::CellCollapsedCallback{}(state.m_world); } } - static bool Branch(SolverState& state) + static bool Branch(SolverState& state, WaveType& wave) { - constexpr_assert(state.propagationQueue.empty()); + constexpr_assert(state.m_propagationQueue.empty()); // Find cell with minimum entropy > 1 size_t minEntropyCell = static_cast(-1); size_t minEntropy = static_cast(-1); - for (size_t i = 0; i < state.wave.size(); ++i) { - size_t entropy = state.wave.Entropy(i); + for (size_t i = 0; i < wave.size(); ++i) { + size_t entropy = wave.Entropy(i); if (entropy > 1 && entropy < minEntropy) { minEntropy = entropy; minEntropyCell = i; @@ -243,12 +218,12 @@ private: } if (minEntropyCell == static_cast(-1)) return false; - constexpr_assert(!state.wave.IsCollapsed(minEntropyCell)); + constexpr_assert(!wave.IsCollapsed(minEntropyCell)); // create a list of possible values - uint16_t availableValues = static_cast(state.wave.Entropy(minEntropyCell)); + uint16_t availableValues = static_cast(wave.Entropy(minEntropyCell)); std::array possibleValues; // inplace vector - MaskType mask = state.wave.GetMask(minEntropyCell); + MaskType mask = wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit @@ -269,29 +244,31 @@ private: valueArray[i] = VariableIDMapT::GetValue(possibleValues[i]); } std::span currentPossibleValues(valueArray.data(), availableValues); - size_t randomIndex = state.randomSelector(currentPossibleValues); + size_t randomIndex = state.m_randomSelector(currentPossibleValues); size_t selectedValue = possibleValues[randomIndex]; { // copy the state and branch out - auto stackFrame = state.allocator.createFrame(); - SolverState newState(state); - CollapseCell(newState, minEntropyCell, static_cast(selectedValue)); - newState.propagationQueue.push(minEntropyCell); + auto stackFrame = state.m_allocator.createFrame(); + auto queueFrame = state.m_propagationQueue.createBranchPoint(); + + auto newWave = wave; + CollapseCell(state, newWave, minEntropyCell, static_cast(selectedValue)); + state.m_propagationQueue.push(minEntropyCell); - if (RunLoop(newState)) + if (RunLoop(state, newWave)) { // copy the solution to the original state - state.wave = newState.wave; + wave = newWave; return true; } } // remove the failure state from the wave - constexpr_assert((state.wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) != 0, "Possible value was not set"); - state.wave.Collapse(minEntropyCell, ~(1 << selectedValue)); - constexpr_assert((state.wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); + constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) != 0, "Possible value was not set"); + wave.Collapse(minEntropyCell, ~(1 << selectedValue)); + constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); // swap replacement value with the last value std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); @@ -300,47 +277,46 @@ private: return false; } - static bool Propagate(SolverState& state) + static bool Propagate(SolverState& state, WaveType& wave) { - while (!state.propagationQueue.empty()) + while (!state.m_propagationQueue.empty()) { - size_t cellId = state.propagationQueue.front(); - state.propagationQueue.pop(); + size_t cellId = state.m_propagationQueue.pop(); - if (state.wave.IsContradicted(cellId)) return false; + if (wave.IsContradicted(cellId)) return false; - constexpr_assert(state.wave.IsCollapsed(cellId), "Cell was not collapsed"); + constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed"); - uint16_t variableID = state.wave.GetVariableID(cellId); - ConstrainerType constrainer(state.wave, state.propagationQueue); + uint16_t variableID = wave.GetVariableID(cellId); + ConstrainerType constrainer(wave, state.m_propagationQueue); using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, ConstrainerType&); - ConstrainerFunctionMapT::template GetFunction(variableID)(state.world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); + ConstrainerFunctionMapT::template GetFunction(variableID)(state.m_world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); } return true; } - static void PopulateWorld(SolverState& state) + static void PopulateWorld(SolverState& state, WaveType& wave) { - for (size_t i = 0; i < state.wave.size(); ++i) + for (size_t i = 0; i < wave.size(); ++i) { - if (state.wave.IsCollapsed(i)) - state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i))); + if (wave.IsCollapsed(i)) + state.m_world.setValue(i, VariableIDMapT::GetValue(wave.GetVariableID(i))); } } - static void PropogateInitialValues(SolverState& state) + static void PropogateInitialValues(SolverState& state, WaveType& wave) { auto allValues = VariableIDMapT::GetAllValues(); - for (size_t i = 0; i < state.wave.size(); ++i) + for (size_t i = 0; i < wave.size(); ++i) { for (size_t j = 0; j < allValues.size(); ++j) { - if (state.world.getValue(i) == allValues[j]) + if (state.m_world.getValue(i) == allValues[j]) { - CollapseCell(state, static_cast(i), static_cast(j)); - state.propagationQueue.push(i); + CollapseCell(state, wave, static_cast(i), static_cast(j)); + state.m_propagationQueue.push(i); break; } } diff --git a/include/nd-wfc/wfc_allocator.hpp b/include/nd-wfc/wfc_allocator.hpp index 7fcd187..c12d081 100644 --- a/include/nd-wfc/wfc_allocator.hpp +++ b/include/nd-wfc/wfc_allocator.hpp @@ -383,16 +383,4 @@ public: WFCStackAllocator* m_allocator; }; -/** - * @brief Stack-allocated vector using WFCStackAllocator - */ -template -using WFCVector = std::vector>; - -/** - * @brief Stack-allocated queue using WFCStackAllocator - */ -template -using WFCQueue = std::queue>>; - } // namespace WFC diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp index b85366e..9d42351 100644 --- a/include/nd-wfc/wfc_builder.hpp +++ b/include/nd-wfc/wfc_builder.hpp @@ -24,13 +24,14 @@ public: constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; using WaveType = Wave; - using ConstrainerType = Constrainer; + using PropagationQueueType = WFCQueue; + using ConstrainerType = Constrainer; template using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; template - requires ConstrainerFunction + requires ConstrainerFunction using DefineConstrainer = Builder +template class Constrainer { public: using IDMapT = typename WaveT::IDMapT; @@ -69,7 +70,7 @@ public: using MaskType = typename BitContainerT::StorageType; public: - Constrainer(WaveT& wave, WFCQueue& propagationQueue) + Constrainer(WaveT& wave, PropagationQueueT& propagationQueue) : m_wave(wave) , m_propagationQueue(propagationQueue) {} @@ -120,7 +121,7 @@ private: private: WaveT& m_wave; - WFCQueue& m_propagationQueue; + PropagationQueueT& m_propagationQueue; }; } \ No newline at end of file diff --git a/include/nd-wfc/wfc_queue.hpp b/include/nd-wfc/wfc_queue.hpp new file mode 100644 index 0000000..b442ddf --- /dev/null +++ b/include/nd-wfc/wfc_queue.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "nd-wfc/wfc_utils.hpp" + +namespace WFC +{ + +template +class WFCQueue { +public: + using ContainerType = std::conditional_t, std::array>; + +public: + WFCQueue() = default; + WFCQueue(const WFCQueue&) = delete; + WFCQueue(WFCQueue&&) = delete; + WFCQueue& operator=(const WFCQueue&) = delete; + WFCQueue& operator=(WFCQueue&&) = delete; + + constexpr WFCQueue(size_t size) + { + if constexpr (Size == 0) + { + m_container.resize(size); + } + } + +public: + constexpr std::span data() const { return std::span(m_container.data(), Size); } + constexpr std::span data() { return std::span(m_container.data(), Size); } + + constexpr std::span FilledData() const { return std::span(m_container.data() + m_front, m_back - m_front); } + constexpr std::span FilledData() { return std::span(m_container.data() + m_front, m_back - m_front); } + + constexpr size_t size() const { return m_container.size(); } + +public: + constexpr bool empty() const { return m_front == m_back; } + constexpr bool full() const { return m_back == size(); } + constexpr bool has(StorageType value) const { return std::find(m_container.begin(), m_container.begin() + m_back, value) != m_container.begin() + m_back; } + +public: + constexpr void push(const StorageType &value) + { + constexpr_assert(!full()); + constexpr_assert(!has(value)); + + m_container[m_back++] = value; + } + + constexpr StorageType pop() + { + constexpr_assert(!empty()); + + return m_container[m_front++]; + } + +public: + struct BranchPoint + { + constexpr BranchPoint(WFCQueue& queue) + : m_queue(queue) + , m_front(queue.m_front) + , m_back(queue.m_back) + {} + + constexpr ~BranchPoint() + { + m_queue.m_front = m_front; + m_queue.m_back = m_back; + } + + WFCQueue& m_queue; + size_t m_front; + size_t m_back; + }; + +public: + constexpr BranchPoint createBranchPoint() + { + return BranchPoint(*this); + } + +private: + ContainerType m_container{}; + size_t m_front = 0; + size_t m_back = 0; +}; + +} // namespace WFC \ No newline at end of file