compile-time only solver
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(sudoku_demo CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Add the main project as a subdirectory to get the nd-wfc library
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../src ${CMAKE_CURRENT_BINARY_DIR}/nd-wfc)
|
||||
|
||||
# Enable testing
|
||||
enable_testing()
|
||||
|
||||
@@ -17,7 +20,6 @@ endif()
|
||||
# Find Google Test (optional)
|
||||
find_package(GTest)
|
||||
if(GTest_FOUND)
|
||||
include_directories(${GTEST_INCLUDE_DIRS})
|
||||
set(HAS_GTEST TRUE)
|
||||
else()
|
||||
set(HAS_GTEST FALSE)
|
||||
@@ -57,6 +59,12 @@ add_executable(debug_failing_puzzles
|
||||
sudoku.cpp
|
||||
)
|
||||
|
||||
# Link all executables to the nd-wfc library
|
||||
target_link_libraries(sudoku_demo PRIVATE nd-wfc)
|
||||
target_link_libraries(sudoku_wfc_demo PRIVATE nd-wfc)
|
||||
target_link_libraries(analyze_failing_puzzles PRIVATE nd-wfc)
|
||||
target_link_libraries(debug_failing_puzzles PRIVATE nd-wfc)
|
||||
|
||||
# Set output directory for sudoku_demo
|
||||
set_target_properties(sudoku_demo PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
||||
@@ -83,20 +91,11 @@ set_target_properties(debug_failing_puzzles PROPERTIES
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
)
|
||||
|
||||
# Include directories
|
||||
# Include directories (current source directory for local headers)
|
||||
target_include_directories(sudoku_demo PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_include_directories(sudoku_wfc_demo PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
target_include_directories(analyze_failing_puzzles PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
target_include_directories(debug_failing_puzzles PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
target_include_directories(sudoku_wfc_demo PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_include_directories(analyze_failing_puzzles PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_include_directories(debug_failing_puzzles PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# Optional: Enable optimizations for release builds
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
@@ -120,7 +119,7 @@ if(HAS_GTEST)
|
||||
test_sudoku.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(sudoku_tests ${GTEST_LIBRARIES} pthread)
|
||||
target_link_libraries(sudoku_tests GTest::gtest GTest::gtest_main nd-wfc pthread)
|
||||
|
||||
# Set test output directory
|
||||
set_target_properties(sudoku_tests PROPERTIES
|
||||
@@ -130,10 +129,7 @@ if(HAS_GTEST)
|
||||
)
|
||||
|
||||
# Include directories for tests
|
||||
target_include_directories(sudoku_tests PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
target_include_directories(sudoku_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# Add test to CTest
|
||||
add_test(NAME sudoku_tests COMMAND sudoku_tests)
|
||||
@@ -146,7 +142,7 @@ if(HAS_BENCHMARK)
|
||||
benchmark_sudoku.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(sudoku_benchmarks ${GTEST_LIBRARIES} benchmark::benchmark pthread)
|
||||
target_link_libraries(sudoku_benchmarks benchmark::benchmark nd-wfc pthread)
|
||||
|
||||
# Set benchmark output directory
|
||||
set_target_properties(sudoku_benchmarks PROPERTIES
|
||||
@@ -156,10 +152,7 @@ if(HAS_BENCHMARK)
|
||||
)
|
||||
|
||||
# Include directories for benchmarks
|
||||
target_include_directories(sudoku_benchmarks PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
target_include_directories(sudoku_benchmarks PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
endif()
|
||||
|
||||
# Installation (optional)
|
||||
|
||||
@@ -50,38 +50,6 @@ void saveFailingPuzzles(const std::vector<std::string>& failingPuzzles, const st
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Create WFC solver
|
||||
auto sudokuSolver = WFC::Builder<Sudoku, uint8_t>()
|
||||
.DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>()
|
||||
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
size_t x = index % 9;
|
||||
size_t y = index / 9;
|
||||
|
||||
// Add row constraints (same row, different columns)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != x) constrainer.Exclude(val, i + y * 9);
|
||||
}
|
||||
|
||||
// Add column constraints (same column, different rows)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != y) constrainer.Exclude(val,x + i * 9);
|
||||
}
|
||||
|
||||
// Add box constraints (3x3 box)
|
||||
size_t box_x = (x / 3) * 3;
|
||||
size_t box_y = (y / 3) * 3;
|
||||
for (size_t j = 0; j < 3; ++j) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
size_t cell_x = box_x + j;
|
||||
size_t cell_y = box_y + k;
|
||||
size_t cell_index = cell_x + cell_y * 9;
|
||||
if (cell_index != index) {
|
||||
constrainer.Exclude(val, cell_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.build();
|
||||
|
||||
// File paths
|
||||
const std::string dataPath = "/home/connor/repos/nd-wfc/demos/sudoku/data";
|
||||
@@ -131,7 +99,7 @@ int main() {
|
||||
|
||||
// Try to solve
|
||||
auto puzzleStart = std::chrono::high_resolution_clock::now();
|
||||
sudokuSolver.Run(sudoku, false); // false = disable verbose output for speed
|
||||
SudokuSolver::Run(sudoku, false); // false = disable verbose output for speed
|
||||
auto puzzleEnd = std::chrono::high_resolution_clock::now();
|
||||
|
||||
bool solved = sudoku.isSolved();
|
||||
|
||||
@@ -88,37 +88,6 @@ void analyzeFailingPuzzle(const Sudoku& originalPuzzle, int puzzleIndex) {
|
||||
std::cout << std::endl;
|
||||
|
||||
// Try different solving approaches
|
||||
auto sudokuSolver = WFC::Builder<Sudoku, uint8_t>()
|
||||
.DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>()
|
||||
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
size_t x = index % 9;
|
||||
size_t y = index / 9;
|
||||
|
||||
// Add row constraints (same row, different columns)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != x) constrainer.Exclude(val, i + y * 9);
|
||||
}
|
||||
|
||||
// Add column constraints (same column, different rows)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != y) constrainer.Exclude(val,x + i * 9);
|
||||
}
|
||||
|
||||
// Add box constraints (3x3 box)
|
||||
size_t box_x = (x / 3) * 3;
|
||||
size_t box_y = (y / 3) * 3;
|
||||
for (size_t j = 0; j < 3; ++j) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
size_t cell_x = box_x + j;
|
||||
size_t cell_y = box_y + k;
|
||||
size_t cell_index = cell_x + cell_y * 9;
|
||||
if (cell_index != index) {
|
||||
constrainer.Exclude(val, cell_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.build();
|
||||
|
||||
// Test 1: Try solving with different configurations
|
||||
std::cout << "Testing different solving approaches:" << std::endl;
|
||||
@@ -128,7 +97,7 @@ void analyzeFailingPuzzle(const Sudoku& originalPuzzle, int puzzleIndex) {
|
||||
std::cout << "\nTest 1: Solving with verbose output..." << std::endl;
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
sudokuSolver.Run(testPuzzle, true); // Enable verbose output
|
||||
SudokuSolver::Run(testPuzzle, true); // Enable verbose output
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
||||
|
||||
@@ -150,7 +119,7 @@ void analyzeFailingPuzzle(const Sudoku& originalPuzzle, int puzzleIndex) {
|
||||
std::cout << "Attempt " << attempt << ": ";
|
||||
|
||||
auto attemptStart = std::chrono::high_resolution_clock::now();
|
||||
sudokuSolver.Run(attemptPuzzle, false); // No verbose output
|
||||
SudokuSolver::Run(attemptPuzzle, false); // No verbose output
|
||||
auto attemptEnd = std::chrono::high_resolution_clock::now();
|
||||
auto attemptDuration = std::chrono::duration_cast<std::chrono::milliseconds>(attemptEnd - attemptStart);
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <nd-wfc/wfc.hpp>
|
||||
|
||||
// 4-bit packed Sudoku board storage - optimal packing
|
||||
// 81 cells * 4 bits = 324 bits
|
||||
@@ -233,3 +234,35 @@ public:
|
||||
private:
|
||||
static bool parseLine(const std::string& line, std::array<uint8_t, 81>& board);
|
||||
};
|
||||
|
||||
using SudokuSolver = WFC::Builder<Sudoku>
|
||||
::DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>
|
||||
::DefineConstrainer<decltype([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
size_t x = index % 9;
|
||||
size_t y = index / 9;
|
||||
|
||||
// Add row constraints (same row, different columns)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != x) constrainer.Exclude(val, i + y * 9);
|
||||
}
|
||||
|
||||
// Add column constraints (same column, different rows)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != y) constrainer.Exclude(val,x + i * 9);
|
||||
}
|
||||
|
||||
// Add box constraints (3x3 box)
|
||||
size_t box_x = (x / 3) * 3;
|
||||
size_t box_y = (y / 3) * 3;
|
||||
for (size_t j = 0; j < 3; ++j) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
size_t cell_x = box_x + j;
|
||||
size_t cell_y = box_y + k;
|
||||
size_t cell_index = cell_x + cell_y * 9;
|
||||
if (cell_index != index) {
|
||||
constrainer.Exclude(val, cell_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}), 1, 2, 3, 4, 5, 6, 7, 8, 9>
|
||||
::Build;
|
||||
@@ -6,42 +6,10 @@ int main()
|
||||
{
|
||||
std::cout << "Running Sudoku WFC" << std::endl;
|
||||
|
||||
auto sudokuSolver = WFC::Builder<Sudoku, uint8_t>()
|
||||
.DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>()
|
||||
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
size_t x = index % 9;
|
||||
size_t y = index / 9;
|
||||
|
||||
// Add row constraints (same row, different columns)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != x) constrainer.Exclude(val, i + y * 9);
|
||||
}
|
||||
|
||||
// Add column constraints (same column, different rows)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != y) constrainer.Exclude(val,x + i * 9);
|
||||
}
|
||||
|
||||
// Add box constraints (3x3 box)
|
||||
size_t box_x = (x / 3) * 3;
|
||||
size_t box_y = (y / 3) * 3;
|
||||
for (size_t j = 0; j < 3; ++j) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
size_t cell_x = box_x + j;
|
||||
size_t cell_y = box_y + k;
|
||||
size_t cell_index = cell_x + cell_y * 9;
|
||||
if (cell_index != index) {
|
||||
constrainer.Exclude(val, cell_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.build();
|
||||
|
||||
Sudoku sudokuWorld;
|
||||
sudokuWorld.setValue(0, 5);
|
||||
sudokuWorld.setValue(80, 1);
|
||||
bool success = sudokuSolver.Run(sudokuWorld, true);
|
||||
bool success = SudokuSolver::Run(sudokuWorld, true);
|
||||
|
||||
if (success) {
|
||||
std::cout << "Sudoku solved successfully!" << std::endl;
|
||||
|
||||
@@ -8,38 +8,6 @@
|
||||
// Forward declaration for helper function
|
||||
std::vector<Sudoku> loadPuzzlesFromFile(const std::string& filename);
|
||||
|
||||
static auto sudokuTestSolver = WFC::Builder<Sudoku, uint8_t>()
|
||||
.DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>()
|
||||
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
size_t x = index % 9;
|
||||
size_t y = index / 9;
|
||||
|
||||
// Add row constraints (same row, different columns)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != x) constrainer.Exclude(val, i + y * 9);
|
||||
}
|
||||
|
||||
// Add column constraints (same column, different rows)
|
||||
for (size_t i = 0; i < 9; ++i) {
|
||||
if (i != y) constrainer.Exclude(val,x + i * 9);
|
||||
}
|
||||
|
||||
// Add box constraints (3x3 box)
|
||||
size_t box_x = (x / 3) * 3;
|
||||
size_t box_y = (y / 3) * 3;
|
||||
for (size_t j = 0; j < 3; ++j) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
size_t cell_x = box_x + j;
|
||||
size_t cell_y = box_y + k;
|
||||
size_t cell_index = cell_x + cell_y * 9;
|
||||
if (cell_index != index) {
|
||||
constrainer.Exclude(val, cell_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.build();
|
||||
|
||||
// Test fixture for Sudoku tests
|
||||
class SudokuTest : public ::testing::Test {
|
||||
protected:
|
||||
@@ -69,7 +37,7 @@ protected:
|
||||
}
|
||||
|
||||
Sudoku SolvePuzzle(Sudoku& sudoku) {
|
||||
sudokuTestSolver.Run(sudoku, true);
|
||||
SudokuSolver::Run(sudoku, true);
|
||||
return sudoku;
|
||||
}
|
||||
};
|
||||
@@ -307,7 +275,7 @@ TEST_F(SudokuTest, EdgeCases) {
|
||||
TEST_F(SudokuTest, WFCIntegration)
|
||||
{
|
||||
auto sudoku = createEasyPuzzle();
|
||||
sudokuTestSolver.Run(sudoku, true);
|
||||
SudokuSolver::Run(sudoku, true);
|
||||
EXPECT_TRUE(sudoku.isSolved());
|
||||
}
|
||||
|
||||
@@ -326,7 +294,7 @@ TEST_F(SudokuTest, LoadAndSolveEasyPuzzles)
|
||||
EXPECT_TRUE(sudoku.isValid()) << "Puzzle " << i << " is not valid";
|
||||
|
||||
auto puzzleStart = std::chrono::high_resolution_clock::now();
|
||||
sudokuTestSolver.Run(sudoku, true);
|
||||
SudokuSolver::Run(sudoku, true);
|
||||
auto puzzleEnd = std::chrono::high_resolution_clock::now();
|
||||
|
||||
EXPECT_TRUE(sudoku.isSolved()) << "Puzzle " << i << " was not solved";
|
||||
@@ -364,7 +332,7 @@ TEST_F(SudokuTest, LoadAndSolveMediumPuzzles)
|
||||
EXPECT_TRUE(sudoku.isValid()) << "Puzzle " << i << " is not valid";
|
||||
|
||||
auto puzzleStart = std::chrono::high_resolution_clock::now();
|
||||
sudokuTestSolver.Run(sudoku, true);
|
||||
SudokuSolver::Run(sudoku, true);
|
||||
auto puzzleEnd = std::chrono::high_resolution_clock::now();
|
||||
|
||||
EXPECT_TRUE(sudoku.isSolved()) << "Puzzle " << i << " was not solved";
|
||||
@@ -402,7 +370,7 @@ TEST_F(SudokuTest, LoadAndSolveHardPuzzles)
|
||||
EXPECT_TRUE(sudoku.isValid()) << "Puzzle " << i << " is not valid";
|
||||
|
||||
auto puzzleStart = std::chrono::high_resolution_clock::now();
|
||||
sudokuTestSolver.Run(sudoku, true);
|
||||
SudokuSolver::Run(sudoku, true);
|
||||
auto puzzleEnd = std::chrono::high_resolution_clock::now();
|
||||
|
||||
EXPECT_TRUE(sudoku.isSolved()) << "Puzzle " << i << " was not solved";
|
||||
@@ -440,7 +408,7 @@ TEST_F(SudokuTest, LoadAndSolveDiabolicalPuzzles)
|
||||
EXPECT_TRUE(sudoku.isValid()) << "Puzzle " << i << " is not valid";
|
||||
|
||||
auto puzzleStart = std::chrono::high_resolution_clock::now();
|
||||
sudokuTestSolver.Run(sudoku, true);
|
||||
SudokuSolver::Run(sudoku, true);
|
||||
auto puzzleEnd = std::chrono::high_resolution_clock::now();
|
||||
|
||||
EXPECT_TRUE(sudoku.isSolved()) << "Puzzle " << i << " was not solved";
|
||||
@@ -491,7 +459,7 @@ TEST_F(SudokuTest, LoadAndSolveFirstPuzzleFromEachFile)
|
||||
EXPECT_TRUE(puzzle.isValid()) << "Loaded puzzle from " << filename << " is not valid";
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
sudokuTestSolver.Run(puzzle, true);
|
||||
SudokuSolver::Run(puzzle, true);
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
EXPECT_TRUE(puzzle.isSolved()) << "Failed to solve first puzzle from " << filename;
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#include <algorithm>
|
||||
#include <concepts>
|
||||
#include <bit>
|
||||
#include <span>
|
||||
#include <tuple>
|
||||
|
||||
#include "wfc_allocator.hpp"
|
||||
|
||||
@@ -37,17 +39,6 @@ concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
|
||||
typename T::ValueType;
|
||||
};
|
||||
|
||||
template <typename MaskType>
|
||||
class Wave;
|
||||
template <typename VariableIDMapT>
|
||||
class Constrainer;
|
||||
template<typename WorldT, typename VarT, typename VariableIDMapT>
|
||||
class WFC;
|
||||
template<typename WorldT, typename VarT, typename VariableIDMapT>
|
||||
class Variable;
|
||||
template <typename VarT>
|
||||
struct WorldValue;
|
||||
|
||||
/**
|
||||
* @brief Class to map variable values to indices at compile time
|
||||
*
|
||||
@@ -104,26 +95,89 @@ public:
|
||||
return static_cast<size_t>(-1); // This line is unreachable if value is found
|
||||
}
|
||||
|
||||
static VarT GetValue(size_t index) {
|
||||
assert(index < ValuesRegisteredAmount);
|
||||
constexpr VarT arr[] = {Values...};
|
||||
return arr[index];
|
||||
}
|
||||
|
||||
template <VarT ... MaskValues>
|
||||
static consteval MaskType GetMask()
|
||||
{
|
||||
return (0 | ... | (1 << GetIndex<MaskValues>()));
|
||||
}
|
||||
|
||||
static consteval std::array<VarT, ValuesRegisteredAmount> GetAllValues()
|
||||
static std::span<const VarT> GetAllValues()
|
||||
{
|
||||
return {Values...};
|
||||
static const VarT allValues[]
|
||||
{
|
||||
Values...
|
||||
};
|
||||
return std::span<const VarT>{ allValues, ValuesRegisteredAmount };
|
||||
}
|
||||
|
||||
static VarT GetValue(size_t index) {
|
||||
assert(index < ValuesRegisteredAmount);
|
||||
return GetAllValues()[index];
|
||||
}
|
||||
|
||||
static consteval VarT GetValueConsteval(size_t index)
|
||||
{
|
||||
constexpr VarT arr[] = {Values...};
|
||||
return arr[index];
|
||||
}
|
||||
|
||||
static consteval size_t size() { return ValuesRegisteredAmount; }
|
||||
};
|
||||
|
||||
template <typename ... ConstrainerFunctions>
|
||||
struct ConstrainerFunctionMap {
|
||||
public:
|
||||
static consteval size_t size() { return sizeof...(ConstrainerFunctions); }
|
||||
|
||||
using TupleType = std::tuple<ConstrainerFunctions...>;
|
||||
|
||||
template <typename ConstrainerFunctionPtrT>
|
||||
static ConstrainerFunctionPtrT GetFunction(size_t index)
|
||||
{
|
||||
static_assert((std::is_empty_v<ConstrainerFunctions> && ...), "Lambdas must not have any captures");
|
||||
static ConstrainerFunctionPtrT functions[] = {
|
||||
static_cast<ConstrainerFunctionPtrT>(ConstrainerFunctions{}) ...
|
||||
};
|
||||
return functions[index];
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to select the correct constrainer function based on the index and the value
|
||||
template<std::size_t I,
|
||||
typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerElementSelector = std::conditional_t<
|
||||
(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple
|
||||
std::conditional_t<SelectedIDsVariableIDMapT::template HasValue<VariableIDMapT::GetValueConsteval(I)>(), // if the value is in the selected IDs
|
||||
NewConstrainerFunctionT,
|
||||
std::tuple_element_t<std::min(I, ConstrainerFunctionMapT::size() - 1), typename ConstrainerFunctionMapT::TupleType>
|
||||
>,
|
||||
EmptyFunctionT
|
||||
>;
|
||||
|
||||
// Helper to make a merged constrainer function map
|
||||
template<typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT,
|
||||
std::size_t... Is>
|
||||
auto MakeMergedConstrainerIDMap(std::index_sequence<Is...>,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*)
|
||||
-> ConstrainerFunctionMap<MergedConstrainerElementSelector<Is, VariableIDMapT, ConstrainerFunctionMapT, NewConstrainerFunctionT, SelectedIDsVariableIDMapT, EmptyFunctionT>...>;
|
||||
|
||||
// Main alias for the merged constrainer function map
|
||||
template<typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerFunctionMap = decltype(
|
||||
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::ValuesRegisteredAmount>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
|
||||
);
|
||||
|
||||
template <typename VarT>
|
||||
struct WorldValue
|
||||
{
|
||||
@@ -246,7 +300,7 @@ struct VariableData {
|
||||
/**
|
||||
* @brief Main WFC class implementing the Wave Function Collapse algorithm
|
||||
*/
|
||||
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>>
|
||||
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>>
|
||||
class WFC {
|
||||
public:
|
||||
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
|
||||
@@ -275,18 +329,16 @@ public:
|
||||
};
|
||||
|
||||
public:
|
||||
WFC(std::vector<VariableData<WorldT, VarT, VariableIDMapT>>&& variables)
|
||||
: m_variables(std::move(variables))
|
||||
{}
|
||||
WFC() = delete;
|
||||
|
||||
public:
|
||||
bool Run(WorldT& world, bool propagateInitialValues = false)
|
||||
static bool Run(WorldT& world, bool propagateInitialValues = true)
|
||||
{
|
||||
WFCStackAllocator allocator{};
|
||||
// std::mt19937 random{ std::random_device{}() }; // Using random values fails 25% of the time
|
||||
std::mt19937 random{ 212 };
|
||||
size_t iterations = 0;
|
||||
SolverState state(world, m_variables.size(), random, allocator, iterations);
|
||||
SolverState state(world, ConstrainerFunctionMapT::size(), random, allocator, iterations);
|
||||
return Run(state, propagateInitialValues);
|
||||
}
|
||||
|
||||
@@ -294,7 +346,7 @@ public:
|
||||
* @brief Run the WFC algorithm to generate a solution
|
||||
* @return true if a solution was found, false if contradiction occurred
|
||||
*/
|
||||
bool Run(SolverState& state, bool propagateInitialValues = false)
|
||||
static bool Run(SolverState& state, bool propagateInitialValues = true)
|
||||
{
|
||||
if (propagateInitialValues)
|
||||
{
|
||||
@@ -309,7 +361,7 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RunLoop(SolverState& state)
|
||||
static bool RunLoop(SolverState& state)
|
||||
{
|
||||
for (; state.iterations < 256; ++state.iterations)
|
||||
{
|
||||
@@ -330,7 +382,7 @@ public:
|
||||
* @param cellId The cell ID
|
||||
* @return The value if collapsed, std::nullopt otherwise
|
||||
*/
|
||||
std::optional<VarT> GetValue(SolverState& state, int cellId) const {
|
||||
static std::optional<VarT> GetValue(SolverState& state, int cellId) {
|
||||
if (state.wave.IsCollapsed(cellId)) {
|
||||
auto variableId = state.wave.GetVariableID(cellId);
|
||||
return VariableIDMapT::GetValue(variableId);
|
||||
@@ -343,18 +395,18 @@ public:
|
||||
* @param cellId The cell ID
|
||||
* @return Set of possible values
|
||||
*/
|
||||
const std::vector<VarT> GetPossibleValues(SolverState& state, int cellId) const
|
||||
static const std::vector<VarT> GetPossibleValues(SolverState& state, int cellId)
|
||||
{
|
||||
std::vector<VarT> possibleValues;
|
||||
MaskType mask = state.wave.GetMask(cellId);
|
||||
for (size_t i = 0; i < m_variables.size(); ++i) {
|
||||
for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) {
|
||||
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
|
||||
}
|
||||
return possibleValues;
|
||||
}
|
||||
|
||||
private:
|
||||
bool Branch(SolverState& state)
|
||||
static bool Branch(SolverState& state)
|
||||
{
|
||||
assert(state.propagationQueue.empty());
|
||||
|
||||
@@ -422,7 +474,7 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Propagate(SolverState& state)
|
||||
static bool Propagate(SolverState& state)
|
||||
{
|
||||
while (!state.propagationQueue.empty())
|
||||
{
|
||||
@@ -435,12 +487,15 @@ private:
|
||||
|
||||
uint16_t variableID = state.wave.GetVariableID(cellId);
|
||||
Constrainer<VariableIDMapT> constrainer(state.wave, state.propagationQueue);
|
||||
m_variables[variableID].constraintFunc(state.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
|
||||
|
||||
using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&);
|
||||
|
||||
ConstrainerFunctionMapT::template GetFunction<ConstrainerFunctionPtrT>(variableID)(state.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PopulateWorld(SolverState& state)
|
||||
static void PopulateWorld(SolverState& state)
|
||||
{
|
||||
for (size_t i = 0; i < state.wave.size(); ++i)
|
||||
{
|
||||
@@ -449,7 +504,7 @@ private:
|
||||
}
|
||||
|
||||
private:
|
||||
void PropogateInitialValues(SolverState& state)
|
||||
static void PropogateInitialValues(SolverState& state)
|
||||
{
|
||||
auto allValues = VariableIDMapT::GetAllValues();
|
||||
for (size_t i = 0; i < state.wave.size(); ++i)
|
||||
@@ -465,69 +520,30 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables {};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Builder class for creating WFC instances
|
||||
*/
|
||||
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>>
|
||||
template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>>
|
||||
class Builder {
|
||||
public:
|
||||
Builder() = default;
|
||||
Builder(std::vector<VariableData<WorldT, VarT, VariableIDMapT>>&& variables)
|
||||
: m_variables(std::move(variables))
|
||||
{}
|
||||
|
||||
public:
|
||||
template <VarT ... Values>
|
||||
auto DefineIDs()
|
||||
{
|
||||
using NewVariableIDMapT = typename VariableIDMapT::template Merge<Values...>;
|
||||
// reinterpret_cast is used to be able to move the variables with an outdated VariableIDMap to the new VariableIDMap. The previous indices still work.
|
||||
return Builder<WorldT, VarT, NewVariableIDMapT>(std::move(reinterpret_cast<std::vector<VariableData<WorldT, VarT, NewVariableIDMapT>>&>(m_variables)));
|
||||
}
|
||||
using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT>;
|
||||
|
||||
/**
|
||||
* @brief Add a variable with its constraint function
|
||||
* @param value The variable value
|
||||
* @param constraintFunc Function that defines constraints when this variable is placed
|
||||
* @return Reference to this builder for method chaining
|
||||
*/
|
||||
template <VarT ... Values>
|
||||
Builder& Variable(const std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc) {
|
||||
m_variables.resize(VariableIDMapT::ValuesRegisteredAmount);
|
||||
template <typename ConstrainerFunctionT, VarT ... CorrespondingValues>
|
||||
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT,
|
||||
MergedConstrainerFunctionMap<
|
||||
VariableIDMapT,
|
||||
ConstrainerFunctionMapT,
|
||||
ConstrainerFunctionT,
|
||||
VariableIDMap<VarT, CorrespondingValues...>,
|
||||
decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {})
|
||||
>
|
||||
>;
|
||||
|
||||
Variable_Internal<Values...>(constraintFunc);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Build the WFC instance
|
||||
* @param world The world instance to work with
|
||||
* @return A unique_ptr to the created WFC instance
|
||||
*/
|
||||
auto build() {
|
||||
return WFC<WorldT, VarT, VariableIDMapT>(std::move(m_variables));
|
||||
}
|
||||
|
||||
private:
|
||||
template <VarT Value, VarT ... Values>
|
||||
void Variable_Internal(const std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc)
|
||||
{
|
||||
m_variables[VariableIDMapT::template GetIndex<Value>()] = VariableData<WorldT, VarT, VariableIDMapT>{
|
||||
Value,
|
||||
constraintFunc
|
||||
};
|
||||
if constexpr (sizeof...(Values) > 0) {
|
||||
Variable_Internal<Values...>(constraintFunc);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables;
|
||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT>;
|
||||
};
|
||||
|
||||
} // namespace WFC
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
void* allocate_aligned_memory(size_t alignment, size_t size) {
|
||||
inline void* allocate_aligned_memory(size_t alignment, size_t size) {
|
||||
void* ptr = nullptr;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -32,7 +32,7 @@ void* allocate_aligned_memory(size_t alignment, size_t size) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void free_aligned_memory(void* ptr) {
|
||||
inline void free_aligned_memory(void* ptr) {
|
||||
#ifdef _MSC_VER
|
||||
_aligned_free(ptr);
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
|
||||
Reference in New Issue
Block a user