diff --git a/include/nd-wfc/wfc.h b/include/nd-wfc/wfc.h index 6f49023..959b8f9 100644 --- a/include/nd-wfc/wfc.h +++ b/include/nd-wfc/wfc.h @@ -9,4 +9,5 @@ #include "wfc.hpp" #include "worlds.hpp" +#include "wfc_builder.hpp" diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index 0af8a9d..3d845e7 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -14,174 +14,17 @@ #include #include +#include "wfc_utils.hpp" +#include "wfc_variable_map.hpp" #include "wfc_allocator.hpp" +#include "wfc_bit_container.hpp" +#include "wfc_wave.hpp" +#include "wfc_constrainer.hpp" +#include "wfc_callbacks.hpp" +#include "wfc_random.hpp" namespace WFC { -inline constexpr void constexpr_assert(bool condition, const char* message = "") { - if (!condition) throw message; -} - -inline int FindNthSetBit(size_t num, int n) { - constexpr_assert(n < std::popcount(num), "index is out of range"); - int bitCount = 0; - while (num) { - if (bitCount == n) { - return std::countr_zero(num); // Index of the current set bit - } - bitCount++; - num &= (num - 1); // turn of lowest set bit - } - return bitCount; -} - -template -concept WorldType = requires(T world, size_t id, typename T::ValueType value) { - { world.size() } -> std::convertible_to; - { world.setValue(id, value) }; - { world.getValue(id) } -> std::convertible_to; - typename T::ValueType; -}; - -/** - * @brief Class to map variable values to indices at compile time - * - * This class is used to map variable values to indices at compile time. - * It is a compile-time map of variable values to indices. - */ -template -class VariableIDMap { -public: - - using Type = VarT; - static constexpr size_t ValuesRegisteredAmount = sizeof...(Values); - - using MaskType = typename std::conditional< - ValuesRegisteredAmount <= 8, - uint8_t, - typename std::conditional< - ValuesRegisteredAmount <= 16, - uint16_t, - typename std::conditional< - ValuesRegisteredAmount <= 32, - uint32_t, - uint64_t - >::type - >::type - >::type; - - template - using Merge = VariableIDMap; - - template - static consteval bool HasValue() - { - constexpr VarT arr[] = {Values...}; - constexpr size_t size = sizeof...(Values); - - for (size_t i = 0; i < size; ++i) - if (arr[i] == Value) - return true; - return false; - } - - template - static consteval size_t GetIndex() - { - static_assert(HasValue(), "Value was not defined"); - constexpr VarT arr[] = {Values...}; - constexpr size_t size = ValuesRegisteredAmount; - - for (size_t i = 0; i < size; ++i) - if (arr[i] == Value) - return i; - - return static_cast(-1); // This line is unreachable if value is found - } - - template - static consteval MaskType GetMask() - { - return (0 | ... | (1 << GetIndex())); - } - - static std::span GetAllValues() - { - static const VarT allValues[] - { - Values... - }; - return std::span{ allValues, ValuesRegisteredAmount }; - } - - static VarT GetValue(size_t index) { - constexpr_assert(index < ValuesRegisteredAmount); - return GetAllValues()[index]; - } - - static consteval VarT GetValueConsteval(size_t index) - { - constexpr VarT arr[] = {Values...}; - return arr[index]; - } - - static consteval size_t size() { return ValuesRegisteredAmount; } -}; - -template -struct ConstrainerFunctionMap { -public: - static consteval size_t size() { return sizeof...(ConstrainerFunctions); } - - using TupleType = std::tuple; - - template - static ConstrainerFunctionPtrT GetFunction(size_t index) - { - static_assert((std::is_empty_v && ...), "Lambdas must not have any captures"); - static ConstrainerFunctionPtrT functions[] = { - static_cast(ConstrainerFunctions{}) ... - }; - return functions[index]; - } -}; - -// Helper to select the correct constrainer function based on the index and the value -template -using MergedConstrainerElementSelector = - std::conditional_t(), // if the value is in the selected IDs - NewConstrainerFunctionT, - std::conditional_t<(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple - std::tuple_element_t, - EmptyFunctionT - > - >; - -// Helper to make a merged constrainer function map -template -auto MakeMergedConstrainerIDMap(std::index_sequence,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*) - -> ConstrainerFunctionMap...>; - -// Main alias for the merged constrainer function map -template -using MergedConstrainerFunctionMap = decltype( - MakeMergedConstrainerIDMap(std::make_index_sequence{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr) -); - template struct WorldValue { @@ -199,150 +42,37 @@ public: uint16_t InternalIndex{}; }; -template -class Wave { -public: - Wave() = default; - Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter(allocator)) - { - for (auto& wave : m_data) wave = (1 << variableAmount) - 1; - } - - Wave(const Wave& other) = default; - -public: - void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; } - size_t size() const { return m_data.size(); } - size_t Entropy(size_t index) const { return std::popcount(m_data[index]); } - bool IsCollapsed(size_t index) const { return Entropy(index) == 1; } - bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](MaskType value) { return std::popcount(value) == 1; }); } - bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](MaskType value) { return value == 0; }); } - bool IsContradicted(size_t index) const { return m_data[index] == 0; } - uint16_t GetVariableID(size_t index) const { return static_cast(std::countr_zero(m_data[index])); } - MaskType GetMask(size_t index) const { return m_data[index]; } - -private: - WFCVector m_data; +template +concept WorldType = requires(T world, size_t id, typename T::ValueType value) { + { world.size() } -> std::convertible_to; + { world.setValue(id, value) }; + { world.getValue(id) } -> std::convertible_to; + typename T::ValueType; }; /** - * @brief Constrainer class used in constraint functions to limit possible values for other cells - */ -template -class Constrainer { -public: - using MaskType = typename VariableIDMapT::MaskType; - -public: - Constrainer(Wave& wave, WFCQueue& propagationQueue) - : m_wave(wave) - , m_propagationQueue(propagationQueue) - {} - - /** - * @brief Constrain a cell to exclude specific values - * @param cellId The ID of the cell to constrain - * @param forbiddenValues The set of forbidden values for this cell - */ - template - void Exclude(size_t cellId) { - static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided"); - ApplyMask(cellId, ~VariableIDMapT::template GetMask()); - } - - void Exclude(WorldValue value, size_t cellId) { - ApplyMask(cellId, ~(1 << value.InternalIndex)); - } - - /** - * @brief Constrain a cell to only allow one specific value - * @param cellId The ID of the cell to constrain - * @param value The only allowed value for this cell - */ - template - void Only(size_t cellId) { - static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided"); - ApplyMask(cellId, VariableIDMapT::template GetMask()); - } - - void Only(WorldValue value, size_t cellId) { - ApplyMask(cellId, 1 << value.InternalIndex); - } - -private: - void ApplyMask(size_t cellId, MaskType mask) { - bool wasCollapsed = m_wave.IsCollapsed(cellId); - - m_wave.Collapse(cellId, mask); - - bool collapsed = m_wave.IsCollapsed(cellId); - if (!wasCollapsed && collapsed) { - m_propagationQueue.push(cellId); - } - } - -private: - Wave& m_wave; - WFCQueue& m_propagationQueue; +* @brief Concept to validate constrainer function signature +* The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) +*/ +template +concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { + func(world, index, value, constrainer); }; /** - * @brief Variable definition with its constraint function - */ -template -struct VariableData { - VarT value{}; - std::function, Constrainer&)> constraintFunc{}; - - VariableData() = default; - VariableData(VarT value, std::function, Constrainer&)> constraintFunc) - : value(value) - , constraintFunc(constraintFunc) - {} -}; - -/** - * @brief Empty callback function - * @param World The world type - */ -template -using EmptyCallback = decltype([](World&){}); - -/** - * @brief Callback struct - * @param WorldT The world type - * @param AllCellsCollapsedCallbackT The all cells collapsed callback type - * @param CellCollapsedCallbackT The cell collapsed callback type - * @param ContradictionCallbackT The contradiction callback type - * @param BranchCallbackT The branch callback type - */ -template , - typename ContradictionCallbackT = EmptyCallback, - typename BranchCallbackT = EmptyCallback -> -struct Callbacks -{ - using CellCollapsedCallback = CellCollapsedCallbackT; - using ContradictionCallback = ContradictionCallbackT; - using BranchCallback = BranchCallbackT; - - template - using SetCellCollapsedCallbackT = Callbacks; - template - using SetContradictionCallbackT = Callbacks; - template - using SetBranchCallbackT = Callbacks; - - static consteval bool HasCellCollapsedCallback() { return !std::is_same_v>; } - static consteval bool HasContradictionCallback() { return !std::is_same_v>; } - static consteval bool HasBranchCallback() { return !std::is_same_v>; } +* @brief Concept to validate random selector function signature +* The function must be callable with parameters: (std::span) and return size_t +*/ +template +concept RandomSelectorFunction = requires(T func, std::span possibleValues) { + { func(possibleValues) } -> std::convertible_to; + { func.rng(static_cast(1)) } -> std::convertible_to; }; /** * @brief Main WFC class implementing the Wave Function Collapse algorithm */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, @@ -351,14 +81,14 @@ class WFC { public: static_assert(WorldType, "WorldT must satisfy World type requirements"); - using MaskType = typename VariableIDMapT::MaskType; + using ElementT = typename VariableIDMapT::ElementT; public: struct SolverState { WorldT& world; WFCQueue propagationQueue; - Wave wave; + Wave wave; std::mt19937& rng; RandomSelectorT& randomSelector; WFCStackAllocator& allocator; @@ -479,7 +209,7 @@ public: static const std::vector GetPossibleValues(SolverState& state, int cellId) { std::vector possibleValues; - MaskType mask = state.wave.GetMask(cellId); + ElementT mask = state.wave.GetMask(cellId); for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); } @@ -489,7 +219,7 @@ public: private: static void CollapseCell(SolverState& state, size_t cellId, uint16_t value) { - constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (1 << value)); + constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (ElementT(1) << value)); state.wave.Collapse(cellId, 1 << value); constexpr_assert(state.wave.IsCollapsed(cellId)); @@ -522,14 +252,14 @@ private: // create a list of possible values uint16_t availableValues = static_cast(state.wave.Entropy(minEntropyCell)); std::array possibleValues; // inplace vector - MaskType mask = state.wave.GetMask(minEntropyCell); + ElementT mask = state.wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds"); possibleValues[i] = index; - constexpr_assert(((mask & (1 << index)) != 0), "Possible value was not set"); + constexpr_assert(((mask & (ElementT(1) << index)) != 0), "Possible value was not set"); mask = mask & (mask - 1); // turn off lowest set bit } @@ -563,9 +293,9 @@ private: } // remove the failure state from the wave - constexpr_assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) != 0, "Possible value was not set"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) != 0, "Possible value was not set"); state.wave.Collapse(minEntropyCell, ~(1 << selectedValue)); - constexpr_assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0, "Wave was not collapsed correctly"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); // swap replacement value with the last value std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); @@ -622,236 +352,4 @@ private: } }; -/** - * @brief Concept to validate constrainer function signature - * The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) - */ -template -concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { - func(world, index, value, constrainer); -}; - -/** - * @brief Concept to validate random selector function signature - * The function must be callable with parameters: (std::span) and return size_t - */ -template -concept RandomSelectorFunction = requires(T func, std::span possibleValues) { - { func(possibleValues) } -> std::convertible_to; - { func.rng(static_cast(1)) } -> std::convertible_to; -}; - -/** - * @brief Default constexpr random selector using a simple seed-based algorithm - * This provides a compile-time random selection that maintains state between calls - */ -template -class DefaultRandomSelector { -private: - mutable uint32_t m_seed; - -public: - constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {} - - constexpr size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - - // Simple linear congruential generator for constexpr compatibility - return static_cast(rng(possibleValues.size())); - } - - constexpr uint32_t rng(uint32_t max) { - m_seed = m_seed * 1103515245 + 12345; - return m_seed % max; - } -}; - -/** - * @brief Advanced random selector using std::mt19937 and std::uniform_int_distribution - * This provides high-quality randomization for runtime use - */ -template -class AdvancedRandomSelector { -private: - std::mt19937& m_rng; - -public: - explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {} - - size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - - return rng(possibleValues.size()); - } - - uint32_t rng(uint32_t max) { - std::uniform_int_distribution dist(0, max); - return dist(m_rng); - } -}; - -/** - * @brief Weight specification for a specific value - * @tparam Value The variable value - * @tparam Weight The 16-bit weight for this value - */ - template - struct Weight { - static constexpr VarT value = Value; - static constexpr uint16_t weight = WeightValue; - }; - - /** - * @brief Compile-time weights storage for weighted random selection - * @tparam VarT The variable type - * @tparam VariableIDMapT The variable ID map type - * @tparam DefaultWeight The default weight for values not explicitly specified - * @tparam WeightSpecs Variadic template parameters of Weight specifications - */ - template - class WeightsMap { - private: - static constexpr size_t NumWeights = sizeof...(WeightSpecs); - - // Helper to get weight for a specific value - static consteval uint16_t GetWeightForValue(VarT targetValue) { - // Check each weight spec to find the target value - uint16_t weight = DefaultWeight; - ((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...); - return weight; - } - - public: - /** - * @brief Get the weight for a specific value at compile time - * @tparam TargetValue The value to get weight for - * @return The weight for the value - */ - template - static consteval uint16_t GetWeight() { - return GetWeightForValue(TargetValue); - } - - /** - * @brief Get weights array for all registered values - * @return Array of weights corresponding to all registered values - */ - static consteval std::array GetWeightsArray() { - std::array weights{}; - - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i)); - } - - return weights; - } - - static consteval uint32_t GetTotalWeight() { - uint32_t totalWeight = 0; - auto weights = GetWeightsArray(); - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - totalWeight += weights[i]; - } - return totalWeight; - } - - static consteval std::array GetCumulativeWeightsArray() { - auto weights = GetWeightsArray(); - uint32_t totalWeight = 0; - std::array cumulativeWeights{}; - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - totalWeight += weights[i]; - cumulativeWeights[i] = totalWeight; - } - return cumulativeWeights; - } - }; - - /** - * @brief Weighted random selector that uses another random selector as backend - * @tparam VarT The variable type - * @tparam VariableIDMapT The variable ID map type - * @tparam BackendSelectorT The backend random selector type - * @tparam WeightsMapT The weights map type containing weight specifications - */ - template - class WeightedSelector { - private: - BackendSelectorT m_backendSelector; - const std::array m_weights; - const std::array m_cumulativeWeights; - - public: - explicit WeightedSelector(BackendSelectorT backendSelector) - : m_backendSelector(backendSelector) - , m_weights(WeightsMapT::GetWeightsArray()) - , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) - {} - - explicit WeightedSelector(uint32_t seed) - requires std::is_same_v> - : m_backendSelector(seed) - , m_weights(WeightsMapT::GetWeightsArray()) - , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) - {} - - size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value"); - - // Use backend selector to pick a random number in range [0, totalWeight) - uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back()); - - // Find which value this random value corresponds to - for (size_t i = 0; i < possibleValues.size(); ++i) { - if (randomValue <= m_cumulativeWeights[i]) { - return i; - } - } - - // Fallback (should not reach here) - return possibleValues.size() - 1; - } - }; - -/** - * @brief Builder class for creating WFC instances - */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector> -class Builder { -public: - - template - using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; - - template - requires ConstrainerFunction - using DefineConstrainer = Builder, - decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) - >, CallbacksT, RandomSelectorT - >; - - template - using SetCellCollapsedCallback = Builder, RandomSelectorT>; - template - using SetContradictionCallback = Builder, RandomSelectorT>; - template - using SetBranchCallback = Builder, RandomSelectorT>; - - template - requires RandomSelectorFunction - using SetRandomSelector = Builder; - - template - using Weights = Builder>>; - - - using Build = WFC; -}; - } // namespace WFC diff --git a/include/nd-wfc/wfc_allocator.hpp b/include/nd-wfc/wfc_allocator.hpp index d3ee329..7fcd187 100644 --- a/include/nd-wfc/wfc_allocator.hpp +++ b/include/nd-wfc/wfc_allocator.hpp @@ -10,6 +10,8 @@ #include #include +#include "wfc_utils.hpp" + #define WFC_USE_STACK_ALLOCATOR inline void* allocate_aligned_memory(size_t alignment, size_t size) { diff --git a/include/nd-wfc/wfc_bit_container.hpp b/include/nd-wfc/wfc_bit_container.hpp new file mode 100644 index 0000000..60b322f --- /dev/null +++ b/include/nd-wfc/wfc_bit_container.hpp @@ -0,0 +1,172 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "wfc_utils.hpp" +#include "wfc_allocator.hpp" + +namespace WFC { + +namespace detail { + // Helper to determine the optimal storage type based on bits needed + template + struct OptimalStorageType { + static constexpr size_t bits_needed = Bits == 0 ? 0 : + (Bits <= 1) ? 1 : + (Bits <= 2) ? 2 : + (Bits <= 4) ? 4 : + (Bits <= 8) ? 8 : + (Bits <= 16) ? 16 : + (Bits <= 32) ? 32 : + (Bits <= 64) ? 64 : + ((Bits + 63) / 64) * 64; // Round up to multiple of 64 for >64 bits + + using type = std::conditional_t>>; + }; + + // Helper for multi-element storage (>64 bits) + template + struct StorageArray { + static constexpr size_t StorageBits = OptimalStorageType::bits_needed; + static constexpr size_t ArraySize = StorageBits > 64 ? (StorageBits / 64) : 1; + using element_type = std::conditional_t::type, uint64_t>; + using type = std::conditional_t>; + }; + + struct Empty{}; +} + +template::type>> +class BitContainer : private AllocatorT{ +public: + using StorageInfo = detail::OptimalStorageType; + using StorageArrayInfo = detail::StorageArray; + using StorageType = typename StorageArrayInfo::type; + using AllocatorType = AllocatorT; + + static constexpr size_t BitsPerElement = Bits; + static constexpr size_t StorageBits = StorageInfo::bits_needed; + static constexpr bool IsResizable = (Size == 0); + static constexpr bool IsMultiElement = (StorageBits > 64); + static constexpr bool IsSubByte = (StorageBits < 8); + static constexpr bool IsDefaultByteLayout = !IsMultiElement && !IsSubByte; + static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max(1u, StorageBits); + + using ContainerType = + std::conditional_t, + std::array>>; + +private: + ContainerType m_container; + +private: + // Mask for extracting bits + static constexpr auto get_Mask() + { + if constexpr (BitsPerElement == 0) + { + return uint64_t{0}; + } + else if constexpr (BitsPerElement >= 64) + { + return ~uint64_t{0}; + } + else + { + return (uint64_t{1} << BitsPerElement) - 1; + } + } + + static constexpr uint64_t Mask = get_Mask(); + +public: + + BitContainer() = default; + BitContainer(AllocatorT& allocator) : AllocatorT(allocator) {}; + explicit BitContainer(size_t initial_size, AllocatorT& allocator) requires (IsResizable) + : AllocatorT(allocator) + , m_container(initial_size, allocator) + {}; + explicit BitContainer(size_t, AllocatorT& allocator) requires (!IsResizable) + : AllocatorT(allocator) + , m_container() + {}; + +public: + // Size operations + constexpr size_t size() const noexcept + { + if constexpr (IsResizable) + { + return m_container.size(); + } + else + { + return Size; + } + } + constexpr std::span data() const { return std::span(m_container); } + constexpr std::span data() { return std::span(m_container); } + + constexpr void resize(size_t new_size) requires (IsResizable) { m_container.resize(new_size); } + constexpr void reserve(size_t capacity) requires (IsResizable) { m_container.reserve(capacity); } + +public: // Sub byte + struct SubTypeAccess + { + constexpr SubTypeAccess(uint8_t& data, uint8_t subIndex) : Data{ data }, Shift{ StorageBits * subIndex } {}; + + constexpr uint8_t GetValue() const { return ((Data >> Shift) & Mask); } + constexpr uint8_t SetValue(uint8_t val) { Clear(); return Data |= ((val & Mask) << Shift); } + constexpr void Clear() { Data &= ~Mask; } + + constexpr operator uint8_t() const { return GetValue(); } + + template constexpr uint8_t operator&=(T other) { return SetValue(GetValue() & other); } + template constexpr uint8_t operator|=(T other) { return SetValue(GetValue() | other); } + template constexpr uint8_t operator^=(T other) { return SetValue(GetValue() ^ other); } + template constexpr uint8_t operator<<=(T other) { return SetValue(GetValue() << other); } + template constexpr uint8_t operator>>=(T other) { return SetValue(GetValue() >> other); } + + uint8_t& Data; + uint8_t Shift; + }; + + constexpr const SubTypeAccess operator[](size_t index) const requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } + constexpr SubTypeAccess operator[](size_t index) requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } + +public: // MultiElement + struct MultiElementAccess + { + constexpr MultiElementAccess(StorageType& data) : Data{ data } {}; + + StorageType& Data; + }; + + constexpr const MultiElementAccess operator[](size_t index) const requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } + constexpr MultiElementAccess operator[](size_t index) requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } + +public: // default + constexpr const StorageType& operator[](size_t index) const requires(IsDefaultByteLayout) { return data()[index]; } + constexpr StorageType& operator[](size_t index) requires(IsDefaultByteLayout) { return data()[index]; } + +}; + +static_assert(BitContainer<1, 10>::ElementsPerByte == 8); +static_assert(BitContainer<2, 10>::ElementsPerByte == 4); +static_assert(BitContainer<4, 10>::ElementsPerByte == 2); +static_assert(BitContainer<8, 10>::ElementsPerByte == 1); + + +} // namespace WFC diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp new file mode 100644 index 0000000..b30697f --- /dev/null +++ b/include/nd-wfc/wfc_builder.hpp @@ -0,0 +1,52 @@ +#pragma once + +namespace WFC { + +#include "wfc_utils.hpp" +#include "wfc_variable_map.hpp" +#include "wfc_constrainer.hpp" +#include "wfc_callbacks.hpp" +#include "wfc_random.hpp" +#include "wfc.hpp" + +/** +* @brief Builder class for creating WFC instances +*/ +template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector> +class Builder { +public: + + template + using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; + + template + requires ConstrainerFunction + using DefineConstrainer = Builder, + decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) + >, CallbacksT, RandomSelectorT + >; + + template + using SetCellCollapsedCallback = Builder, RandomSelectorT>; + template + using SetContradictionCallback = Builder, RandomSelectorT>; + template + using SetBranchCallback = Builder, RandomSelectorT>; + + template + requires RandomSelectorFunction + using SetRandomSelector = Builder; + + template + using Weights = Builder>>; + + + using Build = WFC; +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_callbacks.hpp b/include/nd-wfc/wfc_callbacks.hpp new file mode 100644 index 0000000..f3a02ce --- /dev/null +++ b/include/nd-wfc/wfc_callbacks.hpp @@ -0,0 +1,44 @@ +#pragma once + +namespace WFC { + +/** +* @brief Empty callback function +* @param WorldT The world type +*/ +template +using EmptyCallback = decltype([](WorldT&){}); + +/** + * @brief Callback struct + * @param WorldT The world type + * @param AllCellsCollapsedCallbackT The all cells collapsed callback type + * @param CellCollapsedCallbackT The cell collapsed callback type + * @param ContradictionCallbackT The contradiction callback type + * @param BranchCallbackT The branch callback type + */ + template , + typename ContradictionCallbackT = EmptyCallback, + typename BranchCallbackT = EmptyCallback +> +struct Callbacks +{ + using CellCollapsedCallback = CellCollapsedCallbackT; + using ContradictionCallback = ContradictionCallbackT; + using BranchCallback = BranchCallbackT; + + template + using SetCellCollapsedCallbackT = Callbacks; + template + using SetContradictionCallbackT = Callbacks; + template + using SetBranchCallbackT = Callbacks; + + static consteval bool HasCellCollapsedCallback() { return !std::is_same_v>; } + static consteval bool HasContradictionCallback() { return !std::is_same_v>; } + static consteval bool HasBranchCallback() { return !std::is_same_v>; } +}; + + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_constrainer.hpp b/include/nd-wfc/wfc_constrainer.hpp new file mode 100644 index 0000000..100ae59 --- /dev/null +++ b/include/nd-wfc/wfc_constrainer.hpp @@ -0,0 +1,122 @@ +#pragma once + +#include "wfc_variable_map.hpp" + +namespace WFC { + +template +struct ConstrainerFunctionMap { +public: + static consteval size_t size() { return sizeof...(ConstrainerFunctions); } + + using TupleType = std::tuple; + + template + static ConstrainerFunctionPtrT GetFunction(size_t index) + { + static_assert((std::is_empty_v && ...), "Lambdas must not have any captures"); + static ConstrainerFunctionPtrT functions[] = { + static_cast(ConstrainerFunctions{}) ... + }; + return functions[index]; + } +}; + +// Helper to select the correct constrainer function based on the index and the value +template +using MergedConstrainerElementSelector = + std::conditional_t(), // if the value is in the selected IDs + NewConstrainerFunctionT, + std::conditional_t<(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple + std::tuple_element_t, + EmptyFunctionT + > + >; + +// Helper to make a merged constrainer function map +template +auto MakeMergedConstrainerIDMap(std::index_sequence,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*) + -> ConstrainerFunctionMap...>; + +// Main alias for the merged constrainer function map +template +using MergedConstrainerFunctionMap = decltype( + MakeMergedConstrainerIDMap(std::make_index_sequence{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr) +); + +/** + * @brief Constrainer class used in constraint functions to limit possible values for other cells + */ +template +class Constrainer { +public: + using MaskType = typename VariableIDMapT::MaskType; + +public: + Constrainer(Wave& wave, WFCQueue& propagationQueue) + : m_wave(wave) + , m_propagationQueue(propagationQueue) + {} + + /** + * @brief Constrain a cell to exclude specific values + * @param cellId The ID of the cell to constrain + * @param forbiddenValues The set of forbidden values for this cell + */ + template + void Exclude(size_t cellId) { + static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided"); + ApplyMask(cellId, ~VariableIDMapT::template GetMask()); + } + + void Exclude(WorldValue value, size_t cellId) { + ApplyMask(cellId, ~(1 << value.InternalIndex)); + } + + /** + * @brief Constrain a cell to only allow one specific value + * @param cellId The ID of the cell to constrain + * @param value The only allowed value for this cell + */ + template + void Only(size_t cellId) { + static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided"); + ApplyMask(cellId, VariableIDMapT::template GetMask()); + } + + void Only(WorldValue value, size_t cellId) { + ApplyMask(cellId, 1 << value.InternalIndex); + } + +private: + void ApplyMask(size_t cellId, MaskType mask) { + bool wasCollapsed = m_wave.IsCollapsed(cellId); + + m_wave.Collapse(cellId, mask); + + bool collapsed = m_wave.IsCollapsed(cellId); + if (!wasCollapsed && collapsed) { + m_propagationQueue.push(cellId); + } + } + +private: + Wave& m_wave; + WFCQueue& m_propagationQueue; +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_large_integers.hpp b/include/nd-wfc/wfc_large_integers.hpp new file mode 100644 index 0000000..eca0afa --- /dev/null +++ b/include/nd-wfc/wfc_large_integers.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace WFC { + +template +struct LargeInteger +{ + std::array m_data; + + template + constexpr LargeInteger operator+(const LargeInteger& other) const { + LargeInteger result; + for (size_t i = 0; i < std::max(Size, OtherSize); i++) { + result[i] = m_data[i] + other[i]; + } + return result; + } +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_random.hpp b/include/nd-wfc/wfc_random.hpp new file mode 100644 index 0000000..92cada7 --- /dev/null +++ b/include/nd-wfc/wfc_random.hpp @@ -0,0 +1,178 @@ +#pragma once + +namespace WFC { + +/** +* @brief Default constexpr random selector using a simple seed-based algorithm +* This provides a compile-time random selection that maintains state between calls +*/ +template +class DefaultRandomSelector { +private: + mutable uint32_t m_seed; + +public: + constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {} + + constexpr size_t operator()(std::span possibleValues) const { + constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); + + // Simple linear congruential generator for constexpr compatibility + return static_cast(rng(possibleValues.size())); + } + + constexpr uint32_t rng(uint32_t max) { + m_seed = m_seed * 1103515245 + 12345; + return m_seed % max; + } +}; + +/** +* @brief Advanced random selector using std::mt19937 and std::uniform_int_distribution +* This provides high-quality randomization for runtime use +*/ +template +class AdvancedRandomSelector { +private: + std::mt19937& m_rng; + +public: + explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {} + + size_t operator()(std::span possibleValues) const { + constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); + + return rng(possibleValues.size()); + } + + uint32_t rng(uint32_t max) { + std::uniform_int_distribution dist(0, max); + return dist(m_rng); + } +}; + +/** +* @brief Weight specification for a specific value +* @tparam Value The variable value +* @tparam Weight The 16-bit weight for this value +*/ +template +struct Weight { + static constexpr VarT value = Value; + static constexpr uint16_t weight = WeightValue; +}; + +/** +* @brief Compile-time weights storage for weighted random selection +* @tparam VarT The variable type +* @tparam VariableIDMapT The variable ID map type +* @tparam DefaultWeight The default weight for values not explicitly specified +* @tparam WeightSpecs Variadic template parameters of Weight specifications +*/ +template +class WeightsMap { +private: + static constexpr size_t NumWeights = sizeof...(WeightSpecs); + + // Helper to get weight for a specific value + static consteval uint16_t GetWeightForValue(VarT targetValue) { + // Check each weight spec to find the target value + uint16_t weight = DefaultWeight; + ((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...); + return weight; + } + +public: + /** + * @brief Get the weight for a specific value at compile time + * @tparam TargetValue The value to get weight for + * @return The weight for the value + */ + template + static consteval uint16_t GetWeight() { + return GetWeightForValue(TargetValue); + } + + /** + * @brief Get weights array for all registered values + * @return Array of weights corresponding to all registered values + */ + static consteval std::array GetWeightsArray() { + std::array weights{}; + + for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { + weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i)); + } + + return weights; + } + + static consteval uint32_t GetTotalWeight() { + uint32_t totalWeight = 0; + auto weights = GetWeightsArray(); + for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { + totalWeight += weights[i]; + } + return totalWeight; + } + + static consteval std::array GetCumulativeWeightsArray() { + auto weights = GetWeightsArray(); + uint32_t totalWeight = 0; + std::array cumulativeWeights{}; + for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { + totalWeight += weights[i]; + cumulativeWeights[i] = totalWeight; + } + return cumulativeWeights; + } +}; + +/** +* @brief Weighted random selector that uses another random selector as backend +* @tparam VarT The variable type +* @tparam VariableIDMapT The variable ID map type +* @tparam BackendSelectorT The backend random selector type +* @tparam WeightsMapT The weights map type containing weight specifications +*/ +template +class WeightedSelector { +private: + BackendSelectorT m_backendSelector; + const std::array m_weights; + const std::array m_cumulativeWeights; + +public: + explicit WeightedSelector(BackendSelectorT backendSelector) + : m_backendSelector(backendSelector) + , m_weights(WeightsMapT::GetWeightsArray()) + , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) + {} + + explicit WeightedSelector(uint32_t seed) + requires std::is_same_v> + : m_backendSelector(seed) + , m_weights(WeightsMapT::GetWeightsArray()) + , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) + {} + + size_t operator()(std::span possibleValues) const { + constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); + constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value"); + + // Use backend selector to pick a random number in range [0, totalWeight) + uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back()); + + // Find which value this random value corresponds to + for (size_t i = 0; i < possibleValues.size(); ++i) { + if (randomValue <= m_cumulativeWeights[i]) { + return i; + } + } + + // Fallback (should not reach here) + return possibleValues.size() - 1; + } +}; + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_utils.hpp b/include/nd-wfc/wfc_utils.hpp new file mode 100644 index 0000000..191020c --- /dev/null +++ b/include/nd-wfc/wfc_utils.hpp @@ -0,0 +1,27 @@ +#pragma once + +namespace WFC +{ + + + + + +inline constexpr void constexpr_assert(bool condition, const char* message = "") { + if (!condition) throw message; +} + +inline int FindNthSetBit(size_t num, int n) { + constexpr_assert(n < std::popcount(num), "index is out of range"); + int bitCount = 0; + while (num) { + if (bitCount == n) { + return std::countr_zero(num); // Index of the current set bit + } + bitCount++; + num &= (num - 1); // turn of lowest set bit + } + return bitCount; +} + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_variable_map.hpp b/include/nd-wfc/wfc_variable_map.hpp new file mode 100644 index 0000000..ab998c0 --- /dev/null +++ b/include/nd-wfc/wfc_variable_map.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include "wfc_utils.hpp" + +namespace WFC { + +/** +* @brief Class to map variable values to indices at compile time +* +* This class is used to map variable values to indices at compile time. +* It is a compile-time map of variable values to indices. +*/ +template +class VariableIDMap { +public: + + using Type = VarT; + static constexpr size_t ValuesRegisteredAmount = sizeof...(Values); + + template + using Merge = VariableIDMap; + + template + static consteval bool HasValue() + { + constexpr VarT arr[] = {Values...}; + constexpr size_t size = sizeof...(Values); + + for (size_t i = 0; i < size; ++i) + if (arr[i] == Value) + return true; + return false; + } + + template + static consteval size_t GetIndex() + { + static_assert(HasValue(), "Value was not defined"); + constexpr VarT arr[] = {Values...}; + constexpr size_t size = ValuesRegisteredAmount; + + for (size_t i = 0; i < size; ++i) + if (arr[i] == Value) + return i; + + return static_cast(-1); // This line is unreachable if value is found + } + + static std::span GetAllValues() + { + static const VarT allValues[] + { + Values... + }; + return std::span{ allValues, ValuesRegisteredAmount }; + } + + static VarT GetValue(size_t index) { + constexpr_assert(index < ValuesRegisteredAmount); + return GetAllValues()[index]; + } + + static consteval VarT GetValueConsteval(size_t index) + { + constexpr VarT arr[] = {Values...}; + return arr[index]; + } + + static consteval size_t size() { return ValuesRegisteredAmount; } +}; + + +} \ No newline at end of file diff --git a/include/nd-wfc/wfc_wave.hpp b/include/nd-wfc/wfc_wave.hpp new file mode 100644 index 0000000..7456244 --- /dev/null +++ b/include/nd-wfc/wfc_wave.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "wfc_bit_container.hpp" +#include "wfc_variable_map.hpp" +#include "wfc_allocator.hpp" + +namespace WFC { + +template +class Wave { +public: + using BitContainerT = BitContainer; + using ElementT = typename BitContainerT::StorageType; + +public: + Wave() = default; + Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter(allocator)) + { + for (auto& wave : m_data) wave = (1 << variableAmount) - 1; + } + + Wave(const Wave& other) = default; + +public: + void Collapse(size_t index, ElementT mask) { m_data[index] &= mask; } + size_t size() const { return m_data.size(); } + size_t Entropy(size_t index) const { return std::popcount(m_data[index]); } + bool IsCollapsed(size_t index) const { return Entropy(index) == 1; } + bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](ElementT value) { return std::popcount(value) == 1; }); } + bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](ElementT value) { return value == 0; }); } + bool IsContradicted(size_t index) const { return m_data[index] == 0; } + uint16_t GetVariableID(size_t index) const { return static_cast(std::countr_zero(m_data[index])); } + ElementT GetMask(size_t index) const { return m_data[index]; } + +private: + BitContainerT m_data; +}; + +} \ No newline at end of file