Callback System

This commit is contained in:
2025-08-31 17:42:06 +09:00
parent bd7b27eb18
commit 8dc7bcd618
9 changed files with 276 additions and 84 deletions

3
.cursorindexingignore Normal file
View File

@@ -0,0 +1,3 @@
# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references
.specstory/**

View File

@@ -6,6 +6,11 @@ set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
# Ensure consistent runtime library settings for MSVC
if(MSVC)
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL")
endif()
# Enable testing # Enable testing
enable_testing() enable_testing()

View File

@@ -17,14 +17,15 @@ else()
add_compile_options(-Wall -Wextra -Wpedantic) add_compile_options(-Wall -Wextra -Wpedantic)
endif() endif()
# Find Google Test (optional) # Use FetchContent to get Google Test for consistent builds
find_package(GTest) include(FetchContent)
if(GTest_FOUND) FetchContent_Declare(
set(HAS_GTEST TRUE) googletest
else() URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
set(HAS_GTEST FALSE) )
message(WARNING "Google Test not found. Tests will not be built.") FetchContent_MakeAvailable(googletest)
endif() set(HAS_GTEST TRUE)
message(STATUS "Using Google Test via FetchContent")
# Find Google Benchmark (optional) # Find Google Benchmark (optional)
find_package(benchmark) find_package(benchmark)
@@ -39,6 +40,14 @@ endif()
set(THREADS_PREFER_PTHREAD_FLAG ON) set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads) find_package(Threads)
# Ensure consistent runtime library settings for MSVC
if(MSVC)
# Force all targets to use the same runtime library
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL")
# Ensure Google Test uses the same runtime library
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
endif()
# Create the main executable # Create the main executable
add_executable(sudoku_demo add_executable(sudoku_demo
main.cpp main.cpp
@@ -69,6 +78,14 @@ target_link_libraries(sudoku_wfc_demo PRIVATE nd-wfc)
target_link_libraries(analyze_failing_puzzles PRIVATE nd-wfc) target_link_libraries(analyze_failing_puzzles PRIVATE nd-wfc)
target_link_libraries(debug_failing_puzzles PRIVATE nd-wfc) target_link_libraries(debug_failing_puzzles PRIVATE nd-wfc)
# Ensure consistent runtime library settings for all executables
if(MSVC)
target_compile_options(sudoku_demo PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
target_compile_options(sudoku_wfc_demo PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
target_compile_options(analyze_failing_puzzles PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
target_compile_options(debug_failing_puzzles PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
endif()
# Set output directory for sudoku_demo # Set output directory for sudoku_demo
set_target_properties(sudoku_demo PROPERTIES set_target_properties(sudoku_demo PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
@@ -123,7 +140,16 @@ if(HAS_GTEST)
test_sudoku.cpp test_sudoku.cpp
) )
target_link_libraries(sudoku_tests GTest::gtest GTest::gtest_main nd-wfc) target_link_libraries(sudoku_tests gtest gtest_main nd-wfc)
# Ensure consistent runtime library settings for test executable
if(MSVC)
target_compile_options(sudoku_tests PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
# Ensure Google Test targets use consistent runtime library
set_target_properties(gtest gtest_main gmock gmock_main PROPERTIES
MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL"
)
endif()
if(Threads_FOUND) if(Threads_FOUND)
target_link_libraries(sudoku_tests Threads::Threads) target_link_libraries(sudoku_tests Threads::Threads)
endif() endif()

View File

@@ -6,7 +6,7 @@ int main()
std::cout << "Sudoku Demo" << std::endl; std::cout << "Sudoku Demo" << std::endl;
// Create a simple sudoku puzzle // Create a simple sudoku puzzle
Sudoku sudoku("530070000600195000098000060800060003400803001700020006060000280000419005000080079"); Sudoku sudoku("140000050700200000000300204200080400080090020006050001809001000000006007050000069");
if (sudoku.isValid()) { if (sudoku.isValid()) {
std::cout << "Loaded valid sudoku puzzle:" << std::endl; std::cout << "Loaded valid sudoku puzzle:" << std::endl;

View File

@@ -200,11 +200,11 @@ public: // WFC Support
using ValueType = uint8_t; using ValueType = uint8_t;
ValueType getValue(size_t index) const { ValueType getValue(size_t index) const {
return board_.get(index); return board_.get(static_cast<int>(index));
} }
void setValue(size_t index, ValueType value) { void setValue(size_t index, ValueType value) {
board_.set(index, value); board_.set(static_cast<int>(index), value);
} }
constexpr size_t size() const { constexpr size_t size() const {
@@ -235,35 +235,37 @@ private:
static bool parseLine(const std::string& line, std::array<uint8_t, 81>& board); static bool parseLine(const std::string& line, std::array<uint8_t, 81>& board);
}; };
using SudokuSolver = WFC::Builder<Sudoku>
::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>
::DefineConstrainer<decltype([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
size_t x = index % 9;
size_t y = index / 9;
// Add row constraints (same row, different columns) using SudokuSolverBuilder = WFC::Builder<Sudoku>
for (size_t i = 0; i < 9; ++i) { ::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>
if (i != x) constrainer.Exclude(val, i + y * 9); ::DefineConstrainer<decltype([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
} size_t x = index % 9;
size_t y = index / 9;
// Add column constraints (same column, different rows) // Add row constraints (same row, different columns)
for (size_t i = 0; i < 9; ++i) { for (size_t i = 0; i < 9; ++i) {
if (i != y) constrainer.Exclude(val,x + i * 9); if (i != x) constrainer.Exclude(val, i + y * 9);
} }
// Add box constraints (3x3 box) // Add column constraints (same column, different rows)
size_t box_x = (x / 3) * 3; for (size_t i = 0; i < 9; ++i) {
size_t box_y = (y / 3) * 3; if (i != y) constrainer.Exclude(val,x + i * 9);
for (size_t j = 0; j < 3; ++j) { }
for (size_t k = 0; k < 3; ++k) {
size_t cell_x = box_x + j; // Add box constraints (3x3 box)
size_t cell_y = box_y + k; size_t box_x = (x / 3) * 3;
size_t cell_index = cell_x + cell_y * 9; size_t box_y = (y / 3) * 3;
if (cell_index != index) { for (size_t j = 0; j < 3; ++j) {
constrainer.Exclude(val, cell_index); for (size_t k = 0; k < 3; ++k) {
} size_t cell_x = box_x + j;
size_t cell_y = box_y + k;
size_t cell_index = cell_x + cell_y * 9;
if (cell_index != index) {
constrainer.Exclude(val, cell_index);
} }
} }
}
}), 1, 2, 3, 4, 5, 6, 7, 8, 9>
::Build; }), 1, 2, 3, 4, 5, 6, 7, 8, 9>;
using SudokuSolver = SudokuSolverBuilder::Build;

View File

@@ -1,28 +1,88 @@
#include <nd-wfc/wfc.hpp> #include <nd-wfc/wfc.hpp>
#include "sudoku.h" #include "sudoku.h"
#include <iostream> #include <iostream>
#include <fstream>
int main() // Helper function to load multiple puzzles from a file (one puzzle per line)
{ std::vector<Sudoku> loadPuzzlesFromFile(const std::string& filename) {
std::cout << "Running Sudoku WFC" << std::endl; std::vector<Sudoku> puzzles;
std::ifstream file(filename);
Sudoku sudokuWorld; if (!file.is_open()) {
sudokuWorld.setValue(0, 5); return puzzles;
sudokuWorld.setValue(80, 1); }
bool success = SudokuSolver::Run(sudokuWorld, true);
if (success) { std::string line;
std::cout << "Sudoku solved successfully!" << std::endl; while (std::getline(file, line)) {
// Print the solved sudoku // Remove whitespace
line.erase(std::remove_if(line.begin(), line.end(),
[](char c) { return std::isspace(c); }), line.end());
if (line.empty()) continue;
Sudoku sudoku;
if (sudoku.loadFromString(line)) {
puzzles.push_back(std::move(sudoku));
}
}
return puzzles;
}
using SudokuSolverCallback = SudokuSolverBuilder::SetCellCollapsedCallback<decltype([](Sudoku& sudoku)
{
static Sudoku LastSudoku{};
static int counter = 0;
for (size_t y = 0; y < 9; ++y) { for (size_t y = 0; y < 9; ++y) {
for (size_t x = 0; x < 9; ++x) { for (size_t x = 0; x < 9; ++x) {
std::cout << static_cast<int>(sudokuWorld.getValue(x + y * 9)) << " "; int current = static_cast<int>(sudoku.getValue(x + y * 9));
int last = static_cast<int>(LastSudoku.getValue(x + y * 9));
if (current != last) {
std::cout << "\033[31m" << current << "\033[0m ";
} else {
std::cout << current << " ";
}
if (x == 2 || x == 5) std::cout << "| "; if (x == 2 || x == 5) std::cout << "| ";
} }
std::cout << std::endl; std::cout << std::endl;
if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl; if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl;
} }
LastSudoku = sudoku;
std::cout << "Iteration: " << counter << std::endl;
counter++;
// std::cout << std::endl;
// std::cout << "Press Enter to continue..." << std::endl;
// std::cin.get();
})>
::Build;
int main()
{
std::cout << "Running Sudoku WFC" << std::endl;
Sudoku sudokuWorld{ "040280030010006007609070008000092000900000004000740000500020803400800010070035090" };
bool success = SudokuSolverCallback::Run(sudokuWorld, true);
bool solved = sudokuWorld.isSolved();
if (success && solved) {
std::cout << "Sudoku solved successfully!" << std::endl;
} else { } else {
std::cout << "Failed to solve sudoku!" << std::endl; std::cout << "Failed to solve sudoku!" << std::endl;
} }
// Print the solved sudoku
for (size_t y = 0; y < 9; ++y) {
for (size_t x = 0; x < 9; ++x) {
std::cout << static_cast<int>(sudokuWorld.getValue(x + y * 9)) << " ";
if (x == 2 || x == 5) std::cout << "| ";
}
std::cout << std::endl;
if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl;
}
} }

View File

@@ -282,7 +282,7 @@ TEST_F(SudokuTest, WFCIntegration)
// Tests loading and solving puzzles from data files // Tests loading and solving puzzles from data files
TEST_F(SudokuTest, LoadAndSolveEasyPuzzles) TEST_F(SudokuTest, LoadAndSolveEasyPuzzles)
{ {
std::vector<Sudoku> easyPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_easy.txt"); std::vector<Sudoku> easyPuzzles = loadPuzzlesFromFile("../data/Sudoku_easy.txt");
ASSERT_GT(easyPuzzles.size(), 0) << "No easy puzzles loaded"; ASSERT_GT(easyPuzzles.size(), 0) << "No easy puzzles loaded";
@@ -320,7 +320,7 @@ TEST_F(SudokuTest, LoadAndSolveEasyPuzzles)
TEST_F(SudokuTest, LoadAndSolveMediumPuzzles) TEST_F(SudokuTest, LoadAndSolveMediumPuzzles)
{ {
std::vector<Sudoku> mediumPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_medium.txt"); std::vector<Sudoku> mediumPuzzles = loadPuzzlesFromFile("../data/Sudoku_medium.txt");
ASSERT_GT(mediumPuzzles.size(), 0) << "No medium puzzles loaded"; ASSERT_GT(mediumPuzzles.size(), 0) << "No medium puzzles loaded";
@@ -358,7 +358,7 @@ TEST_F(SudokuTest, LoadAndSolveMediumPuzzles)
TEST_F(SudokuTest, LoadAndSolveHardPuzzles) TEST_F(SudokuTest, LoadAndSolveHardPuzzles)
{ {
std::vector<Sudoku> hardPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_hard.txt"); std::vector<Sudoku> hardPuzzles = loadPuzzlesFromFile("../data/Sudoku_hard.txt");
ASSERT_GT(hardPuzzles.size(), 0) << "No hard puzzles loaded"; ASSERT_GT(hardPuzzles.size(), 0) << "No hard puzzles loaded";
@@ -396,7 +396,7 @@ TEST_F(SudokuTest, LoadAndSolveHardPuzzles)
TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles) TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles)
{ {
std::vector<Sudoku> diabolicalPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_diabolical.txt"); std::vector<Sudoku> diabolicalPuzzles = loadPuzzlesFromFile("../data/Sudoku_diabolical.txt");
ASSERT_GT(diabolicalPuzzles.size(), 0) << "No diabolical puzzles loaded"; ASSERT_GT(diabolicalPuzzles.size(), 0) << "No diabolical puzzles loaded";
@@ -435,7 +435,7 @@ TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles)
// Test loading a single puzzle from each difficulty file // Test loading a single puzzle from each difficulty file
TEST_F(SudokuTest, LoadAndSolveFirstPuzzleFromEachFile) TEST_F(SudokuTest, LoadAndSolveFirstPuzzleFromEachFile)
{ {
const std::string dataPath = "/home/connor/repos/nd-wfc/demos/sudoku/data"; const std::string dataPath = "../data";
const std::vector<std::string> files = {"Sudoku_easy.txt", "Sudoku_medium.txt", "Sudoku_hard.txt", "Sudoku_diabolical.txt"}; const std::vector<std::string> files = {"Sudoku_easy.txt", "Sudoku_medium.txt", "Sudoku_hard.txt", "Sudoku_diabolical.txt"};
for (const auto& filename : files) { for (const auto& filename : files) {

View File

@@ -297,10 +297,51 @@ struct VariableData {
{} {}
}; };
/**
* @brief Empty callback function
* @param World The world type
*/
template <typename World>
using EmptyCallback = decltype([](World&){});
/**
* @brief Callback struct
* @param WorldT The world type
* @param AllCellsCollapsedCallbackT The all cells collapsed callback type
* @param CellCollapsedCallbackT The cell collapsed callback type
* @param ContradictionCallbackT The contradiction callback type
* @param BranchCallbackT The branch callback type
*/
template <typename WorldT,
typename CellCollapsedCallbackT = EmptyCallback<WorldT>,
typename ContradictionCallbackT = EmptyCallback<WorldT>,
typename BranchCallbackT = EmptyCallback<WorldT>
>
struct Callbacks
{
using CellCollapsedCallback = CellCollapsedCallbackT;
using ContradictionCallback = ContradictionCallbackT;
using BranchCallback = BranchCallbackT;
template <typename NewCellCollapsedCallbackT>
using SetCellCollapsedCallbackT = Callbacks<WorldT, NewCellCollapsedCallbackT, ContradictionCallbackT, BranchCallbackT>;
template <typename NewContradictionCallbackT>
using SetContradictionCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, NewContradictionCallbackT, BranchCallbackT>;
template <typename NewBranchCallbackT>
using SetBranchCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, ContradictionCallbackT, NewBranchCallbackT>;
static consteval bool HasCellCollapsedCallback() { return !std::is_same_v<CellCollapsedCallbackT, EmptyCallback<WorldT>>; }
static consteval bool HasContradictionCallback() { return !std::is_same_v<ContradictionCallbackT, EmptyCallback<WorldT>>; }
static consteval bool HasBranchCallback() { return !std::is_same_v<BranchCallbackT, EmptyCallback<WorldT>>; }
};
/** /**
* @brief Main WFC class implementing the Wave Function Collapse algorithm * @brief Main WFC class implementing the Wave Function Collapse algorithm
*/ */
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>> template<typename WorldT, typename VarT,
typename VariableIDMapT = VariableIDMap<VarT>,
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>>
class WFC { class WFC {
public: public:
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements"); static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
@@ -308,7 +349,8 @@ public:
using MaskType = typename VariableIDMapT::MaskType; using MaskType = typename VariableIDMapT::MaskType;
public: public:
struct SolverState { struct SolverState
{
WorldT& world; WorldT& world;
WFCQueue<size_t> propagationQueue; WFCQueue<size_t> propagationQueue;
Wave<MaskType> wave; Wave<MaskType> wave;
@@ -329,17 +371,25 @@ public:
}; };
public: public:
WFC() = delete; WFC() = delete; // dont make an instance of this class, only use the static methods.
public: public:
static bool Run(WorldT& world, bool propagateInitialValues = true) static bool Run(WorldT& world, bool propagateInitialValues = true, WFCStackAllocator* allocator = nullptr)
{ {
WFCStackAllocator allocator{}; //std::mt19937 random{ 212 };
// std::mt19937 random{ std::random_device{}() }; // Using random values fails 25% of the time std::mt19937 random{ std::random_device{}() };
std::mt19937 random{ 212 };
size_t iterations = 0; size_t iterations = 0;
SolverState state(world, ConstrainerFunctionMapT::size(), random, allocator, iterations); if (!allocator)
return Run(state, propagateInitialValues); {
WFCStackAllocator newAllocator{};
SolverState state(world, ConstrainerFunctionMapT::size(), random, newAllocator, iterations);
return Run(state, propagateInitialValues);
}
else
{
SolverState state(world, ConstrainerFunctionMapT::size(), random, *allocator, iterations);
return Run(state, propagateInitialValues);
}
} }
/** /**
@@ -363,18 +413,32 @@ public:
static bool RunLoop(SolverState& state) static bool RunLoop(SolverState& state)
{ {
for (; state.iterations < 256; ++state.iterations) for (; state.iterations < 1024; ++state.iterations)
{ {
if (!Propagate(state)) if (!Propagate(state))
return false; return false;
if (state.wave.HasContradiction()) if (state.wave.HasContradiction())
{
if constexpr (CallbacksT::HasContradictionCallback())
{
PopulateWorld(state);
typename CallbacksT::ContradictionCallback{}(state.world);
}
return false; return false;
}
if (state.wave.IsFullyCollapsed()) if (state.wave.IsFullyCollapsed())
return true; return true;
Branch(state); if constexpr (CallbacksT::HasBranchCallback())
{
PopulateWorld(state);
typename CallbacksT::BranchCallback{}(state.world);
}
if (Branch(state))
return true;
} }
return false; return false;
} }
@@ -408,6 +472,19 @@ public:
} }
private: private:
static void CollapseCell(SolverState& state, size_t cellId, uint16_t value)
{
assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (1 << value));
state.wave.Collapse(cellId, 1 << value);
assert(state.wave.IsCollapsed(cellId));
if constexpr (CallbacksT::HasCellCollapsedCallback())
{
PopulateWorld(state);
typename CallbacksT::CellCollapsedCallback{}(state.world);
}
}
static bool Branch(SolverState& state) static bool Branch(SolverState& state)
{ {
assert(state.propagationQueue.empty()); assert(state.propagationQueue.empty());
@@ -428,12 +505,12 @@ private:
assert(!state.wave.IsCollapsed(minEntropyCell)); assert(!state.wave.IsCollapsed(minEntropyCell));
// create a list of possible values // create a list of possible values
uint16_t availableValues = state.wave.Entropy(minEntropyCell); uint16_t availableValues = static_cast<uint16_t>(state.wave.Entropy(minEntropyCell));
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
MaskType mask = state.wave.GetMask(minEntropyCell); MaskType mask = state.wave.GetMask(minEntropyCell);
for (size_t i = 0; i < availableValues; ++i) for (size_t i = 0; i < availableValues; ++i)
{ {
uint16_t index = std::countr_zero(mask); // get the index of the lowest set bit uint16_t index = static_cast<uint16_t>(std::countr_zero(mask)); // get the index of the lowest set bit
assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds"); assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds");
possibleValues[i] = index; possibleValues[i] = index;
@@ -443,16 +520,17 @@ private:
} }
// randomly select a value from possible values // randomly select a value from possible values
for (size_t i = 0; i < availableValues; ++i) while (availableValues)
{ {
std::uniform_int_distribution<size_t> dist(0, availableValues - 1); std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
size_t selectedValue = possibleValues[dist(state.rng)]; size_t randomIndex = dist(state.rng);
size_t selectedValue = possibleValues[randomIndex];
{ {
// copy the state and branch out // copy the state and branch out
auto stackFrame = state.allocator.createFrame(); auto stackFrame = state.allocator.createFrame();
SolverState newState(state); SolverState newState(state);
newState.wave.Collapse(minEntropyCell, 1 << selectedValue); CollapseCell(newState, minEntropyCell, static_cast<uint16_t>(selectedValue));
newState.propagationQueue.push(minEntropyCell); newState.propagationQueue.push(minEntropyCell);
if (RunLoop(newState)) if (RunLoop(newState))
@@ -470,7 +548,7 @@ private:
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly"); assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly");
// swap replacement value with the last value // swap replacement value with the last value
std::swap(possibleValues[i], possibleValues[--availableValues]); std::swap(possibleValues[randomIndex], possibleValues[--availableValues]);
} }
return false; return false;
@@ -501,7 +579,8 @@ private:
{ {
for (size_t i = 0; i < state.wave.size(); ++i) for (size_t i = 0; i < state.wave.size(); ++i)
{ {
state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i))); if (state.wave.IsCollapsed(i))
state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i)));
} }
} }
@@ -514,7 +593,7 @@ private:
{ {
if (state.world.getValue(i) == allValues[j]) if (state.world.getValue(i) == allValues[j])
{ {
state.wave.Collapse(i, 1 << j); CollapseCell(state, static_cast<uint16_t>(i), static_cast<uint16_t>(j));
state.propagationQueue.push(i); state.propagationQueue.push(i);
break; break;
} }
@@ -523,28 +602,45 @@ private:
} }
}; };
/**
* @brief Concept to validate constrainer function signature
* The function must be callable with parameters: (WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)
*/
template <typename T, typename WorldT, typename VarT, typename VariableIDMapT>
concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue<VarT> value, Constrainer<VariableIDMapT>& constrainer) {
func(world, index, value, constrainer);
};
/** /**
* @brief Builder class for creating WFC instances * @brief Builder class for creating WFC instances
*/ */
template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>> template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>, typename CallbacksT = Callbacks<WorldT>>
class Builder { class Builder {
public: public:
template <VarT ... Values> template <VarT ... Values>
using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT>; using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT, CallbacksT>;
template <typename ConstrainerFunctionT, VarT ... CorrespondingValues> template <typename ConstrainerFunctionT, VarT ... CorrespondingValues>
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT, requires ConstrainerFunction<ConstrainerFunctionT, WorldT, VarT, VariableIDMapT>
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT,
MergedConstrainerFunctionMap< MergedConstrainerFunctionMap<
VariableIDMapT, VariableIDMapT,
ConstrainerFunctionMapT, ConstrainerFunctionMapT,
ConstrainerFunctionT, ConstrainerFunctionT,
VariableIDMap<VarT, CorrespondingValues...>, VariableIDMap<VarT, CorrespondingValues...>,
decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {}) decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {})
> >, CallbacksT
>; >;
template <typename NewCellCollapsedCallbackT>
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>>;
template <typename NewContradictionCallbackT>
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>>;
template <typename NewBranchCallbackT>
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT>; using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT>;
}; };
} // namespace WFC } // namespace WFC

View File

@@ -10,7 +10,7 @@
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
//#define WFC_USE_STACK_ALLOCATOR #define WFC_USE_STACK_ALLOCATOR
inline void* allocate_aligned_memory(size_t alignment, size_t size) { inline void* allocate_aligned_memory(size_t alignment, size_t size) {
#ifdef WFC_USE_STACK_ALLOCATOR #ifdef WFC_USE_STACK_ALLOCATOR