From 34a246bac4e00f99fc991f1461467ff0c1c899b5 Mon Sep 17 00:00:00 2001 From: Connor De Meyer Date: Thu, 4 Sep 2025 09:51:32 +0900 Subject: [PATCH 1/3] prompt description --- prompts/10-bit-container | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 prompts/10-bit-container diff --git a/prompts/10-bit-container b/prompts/10-bit-container new file mode 100644 index 0000000..4532a9a --- /dev/null +++ b/prompts/10-bit-container @@ -0,0 +1,10 @@ +We're using c++20, I want you to create a file wfc_bit_container.hpp in include\nd-wfc which contains a templated class whose goal is to minimize the amount of bits we use for masking. The template type should be a number (The number of bits stored) and should have the Size of the container (how many of these integers should we store). The size is 0 by default, which means it's resizable and it should use an std::vector instead of an std::array. +examples of given parameters: +- Bits: 2; Size: 8; -> 2 requires 2 bits, std::array +- Bits: 3; size: 4; -> 3 requires 4 bits, std::array +- Bits: 0; size 128; -> 0*128 == 0; std::array +- Bits: 32; size: 100; -> std::array +- Bits: 256; size: 0; -> std::vector +Make sure that the amount of bits used is a power of 2: 0, 1, 2, 4, 8, 16, 32, 64. If more than 64 bits are required, make sure it is a multiple of 64: 64, 128, 196, 256, etc. +We should be able to get/set values with an index, do bit operations like |, &, ^ on individual elements, do std::countl_zero, std::countl_one, std::countr_zero, std::countr_one & std::popcount. +This repo tries to be as optimized as possible. Make good use of `asserts` instead of `if conditions` wherever you are checking something. \ No newline at end of file From d9a83f88223f789f52e5da81781c77fede3a3efc Mon Sep 17 00:00:00 2001 From: cdemeyer-teachx Date: Sun, 7 Sep 2025 15:27:22 +0900 Subject: [PATCH 2/3] Initial changes --- include/nd-wfc/wfc.h | 1 + include/nd-wfc/wfc.hpp | 572 ++------------------------ include/nd-wfc/wfc_allocator.hpp | 2 + include/nd-wfc/wfc_bit_container.hpp | 172 ++++++++ include/nd-wfc/wfc_builder.hpp | 52 +++ include/nd-wfc/wfc_callbacks.hpp | 44 ++ include/nd-wfc/wfc_constrainer.hpp | 122 ++++++ include/nd-wfc/wfc_large_integers.hpp | 22 + include/nd-wfc/wfc_random.hpp | 178 ++++++++ include/nd-wfc/wfc_utils.hpp | 27 ++ include/nd-wfc/wfc_variable_map.hpp | 73 ++++ include/nd-wfc/wfc_wave.hpp | 39 ++ 12 files changed, 767 insertions(+), 537 deletions(-) create mode 100644 include/nd-wfc/wfc_bit_container.hpp create mode 100644 include/nd-wfc/wfc_builder.hpp create mode 100644 include/nd-wfc/wfc_callbacks.hpp create mode 100644 include/nd-wfc/wfc_constrainer.hpp create mode 100644 include/nd-wfc/wfc_large_integers.hpp create mode 100644 include/nd-wfc/wfc_random.hpp create mode 100644 include/nd-wfc/wfc_utils.hpp create mode 100644 include/nd-wfc/wfc_variable_map.hpp create mode 100644 include/nd-wfc/wfc_wave.hpp diff --git a/include/nd-wfc/wfc.h b/include/nd-wfc/wfc.h index 6f49023..959b8f9 100644 --- a/include/nd-wfc/wfc.h +++ b/include/nd-wfc/wfc.h @@ -9,4 +9,5 @@ #include "wfc.hpp" #include "worlds.hpp" +#include "wfc_builder.hpp" diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index 0af8a9d..3d845e7 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -14,174 +14,17 @@ #include #include +#include "wfc_utils.hpp" +#include "wfc_variable_map.hpp" #include "wfc_allocator.hpp" +#include "wfc_bit_container.hpp" +#include "wfc_wave.hpp" +#include "wfc_constrainer.hpp" +#include "wfc_callbacks.hpp" +#include "wfc_random.hpp" namespace WFC { -inline constexpr void constexpr_assert(bool condition, const char* message = "") { - if (!condition) throw message; -} - -inline int FindNthSetBit(size_t num, int n) { - constexpr_assert(n < std::popcount(num), "index is out of range"); - int bitCount = 0; - while (num) { - if (bitCount == n) { - return std::countr_zero(num); // Index of the current set bit - } - bitCount++; - num &= (num - 1); // turn of lowest set bit - } - return bitCount; -} - -template -concept WorldType = requires(T world, size_t id, typename T::ValueType value) { - { world.size() } -> std::convertible_to; - { world.setValue(id, value) }; - { world.getValue(id) } -> std::convertible_to; - typename T::ValueType; -}; - -/** - * @brief Class to map variable values to indices at compile time - * - * This class is used to map variable values to indices at compile time. - * It is a compile-time map of variable values to indices. - */ -template -class VariableIDMap { -public: - - using Type = VarT; - static constexpr size_t ValuesRegisteredAmount = sizeof...(Values); - - using MaskType = typename std::conditional< - ValuesRegisteredAmount <= 8, - uint8_t, - typename std::conditional< - ValuesRegisteredAmount <= 16, - uint16_t, - typename std::conditional< - ValuesRegisteredAmount <= 32, - uint32_t, - uint64_t - >::type - >::type - >::type; - - template - using Merge = VariableIDMap; - - template - static consteval bool HasValue() - { - constexpr VarT arr[] = {Values...}; - constexpr size_t size = sizeof...(Values); - - for (size_t i = 0; i < size; ++i) - if (arr[i] == Value) - return true; - return false; - } - - template - static consteval size_t GetIndex() - { - static_assert(HasValue(), "Value was not defined"); - constexpr VarT arr[] = {Values...}; - constexpr size_t size = ValuesRegisteredAmount; - - for (size_t i = 0; i < size; ++i) - if (arr[i] == Value) - return i; - - return static_cast(-1); // This line is unreachable if value is found - } - - template - static consteval MaskType GetMask() - { - return (0 | ... | (1 << GetIndex())); - } - - static std::span GetAllValues() - { - static const VarT allValues[] - { - Values... - }; - return std::span{ allValues, ValuesRegisteredAmount }; - } - - static VarT GetValue(size_t index) { - constexpr_assert(index < ValuesRegisteredAmount); - return GetAllValues()[index]; - } - - static consteval VarT GetValueConsteval(size_t index) - { - constexpr VarT arr[] = {Values...}; - return arr[index]; - } - - static consteval size_t size() { return ValuesRegisteredAmount; } -}; - -template -struct ConstrainerFunctionMap { -public: - static consteval size_t size() { return sizeof...(ConstrainerFunctions); } - - using TupleType = std::tuple; - - template - static ConstrainerFunctionPtrT GetFunction(size_t index) - { - static_assert((std::is_empty_v && ...), "Lambdas must not have any captures"); - static ConstrainerFunctionPtrT functions[] = { - static_cast(ConstrainerFunctions{}) ... - }; - return functions[index]; - } -}; - -// Helper to select the correct constrainer function based on the index and the value -template -using MergedConstrainerElementSelector = - std::conditional_t(), // if the value is in the selected IDs - NewConstrainerFunctionT, - std::conditional_t<(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple - std::tuple_element_t, - EmptyFunctionT - > - >; - -// Helper to make a merged constrainer function map -template -auto MakeMergedConstrainerIDMap(std::index_sequence,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*) - -> ConstrainerFunctionMap...>; - -// Main alias for the merged constrainer function map -template -using MergedConstrainerFunctionMap = decltype( - MakeMergedConstrainerIDMap(std::make_index_sequence{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr) -); - template struct WorldValue { @@ -199,150 +42,37 @@ public: uint16_t InternalIndex{}; }; -template -class Wave { -public: - Wave() = default; - Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter(allocator)) - { - for (auto& wave : m_data) wave = (1 << variableAmount) - 1; - } - - Wave(const Wave& other) = default; - -public: - void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; } - size_t size() const { return m_data.size(); } - size_t Entropy(size_t index) const { return std::popcount(m_data[index]); } - bool IsCollapsed(size_t index) const { return Entropy(index) == 1; } - bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](MaskType value) { return std::popcount(value) == 1; }); } - bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](MaskType value) { return value == 0; }); } - bool IsContradicted(size_t index) const { return m_data[index] == 0; } - uint16_t GetVariableID(size_t index) const { return static_cast(std::countr_zero(m_data[index])); } - MaskType GetMask(size_t index) const { return m_data[index]; } - -private: - WFCVector m_data; +template +concept WorldType = requires(T world, size_t id, typename T::ValueType value) { + { world.size() } -> std::convertible_to; + { world.setValue(id, value) }; + { world.getValue(id) } -> std::convertible_to; + typename T::ValueType; }; /** - * @brief Constrainer class used in constraint functions to limit possible values for other cells - */ -template -class Constrainer { -public: - using MaskType = typename VariableIDMapT::MaskType; - -public: - Constrainer(Wave& wave, WFCQueue& propagationQueue) - : m_wave(wave) - , m_propagationQueue(propagationQueue) - {} - - /** - * @brief Constrain a cell to exclude specific values - * @param cellId The ID of the cell to constrain - * @param forbiddenValues The set of forbidden values for this cell - */ - template - void Exclude(size_t cellId) { - static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided"); - ApplyMask(cellId, ~VariableIDMapT::template GetMask()); - } - - void Exclude(WorldValue value, size_t cellId) { - ApplyMask(cellId, ~(1 << value.InternalIndex)); - } - - /** - * @brief Constrain a cell to only allow one specific value - * @param cellId The ID of the cell to constrain - * @param value The only allowed value for this cell - */ - template - void Only(size_t cellId) { - static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided"); - ApplyMask(cellId, VariableIDMapT::template GetMask()); - } - - void Only(WorldValue value, size_t cellId) { - ApplyMask(cellId, 1 << value.InternalIndex); - } - -private: - void ApplyMask(size_t cellId, MaskType mask) { - bool wasCollapsed = m_wave.IsCollapsed(cellId); - - m_wave.Collapse(cellId, mask); - - bool collapsed = m_wave.IsCollapsed(cellId); - if (!wasCollapsed && collapsed) { - m_propagationQueue.push(cellId); - } - } - -private: - Wave& m_wave; - WFCQueue& m_propagationQueue; +* @brief Concept to validate constrainer function signature +* The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) +*/ +template +concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { + func(world, index, value, constrainer); }; /** - * @brief Variable definition with its constraint function - */ -template -struct VariableData { - VarT value{}; - std::function, Constrainer&)> constraintFunc{}; - - VariableData() = default; - VariableData(VarT value, std::function, Constrainer&)> constraintFunc) - : value(value) - , constraintFunc(constraintFunc) - {} -}; - -/** - * @brief Empty callback function - * @param World The world type - */ -template -using EmptyCallback = decltype([](World&){}); - -/** - * @brief Callback struct - * @param WorldT The world type - * @param AllCellsCollapsedCallbackT The all cells collapsed callback type - * @param CellCollapsedCallbackT The cell collapsed callback type - * @param ContradictionCallbackT The contradiction callback type - * @param BranchCallbackT The branch callback type - */ -template , - typename ContradictionCallbackT = EmptyCallback, - typename BranchCallbackT = EmptyCallback -> -struct Callbacks -{ - using CellCollapsedCallback = CellCollapsedCallbackT; - using ContradictionCallback = ContradictionCallbackT; - using BranchCallback = BranchCallbackT; - - template - using SetCellCollapsedCallbackT = Callbacks; - template - using SetContradictionCallbackT = Callbacks; - template - using SetBranchCallbackT = Callbacks; - - static consteval bool HasCellCollapsedCallback() { return !std::is_same_v>; } - static consteval bool HasContradictionCallback() { return !std::is_same_v>; } - static consteval bool HasBranchCallback() { return !std::is_same_v>; } +* @brief Concept to validate random selector function signature +* The function must be callable with parameters: (std::span) and return size_t +*/ +template +concept RandomSelectorFunction = requires(T func, std::span possibleValues) { + { func(possibleValues) } -> std::convertible_to; + { func.rng(static_cast(1)) } -> std::convertible_to; }; /** * @brief Main WFC class implementing the Wave Function Collapse algorithm */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, @@ -351,14 +81,14 @@ class WFC { public: static_assert(WorldType, "WorldT must satisfy World type requirements"); - using MaskType = typename VariableIDMapT::MaskType; + using ElementT = typename VariableIDMapT::ElementT; public: struct SolverState { WorldT& world; WFCQueue propagationQueue; - Wave wave; + Wave wave; std::mt19937& rng; RandomSelectorT& randomSelector; WFCStackAllocator& allocator; @@ -479,7 +209,7 @@ public: static const std::vector GetPossibleValues(SolverState& state, int cellId) { std::vector possibleValues; - MaskType mask = state.wave.GetMask(cellId); + ElementT mask = state.wave.GetMask(cellId); for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); } @@ -489,7 +219,7 @@ public: private: static void CollapseCell(SolverState& state, size_t cellId, uint16_t value) { - constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (1 << value)); + constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (ElementT(1) << value)); state.wave.Collapse(cellId, 1 << value); constexpr_assert(state.wave.IsCollapsed(cellId)); @@ -522,14 +252,14 @@ private: // create a list of possible values uint16_t availableValues = static_cast(state.wave.Entropy(minEntropyCell)); std::array possibleValues; // inplace vector - MaskType mask = state.wave.GetMask(minEntropyCell); + ElementT mask = state.wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds"); possibleValues[i] = index; - constexpr_assert(((mask & (1 << index)) != 0), "Possible value was not set"); + constexpr_assert(((mask & (ElementT(1) << index)) != 0), "Possible value was not set"); mask = mask & (mask - 1); // turn off lowest set bit } @@ -563,9 +293,9 @@ private: } // remove the failure state from the wave - constexpr_assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) != 0, "Possible value was not set"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) != 0, "Possible value was not set"); state.wave.Collapse(minEntropyCell, ~(1 << selectedValue)); - constexpr_assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0, "Wave was not collapsed correctly"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); // swap replacement value with the last value std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); @@ -622,236 +352,4 @@ private: } }; -/** - * @brief Concept to validate constrainer function signature - * The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) - */ -template -concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { - func(world, index, value, constrainer); -}; - -/** - * @brief Concept to validate random selector function signature - * The function must be callable with parameters: (std::span) and return size_t - */ -template -concept RandomSelectorFunction = requires(T func, std::span possibleValues) { - { func(possibleValues) } -> std::convertible_to; - { func.rng(static_cast(1)) } -> std::convertible_to; -}; - -/** - * @brief Default constexpr random selector using a simple seed-based algorithm - * This provides a compile-time random selection that maintains state between calls - */ -template -class DefaultRandomSelector { -private: - mutable uint32_t m_seed; - -public: - constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {} - - constexpr size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - - // Simple linear congruential generator for constexpr compatibility - return static_cast(rng(possibleValues.size())); - } - - constexpr uint32_t rng(uint32_t max) { - m_seed = m_seed * 1103515245 + 12345; - return m_seed % max; - } -}; - -/** - * @brief Advanced random selector using std::mt19937 and std::uniform_int_distribution - * This provides high-quality randomization for runtime use - */ -template -class AdvancedRandomSelector { -private: - std::mt19937& m_rng; - -public: - explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {} - - size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - - return rng(possibleValues.size()); - } - - uint32_t rng(uint32_t max) { - std::uniform_int_distribution dist(0, max); - return dist(m_rng); - } -}; - -/** - * @brief Weight specification for a specific value - * @tparam Value The variable value - * @tparam Weight The 16-bit weight for this value - */ - template - struct Weight { - static constexpr VarT value = Value; - static constexpr uint16_t weight = WeightValue; - }; - - /** - * @brief Compile-time weights storage for weighted random selection - * @tparam VarT The variable type - * @tparam VariableIDMapT The variable ID map type - * @tparam DefaultWeight The default weight for values not explicitly specified - * @tparam WeightSpecs Variadic template parameters of Weight specifications - */ - template - class WeightsMap { - private: - static constexpr size_t NumWeights = sizeof...(WeightSpecs); - - // Helper to get weight for a specific value - static consteval uint16_t GetWeightForValue(VarT targetValue) { - // Check each weight spec to find the target value - uint16_t weight = DefaultWeight; - ((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...); - return weight; - } - - public: - /** - * @brief Get the weight for a specific value at compile time - * @tparam TargetValue The value to get weight for - * @return The weight for the value - */ - template - static consteval uint16_t GetWeight() { - return GetWeightForValue(TargetValue); - } - - /** - * @brief Get weights array for all registered values - * @return Array of weights corresponding to all registered values - */ - static consteval std::array GetWeightsArray() { - std::array weights{}; - - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i)); - } - - return weights; - } - - static consteval uint32_t GetTotalWeight() { - uint32_t totalWeight = 0; - auto weights = GetWeightsArray(); - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - totalWeight += weights[i]; - } - return totalWeight; - } - - static consteval std::array GetCumulativeWeightsArray() { - auto weights = GetWeightsArray(); - uint32_t totalWeight = 0; - std::array cumulativeWeights{}; - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - totalWeight += weights[i]; - cumulativeWeights[i] = totalWeight; - } - return cumulativeWeights; - } - }; - - /** - * @brief Weighted random selector that uses another random selector as backend - * @tparam VarT The variable type - * @tparam VariableIDMapT The variable ID map type - * @tparam BackendSelectorT The backend random selector type - * @tparam WeightsMapT The weights map type containing weight specifications - */ - template - class WeightedSelector { - private: - BackendSelectorT m_backendSelector; - const std::array m_weights; - const std::array m_cumulativeWeights; - - public: - explicit WeightedSelector(BackendSelectorT backendSelector) - : m_backendSelector(backendSelector) - , m_weights(WeightsMapT::GetWeightsArray()) - , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) - {} - - explicit WeightedSelector(uint32_t seed) - requires std::is_same_v> - : m_backendSelector(seed) - , m_weights(WeightsMapT::GetWeightsArray()) - , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) - {} - - size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value"); - - // Use backend selector to pick a random number in range [0, totalWeight) - uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back()); - - // Find which value this random value corresponds to - for (size_t i = 0; i < possibleValues.size(); ++i) { - if (randomValue <= m_cumulativeWeights[i]) { - return i; - } - } - - // Fallback (should not reach here) - return possibleValues.size() - 1; - } - }; - -/** - * @brief Builder class for creating WFC instances - */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector> -class Builder { -public: - - template - using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; - - template - requires ConstrainerFunction - using DefineConstrainer = Builder, - decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) - >, CallbacksT, RandomSelectorT - >; - - template - using SetCellCollapsedCallback = Builder, RandomSelectorT>; - template - using SetContradictionCallback = Builder, RandomSelectorT>; - template - using SetBranchCallback = Builder, RandomSelectorT>; - - template - requires RandomSelectorFunction - using SetRandomSelector = Builder; - - template - using Weights = Builder>>; - - - using Build = WFC; -}; - } // namespace WFC diff --git a/include/nd-wfc/wfc_allocator.hpp b/include/nd-wfc/wfc_allocator.hpp index d3ee329..7fcd187 100644 --- a/include/nd-wfc/wfc_allocator.hpp +++ b/include/nd-wfc/wfc_allocator.hpp @@ -10,6 +10,8 @@ #include #include +#include "wfc_utils.hpp" + #define WFC_USE_STACK_ALLOCATOR inline void* allocate_aligned_memory(size_t alignment, size_t size) { diff --git a/include/nd-wfc/wfc_bit_container.hpp b/include/nd-wfc/wfc_bit_container.hpp new file mode 100644 index 0000000..60b322f --- /dev/null +++ b/include/nd-wfc/wfc_bit_container.hpp @@ -0,0 +1,172 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "wfc_utils.hpp" +#include "wfc_allocator.hpp" + +namespace WFC { + +namespace detail { + // Helper to determine the optimal storage type based on bits needed + template + struct OptimalStorageType { + static constexpr size_t bits_needed = Bits == 0 ? 0 : + (Bits <= 1) ? 1 : + (Bits <= 2) ? 2 : + (Bits <= 4) ? 4 : + (Bits <= 8) ? 8 : + (Bits <= 16) ? 16 : + (Bits <= 32) ? 32 : + (Bits <= 64) ? 64 : + ((Bits + 63) / 64) * 64; // Round up to multiple of 64 for >64 bits + + using type = std::conditional_t>>; + }; + + // Helper for multi-element storage (>64 bits) + template + struct StorageArray { + static constexpr size_t StorageBits = OptimalStorageType::bits_needed; + static constexpr size_t ArraySize = StorageBits > 64 ? (StorageBits / 64) : 1; + using element_type = std::conditional_t::type, uint64_t>; + using type = std::conditional_t>; + }; + + struct Empty{}; +} + +template::type>> +class BitContainer : private AllocatorT{ +public: + using StorageInfo = detail::OptimalStorageType; + using StorageArrayInfo = detail::StorageArray; + using StorageType = typename StorageArrayInfo::type; + using AllocatorType = AllocatorT; + + static constexpr size_t BitsPerElement = Bits; + static constexpr size_t StorageBits = StorageInfo::bits_needed; + static constexpr bool IsResizable = (Size == 0); + static constexpr bool IsMultiElement = (StorageBits > 64); + static constexpr bool IsSubByte = (StorageBits < 8); + static constexpr bool IsDefaultByteLayout = !IsMultiElement && !IsSubByte; + static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max(1u, StorageBits); + + using ContainerType = + std::conditional_t, + std::array>>; + +private: + ContainerType m_container; + +private: + // Mask for extracting bits + static constexpr auto get_Mask() + { + if constexpr (BitsPerElement == 0) + { + return uint64_t{0}; + } + else if constexpr (BitsPerElement >= 64) + { + return ~uint64_t{0}; + } + else + { + return (uint64_t{1} << BitsPerElement) - 1; + } + } + + static constexpr uint64_t Mask = get_Mask(); + +public: + + BitContainer() = default; + BitContainer(AllocatorT& allocator) : AllocatorT(allocator) {}; + explicit BitContainer(size_t initial_size, AllocatorT& allocator) requires (IsResizable) + : AllocatorT(allocator) + , m_container(initial_size, allocator) + {}; + explicit BitContainer(size_t, AllocatorT& allocator) requires (!IsResizable) + : AllocatorT(allocator) + , m_container() + {}; + +public: + // Size operations + constexpr size_t size() const noexcept + { + if constexpr (IsResizable) + { + return m_container.size(); + } + else + { + return Size; + } + } + constexpr std::span data() const { return std::span(m_container); } + constexpr std::span data() { return std::span(m_container); } + + constexpr void resize(size_t new_size) requires (IsResizable) { m_container.resize(new_size); } + constexpr void reserve(size_t capacity) requires (IsResizable) { m_container.reserve(capacity); } + +public: // Sub byte + struct SubTypeAccess + { + constexpr SubTypeAccess(uint8_t& data, uint8_t subIndex) : Data{ data }, Shift{ StorageBits * subIndex } {}; + + constexpr uint8_t GetValue() const { return ((Data >> Shift) & Mask); } + constexpr uint8_t SetValue(uint8_t val) { Clear(); return Data |= ((val & Mask) << Shift); } + constexpr void Clear() { Data &= ~Mask; } + + constexpr operator uint8_t() const { return GetValue(); } + + template constexpr uint8_t operator&=(T other) { return SetValue(GetValue() & other); } + template constexpr uint8_t operator|=(T other) { return SetValue(GetValue() | other); } + template constexpr uint8_t operator^=(T other) { return SetValue(GetValue() ^ other); } + template constexpr uint8_t operator<<=(T other) { return SetValue(GetValue() << other); } + template constexpr uint8_t operator>>=(T other) { return SetValue(GetValue() >> other); } + + uint8_t& Data; + uint8_t Shift; + }; + + constexpr const SubTypeAccess operator[](size_t index) const requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } + constexpr SubTypeAccess operator[](size_t index) requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } + +public: // MultiElement + struct MultiElementAccess + { + constexpr MultiElementAccess(StorageType& data) : Data{ data } {}; + + StorageType& Data; + }; + + constexpr const MultiElementAccess operator[](size_t index) const requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } + constexpr MultiElementAccess operator[](size_t index) requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } + +public: // default + constexpr const StorageType& operator[](size_t index) const requires(IsDefaultByteLayout) { return data()[index]; } + constexpr StorageType& operator[](size_t index) requires(IsDefaultByteLayout) { return data()[index]; } + +}; + +static_assert(BitContainer<1, 10>::ElementsPerByte == 8); +static_assert(BitContainer<2, 10>::ElementsPerByte == 4); +static_assert(BitContainer<4, 10>::ElementsPerByte == 2); +static_assert(BitContainer<8, 10>::ElementsPerByte == 1); + + +} // namespace WFC diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp new file mode 100644 index 0000000..b30697f --- /dev/null +++ b/include/nd-wfc/wfc_builder.hpp @@ -0,0 +1,52 @@ +#pragma once + +namespace WFC { + +#include "wfc_utils.hpp" +#include "wfc_variable_map.hpp" +#include "wfc_constrainer.hpp" +#include "wfc_callbacks.hpp" +#include "wfc_random.hpp" +#include "wfc.hpp" + +/** +* @brief Builder class for creating WFC instances +*/ +template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector> +class Builder { +public: + + template + using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; + + template + requires ConstrainerFunction + using DefineConstrainer = Builder, + decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) + >, CallbacksT, RandomSelectorT + >; + + template + using SetCellCollapsedCallback = Builder, RandomSelectorT>; + template + using SetContradictionCallback = Builder, RandomSelectorT>; + template + using SetBranchCallback = Builder, RandomSelectorT>; + + template + requires RandomSelectorFunction + using SetRandomSelector = Builder; + + template + using Weights = Builder>>; + + + using Build = WFC; +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_callbacks.hpp b/include/nd-wfc/wfc_callbacks.hpp new file mode 100644 index 0000000..f3a02ce --- /dev/null +++ b/include/nd-wfc/wfc_callbacks.hpp @@ -0,0 +1,44 @@ +#pragma once + +namespace WFC { + +/** +* @brief Empty callback function +* @param WorldT The world type +*/ +template +using EmptyCallback = decltype([](WorldT&){}); + +/** + * @brief Callback struct + * @param WorldT The world type + * @param AllCellsCollapsedCallbackT The all cells collapsed callback type + * @param CellCollapsedCallbackT The cell collapsed callback type + * @param ContradictionCallbackT The contradiction callback type + * @param BranchCallbackT The branch callback type + */ + template , + typename ContradictionCallbackT = EmptyCallback, + typename BranchCallbackT = EmptyCallback +> +struct Callbacks +{ + using CellCollapsedCallback = CellCollapsedCallbackT; + using ContradictionCallback = ContradictionCallbackT; + using BranchCallback = BranchCallbackT; + + template + using SetCellCollapsedCallbackT = Callbacks; + template + using SetContradictionCallbackT = Callbacks; + template + using SetBranchCallbackT = Callbacks; + + static consteval bool HasCellCollapsedCallback() { return !std::is_same_v>; } + static consteval bool HasContradictionCallback() { return !std::is_same_v>; } + static consteval bool HasBranchCallback() { return !std::is_same_v>; } +}; + + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_constrainer.hpp b/include/nd-wfc/wfc_constrainer.hpp new file mode 100644 index 0000000..100ae59 --- /dev/null +++ b/include/nd-wfc/wfc_constrainer.hpp @@ -0,0 +1,122 @@ +#pragma once + +#include "wfc_variable_map.hpp" + +namespace WFC { + +template +struct ConstrainerFunctionMap { +public: + static consteval size_t size() { return sizeof...(ConstrainerFunctions); } + + using TupleType = std::tuple; + + template + static ConstrainerFunctionPtrT GetFunction(size_t index) + { + static_assert((std::is_empty_v && ...), "Lambdas must not have any captures"); + static ConstrainerFunctionPtrT functions[] = { + static_cast(ConstrainerFunctions{}) ... + }; + return functions[index]; + } +}; + +// Helper to select the correct constrainer function based on the index and the value +template +using MergedConstrainerElementSelector = + std::conditional_t(), // if the value is in the selected IDs + NewConstrainerFunctionT, + std::conditional_t<(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple + std::tuple_element_t, + EmptyFunctionT + > + >; + +// Helper to make a merged constrainer function map +template +auto MakeMergedConstrainerIDMap(std::index_sequence,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*) + -> ConstrainerFunctionMap...>; + +// Main alias for the merged constrainer function map +template +using MergedConstrainerFunctionMap = decltype( + MakeMergedConstrainerIDMap(std::make_index_sequence{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr) +); + +/** + * @brief Constrainer class used in constraint functions to limit possible values for other cells + */ +template +class Constrainer { +public: + using MaskType = typename VariableIDMapT::MaskType; + +public: + Constrainer(Wave& wave, WFCQueue& propagationQueue) + : m_wave(wave) + , m_propagationQueue(propagationQueue) + {} + + /** + * @brief Constrain a cell to exclude specific values + * @param cellId The ID of the cell to constrain + * @param forbiddenValues The set of forbidden values for this cell + */ + template + void Exclude(size_t cellId) { + static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided"); + ApplyMask(cellId, ~VariableIDMapT::template GetMask()); + } + + void Exclude(WorldValue value, size_t cellId) { + ApplyMask(cellId, ~(1 << value.InternalIndex)); + } + + /** + * @brief Constrain a cell to only allow one specific value + * @param cellId The ID of the cell to constrain + * @param value The only allowed value for this cell + */ + template + void Only(size_t cellId) { + static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided"); + ApplyMask(cellId, VariableIDMapT::template GetMask()); + } + + void Only(WorldValue value, size_t cellId) { + ApplyMask(cellId, 1 << value.InternalIndex); + } + +private: + void ApplyMask(size_t cellId, MaskType mask) { + bool wasCollapsed = m_wave.IsCollapsed(cellId); + + m_wave.Collapse(cellId, mask); + + bool collapsed = m_wave.IsCollapsed(cellId); + if (!wasCollapsed && collapsed) { + m_propagationQueue.push(cellId); + } + } + +private: + Wave& m_wave; + WFCQueue& m_propagationQueue; +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_large_integers.hpp b/include/nd-wfc/wfc_large_integers.hpp new file mode 100644 index 0000000..eca0afa --- /dev/null +++ b/include/nd-wfc/wfc_large_integers.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace WFC { + +template +struct LargeInteger +{ + std::array m_data; + + template + constexpr LargeInteger operator+(const LargeInteger& other) const { + LargeInteger result; + for (size_t i = 0; i < std::max(Size, OtherSize); i++) { + result[i] = m_data[i] + other[i]; + } + return result; + } +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_random.hpp b/include/nd-wfc/wfc_random.hpp new file mode 100644 index 0000000..92cada7 --- /dev/null +++ b/include/nd-wfc/wfc_random.hpp @@ -0,0 +1,178 @@ +#pragma once + +namespace WFC { + +/** +* @brief Default constexpr random selector using a simple seed-based algorithm +* This provides a compile-time random selection that maintains state between calls +*/ +template +class DefaultRandomSelector { +private: + mutable uint32_t m_seed; + +public: + constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {} + + constexpr size_t operator()(std::span possibleValues) const { + constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); + + // Simple linear congruential generator for constexpr compatibility + return static_cast(rng(possibleValues.size())); + } + + constexpr uint32_t rng(uint32_t max) { + m_seed = m_seed * 1103515245 + 12345; + return m_seed % max; + } +}; + +/** +* @brief Advanced random selector using std::mt19937 and std::uniform_int_distribution +* This provides high-quality randomization for runtime use +*/ +template +class AdvancedRandomSelector { +private: + std::mt19937& m_rng; + +public: + explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {} + + size_t operator()(std::span possibleValues) const { + constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); + + return rng(possibleValues.size()); + } + + uint32_t rng(uint32_t max) { + std::uniform_int_distribution dist(0, max); + return dist(m_rng); + } +}; + +/** +* @brief Weight specification for a specific value +* @tparam Value The variable value +* @tparam Weight The 16-bit weight for this value +*/ +template +struct Weight { + static constexpr VarT value = Value; + static constexpr uint16_t weight = WeightValue; +}; + +/** +* @brief Compile-time weights storage for weighted random selection +* @tparam VarT The variable type +* @tparam VariableIDMapT The variable ID map type +* @tparam DefaultWeight The default weight for values not explicitly specified +* @tparam WeightSpecs Variadic template parameters of Weight specifications +*/ +template +class WeightsMap { +private: + static constexpr size_t NumWeights = sizeof...(WeightSpecs); + + // Helper to get weight for a specific value + static consteval uint16_t GetWeightForValue(VarT targetValue) { + // Check each weight spec to find the target value + uint16_t weight = DefaultWeight; + ((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...); + return weight; + } + +public: + /** + * @brief Get the weight for a specific value at compile time + * @tparam TargetValue The value to get weight for + * @return The weight for the value + */ + template + static consteval uint16_t GetWeight() { + return GetWeightForValue(TargetValue); + } + + /** + * @brief Get weights array for all registered values + * @return Array of weights corresponding to all registered values + */ + static consteval std::array GetWeightsArray() { + std::array weights{}; + + for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { + weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i)); + } + + return weights; + } + + static consteval uint32_t GetTotalWeight() { + uint32_t totalWeight = 0; + auto weights = GetWeightsArray(); + for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { + totalWeight += weights[i]; + } + return totalWeight; + } + + static consteval std::array GetCumulativeWeightsArray() { + auto weights = GetWeightsArray(); + uint32_t totalWeight = 0; + std::array cumulativeWeights{}; + for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { + totalWeight += weights[i]; + cumulativeWeights[i] = totalWeight; + } + return cumulativeWeights; + } +}; + +/** +* @brief Weighted random selector that uses another random selector as backend +* @tparam VarT The variable type +* @tparam VariableIDMapT The variable ID map type +* @tparam BackendSelectorT The backend random selector type +* @tparam WeightsMapT The weights map type containing weight specifications +*/ +template +class WeightedSelector { +private: + BackendSelectorT m_backendSelector; + const std::array m_weights; + const std::array m_cumulativeWeights; + +public: + explicit WeightedSelector(BackendSelectorT backendSelector) + : m_backendSelector(backendSelector) + , m_weights(WeightsMapT::GetWeightsArray()) + , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) + {} + + explicit WeightedSelector(uint32_t seed) + requires std::is_same_v> + : m_backendSelector(seed) + , m_weights(WeightsMapT::GetWeightsArray()) + , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) + {} + + size_t operator()(std::span possibleValues) const { + constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); + constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value"); + + // Use backend selector to pick a random number in range [0, totalWeight) + uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back()); + + // Find which value this random value corresponds to + for (size_t i = 0; i < possibleValues.size(); ++i) { + if (randomValue <= m_cumulativeWeights[i]) { + return i; + } + } + + // Fallback (should not reach here) + return possibleValues.size() - 1; + } +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_utils.hpp b/include/nd-wfc/wfc_utils.hpp new file mode 100644 index 0000000..191020c --- /dev/null +++ b/include/nd-wfc/wfc_utils.hpp @@ -0,0 +1,27 @@ +#pragma once + +namespace WFC +{ + + + + + +inline constexpr void constexpr_assert(bool condition, const char* message = "") { + if (!condition) throw message; +} + +inline int FindNthSetBit(size_t num, int n) { + constexpr_assert(n < std::popcount(num), "index is out of range"); + int bitCount = 0; + while (num) { + if (bitCount == n) { + return std::countr_zero(num); // Index of the current set bit + } + bitCount++; + num &= (num - 1); // turn of lowest set bit + } + return bitCount; +} + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_variable_map.hpp b/include/nd-wfc/wfc_variable_map.hpp new file mode 100644 index 0000000..ab998c0 --- /dev/null +++ b/include/nd-wfc/wfc_variable_map.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include "wfc_utils.hpp" + +namespace WFC { + +/** +* @brief Class to map variable values to indices at compile time +* +* This class is used to map variable values to indices at compile time. +* It is a compile-time map of variable values to indices. +*/ +template +class VariableIDMap { +public: + + using Type = VarT; + static constexpr size_t ValuesRegisteredAmount = sizeof...(Values); + + template + using Merge = VariableIDMap; + + template + static consteval bool HasValue() + { + constexpr VarT arr[] = {Values...}; + constexpr size_t size = sizeof...(Values); + + for (size_t i = 0; i < size; ++i) + if (arr[i] == Value) + return true; + return false; + } + + template + static consteval size_t GetIndex() + { + static_assert(HasValue(), "Value was not defined"); + constexpr VarT arr[] = {Values...}; + constexpr size_t size = ValuesRegisteredAmount; + + for (size_t i = 0; i < size; ++i) + if (arr[i] == Value) + return i; + + return static_cast(-1); // This line is unreachable if value is found + } + + static std::span GetAllValues() + { + static const VarT allValues[] + { + Values... + }; + return std::span{ allValues, ValuesRegisteredAmount }; + } + + static VarT GetValue(size_t index) { + constexpr_assert(index < ValuesRegisteredAmount); + return GetAllValues()[index]; + } + + static consteval VarT GetValueConsteval(size_t index) + { + constexpr VarT arr[] = {Values...}; + return arr[index]; + } + + static consteval size_t size() { return ValuesRegisteredAmount; } +}; + + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_wave.hpp b/include/nd-wfc/wfc_wave.hpp new file mode 100644 index 0000000..7456244 --- /dev/null +++ b/include/nd-wfc/wfc_wave.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "wfc_bit_container.hpp" +#include "wfc_variable_map.hpp" +#include "wfc_allocator.hpp" + +namespace WFC { + +template +class Wave { +public: + using BitContainerT = BitContainer; + using ElementT = typename BitContainerT::StorageType; + +public: + Wave() = default; + Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter(allocator)) + { + for (auto& wave : m_data) wave = (1 << variableAmount) - 1; + } + + Wave(const Wave& other) = default; + +public: + void Collapse(size_t index, ElementT mask) { m_data[index] &= mask; } + size_t size() const { return m_data.size(); } + size_t Entropy(size_t index) const { return std::popcount(m_data[index]); } + bool IsCollapsed(size_t index) const { return Entropy(index) == 1; } + bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](ElementT value) { return std::popcount(value) == 1; }); } + bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](ElementT value) { return value == 0; }); } + bool IsContradicted(size_t index) const { return m_data[index] == 0; } + uint16_t GetVariableID(size_t index) const { return static_cast(std::countr_zero(m_data[index])); } + ElementT GetMask(size_t index) const { return m_data[index]; } + +private: + BitContainerT m_data; +}; + +} \ No newline at end of file From bc9d7e3b9bbfb63fe3ef359a7dc09c1011e517aa Mon Sep 17 00:00:00 2001 From: cdemeyer-teachx Date: Wed, 10 Sep 2025 12:21:31 +0900 Subject: [PATCH 3/3] implementation + tests pass --- demos/sudoku/sudoku.h | 7 +- demos/sudoku/sudoku_wfc.cpp | 7 +- demos/sudoku/test_sudoku.cpp | 2 - include/nd-wfc/wfc.hpp | 58 ++- include/nd-wfc/wfc_bit_container.hpp | 123 +++++-- include/nd-wfc/wfc_builder.hpp | 16 +- include/nd-wfc/wfc_callbacks.hpp | 5 +- include/nd-wfc/wfc_constrainer.hpp | 24 +- include/nd-wfc/wfc_large_integers.hpp | 505 +++++++++++++++++++++++++- include/nd-wfc/wfc_random.hpp | 4 +- include/nd-wfc/wfc_utils.hpp | 17 + include/nd-wfc/wfc_variable_map.hpp | 6 + include/nd-wfc/wfc_wave.hpp | 4 +- 13 files changed, 693 insertions(+), 85 deletions(-) diff --git a/demos/sudoku/sudoku.h b/demos/sudoku/sudoku.h index b15bd8d..8d0ffd2 100644 --- a/demos/sudoku/sudoku.h +++ b/demos/sudoku/sudoku.h @@ -11,7 +11,7 @@ #include #include -#include +#include // 4-bit packed Sudoku board storage - optimal packing // 81 cells * 4 bits = 324 bits @@ -38,7 +38,7 @@ public: uint8_t result = (data[byteIndex] >> shiftAmount) & 0xF; // Debug assertion: ensure result is in valid range - WFC::constexpr_assert(result >= 0 && result <= 9, "Sudoku cell value must be between 0-9"); + WFC::constexpr_assert(result <= 9, "Sudoku cell value must be between 0-9"); return result; } @@ -49,7 +49,7 @@ public: // Optimization: (pos & 1) << 2 instead of (pos % 2) * 4 constexpr inline void set(int pos, uint8_t value) { // Assert that value is in valid Sudoku range (0-9) - WFC::constexpr_assert(value >= 0 && value <= 9, "Sudoku cell value must be between 0-9"); + WFC::constexpr_assert(value <= 9, "Sudoku cell value must be between 0-9"); int byteIndex = pos >> 1; // pos / 2 using right shift @@ -294,6 +294,7 @@ public: // WFC Support // Static assert to ensure correct size (now 56 bytes with solver additions) static_assert(sizeof(Sudoku) == 41, "Sudoku class must be exactly 41 bytes"); +static_assert(WFC::HasConstexprSize, "Sudoku class must have a constexpr size() method"); // Fast solution validator (stateless) class SudokuValidator { diff --git a/demos/sudoku/sudoku_wfc.cpp b/demos/sudoku/sudoku_wfc.cpp index 2bd8827..169dae5 100644 --- a/demos/sudoku/sudoku_wfc.cpp +++ b/demos/sudoku/sudoku_wfc.cpp @@ -58,16 +58,11 @@ using SudokuSolverCallback = SudokuSolverBuilder::SetCellCollapsedCallback ::Build; -Sudoku GetWorldConsteval() -{ - return Sudoku{ "6......3.......7....7463....7.8...2.4...9...1.9...7.8....9851....6.......1......9" }; -} - int main() { std::cout << "Running Sudoku WFC" << std::endl; - Sudoku sudokuWorld = GetWorldConsteval(); + Sudoku sudokuWorld = Sudoku{ "6......3.......7....7463....7.8...2.4...9...1.9...7.8....9851....6.......1......9" }; bool success = SudokuSolverCallback::Run(sudokuWorld, true); diff --git a/demos/sudoku/test_sudoku.cpp b/demos/sudoku/test_sudoku.cpp index 8d238ca..7713813 100644 --- a/demos/sudoku/test_sudoku.cpp +++ b/demos/sudoku/test_sudoku.cpp @@ -288,9 +288,7 @@ void testPuzzleSolving(const std::string& difficulty, const std::string& filenam Sudoku& sudoku = puzzles[i]; EXPECT_TRUE(sudoku.isValid()) << difficulty << " puzzle " << i << " is not valid"; - auto puzzleStart = std::chrono::high_resolution_clock::now(); SudokuSolver::Run(sudoku, allocator); - auto puzzleEnd = std::chrono::high_resolution_clock::now(); EXPECT_TRUE(sudoku.isSolved()) << difficulty << " puzzle " << i << " was not solved. Puzzle string: " << sudoku.toString(); diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index 3d845e7..daed5e3 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -25,23 +25,6 @@ namespace WFC { -template -struct WorldValue -{ -public: - WorldValue() = default; - WorldValue(VarT value, uint16_t internalIndex) - : Value(value) - , InternalIndex(internalIndex) - {} -public: - operator VarT() const { return Value; } - -public: - VarT Value{}; - uint16_t InternalIndex{}; -}; - template concept WorldType = requires(T world, size_t id, typename T::ValueType value) { { world.size() } -> std::convertible_to; @@ -64,15 +47,20 @@ concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, Worl * The function must be callable with parameters: (std::span) and return size_t */ template -concept RandomSelectorFunction = requires(T func, std::span possibleValues) { +concept RandomSelectorFunction = requires(const T& func, std::span possibleValues) { { func(possibleValues) } -> std::convertible_to; { func.rng(static_cast(1)) } -> std::convertible_to; }; +template +concept HasConstexprSize = requires { + { []() constexpr -> std::size_t { return WorldT{}.size(); }() }; +}; + /** * @brief Main WFC class implementing the Wave Function Collapse algorithm */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, @@ -81,14 +69,19 @@ class WFC { public: static_assert(WorldType, "WorldT must satisfy World type requirements"); - using ElementT = typename VariableIDMapT::ElementT; + // Try getting the world size, which is only available if the world type has a constexpr size() method + constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + + using WaveType = Wave; + using ConstrainerType = Constrainer; + using MaskType = typename WaveType::ElementT; public: struct SolverState { WorldT& world; WFCQueue propagationQueue; - Wave wave; + WaveType wave; std::mt19937& rng; RandomSelectorT& randomSelector; WFCStackAllocator& allocator; @@ -97,7 +90,7 @@ public: SolverState(WorldT& world, size_t variableAmount, std::mt19937& rng, RandomSelectorT& randomSelector, WFCStackAllocator& allocator, size_t& iterations) : world(world) , propagationQueue{ WFCStackAllocatorAdapter(allocator) } - , wave{ world.size(), variableAmount, allocator } + , wave{ WorldSize, variableAmount, allocator } , rng(rng) , randomSelector(randomSelector) , allocator(allocator) @@ -111,6 +104,7 @@ public: WFC() = delete; // dont make an instance of this class, only use the static methods. public: + static bool Run(WorldT& world, uint32_t seed = std::random_device{}()) { WFCStackAllocator allocator{}; @@ -134,10 +128,12 @@ public: allocator, iterations }; - return Run(state); + bool result = Run(state); allocator.reset(); constexpr_assert(allocator.getUsed() == 0, "Allocator must be empty"); + + return result; } /** @@ -209,7 +205,7 @@ public: static const std::vector GetPossibleValues(SolverState& state, int cellId) { std::vector possibleValues; - ElementT mask = state.wave.GetMask(cellId); + MaskType mask = state.wave.GetMask(cellId); for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); } @@ -219,7 +215,7 @@ public: private: static void CollapseCell(SolverState& state, size_t cellId, uint16_t value) { - constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (ElementT(1) << value)); + constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (MaskType(1) << value)); state.wave.Collapse(cellId, 1 << value); constexpr_assert(state.wave.IsCollapsed(cellId)); @@ -252,14 +248,14 @@ private: // create a list of possible values uint16_t availableValues = static_cast(state.wave.Entropy(minEntropyCell)); std::array possibleValues; // inplace vector - ElementT mask = state.wave.GetMask(minEntropyCell); + MaskType mask = state.wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds"); possibleValues[i] = index; - constexpr_assert(((mask & (ElementT(1) << index)) != 0), "Possible value was not set"); + constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); mask = mask & (mask - 1); // turn off lowest set bit } @@ -293,9 +289,9 @@ private: } // remove the failure state from the wave - constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) != 0, "Possible value was not set"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) != 0, "Possible value was not set"); state.wave.Collapse(minEntropyCell, ~(1 << selectedValue)); - constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); // swap replacement value with the last value std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); @@ -316,9 +312,9 @@ private: constexpr_assert(state.wave.IsCollapsed(cellId), "Cell was not collapsed"); uint16_t variableID = state.wave.GetVariableID(cellId); - Constrainer constrainer(state.wave, state.propagationQueue); + ConstrainerType constrainer(state.wave, state.propagationQueue); - using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, Constrainer&); + using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, ConstrainerType&); ConstrainerFunctionMapT::template GetFunction(variableID)(state.world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); } diff --git a/include/nd-wfc/wfc_bit_container.hpp b/include/nd-wfc/wfc_bit_container.hpp index 60b322f..509feb4 100644 --- a/include/nd-wfc/wfc_bit_container.hpp +++ b/include/nd-wfc/wfc_bit_container.hpp @@ -6,9 +6,11 @@ #include #include #include +#include #include "wfc_utils.hpp" #include "wfc_allocator.hpp" +#include "wfc_large_integers.hpp" namespace WFC { @@ -38,7 +40,7 @@ namespace detail { static constexpr size_t StorageBits = OptimalStorageType::bits_needed; static constexpr size_t ArraySize = StorageBits > 64 ? (StorageBits / 64) : 1; using element_type = std::conditional_t::type, uint64_t>; - using type = std::conditional_t>; + using type = std::conditional_t>; }; struct Empty{}; @@ -57,7 +59,6 @@ public: static constexpr bool IsResizable = (Size == 0); static constexpr bool IsMultiElement = (StorageBits > 64); static constexpr bool IsSubByte = (StorageBits < 8); - static constexpr bool IsDefaultByteLayout = !IsMultiElement && !IsSubByte; static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max(1u, StorageBits); using ContainerType = @@ -90,15 +91,30 @@ private: static constexpr uint64_t Mask = get_Mask(); +public: + static constexpr StorageType GetWaveMask() + { + return (StorageType{1} << BitsPerElement) - 1; + } + + static constexpr StorageType GetMask(std::span indices) + { + StorageType mask = 0; + for (const auto& index : indices) { + mask |= (StorageType{1} << index); + } + return mask; + } + public: BitContainer() = default; - BitContainer(AllocatorT& allocator) : AllocatorT(allocator) {}; - explicit BitContainer(size_t initial_size, AllocatorT& allocator) requires (IsResizable) + BitContainer(const AllocatorT& allocator) : AllocatorT(allocator) {}; + explicit BitContainer(size_t initial_size, const AllocatorT& allocator) requires (IsResizable) : AllocatorT(allocator) , m_container(initial_size, allocator) {}; - explicit BitContainer(size_t, AllocatorT& allocator) requires (!IsResizable) + explicit BitContainer(size_t, const AllocatorT& allocator) requires (!IsResizable) : AllocatorT(allocator) , m_container() {}; @@ -131,13 +147,15 @@ public: // Sub byte constexpr uint8_t SetValue(uint8_t val) { Clear(); return Data |= ((val & Mask) << Shift); } constexpr void Clear() { Data &= ~Mask; } + + constexpr SubTypeAccess& operator=(uint8_t other) { return SetValue(other); } constexpr operator uint8_t() const { return GetValue(); } - template constexpr uint8_t operator&=(T other) { return SetValue(GetValue() & other); } - template constexpr uint8_t operator|=(T other) { return SetValue(GetValue() | other); } - template constexpr uint8_t operator^=(T other) { return SetValue(GetValue() ^ other); } - template constexpr uint8_t operator<<=(T other) { return SetValue(GetValue() << other); } - template constexpr uint8_t operator>>=(T other) { return SetValue(GetValue() >> other); } + constexpr SubTypeAccess& operator&=(uint8_t other) { return SetValue(GetValue() & other); } + constexpr SubTypeAccess& operator|=(uint8_t other) { return SetValue(GetValue() | other); } + constexpr SubTypeAccess& operator^=(uint8_t other) { return SetValue(GetValue() ^ other); } + constexpr SubTypeAccess& operator<<=(uint8_t other) { return SetValue(GetValue() << other); } + constexpr SubTypeAccess& operator>>=(uint8_t other) { return SetValue(GetValue() >> other); } uint8_t& Data; uint8_t Shift; @@ -146,23 +164,86 @@ public: // Sub byte constexpr const SubTypeAccess operator[](size_t index) const requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } constexpr SubTypeAccess operator[](size_t index) requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } -public: // MultiElement - struct MultiElementAccess - { - constexpr MultiElementAccess(StorageType& data) : Data{ data } {}; +public: // default + constexpr const StorageType& operator[](size_t index) const requires(!IsSubByte) { return data()[index]; } + constexpr StorageType& operator[](size_t index) requires(!IsSubByte) { return data()[index]; } - StorageType& Data; +public: // iterators + template + class BitIterator { + public: + // Iterator traits + using iterator_category = std::random_access_iterator_tag; + using value_type = StorageType; + using difference_type = std::ptrdiff_t; + using pointer = std::conditional_t; + using reference = std::conditional_t; + + private: + using ContainerType = std::conditional_t; + + ContainerType* m_container{}; + size_t m_index{}; + + public: + // Constructor + constexpr BitIterator() = default; + constexpr BitIterator(ContainerType& container, size_t index) : m_container(&container), m_index(index) {} + + // Dereference + constexpr reference operator*() const { return (*m_container)[m_index]; } + constexpr pointer operator->() const { return &(*m_container)[m_index]; } + + // Element access + constexpr reference operator[](difference_type n) const { return (*m_container)[m_index + n]; } + + // Increment / Decrement + constexpr BitIterator& operator++() { ++m_index; return *this; } + constexpr BitIterator operator++(int) { BitIterator tmp = *this; ++m_index; return tmp; } + constexpr BitIterator& operator--() { --m_index; return *this; } + constexpr BitIterator operator--(int) { BitIterator tmp = *this; --m_index; return tmp; } + + // Arithmetic + constexpr BitIterator operator+(difference_type n) const { return BitIterator(*m_container, m_index + n); } + constexpr BitIterator operator-(difference_type n) const { return BitIterator(*m_container, m_index - n); } + constexpr difference_type operator-(const BitIterator& other) const { return static_cast(m_index) - static_cast(other.m_index); } + + // Assignment + constexpr BitIterator& operator+=(difference_type n) { m_index += n; return *this; } + constexpr BitIterator& operator-=(difference_type n) { m_index -= n; return *this; } + + // Comparison + constexpr bool operator==(const BitIterator& other) const { return m_index == other.m_index; } + constexpr bool operator!=(const BitIterator& other) const { return m_index != other.m_index; } + constexpr bool operator<(const BitIterator& other) const { return m_index < other.m_index; } + constexpr bool operator>(const BitIterator& other) const { return m_index > other.m_index; } + constexpr bool operator<=(const BitIterator& other) const { return m_index <= other.m_index; } + constexpr bool operator>=(const BitIterator& other) const { return m_index >= other.m_index; } + + // Conversion from non-const to const iterator + constexpr operator BitIterator() const { + return BitIterator(*m_container, m_index); + } }; - constexpr const MultiElementAccess operator[](size_t index) const requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } - constexpr MultiElementAccess operator[](size_t index) requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } - -public: // default - constexpr const StorageType& operator[](size_t index) const requires(IsDefaultByteLayout) { return data()[index]; } - constexpr StorageType& operator[](size_t index) requires(IsDefaultByteLayout) { return data()[index]; } + // Type aliases for convenience + using ConstIterator = BitIterator; + using Iterator = BitIterator; + constexpr Iterator begin() { return Iterator{*this, 0}; } + constexpr Iterator end() { return Iterator{*this, size()}; } + constexpr const ConstIterator begin() const { return ConstIterator{*this, 0}; } + constexpr const ConstIterator end() const { return ConstIterator{*this, size()}; } }; +// Free function for iterator addition +template ::type>, bool IsConst> +BitContainer::BitIterator operator+( + typename BitContainer::template BitIterator::difference_type n, + const typename BitContainer::template BitIterator& it) { + return it + n; +} + static_assert(BitContainer<1, 10>::ElementsPerByte == 8); static_assert(BitContainer<2, 10>::ElementsPerByte == 4); static_assert(BitContainer<4, 10>::ElementsPerByte == 2); diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp index b30697f..b85366e 100644 --- a/include/nd-wfc/wfc_builder.hpp +++ b/include/nd-wfc/wfc_builder.hpp @@ -12,22 +12,32 @@ namespace WFC { /** * @brief Builder class for creating WFC instances */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector> +template< + typename WorldT, + typename VarT = typename WorldT::ValueType, + typename VariableIDMapT = VariableIDMap, + typename ConstrainerFunctionMapT = ConstrainerFunctionMap, + typename CallbacksT = Callbacks, + typename RandomSelectorT = DefaultRandomSelector> class Builder { public: + constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + + using WaveType = Wave; + using ConstrainerType = Constrainer; template using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; template - requires ConstrainerFunction + requires ConstrainerFunction using DefineConstrainer = Builder, - decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) + decltype([](WorldT&, size_t, WorldValue, ConstrainerType&) {}) >, CallbacksT, RandomSelectorT >; diff --git a/include/nd-wfc/wfc_callbacks.hpp b/include/nd-wfc/wfc_callbacks.hpp index f3a02ce..8965b58 100644 --- a/include/nd-wfc/wfc_callbacks.hpp +++ b/include/nd-wfc/wfc_callbacks.hpp @@ -7,7 +7,10 @@ namespace WFC { * @param WorldT The world type */ template -using EmptyCallback = decltype([](WorldT&){}); +struct EmptyCallback +{ + void operator()(WorldT&) const {} +}; /** * @brief Callback struct diff --git a/include/nd-wfc/wfc_constrainer.hpp b/include/nd-wfc/wfc_constrainer.hpp index 100ae59..d40635e 100644 --- a/include/nd-wfc/wfc_constrainer.hpp +++ b/include/nd-wfc/wfc_constrainer.hpp @@ -61,13 +61,15 @@ using MergedConstrainerFunctionMap = decltype( /** * @brief Constrainer class used in constraint functions to limit possible values for other cells */ -template +template class Constrainer { public: - using MaskType = typename VariableIDMapT::MaskType; + using IDMapT = typename WaveT::IDMapT; + using BitContainerT = typename WaveT::BitContainerT; + using MaskType = typename BitContainerT::StorageType; public: - Constrainer(Wave& wave, WFCQueue& propagationQueue) + Constrainer(WaveT& wave, WFCQueue& propagationQueue) : m_wave(wave) , m_propagationQueue(propagationQueue) {} @@ -77,13 +79,14 @@ public: * @param cellId The ID of the cell to constrain * @param forbiddenValues The set of forbidden values for this cell */ - template + template void Exclude(size_t cellId) { static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided"); - ApplyMask(cellId, ~VariableIDMapT::template GetMask()); + auto indices = IDMapT::template ValuesToIndices(); + ApplyMask(cellId, ~BitContainerT::GetMask(indices)); } - void Exclude(WorldValue value, size_t cellId) { + void Exclude(WorldValue value, size_t cellId) { ApplyMask(cellId, ~(1 << value.InternalIndex)); } @@ -92,13 +95,14 @@ public: * @param cellId The ID of the cell to constrain * @param value The only allowed value for this cell */ - template + template void Only(size_t cellId) { static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided"); - ApplyMask(cellId, VariableIDMapT::template GetMask()); + auto indices = IDMapT::template ValuesToIndices(); + ApplyMask(cellId, BitContainerT::GetMask(indices)); } - void Only(WorldValue value, size_t cellId) { + void Only(WorldValue value, size_t cellId) { ApplyMask(cellId, 1 << value.InternalIndex); } @@ -115,7 +119,7 @@ private: } private: - Wave& m_wave; + WaveT& m_wave; WFCQueue& m_propagationQueue; }; diff --git a/include/nd-wfc/wfc_large_integers.hpp b/include/nd-wfc/wfc_large_integers.hpp index eca0afa..6a2484d 100644 --- a/include/nd-wfc/wfc_large_integers.hpp +++ b/include/nd-wfc/wfc_large_integers.hpp @@ -1,22 +1,517 @@ #pragma once #include +#include +#include +#include +#include +#include + +// Detect __uint128_t support +#if (defined(__SIZEOF_INT128__) || defined(__INTEL_COMPILER) || (defined(__GNUC__) && __GNUC__ >= 4)) && !defined(_MSC_VER) +#define WFC_HAS_UINT128 1 +#else +#define WFC_HAS_UINT128 0 +#endif namespace WFC { template struct LargeInteger { + static_assert(Size > 0, "Size must be greater than 0"); + std::array m_data; + // Constructors + constexpr LargeInteger() = default; + constexpr LargeInteger(const LargeInteger&) = default; + constexpr LargeInteger(LargeInteger&&) = default; + constexpr LargeInteger& operator=(const LargeInteger&) = default; + constexpr LargeInteger& operator=(LargeInteger&&) = default; + + // Constructor from uint64_t (for small values) + template && std::is_unsigned_v>> + constexpr explicit LargeInteger(T value) { + m_data.fill(0); + if constexpr (sizeof(T) <= sizeof(uint64_t)) { + m_data[0] = static_cast(value); + } else { + // Handle larger types if needed + static_assert(sizeof(T) <= sizeof(uint64_t), "Type too large for LargeInteger"); + } + } + + // Access operators + constexpr uint64_t& operator[](size_t index) { return m_data[index]; } + constexpr const uint64_t& operator[](size_t index) const { return m_data[index]; } + + // Helper function to get the larger size type template - constexpr LargeInteger operator+(const LargeInteger& other) const { - LargeInteger result; - for (size_t i = 0; i < std::max(Size, OtherSize); i++) { - result[i] = m_data[i] + other[i]; + using LargerType = LargeInteger; + + // Helper function to promote operands to the same size + template + constexpr auto promote(const LargeInteger& other) const { + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger lhs_promoted{}; + LargeInteger rhs_promoted{}; + + // Copy data, padding with zeros + for (size_t i = 0; i < Size; ++i) { + lhs_promoted[i] = m_data[i]; + } + for (size_t i = 0; i < OtherSize; ++i) { + rhs_promoted[i] = other[i]; + } + + return std::make_pair(lhs_promoted, rhs_promoted); + } + + // Arithmetic operators + template + constexpr LargerType operator+(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; + + uint64_t carry = 0; + for (size_t i = 0; i < ResultSize; ++i) { + uint64_t sum = lhs[i] + rhs[i] + carry; + result[i] = sum; + carry = (sum < lhs[i] || (carry && sum == lhs[i])) ? 1 : 0; + } + + return result; + } + + template + constexpr LargeInteger& operator+=(const LargeInteger& other) { + *this = *this + other; + return *this; + } + + template + constexpr LargerType operator-(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; + + uint64_t borrow = 0; + for (size_t i = 0; i < ResultSize; ++i) { + uint64_t diff = lhs[i] - rhs[i] - borrow; + result[i] = diff; + borrow = (lhs[i] < rhs[i] + borrow) ? 1 : 0; + } + + return result; + } + + template + constexpr LargeInteger& operator-=(const LargeInteger& other) { + *this = *this - other; + return *this; + } + + template + constexpr LargerType operator*(const LargeInteger& other) const { +#if WFC_HAS_UINT128 + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; // Multiplication can double the size + + for (size_t i = 0; i < ResultSize; ++i) { + uint64_t carry = 0; + for (size_t j = 0; j < ResultSize; ++j) { + __uint128_t product = static_cast<__uint128_t>(lhs[i]) * rhs[j] + result[i + j] + carry; + result[i + j] = static_cast(product); + carry = product >> 64; + } + size_t k = i + ResultSize; + while (carry && k < ResultSize * 2) { + __uint128_t sum = result[k] + carry; + result[k] = static_cast(sum); + carry = sum >> 64; + ++k; + } + } + + // Truncate to the larger of the original sizes + LargeInteger final_result{}; + for (size_t i = 0; i < ResultSize; ++i) { + final_result[i] = result[i]; + } + return final_result; +#else + throw std::runtime_error("LargeInteger multiplication requires __uint128_t support, which is not available on this compiler/platform"); +#endif + } + + template + constexpr LargeInteger& operator*=(const LargeInteger& other) { + *this = *this * other; + return *this; + } + + // Division and modulo (simplified implementation) + template + constexpr LargerType operator/(const LargeInteger& other) const { + // Simplified division - assumes other is not zero and result fits + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; + + // This is a very basic division implementation + // For a full implementation, you'd need proper long division + LargeInteger temp = lhs; + while (temp >= rhs) { + temp = temp - rhs; + result = result + LargeInteger{1}; + } + + return result; + } + + template + constexpr LargerType operator%(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger temp = lhs; + while (temp >= rhs) { + temp = temp - rhs; + } + return temp; + } + + // Unary operators + constexpr LargeInteger operator-() const { + LargeInteger result{}; + for (size_t i = 0; i < Size; ++i) { + result[i] = ~m_data[i] + 1; // Two's complement + } + return result; + } + + constexpr LargeInteger operator~() const { + LargeInteger result{}; + for (size_t i = 0; i < Size; ++i) { + result[i] = ~m_data[i]; + } + return result; + } + + // Bit operations + template + constexpr LargerType operator&(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.bitwise_op(rhs, std::bit_and{}); + } + + template + constexpr LargeInteger& operator&=(const LargeInteger& other) { + *this = *this & other; + return *this; + } + + template + constexpr LargerType operator|(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.bitwise_op(rhs, std::bit_or{}); + } + + template + constexpr LargeInteger& operator|=(const LargeInteger& other) { + *this = *this | other; + return *this; + } + + template + constexpr LargerType operator^(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.bitwise_op(rhs, std::bit_xor{}); + } + + template + constexpr LargeInteger& operator^=(const LargeInteger& other) { + *this = *this ^ other; + return *this; + } + + template + constexpr LargerType operator<<(size_t shift) const { + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result = *this; + + size_t word_shift = shift / 64; + size_t bit_shift = shift % 64; + + if (word_shift >= ResultSize) { + result.m_data.fill(0); + return result; + } + + // Shift words + for (size_t i = ResultSize - 1; i >= word_shift; --i) { + result[i] = result[i - word_shift]; + } + for (size_t i = 0; i < word_shift; ++i) { + result[i] = 0; + } + + // Shift bits + if (bit_shift > 0) { + uint64_t carry = 0; + for (size_t i = word_shift; i < ResultSize; ++i) { + uint64_t new_carry = result[i] >> (64 - bit_shift); + result[i] = (result[i] << bit_shift) | carry; + carry = new_carry; + } + } + + return result; + } + + template + constexpr LargeInteger& operator<<=(size_t shift) { + *this = *this << shift; + return *this; + } + + template + constexpr LargerType operator>>(size_t shift) const { + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result = *this; + + size_t word_shift = shift / 64; + size_t bit_shift = shift % 64; + + if (word_shift >= ResultSize) { + result.m_data.fill(0); + return result; + } + + // Shift words + for (size_t i = 0; i < ResultSize - word_shift; ++i) { + result[i] = result[i + word_shift]; + } + for (size_t i = ResultSize - word_shift; i < ResultSize; ++i) { + result[i] = 0; + } + + // Shift bits + if (bit_shift > 0) { + uint64_t carry = 0; + for (size_t i = ResultSize - word_shift - 1; i < ResultSize; --i) { + uint64_t new_carry = result[i] << (64 - bit_shift); + result[i] = (result[i] >> bit_shift) | carry; + carry = new_carry; + if (i == 0) break; + } + } + + return result; + } + + template + constexpr LargeInteger& operator>>=(size_t shift) { + *this = *this >> shift; + return *this; + } + + // Comparison operators + template + constexpr bool operator==(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.m_data == rhs.m_data; + } + + template + constexpr bool operator!=(const LargeInteger& other) const { + return !(*this == other); + } + + template + constexpr bool operator<(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + for (size_t i = lhs.m_data.size(); i > 0; --i) { + if (lhs.m_data[i-1] != rhs.m_data[i-1]) { + return lhs.m_data[i-1] < rhs.m_data[i-1]; + } + } + return false; + } + + template + constexpr bool operator<=(const LargeInteger& other) const { + return *this < other || *this == other; + } + + template + constexpr bool operator>(const LargeInteger& other) const { + return other < *this; + } + + template + constexpr bool operator>=(const LargeInteger& other) const { + return other <= *this; + } + + // std::bit library functions + constexpr int countl_zero() const { + for (size_t i = Size; i > 0; --i) { + if (m_data[i-1] != 0) { + return std::countl_zero(m_data[i-1]) + (Size - i) * 64; + } + } + return Size * 64; + } + + constexpr int countl_one() const { + for (size_t i = Size; i > 0; --i) { + if (m_data[i-1] != std::numeric_limits::max()) { + return std::countl_one(m_data[i-1]) + (Size - i) * 64; + } + } + return Size * 64; + } + + constexpr int countr_zero() const { + for (size_t i = 0; i < Size; ++i) { + if (m_data[i] != 0) { + return std::countr_zero(m_data[i]) + i * 64; + } + } + return Size * 64; + } + + constexpr int countr_one() const { + for (size_t i = 0; i < Size; ++i) { + if (m_data[i] != std::numeric_limits::max()) { + return std::countr_one(m_data[i]) + i * 64; + } + } + return Size * 64; + } + + constexpr int popcount() const { + int count = 0; + for (size_t i = 0; i < Size; ++i) { + count += std::popcount(m_data[i]); + } + return count; + } + + template + constexpr LargerType rotl(size_t shift) const { + shift %= (Size * 64); + return (*this << shift) | (*this >> ((Size * 64) - shift)); + } + + template + constexpr LargerType rotr(size_t shift) const { + shift %= (Size * 64); + return (*this >> shift) | (*this << ((Size * 64) - shift)); + } + + constexpr bool has_single_bit() const { + return popcount() == 1; + } + + constexpr LargeInteger bit_ceil() const { + if (*this == LargeInteger{0}) return LargeInteger{1}; + + LargeInteger result = *this; + result -= LargeInteger{1}; + result |= result >> 1; + result |= result >> 2; + result |= result >> 4; + result |= result >> 8; + result |= result >> 16; + result |= result >> 32; + + // Handle multi-word case + for (size_t i = 1; i < Size; ++i) { + if (result[i] != 0) { + // Find the highest set bit in the higher words + size_t highest_word = Size - 1; + for (size_t j = Size - 1; j > 0; --j) { + if (result[j] != 0) { + highest_word = j; + break; + } + } + // Set all lower words to 0 and the highest word to the power of 2 + for (size_t j = 0; j < highest_word; ++j) { + result[j] = 0; + } + result[highest_word] = uint64_t(1) << (63 - std::countl_zero(result[highest_word])); + break; + } + } + + result += LargeInteger{1}; + return result; + } + + constexpr LargeInteger bit_floor() const { + if (*this == LargeInteger{0}) return LargeInteger{0}; + + LargeInteger result = *this; + result |= result >> 1; + result |= result >> 2; + result |= result >> 4; + result |= result >> 8; + result |= result >> 16; + result |= result >> 32; + + // Handle multi-word case + for (size_t i = 1; i < Size; ++i) { + if (result[i] != 0) { + size_t highest_word = Size - 1; + for (size_t j = Size - 1; j > 0; --j) { + if (result[j] != 0) { + highest_word = j; + break; + } + } + for (size_t j = 0; j < highest_word; ++j) { + result[j] = 0; + } + result[highest_word] = uint64_t(1) << (63 - std::countl_zero(result[highest_word])); + return result; + } + } + + // Single word case + result = LargeInteger{uint64_t(1) << (63 - std::countl_zero(result[0]))}; + return result; + } + + constexpr int bit_width() const { + if (*this == LargeInteger{0}) return 0; + + for (size_t i = Size; i > 0; --i) { + if (m_data[i-1] != 0) { + return (i - 1) * 64 + 64 - std::countl_zero(m_data[i-1]); + } + } + return 0; + } + +private: + // Helper function for bitwise operations + template + constexpr LargeInteger bitwise_op(const LargeInteger& other, Op op) const { + LargeInteger result{}; + for (size_t i = 0; i < Size; ++i) { + result[i] = op(m_data[i], other[i]); } return result; } }; -} \ No newline at end of file +// Deduction guide for constructor from integral types +template +LargeInteger(T) -> LargeInteger<1>; + +} // namespace WFC \ No newline at end of file diff --git a/include/nd-wfc/wfc_random.hpp b/include/nd-wfc/wfc_random.hpp index 92cada7..fb713ba 100644 --- a/include/nd-wfc/wfc_random.hpp +++ b/include/nd-wfc/wfc_random.hpp @@ -21,7 +21,7 @@ public: return static_cast(rng(possibleValues.size())); } - constexpr uint32_t rng(uint32_t max) { + constexpr uint32_t rng(uint32_t max) const { m_seed = m_seed * 1103515245 + 12345; return m_seed % max; } @@ -45,7 +45,7 @@ public: return rng(possibleValues.size()); } - uint32_t rng(uint32_t max) { + uint32_t rng(uint32_t max) const { std::uniform_int_distribution dist(0, max); return dist(m_rng); } diff --git a/include/nd-wfc/wfc_utils.hpp b/include/nd-wfc/wfc_utils.hpp index 191020c..26a955d 100644 --- a/include/nd-wfc/wfc_utils.hpp +++ b/include/nd-wfc/wfc_utils.hpp @@ -24,4 +24,21 @@ inline int FindNthSetBit(size_t num, int n) { return bitCount; } +template +struct WorldValue +{ +public: + WorldValue() = default; + WorldValue(VarT value, uint16_t internalIndex) + : Value(value) + , InternalIndex(internalIndex) + {} +public: + operator VarT() const { return Value; } + +public: + VarT Value{}; + uint16_t InternalIndex{}; +}; + } \ No newline at end of file diff --git a/include/nd-wfc/wfc_variable_map.hpp b/include/nd-wfc/wfc_variable_map.hpp index ab998c0..a4d553c 100644 --- a/include/nd-wfc/wfc_variable_map.hpp +++ b/include/nd-wfc/wfc_variable_map.hpp @@ -67,6 +67,12 @@ public: } static consteval size_t size() { return ValuesRegisteredAmount; } + + template + static constexpr auto ValuesToIndices() -> std::array { + std::array indices = {GetIndex()...}; + return indices; + } }; diff --git a/include/nd-wfc/wfc_wave.hpp b/include/nd-wfc/wfc_wave.hpp index 7456244..7ace7b2 100644 --- a/include/nd-wfc/wfc_wave.hpp +++ b/include/nd-wfc/wfc_wave.hpp @@ -11,10 +11,12 @@ class Wave { public: using BitContainerT = BitContainer; using ElementT = typename BitContainerT::StorageType; + using IDMapT = VariableIDMapT; + static constexpr size_t ElementsAmount = Size; public: Wave() = default; - Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter(allocator)) + Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, allocator) { for (auto& wave : m_data) wave = (1 << variableAmount) - 1; }