From 8dc7bcd61872128c97fb1a202321d4a8b056611e Mon Sep 17 00:00:00 2001 From: Connor De Meyer Date: Sun, 31 Aug 2025 17:42:06 +0900 Subject: [PATCH] Callback System --- .cursorindexingignore | 3 + CMakeLists.txt | 5 + demos/sudoku/CMakeLists.txt | 44 +++++++-- demos/sudoku/main.cpp | 2 +- demos/sudoku/sudoku.h | 60 ++++++------ demos/sudoku/sudoku_wfc.cpp | 82 ++++++++++++++--- demos/sudoku/test_sudoku.cpp | 10 +- include/nd-wfc/wfc.hpp | 152 +++++++++++++++++++++++++------ include/nd-wfc/wfc_allocator.hpp | 2 +- 9 files changed, 276 insertions(+), 84 deletions(-) create mode 100644 .cursorindexingignore diff --git a/.cursorindexingignore b/.cursorindexingignore new file mode 100644 index 0000000..953908e --- /dev/null +++ b/.cursorindexingignore @@ -0,0 +1,3 @@ + +# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references +.specstory/** diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d763c0..cdead34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,11 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +# Ensure consistent runtime library settings for MSVC +if(MSVC) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL") +endif() + # Enable testing enable_testing() diff --git a/demos/sudoku/CMakeLists.txt b/demos/sudoku/CMakeLists.txt index ae2f7f2..f0a2c1c 100644 --- a/demos/sudoku/CMakeLists.txt +++ b/demos/sudoku/CMakeLists.txt @@ -17,14 +17,15 @@ else() add_compile_options(-Wall -Wextra -Wpedantic) endif() -# Find Google Test (optional) -find_package(GTest) -if(GTest_FOUND) - set(HAS_GTEST TRUE) -else() - set(HAS_GTEST FALSE) - message(WARNING "Google Test not found. Tests will not be built.") -endif() +# Use FetchContent to get Google Test for consistent builds +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +FetchContent_MakeAvailable(googletest) +set(HAS_GTEST TRUE) +message(STATUS "Using Google Test via FetchContent") # Find Google Benchmark (optional) find_package(benchmark) @@ -39,6 +40,14 @@ endif() set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads) +# Ensure consistent runtime library settings for MSVC +if(MSVC) + # Force all targets to use the same runtime library + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL") + # Ensure Google Test uses the same runtime library + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +endif() + # Create the main executable add_executable(sudoku_demo main.cpp @@ -69,6 +78,14 @@ target_link_libraries(sudoku_wfc_demo PRIVATE nd-wfc) target_link_libraries(analyze_failing_puzzles PRIVATE nd-wfc) target_link_libraries(debug_failing_puzzles PRIVATE nd-wfc) +# Ensure consistent runtime library settings for all executables +if(MSVC) + target_compile_options(sudoku_demo PRIVATE $<$:/MDd> $<$:/MD>) + target_compile_options(sudoku_wfc_demo PRIVATE $<$:/MDd> $<$:/MD>) + target_compile_options(analyze_failing_puzzles PRIVATE $<$:/MDd> $<$:/MD>) + target_compile_options(debug_failing_puzzles PRIVATE $<$:/MDd> $<$:/MD>) +endif() + # Set output directory for sudoku_demo set_target_properties(sudoku_demo PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin @@ -123,7 +140,16 @@ if(HAS_GTEST) test_sudoku.cpp ) - target_link_libraries(sudoku_tests GTest::gtest GTest::gtest_main nd-wfc) + target_link_libraries(sudoku_tests gtest gtest_main nd-wfc) + + # Ensure consistent runtime library settings for test executable + if(MSVC) + target_compile_options(sudoku_tests PRIVATE $<$:/MDd> $<$:/MD>) + # Ensure Google Test targets use consistent runtime library + set_target_properties(gtest gtest_main gmock gmock_main PROPERTIES + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL" + ) + endif() if(Threads_FOUND) target_link_libraries(sudoku_tests Threads::Threads) endif() diff --git a/demos/sudoku/main.cpp b/demos/sudoku/main.cpp index ae5c34f..1f0ee56 100644 --- a/demos/sudoku/main.cpp +++ b/demos/sudoku/main.cpp @@ -6,7 +6,7 @@ int main() std::cout << "Sudoku Demo" << std::endl; // Create a simple sudoku puzzle - Sudoku sudoku("530070000600195000098000060800060003400803001700020006060000280000419005000080079"); + Sudoku sudoku("140000050700200000000300204200080400080090020006050001809001000000006007050000069"); if (sudoku.isValid()) { std::cout << "Loaded valid sudoku puzzle:" << std::endl; diff --git a/demos/sudoku/sudoku.h b/demos/sudoku/sudoku.h index d22b0f7..a560138 100644 --- a/demos/sudoku/sudoku.h +++ b/demos/sudoku/sudoku.h @@ -200,11 +200,11 @@ public: // WFC Support using ValueType = uint8_t; ValueType getValue(size_t index) const { - return board_.get(index); + return board_.get(static_cast(index)); } void setValue(size_t index, ValueType value) { - board_.set(index, value); + board_.set(static_cast(index), value); } constexpr size_t size() const { @@ -235,35 +235,37 @@ private: static bool parseLine(const std::string& line, std::array& board); }; -using SudokuSolver = WFC::Builder - ::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9> - ::DefineConstrainer val, auto& constrainer) { - size_t x = index % 9; - size_t y = index / 9; - // Add row constraints (same row, different columns) - for (size_t i = 0; i < 9; ++i) { - if (i != x) constrainer.Exclude(val, i + y * 9); - } +using SudokuSolverBuilder = WFC::Builder + ::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9> + ::DefineConstrainer val, auto& constrainer) { + size_t x = index % 9; + size_t y = index / 9; - // Add column constraints (same column, different rows) - for (size_t i = 0; i < 9; ++i) { - if (i != y) constrainer.Exclude(val,x + i * 9); - } + // Add row constraints (same row, different columns) + for (size_t i = 0; i < 9; ++i) { + if (i != x) constrainer.Exclude(val, i + y * 9); + } - // Add box constraints (3x3 box) - size_t box_x = (x / 3) * 3; - size_t box_y = (y / 3) * 3; - for (size_t j = 0; j < 3; ++j) { - for (size_t k = 0; k < 3; ++k) { - size_t cell_x = box_x + j; - size_t cell_y = box_y + k; - size_t cell_index = cell_x + cell_y * 9; - if (cell_index != index) { - constrainer.Exclude(val, cell_index); - } + // Add column constraints (same column, different rows) + for (size_t i = 0; i < 9; ++i) { + if (i != y) constrainer.Exclude(val,x + i * 9); + } + + // Add box constraints (3x3 box) + size_t box_x = (x / 3) * 3; + size_t box_y = (y / 3) * 3; + for (size_t j = 0; j < 3; ++j) { + for (size_t k = 0; k < 3; ++k) { + size_t cell_x = box_x + j; + size_t cell_y = box_y + k; + size_t cell_index = cell_x + cell_y * 9; + if (cell_index != index) { + constrainer.Exclude(val, cell_index); } } - - }), 1, 2, 3, 4, 5, 6, 7, 8, 9> - ::Build; \ No newline at end of file + } + + }), 1, 2, 3, 4, 5, 6, 7, 8, 9>; + +using SudokuSolver = SudokuSolverBuilder::Build; \ No newline at end of file diff --git a/demos/sudoku/sudoku_wfc.cpp b/demos/sudoku/sudoku_wfc.cpp index b1c1dd3..9f61975 100644 --- a/demos/sudoku/sudoku_wfc.cpp +++ b/demos/sudoku/sudoku_wfc.cpp @@ -1,28 +1,88 @@ #include #include "sudoku.h" #include +#include -int main() -{ - std::cout << "Running Sudoku WFC" << std::endl; +// Helper function to load multiple puzzles from a file (one puzzle per line) +std::vector loadPuzzlesFromFile(const std::string& filename) { + std::vector puzzles; + std::ifstream file(filename); - Sudoku sudokuWorld; - sudokuWorld.setValue(0, 5); - sudokuWorld.setValue(80, 1); - bool success = SudokuSolver::Run(sudokuWorld, true); + if (!file.is_open()) { + return puzzles; + } - if (success) { - std::cout << "Sudoku solved successfully!" << std::endl; - // Print the solved sudoku + std::string line; + while (std::getline(file, line)) { + // Remove whitespace + line.erase(std::remove_if(line.begin(), line.end(), + [](char c) { return std::isspace(c); }), line.end()); + + if (line.empty()) continue; + + Sudoku sudoku; + if (sudoku.loadFromString(line)) { + puzzles.push_back(std::move(sudoku)); + } + } + + return puzzles; +} + +using SudokuSolverCallback = SudokuSolverBuilder::SetCellCollapsedCallback(sudokuWorld.getValue(x + y * 9)) << " "; + int current = static_cast(sudoku.getValue(x + y * 9)); + int last = static_cast(LastSudoku.getValue(x + y * 9)); + if (current != last) { + std::cout << "\033[31m" << current << "\033[0m "; + } else { + std::cout << current << " "; + } if (x == 2 || x == 5) std::cout << "| "; } std::cout << std::endl; if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl; } + LastSudoku = sudoku; + + std::cout << "Iteration: " << counter << std::endl; + counter++; + + // std::cout << std::endl; + // std::cout << "Press Enter to continue..." << std::endl; + // std::cin.get(); + })> + ::Build; + +int main() +{ + std::cout << "Running Sudoku WFC" << std::endl; + + Sudoku sudokuWorld{ "040280030010006007609070008000092000900000004000740000500020803400800010070035090" }; + + bool success = SudokuSolverCallback::Run(sudokuWorld, true); + + bool solved = sudokuWorld.isSolved(); + + if (success && solved) { + std::cout << "Sudoku solved successfully!" << std::endl; } else { std::cout << "Failed to solve sudoku!" << std::endl; } + + // Print the solved sudoku + for (size_t y = 0; y < 9; ++y) { + for (size_t x = 0; x < 9; ++x) { + std::cout << static_cast(sudokuWorld.getValue(x + y * 9)) << " "; + if (x == 2 || x == 5) std::cout << "| "; + } + std::cout << std::endl; + if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl; + } + + } \ No newline at end of file diff --git a/demos/sudoku/test_sudoku.cpp b/demos/sudoku/test_sudoku.cpp index e8d20f4..54a9f50 100644 --- a/demos/sudoku/test_sudoku.cpp +++ b/demos/sudoku/test_sudoku.cpp @@ -282,7 +282,7 @@ TEST_F(SudokuTest, WFCIntegration) // Tests loading and solving puzzles from data files TEST_F(SudokuTest, LoadAndSolveEasyPuzzles) { - std::vector easyPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_easy.txt"); + std::vector easyPuzzles = loadPuzzlesFromFile("../data/Sudoku_easy.txt"); ASSERT_GT(easyPuzzles.size(), 0) << "No easy puzzles loaded"; @@ -320,7 +320,7 @@ TEST_F(SudokuTest, LoadAndSolveEasyPuzzles) TEST_F(SudokuTest, LoadAndSolveMediumPuzzles) { - std::vector mediumPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_medium.txt"); + std::vector mediumPuzzles = loadPuzzlesFromFile("../data/Sudoku_medium.txt"); ASSERT_GT(mediumPuzzles.size(), 0) << "No medium puzzles loaded"; @@ -358,7 +358,7 @@ TEST_F(SudokuTest, LoadAndSolveMediumPuzzles) TEST_F(SudokuTest, LoadAndSolveHardPuzzles) { - std::vector hardPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_hard.txt"); + std::vector hardPuzzles = loadPuzzlesFromFile("../data/Sudoku_hard.txt"); ASSERT_GT(hardPuzzles.size(), 0) << "No hard puzzles loaded"; @@ -396,7 +396,7 @@ TEST_F(SudokuTest, LoadAndSolveHardPuzzles) TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles) { - std::vector diabolicalPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_diabolical.txt"); + std::vector diabolicalPuzzles = loadPuzzlesFromFile("../data/Sudoku_diabolical.txt"); ASSERT_GT(diabolicalPuzzles.size(), 0) << "No diabolical puzzles loaded"; @@ -435,7 +435,7 @@ TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles) // Test loading a single puzzle from each difficulty file TEST_F(SudokuTest, LoadAndSolveFirstPuzzleFromEachFile) { - const std::string dataPath = "/home/connor/repos/nd-wfc/demos/sudoku/data"; + const std::string dataPath = "../data"; const std::vector files = {"Sudoku_easy.txt", "Sudoku_medium.txt", "Sudoku_hard.txt", "Sudoku_diabolical.txt"}; for (const auto& filename : files) { diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index f5fab33..b6fac05 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -297,10 +297,51 @@ struct VariableData { {} }; +/** + * @brief Empty callback function + * @param World The world type + */ +template +using EmptyCallback = decltype([](World&){}); + +/** + * @brief Callback struct + * @param WorldT The world type + * @param AllCellsCollapsedCallbackT The all cells collapsed callback type + * @param CellCollapsedCallbackT The cell collapsed callback type + * @param ContradictionCallbackT The contradiction callback type + * @param BranchCallbackT The branch callback type + */ +template , + typename ContradictionCallbackT = EmptyCallback, + typename BranchCallbackT = EmptyCallback +> +struct Callbacks +{ + using CellCollapsedCallback = CellCollapsedCallbackT; + using ContradictionCallback = ContradictionCallbackT; + using BranchCallback = BranchCallbackT; + + template + using SetCellCollapsedCallbackT = Callbacks; + template + using SetContradictionCallbackT = Callbacks; + template + using SetBranchCallbackT = Callbacks; + + static consteval bool HasCellCollapsedCallback() { return !std::is_same_v>; } + static consteval bool HasContradictionCallback() { return !std::is_same_v>; } + static consteval bool HasBranchCallback() { return !std::is_same_v>; } +}; + /** * @brief Main WFC class implementing the Wave Function Collapse algorithm */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap> +template, + typename ConstrainerFunctionMapT = ConstrainerFunctionMap, + typename CallbacksT = Callbacks> class WFC { public: static_assert(WorldType, "WorldT must satisfy World type requirements"); @@ -308,7 +349,8 @@ public: using MaskType = typename VariableIDMapT::MaskType; public: - struct SolverState { + struct SolverState + { WorldT& world; WFCQueue propagationQueue; Wave wave; @@ -329,17 +371,25 @@ public: }; public: - WFC() = delete; + WFC() = delete; // dont make an instance of this class, only use the static methods. public: - static bool Run(WorldT& world, bool propagateInitialValues = true) + static bool Run(WorldT& world, bool propagateInitialValues = true, WFCStackAllocator* allocator = nullptr) { - WFCStackAllocator allocator{}; - // std::mt19937 random{ std::random_device{}() }; // Using random values fails 25% of the time - std::mt19937 random{ 212 }; + //std::mt19937 random{ 212 }; + std::mt19937 random{ std::random_device{}() }; size_t iterations = 0; - SolverState state(world, ConstrainerFunctionMapT::size(), random, allocator, iterations); - return Run(state, propagateInitialValues); + if (!allocator) + { + WFCStackAllocator newAllocator{}; + SolverState state(world, ConstrainerFunctionMapT::size(), random, newAllocator, iterations); + return Run(state, propagateInitialValues); + } + else + { + SolverState state(world, ConstrainerFunctionMapT::size(), random, *allocator, iterations); + return Run(state, propagateInitialValues); + } } /** @@ -363,18 +413,32 @@ public: static bool RunLoop(SolverState& state) { - for (; state.iterations < 256; ++state.iterations) + for (; state.iterations < 1024; ++state.iterations) { if (!Propagate(state)) return false; if (state.wave.HasContradiction()) + { + if constexpr (CallbacksT::HasContradictionCallback()) + { + PopulateWorld(state); + typename CallbacksT::ContradictionCallback{}(state.world); + } return false; + } if (state.wave.IsFullyCollapsed()) return true; - Branch(state); + if constexpr (CallbacksT::HasBranchCallback()) + { + PopulateWorld(state); + typename CallbacksT::BranchCallback{}(state.world); + } + + if (Branch(state)) + return true; } return false; } @@ -408,6 +472,19 @@ public: } private: + static void CollapseCell(SolverState& state, size_t cellId, uint16_t value) + { + assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (1 << value)); + state.wave.Collapse(cellId, 1 << value); + assert(state.wave.IsCollapsed(cellId)); + + if constexpr (CallbacksT::HasCellCollapsedCallback()) + { + PopulateWorld(state); + typename CallbacksT::CellCollapsedCallback{}(state.world); + } + } + static bool Branch(SolverState& state) { assert(state.propagationQueue.empty()); @@ -428,12 +505,12 @@ private: assert(!state.wave.IsCollapsed(minEntropyCell)); // create a list of possible values - uint16_t availableValues = state.wave.Entropy(minEntropyCell); + uint16_t availableValues = static_cast(state.wave.Entropy(minEntropyCell)); std::array possibleValues; // inplace vector MaskType mask = state.wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { - uint16_t index = std::countr_zero(mask); // get the index of the lowest set bit + uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds"); possibleValues[i] = index; @@ -443,16 +520,17 @@ private: } // randomly select a value from possible values - for (size_t i = 0; i < availableValues; ++i) + while (availableValues) { std::uniform_int_distribution dist(0, availableValues - 1); - size_t selectedValue = possibleValues[dist(state.rng)]; + size_t randomIndex = dist(state.rng); + size_t selectedValue = possibleValues[randomIndex]; { // copy the state and branch out auto stackFrame = state.allocator.createFrame(); SolverState newState(state); - newState.wave.Collapse(minEntropyCell, 1 << selectedValue); + CollapseCell(newState, minEntropyCell, static_cast(selectedValue)); newState.propagationQueue.push(minEntropyCell); if (RunLoop(newState)) @@ -470,7 +548,7 @@ private: assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly"); // swap replacement value with the last value - std::swap(possibleValues[i], possibleValues[--availableValues]); + std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); } return false; @@ -501,7 +579,8 @@ private: { for (size_t i = 0; i < state.wave.size(); ++i) { - state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i))); + if (state.wave.IsCollapsed(i)) + state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i))); } } @@ -514,7 +593,7 @@ private: { if (state.world.getValue(i) == allValues[j]) { - state.wave.Collapse(i, 1 << j); + CollapseCell(state, static_cast(i), static_cast(j)); state.propagationQueue.push(i); break; } @@ -523,28 +602,45 @@ private: } }; +/** + * @brief Concept to validate constrainer function signature + * The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) + */ +template +concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { + func(world, index, value, constrainer); +}; + /** * @brief Builder class for creating WFC instances */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap> +template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks> class Builder { public: template - using DefineIDs = Builder, ConstrainerFunctionMapT>; + using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT>; template - using DefineConstrainer = Builder + using DefineConstrainer = Builder, + VariableIDMapT, + ConstrainerFunctionMapT, + ConstrainerFunctionT, + VariableIDMap, decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) - > + >, CallbacksT >; + + template + using SetCellCollapsedCallback = Builder>; + template + using SetContradictionCallback = Builder>; + template + using SetBranchCallback = Builder>; - using Build = WFC; + using Build = WFC; }; } // namespace WFC diff --git a/include/nd-wfc/wfc_allocator.hpp b/include/nd-wfc/wfc_allocator.hpp index fd64e04..d3ee329 100644 --- a/include/nd-wfc/wfc_allocator.hpp +++ b/include/nd-wfc/wfc_allocator.hpp @@ -10,7 +10,7 @@ #include #include -//#define WFC_USE_STACK_ALLOCATOR +#define WFC_USE_STACK_ALLOCATOR inline void* allocate_aligned_memory(size_t alignment, size_t size) { #ifdef WFC_USE_STACK_ALLOCATOR