Callback System
This commit is contained in:
3
.cursorindexingignore
Normal file
3
.cursorindexingignore
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
|
||||||
|
# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references
|
||||||
|
.specstory/**
|
||||||
@@ -6,6 +6,11 @@ set(CMAKE_CXX_STANDARD 20)
|
|||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||||
|
|
||||||
|
# Ensure consistent runtime library settings for MSVC
|
||||||
|
if(MSVC)
|
||||||
|
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Enable testing
|
# Enable testing
|
||||||
enable_testing()
|
enable_testing()
|
||||||
|
|
||||||
|
|||||||
@@ -17,14 +17,15 @@ else()
|
|||||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Find Google Test (optional)
|
# Use FetchContent to get Google Test for consistent builds
|
||||||
find_package(GTest)
|
include(FetchContent)
|
||||||
if(GTest_FOUND)
|
FetchContent_Declare(
|
||||||
set(HAS_GTEST TRUE)
|
googletest
|
||||||
else()
|
URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
|
||||||
set(HAS_GTEST FALSE)
|
)
|
||||||
message(WARNING "Google Test not found. Tests will not be built.")
|
FetchContent_MakeAvailable(googletest)
|
||||||
endif()
|
set(HAS_GTEST TRUE)
|
||||||
|
message(STATUS "Using Google Test via FetchContent")
|
||||||
|
|
||||||
# Find Google Benchmark (optional)
|
# Find Google Benchmark (optional)
|
||||||
find_package(benchmark)
|
find_package(benchmark)
|
||||||
@@ -39,6 +40,14 @@ endif()
|
|||||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||||
find_package(Threads)
|
find_package(Threads)
|
||||||
|
|
||||||
|
# Ensure consistent runtime library settings for MSVC
|
||||||
|
if(MSVC)
|
||||||
|
# Force all targets to use the same runtime library
|
||||||
|
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL")
|
||||||
|
# Ensure Google Test uses the same runtime library
|
||||||
|
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Create the main executable
|
# Create the main executable
|
||||||
add_executable(sudoku_demo
|
add_executable(sudoku_demo
|
||||||
main.cpp
|
main.cpp
|
||||||
@@ -69,6 +78,14 @@ target_link_libraries(sudoku_wfc_demo PRIVATE nd-wfc)
|
|||||||
target_link_libraries(analyze_failing_puzzles PRIVATE nd-wfc)
|
target_link_libraries(analyze_failing_puzzles PRIVATE nd-wfc)
|
||||||
target_link_libraries(debug_failing_puzzles PRIVATE nd-wfc)
|
target_link_libraries(debug_failing_puzzles PRIVATE nd-wfc)
|
||||||
|
|
||||||
|
# Ensure consistent runtime library settings for all executables
|
||||||
|
if(MSVC)
|
||||||
|
target_compile_options(sudoku_demo PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
|
||||||
|
target_compile_options(sudoku_wfc_demo PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
|
||||||
|
target_compile_options(analyze_failing_puzzles PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
|
||||||
|
target_compile_options(debug_failing_puzzles PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Set output directory for sudoku_demo
|
# Set output directory for sudoku_demo
|
||||||
set_target_properties(sudoku_demo PROPERTIES
|
set_target_properties(sudoku_demo PROPERTIES
|
||||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
||||||
@@ -123,7 +140,16 @@ if(HAS_GTEST)
|
|||||||
test_sudoku.cpp
|
test_sudoku.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(sudoku_tests GTest::gtest GTest::gtest_main nd-wfc)
|
target_link_libraries(sudoku_tests gtest gtest_main nd-wfc)
|
||||||
|
|
||||||
|
# Ensure consistent runtime library settings for test executable
|
||||||
|
if(MSVC)
|
||||||
|
target_compile_options(sudoku_tests PRIVATE $<$<CONFIG:Debug>:/MDd> $<$<CONFIG:Release>:/MD>)
|
||||||
|
# Ensure Google Test targets use consistent runtime library
|
||||||
|
set_target_properties(gtest gtest_main gmock gmock_main PROPERTIES
|
||||||
|
MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
if(Threads_FOUND)
|
if(Threads_FOUND)
|
||||||
target_link_libraries(sudoku_tests Threads::Threads)
|
target_link_libraries(sudoku_tests Threads::Threads)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ int main()
|
|||||||
std::cout << "Sudoku Demo" << std::endl;
|
std::cout << "Sudoku Demo" << std::endl;
|
||||||
|
|
||||||
// Create a simple sudoku puzzle
|
// Create a simple sudoku puzzle
|
||||||
Sudoku sudoku("530070000600195000098000060800060003400803001700020006060000280000419005000080079");
|
Sudoku sudoku("140000050700200000000300204200080400080090020006050001809001000000006007050000069");
|
||||||
|
|
||||||
if (sudoku.isValid()) {
|
if (sudoku.isValid()) {
|
||||||
std::cout << "Loaded valid sudoku puzzle:" << std::endl;
|
std::cout << "Loaded valid sudoku puzzle:" << std::endl;
|
||||||
|
|||||||
@@ -200,11 +200,11 @@ public: // WFC Support
|
|||||||
using ValueType = uint8_t;
|
using ValueType = uint8_t;
|
||||||
|
|
||||||
ValueType getValue(size_t index) const {
|
ValueType getValue(size_t index) const {
|
||||||
return board_.get(index);
|
return board_.get(static_cast<int>(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
void setValue(size_t index, ValueType value) {
|
void setValue(size_t index, ValueType value) {
|
||||||
board_.set(index, value);
|
board_.set(static_cast<int>(index), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t size() const {
|
constexpr size_t size() const {
|
||||||
@@ -235,7 +235,8 @@ private:
|
|||||||
static bool parseLine(const std::string& line, std::array<uint8_t, 81>& board);
|
static bool parseLine(const std::string& line, std::array<uint8_t, 81>& board);
|
||||||
};
|
};
|
||||||
|
|
||||||
using SudokuSolver = WFC::Builder<Sudoku>
|
|
||||||
|
using SudokuSolverBuilder = WFC::Builder<Sudoku>
|
||||||
::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>
|
::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>
|
||||||
::DefineConstrainer<decltype([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
::DefineConstrainer<decltype([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||||
size_t x = index % 9;
|
size_t x = index % 9;
|
||||||
@@ -265,5 +266,6 @@ using SudokuSolver = WFC::Builder<Sudoku>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}), 1, 2, 3, 4, 5, 6, 7, 8, 9>
|
}), 1, 2, 3, 4, 5, 6, 7, 8, 9>;
|
||||||
::Build;
|
|
||||||
|
using SudokuSolver = SudokuSolverBuilder::Build;
|
||||||
@@ -1,18 +1,79 @@
|
|||||||
#include <nd-wfc/wfc.hpp>
|
#include <nd-wfc/wfc.hpp>
|
||||||
#include "sudoku.h"
|
#include "sudoku.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
// Helper function to load multiple puzzles from a file (one puzzle per line)
|
||||||
|
std::vector<Sudoku> loadPuzzlesFromFile(const std::string& filename) {
|
||||||
|
std::vector<Sudoku> puzzles;
|
||||||
|
std::ifstream file(filename);
|
||||||
|
|
||||||
|
if (!file.is_open()) {
|
||||||
|
return puzzles;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(file, line)) {
|
||||||
|
// Remove whitespace
|
||||||
|
line.erase(std::remove_if(line.begin(), line.end(),
|
||||||
|
[](char c) { return std::isspace(c); }), line.end());
|
||||||
|
|
||||||
|
if (line.empty()) continue;
|
||||||
|
|
||||||
|
Sudoku sudoku;
|
||||||
|
if (sudoku.loadFromString(line)) {
|
||||||
|
puzzles.push_back(std::move(sudoku));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return puzzles;
|
||||||
|
}
|
||||||
|
|
||||||
|
using SudokuSolverCallback = SudokuSolverBuilder::SetCellCollapsedCallback<decltype([](Sudoku& sudoku)
|
||||||
|
{
|
||||||
|
static Sudoku LastSudoku{};
|
||||||
|
static int counter = 0;
|
||||||
|
for (size_t y = 0; y < 9; ++y) {
|
||||||
|
for (size_t x = 0; x < 9; ++x) {
|
||||||
|
int current = static_cast<int>(sudoku.getValue(x + y * 9));
|
||||||
|
int last = static_cast<int>(LastSudoku.getValue(x + y * 9));
|
||||||
|
if (current != last) {
|
||||||
|
std::cout << "\033[31m" << current << "\033[0m ";
|
||||||
|
} else {
|
||||||
|
std::cout << current << " ";
|
||||||
|
}
|
||||||
|
if (x == 2 || x == 5) std::cout << "| ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl;
|
||||||
|
}
|
||||||
|
LastSudoku = sudoku;
|
||||||
|
|
||||||
|
std::cout << "Iteration: " << counter << std::endl;
|
||||||
|
counter++;
|
||||||
|
|
||||||
|
// std::cout << std::endl;
|
||||||
|
// std::cout << "Press Enter to continue..." << std::endl;
|
||||||
|
// std::cin.get();
|
||||||
|
})>
|
||||||
|
::Build;
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
std::cout << "Running Sudoku WFC" << std::endl;
|
std::cout << "Running Sudoku WFC" << std::endl;
|
||||||
|
|
||||||
Sudoku sudokuWorld;
|
Sudoku sudokuWorld{ "040280030010006007609070008000092000900000004000740000500020803400800010070035090" };
|
||||||
sudokuWorld.setValue(0, 5);
|
|
||||||
sudokuWorld.setValue(80, 1);
|
|
||||||
bool success = SudokuSolver::Run(sudokuWorld, true);
|
|
||||||
|
|
||||||
if (success) {
|
bool success = SudokuSolverCallback::Run(sudokuWorld, true);
|
||||||
|
|
||||||
|
bool solved = sudokuWorld.isSolved();
|
||||||
|
|
||||||
|
if (success && solved) {
|
||||||
std::cout << "Sudoku solved successfully!" << std::endl;
|
std::cout << "Sudoku solved successfully!" << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << "Failed to solve sudoku!" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
// Print the solved sudoku
|
// Print the solved sudoku
|
||||||
for (size_t y = 0; y < 9; ++y) {
|
for (size_t y = 0; y < 9; ++y) {
|
||||||
for (size_t x = 0; x < 9; ++x) {
|
for (size_t x = 0; x < 9; ++x) {
|
||||||
@@ -22,7 +83,6 @@ int main()
|
|||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl;
|
if (y == 2 || y == 5) std::cout << "------+-------+------" << std::endl;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
std::cout << "Failed to solve sudoku!" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -282,7 +282,7 @@ TEST_F(SudokuTest, WFCIntegration)
|
|||||||
// Tests loading and solving puzzles from data files
|
// Tests loading and solving puzzles from data files
|
||||||
TEST_F(SudokuTest, LoadAndSolveEasyPuzzles)
|
TEST_F(SudokuTest, LoadAndSolveEasyPuzzles)
|
||||||
{
|
{
|
||||||
std::vector<Sudoku> easyPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_easy.txt");
|
std::vector<Sudoku> easyPuzzles = loadPuzzlesFromFile("../data/Sudoku_easy.txt");
|
||||||
|
|
||||||
ASSERT_GT(easyPuzzles.size(), 0) << "No easy puzzles loaded";
|
ASSERT_GT(easyPuzzles.size(), 0) << "No easy puzzles loaded";
|
||||||
|
|
||||||
@@ -320,7 +320,7 @@ TEST_F(SudokuTest, LoadAndSolveEasyPuzzles)
|
|||||||
|
|
||||||
TEST_F(SudokuTest, LoadAndSolveMediumPuzzles)
|
TEST_F(SudokuTest, LoadAndSolveMediumPuzzles)
|
||||||
{
|
{
|
||||||
std::vector<Sudoku> mediumPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_medium.txt");
|
std::vector<Sudoku> mediumPuzzles = loadPuzzlesFromFile("../data/Sudoku_medium.txt");
|
||||||
|
|
||||||
ASSERT_GT(mediumPuzzles.size(), 0) << "No medium puzzles loaded";
|
ASSERT_GT(mediumPuzzles.size(), 0) << "No medium puzzles loaded";
|
||||||
|
|
||||||
@@ -358,7 +358,7 @@ TEST_F(SudokuTest, LoadAndSolveMediumPuzzles)
|
|||||||
|
|
||||||
TEST_F(SudokuTest, LoadAndSolveHardPuzzles)
|
TEST_F(SudokuTest, LoadAndSolveHardPuzzles)
|
||||||
{
|
{
|
||||||
std::vector<Sudoku> hardPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_hard.txt");
|
std::vector<Sudoku> hardPuzzles = loadPuzzlesFromFile("../data/Sudoku_hard.txt");
|
||||||
|
|
||||||
ASSERT_GT(hardPuzzles.size(), 0) << "No hard puzzles loaded";
|
ASSERT_GT(hardPuzzles.size(), 0) << "No hard puzzles loaded";
|
||||||
|
|
||||||
@@ -396,7 +396,7 @@ TEST_F(SudokuTest, LoadAndSolveHardPuzzles)
|
|||||||
|
|
||||||
TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles)
|
TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles)
|
||||||
{
|
{
|
||||||
std::vector<Sudoku> diabolicalPuzzles = loadPuzzlesFromFile("/home/connor/repos/nd-wfc/demos/sudoku/data/Sudoku_diabolical.txt");
|
std::vector<Sudoku> diabolicalPuzzles = loadPuzzlesFromFile("../data/Sudoku_diabolical.txt");
|
||||||
|
|
||||||
ASSERT_GT(diabolicalPuzzles.size(), 0) << "No diabolical puzzles loaded";
|
ASSERT_GT(diabolicalPuzzles.size(), 0) << "No diabolical puzzles loaded";
|
||||||
|
|
||||||
@@ -435,7 +435,7 @@ TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles)
|
|||||||
// Test loading a single puzzle from each difficulty file
|
// Test loading a single puzzle from each difficulty file
|
||||||
TEST_F(SudokuTest, LoadAndSolveFirstPuzzleFromEachFile)
|
TEST_F(SudokuTest, LoadAndSolveFirstPuzzleFromEachFile)
|
||||||
{
|
{
|
||||||
const std::string dataPath = "/home/connor/repos/nd-wfc/demos/sudoku/data";
|
const std::string dataPath = "../data";
|
||||||
const std::vector<std::string> files = {"Sudoku_easy.txt", "Sudoku_medium.txt", "Sudoku_hard.txt", "Sudoku_diabolical.txt"};
|
const std::vector<std::string> files = {"Sudoku_easy.txt", "Sudoku_medium.txt", "Sudoku_hard.txt", "Sudoku_diabolical.txt"};
|
||||||
|
|
||||||
for (const auto& filename : files) {
|
for (const auto& filename : files) {
|
||||||
|
|||||||
@@ -297,10 +297,51 @@ struct VariableData {
|
|||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Empty callback function
|
||||||
|
* @param World The world type
|
||||||
|
*/
|
||||||
|
template <typename World>
|
||||||
|
using EmptyCallback = decltype([](World&){});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Callback struct
|
||||||
|
* @param WorldT The world type
|
||||||
|
* @param AllCellsCollapsedCallbackT The all cells collapsed callback type
|
||||||
|
* @param CellCollapsedCallbackT The cell collapsed callback type
|
||||||
|
* @param ContradictionCallbackT The contradiction callback type
|
||||||
|
* @param BranchCallbackT The branch callback type
|
||||||
|
*/
|
||||||
|
template <typename WorldT,
|
||||||
|
typename CellCollapsedCallbackT = EmptyCallback<WorldT>,
|
||||||
|
typename ContradictionCallbackT = EmptyCallback<WorldT>,
|
||||||
|
typename BranchCallbackT = EmptyCallback<WorldT>
|
||||||
|
>
|
||||||
|
struct Callbacks
|
||||||
|
{
|
||||||
|
using CellCollapsedCallback = CellCollapsedCallbackT;
|
||||||
|
using ContradictionCallback = ContradictionCallbackT;
|
||||||
|
using BranchCallback = BranchCallbackT;
|
||||||
|
|
||||||
|
template <typename NewCellCollapsedCallbackT>
|
||||||
|
using SetCellCollapsedCallbackT = Callbacks<WorldT, NewCellCollapsedCallbackT, ContradictionCallbackT, BranchCallbackT>;
|
||||||
|
template <typename NewContradictionCallbackT>
|
||||||
|
using SetContradictionCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, NewContradictionCallbackT, BranchCallbackT>;
|
||||||
|
template <typename NewBranchCallbackT>
|
||||||
|
using SetBranchCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, ContradictionCallbackT, NewBranchCallbackT>;
|
||||||
|
|
||||||
|
static consteval bool HasCellCollapsedCallback() { return !std::is_same_v<CellCollapsedCallbackT, EmptyCallback<WorldT>>; }
|
||||||
|
static consteval bool HasContradictionCallback() { return !std::is_same_v<ContradictionCallbackT, EmptyCallback<WorldT>>; }
|
||||||
|
static consteval bool HasBranchCallback() { return !std::is_same_v<BranchCallbackT, EmptyCallback<WorldT>>; }
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Main WFC class implementing the Wave Function Collapse algorithm
|
* @brief Main WFC class implementing the Wave Function Collapse algorithm
|
||||||
*/
|
*/
|
||||||
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>>
|
template<typename WorldT, typename VarT,
|
||||||
|
typename VariableIDMapT = VariableIDMap<VarT>,
|
||||||
|
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
||||||
|
typename CallbacksT = Callbacks<WorldT>>
|
||||||
class WFC {
|
class WFC {
|
||||||
public:
|
public:
|
||||||
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
|
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
|
||||||
@@ -308,7 +349,8 @@ public:
|
|||||||
using MaskType = typename VariableIDMapT::MaskType;
|
using MaskType = typename VariableIDMapT::MaskType;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
struct SolverState {
|
struct SolverState
|
||||||
|
{
|
||||||
WorldT& world;
|
WorldT& world;
|
||||||
WFCQueue<size_t> propagationQueue;
|
WFCQueue<size_t> propagationQueue;
|
||||||
Wave<MaskType> wave;
|
Wave<MaskType> wave;
|
||||||
@@ -329,18 +371,26 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
WFC() = delete;
|
WFC() = delete; // dont make an instance of this class, only use the static methods.
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static bool Run(WorldT& world, bool propagateInitialValues = true)
|
static bool Run(WorldT& world, bool propagateInitialValues = true, WFCStackAllocator* allocator = nullptr)
|
||||||
{
|
{
|
||||||
WFCStackAllocator allocator{};
|
//std::mt19937 random{ 212 };
|
||||||
// std::mt19937 random{ std::random_device{}() }; // Using random values fails 25% of the time
|
std::mt19937 random{ std::random_device{}() };
|
||||||
std::mt19937 random{ 212 };
|
|
||||||
size_t iterations = 0;
|
size_t iterations = 0;
|
||||||
SolverState state(world, ConstrainerFunctionMapT::size(), random, allocator, iterations);
|
if (!allocator)
|
||||||
|
{
|
||||||
|
WFCStackAllocator newAllocator{};
|
||||||
|
SolverState state(world, ConstrainerFunctionMapT::size(), random, newAllocator, iterations);
|
||||||
return Run(state, propagateInitialValues);
|
return Run(state, propagateInitialValues);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
SolverState state(world, ConstrainerFunctionMapT::size(), random, *allocator, iterations);
|
||||||
|
return Run(state, propagateInitialValues);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Run the WFC algorithm to generate a solution
|
* @brief Run the WFC algorithm to generate a solution
|
||||||
@@ -363,18 +413,32 @@ public:
|
|||||||
|
|
||||||
static bool RunLoop(SolverState& state)
|
static bool RunLoop(SolverState& state)
|
||||||
{
|
{
|
||||||
for (; state.iterations < 256; ++state.iterations)
|
for (; state.iterations < 1024; ++state.iterations)
|
||||||
{
|
{
|
||||||
if (!Propagate(state))
|
if (!Propagate(state))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (state.wave.HasContradiction())
|
if (state.wave.HasContradiction())
|
||||||
|
{
|
||||||
|
if constexpr (CallbacksT::HasContradictionCallback())
|
||||||
|
{
|
||||||
|
PopulateWorld(state);
|
||||||
|
typename CallbacksT::ContradictionCallback{}(state.world);
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (state.wave.IsFullyCollapsed())
|
if (state.wave.IsFullyCollapsed())
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
Branch(state);
|
if constexpr (CallbacksT::HasBranchCallback())
|
||||||
|
{
|
||||||
|
PopulateWorld(state);
|
||||||
|
typename CallbacksT::BranchCallback{}(state.world);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Branch(state))
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -408,6 +472,19 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
static void CollapseCell(SolverState& state, size_t cellId, uint16_t value)
|
||||||
|
{
|
||||||
|
assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (1 << value));
|
||||||
|
state.wave.Collapse(cellId, 1 << value);
|
||||||
|
assert(state.wave.IsCollapsed(cellId));
|
||||||
|
|
||||||
|
if constexpr (CallbacksT::HasCellCollapsedCallback())
|
||||||
|
{
|
||||||
|
PopulateWorld(state);
|
||||||
|
typename CallbacksT::CellCollapsedCallback{}(state.world);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static bool Branch(SolverState& state)
|
static bool Branch(SolverState& state)
|
||||||
{
|
{
|
||||||
assert(state.propagationQueue.empty());
|
assert(state.propagationQueue.empty());
|
||||||
@@ -428,12 +505,12 @@ private:
|
|||||||
assert(!state.wave.IsCollapsed(minEntropyCell));
|
assert(!state.wave.IsCollapsed(minEntropyCell));
|
||||||
|
|
||||||
// create a list of possible values
|
// create a list of possible values
|
||||||
uint16_t availableValues = state.wave.Entropy(minEntropyCell);
|
uint16_t availableValues = static_cast<uint16_t>(state.wave.Entropy(minEntropyCell));
|
||||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
|
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
|
||||||
MaskType mask = state.wave.GetMask(minEntropyCell);
|
MaskType mask = state.wave.GetMask(minEntropyCell);
|
||||||
for (size_t i = 0; i < availableValues; ++i)
|
for (size_t i = 0; i < availableValues; ++i)
|
||||||
{
|
{
|
||||||
uint16_t index = std::countr_zero(mask); // get the index of the lowest set bit
|
uint16_t index = static_cast<uint16_t>(std::countr_zero(mask)); // get the index of the lowest set bit
|
||||||
assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds");
|
assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds");
|
||||||
|
|
||||||
possibleValues[i] = index;
|
possibleValues[i] = index;
|
||||||
@@ -443,16 +520,17 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// randomly select a value from possible values
|
// randomly select a value from possible values
|
||||||
for (size_t i = 0; i < availableValues; ++i)
|
while (availableValues)
|
||||||
{
|
{
|
||||||
std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
|
std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
|
||||||
size_t selectedValue = possibleValues[dist(state.rng)];
|
size_t randomIndex = dist(state.rng);
|
||||||
|
size_t selectedValue = possibleValues[randomIndex];
|
||||||
|
|
||||||
{
|
{
|
||||||
// copy the state and branch out
|
// copy the state and branch out
|
||||||
auto stackFrame = state.allocator.createFrame();
|
auto stackFrame = state.allocator.createFrame();
|
||||||
SolverState newState(state);
|
SolverState newState(state);
|
||||||
newState.wave.Collapse(minEntropyCell, 1 << selectedValue);
|
CollapseCell(newState, minEntropyCell, static_cast<uint16_t>(selectedValue));
|
||||||
newState.propagationQueue.push(minEntropyCell);
|
newState.propagationQueue.push(minEntropyCell);
|
||||||
|
|
||||||
if (RunLoop(newState))
|
if (RunLoop(newState))
|
||||||
@@ -470,7 +548,7 @@ private:
|
|||||||
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly");
|
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly");
|
||||||
|
|
||||||
// swap replacement value with the last value
|
// swap replacement value with the last value
|
||||||
std::swap(possibleValues[i], possibleValues[--availableValues]);
|
std::swap(possibleValues[randomIndex], possibleValues[--availableValues]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
@@ -501,6 +579,7 @@ private:
|
|||||||
{
|
{
|
||||||
for (size_t i = 0; i < state.wave.size(); ++i)
|
for (size_t i = 0; i < state.wave.size(); ++i)
|
||||||
{
|
{
|
||||||
|
if (state.wave.IsCollapsed(i))
|
||||||
state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i)));
|
state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -514,7 +593,7 @@ private:
|
|||||||
{
|
{
|
||||||
if (state.world.getValue(i) == allValues[j])
|
if (state.world.getValue(i) == allValues[j])
|
||||||
{
|
{
|
||||||
state.wave.Collapse(i, 1 << j);
|
CollapseCell(state, static_cast<uint16_t>(i), static_cast<uint16_t>(j));
|
||||||
state.propagationQueue.push(i);
|
state.propagationQueue.push(i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -523,17 +602,27 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Concept to validate constrainer function signature
|
||||||
|
* The function must be callable with parameters: (WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)
|
||||||
|
*/
|
||||||
|
template <typename T, typename WorldT, typename VarT, typename VariableIDMapT>
|
||||||
|
concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue<VarT> value, Constrainer<VariableIDMapT>& constrainer) {
|
||||||
|
func(world, index, value, constrainer);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Builder class for creating WFC instances
|
* @brief Builder class for creating WFC instances
|
||||||
*/
|
*/
|
||||||
template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>>
|
template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>, typename CallbacksT = Callbacks<WorldT>>
|
||||||
class Builder {
|
class Builder {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
template <VarT ... Values>
|
template <VarT ... Values>
|
||||||
using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT>;
|
using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT, CallbacksT>;
|
||||||
|
|
||||||
template <typename ConstrainerFunctionT, VarT ... CorrespondingValues>
|
template <typename ConstrainerFunctionT, VarT ... CorrespondingValues>
|
||||||
|
requires ConstrainerFunction<ConstrainerFunctionT, WorldT, VarT, VariableIDMapT>
|
||||||
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT,
|
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT,
|
||||||
MergedConstrainerFunctionMap<
|
MergedConstrainerFunctionMap<
|
||||||
VariableIDMapT,
|
VariableIDMapT,
|
||||||
@@ -541,10 +630,17 @@ public:
|
|||||||
ConstrainerFunctionT,
|
ConstrainerFunctionT,
|
||||||
VariableIDMap<VarT, CorrespondingValues...>,
|
VariableIDMap<VarT, CorrespondingValues...>,
|
||||||
decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {})
|
decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {})
|
||||||
>
|
>, CallbacksT
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT>;
|
template <typename NewCellCollapsedCallbackT>
|
||||||
|
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>>;
|
||||||
|
template <typename NewContradictionCallbackT>
|
||||||
|
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>>;
|
||||||
|
template <typename NewBranchCallbackT>
|
||||||
|
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>>;
|
||||||
|
|
||||||
|
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT>;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace WFC
|
} // namespace WFC
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
//#define WFC_USE_STACK_ALLOCATOR
|
#define WFC_USE_STACK_ALLOCATOR
|
||||||
|
|
||||||
inline void* allocate_aligned_memory(size_t alignment, size_t size) {
|
inline void* allocate_aligned_memory(size_t alignment, size_t size) {
|
||||||
#ifdef WFC_USE_STACK_ALLOCATOR
|
#ifdef WFC_USE_STACK_ALLOCATOR
|
||||||
|
|||||||
Reference in New Issue
Block a user