diff --git a/demos/nonogram/nonogram.h b/demos/nonogram/nonogram.h index aca9d47..3b6111b 100644 --- a/demos/nonogram/nonogram.h +++ b/demos/nonogram/nonogram.h @@ -8,6 +8,8 @@ #include #include +#include + // Forward declarations struct NonogramHints; struct NonogramSolution; @@ -178,3 +180,7 @@ std::string trim(const std::string& str); bool startsWith(const std::string& str, const std::string& prefix); std::vector split(const std::string& str, char delimiter); std::vector parseNumberSequence(const std::string& str, char delimiter); + +using NonogramWFC = WFC::Builder + ::Define + ::Build; \ No newline at end of file diff --git a/demos/sudoku/sudoku.h b/demos/sudoku/sudoku.h index 305cff1..b54ec1d 100644 --- a/demos/sudoku/sudoku.h +++ b/demos/sudoku/sudoku.h @@ -279,15 +279,15 @@ private: public: // WFC Support using ValueType = uint8_t; - constexpr inline ValueType getValue(size_t index) const { + constexpr inline ValueType getValue(uint8_t index) const { return board_.get(static_cast(index)); } - constexpr inline void setValue(size_t index, ValueType value) { + constexpr inline void setValue(uint8_t index, ValueType value) { board_.set(static_cast(index), value); } - constexpr inline size_t size() const { + constexpr inline uint8_t size() const { return 81; } }; diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index 3f230f8..54f477b 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -22,60 +23,50 @@ #include "wfc_callbacks.hpp" #include "wfc_random.hpp" #include "wfc_queue.hpp" +#include "wfc_weights.hpp" namespace WFC { template -concept WorldType = requires(T world, size_t id, typename T::ValueType value) { - { world.size() } -> std::convertible_to; - { world.setValue(id, value) }; - { world.getValue(id) } -> std::convertible_to; +concept WorldType = requires(T world, typename T::ValueType value) { + { world.size() } -> std::is_integral; + { world.setValue(static_cast(0), value) }; + { world.getValue(static_cast(0)) } -> std::convertible_to; typename T::ValueType; }; -/** -* @brief Concept to validate constrainer function signature -* The function must be callable with parameters: (WorldT&, size_t, WorldValue, Constrainer&) -*/ template concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue value, Constrainer& constrainer) { func(world, index, value, constrainer); }; -/** -* @brief Concept to validate random selector function signature -* The function must be callable with parameters: (std::span) and return size_t -*/ -template -concept RandomSelectorFunction = requires(const T& func, std::span possibleValues) { - { func(possibleValues) } -> std::convertible_to; - { func.rng(static_cast(1)) } -> std::convertible_to; -}; - template concept HasConstexprSize = requires { { []() constexpr -> std::size_t { return WorldT{}.size(); }() }; }; -/** - * @brief Main WFC class implementing the Wave Function Collapse algorithm - */ template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, - typename RandomSelectorT = DefaultRandomSelector> + typename RandomSelectorT = DefaultRandomSelector, + typename WeightsMapT = WeightsMap +> class WFC { public: static_assert(WorldType, "WorldT must satisfy World type requirements"); - // Try getting the world size, which is only available if the world type has a constexpr size() method - constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + using WorldSizeT = decltype(WorldT{}.size()); - using WaveType = Wave; - using PropagationQueueType = WFCQueue; + // Try getting the world size, which is only available if the world type has a constexpr size() method + constexpr static WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + + using WaveType = Wave; + using PropagationQueueType = WFCQueue; using ConstrainerType = Constrainer; using MaskType = typename WaveType::ElementT; + using VariableIDT = typename WaveType::VariableIDT; + using WeightsBufferType = BitContainer; public: struct SolverState @@ -88,7 +79,7 @@ public: SolverState(WorldT& world, uint32_t seed) : m_world(world) - , m_propagationQueue{ WorldSize ? WorldSize : static_cast(world.size()) } + , m_propagationQueue{ WorldSize ? WorldSize : static_cast(world.size()) } , m_randomSelector(seed) {} @@ -114,7 +105,7 @@ public: */ static bool Run(SolverState& state) { - WaveType wave{ WorldSize, VariableIDMapT::ValuesRegisteredAmount, state.m_allocator }; + WaveType wave{ WorldSize, VariableIDMapT::size(), state.m_allocator }; PropogateInitialValues(state, wave); @@ -188,7 +179,7 @@ public: } private: - static void CollapseCell(SolverState& state, WaveType& wave, size_t cellId, uint16_t value) + static void CollapseCell(SolverState& state, WaveType& wave, WorldSizeT cellId, VariableIDT value) { constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value)); wave.Collapse(cellId, 1 << value); @@ -201,51 +192,111 @@ private: } } - static bool Branch(SolverState& state, WaveType& wave) + static WorldSizeT FindMinimumEntropyCells(std::span& buffer, WaveType& wave) { - constexpr_assert(state.m_propagationQueue.empty()); - - // Find cell with minimum entropy > 1 - size_t minEntropyCell = static_cast(-1); - size_t minEntropy = static_cast(-1); - - for (size_t i = 0; i < wave.size(); ++i) { - size_t entropy = wave.Entropy(i); - if (entropy > 1 && entropy < minEntropy) { - minEntropy = entropy; - minEntropyCell = i; - } - } - if (minEntropyCell == static_cast(-1)) return false; + auto entropyGetter = [&wave](size_t index) -> size_t { return wave.Entropy(index); }; + auto entropyFilter = [&wave](size_t entropy) -> bool { return entropy > 1; }; + auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(entropyGetter) | std::views::filter(entropyFilter)); constexpr_assert(!wave.IsCollapsed(minEntropyCell)); // create a list of possible values - uint16_t availableValues = static_cast(wave.Entropy(minEntropyCell)); - std::array possibleValues; // inplace vector + VariableIDT availableValues = wave.Entropy(minEntropyCell); MaskType mask = wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { - uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit - constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds"); + VariableIDT index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit + constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds"); - possibleValues[i] = index; + buffer[i] = index; constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); - + mask = mask & (mask - 1); // turn off lowest set bit } - // randomly select a value from possible values - while (availableValues) + return minEntropyCell; + } + + using WeightsType = MinimumIntegerType; + using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast(1))); + + static WorldSizeT FindMinimumEntropyCellsWeighted(std::span& buffer, std::span& weights, WaveType& wave) + { + constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits::max()) / VariableIDMapT::size(); + + auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t + { + auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; }; + auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight(index); }; + + auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter); + return std::accumulate(view.begin(), view.end(), 0); + }; + + auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter)); + + VariableIDT availableValues = wave.Entropy(minEntropyCell); + MaskType mask = wave.GetMask(minEntropyCell); + for (size_t i = 0; i < availableValues; ++i) { - // Create a span of the actual variable values for the random selector - std::array valueArray; - for (size_t i = 0; i < availableValues; ++i) { - valueArray[i] = VariableIDMapT::GetValue(possibleValues[i]); + VariableIDT index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit + constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds"); + + buffer[i] = index; + constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); + + weights[i] = wave.template GetWeight(index); + + mask = mask & (mask - 1); // turn off lowest set bit + } + + return minEntropyCell; + } + + static bool Branch(SolverState& state, WaveType& wave) + { + constexpr_assert(state.m_propagationQueue.empty()); + + std::array Buffer{}; + std::array WeightsBuffer{}; + uint64_t accumulatedWeights = 0; + WorldSizeT minEntropyCell{}; + + if constexpr (WeightsMapT::HasWeights()) + { + minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave); + accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0); + } + else + { + minEntropyCell = FindMinimumEntropyCells(Buffer, wave); + } + + // randomly select a value from possible values + while (Buffer.size()) + { + size_t randomIndex; + VariableIDT selectedValue; + if constexpr (WeightsMapT::HasWeights()) + { + auto randomWeight = state.m_randomSelector.rng(accumulatedWeights); + for (size_t i = 0; i < WeightsBuffer.size(); ++i) + { + if (randomWeight < WeightsBuffer[i]) + { + randomIndex = i; + selectedValue = Buffer[i]; + break; + } + randomWeight -= WeightsBuffer[i]; + } + accumulatedWeights -= WeightsBuffer[randomIndex]; + } + else + { + randomIndex = state.m_randomSelector.rng(Buffer.size()); + selectedValue = Buffer[randomIndex]; } - std::span currentPossibleValues(valueArray.data(), availableValues); - size_t randomIndex = state.m_randomSelector(currentPossibleValues); - size_t selectedValue = possibleValues[randomIndex]; { // copy the state and branch out @@ -253,12 +304,12 @@ private: auto queueFrame = state.m_propagationQueue.createBranchPoint(); auto newWave = wave; - CollapseCell(state, newWave, minEntropyCell, static_cast(selectedValue)); + CollapseCell(state, newWave, minEntropyCell, selectedValue); state.m_propagationQueue.push(minEntropyCell); if (RunLoop(state, newWave)) { - // copy the solution to the original state + // move the solution to the original state wave = newWave; return true; @@ -271,7 +322,8 @@ private: constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); // swap replacement value with the last value - std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); + std::swap(Buffer[randomIndex], Buffer[Buffer.size() - 1]); + Buffer = Buffer.subspan(0, Buffer.size() - 1); } return false; @@ -281,16 +333,16 @@ private: { while (!state.m_propagationQueue.empty()) { - size_t cellId = state.m_propagationQueue.pop(); + WorldSizeT cellId = state.m_propagationQueue.pop(); if (wave.IsContradicted(cellId)) return false; constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed"); - uint16_t variableID = wave.GetVariableID(cellId); + VariableIDT variableID = wave.GetVariableID(cellId); ConstrainerType constrainer(wave, state.m_propagationQueue); - using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, ConstrainerType&); + using ConstrainerFunctionPtrT = void(*)(WorldT&, WorldSizeT, WorldValue, ConstrainerType&); ConstrainerFunctionMapT::template GetFunction(variableID)(state.m_world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); } @@ -314,7 +366,7 @@ private: { if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j)) { - CollapseCell(state, wave, static_cast(i), static_cast(j)); + CollapseCell(state, wave, static_cast(i), static_cast(j)); state.m_propagationQueue.push(i); break; } diff --git a/include/nd-wfc/wfc_bit_container.hpp b/include/nd-wfc/wfc_bit_container.hpp index 509feb4..1ba5df6 100644 --- a/include/nd-wfc/wfc_bit_container.hpp +++ b/include/nd-wfc/wfc_bit_container.hpp @@ -60,6 +60,7 @@ public: static constexpr bool IsMultiElement = (StorageBits > 64); static constexpr bool IsSubByte = (StorageBits < 8); static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max(1u, StorageBits); + static constexpr size_t MaxValue = (StorageType{1} << BitsPerElement) - 1; using ContainerType = std::conditional_t, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector, + typename WeightsMapT = WeightsMap, typename SelectedValueT = void> class Builder { public: - constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + using WorldSizeT = decltype(WorldT{}.size()); + constexpr static WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; - using WaveType = Wave; - using PropagationQueueType = WFCQueue; + using WaveType = Wave; + using PropagationQueueType = WFCQueue; using ConstrainerType = Constrainer; template - using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap>; + using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap>; template - using DefineRange = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange>; + using DefineRange = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange>; template - using DefineRange0 = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange>; + using DefineRange0 = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange>; template - using Variable = Builder>; + using Variable = Builder>; template - using VariableRange = Builder>; + using VariableRange = Builder>; + using EmptyConstrainerFunctionT = EmptyConstrainerFunction; + template requires ConstrainerFunction using Constrain = Builder, ConstrainerType&) {}) - >, CallbacksT, RandomSelectorT, SelectedValueT + EmptyConstrainerFunctionT + >, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT >; template @@ -66,27 +71,28 @@ public: ConstrainerFunctionMapT, ConstrainerFunctionT, VariableIDMapT, - decltype([](WorldT&, size_t, WorldValue, ConstrainerType&) {}) - >, CallbacksT, RandomSelectorT + EmptyConstrainerFunctionT + >, CallbacksT, RandomSelectorT, WeightsMapT >; template - using SetCellCollapsedCallback = Builder, RandomSelectorT>; + using SetCellCollapsedCallback = Builder, RandomSelectorT, WeightsMapT>; template - using SetContradictionCallback = Builder, RandomSelectorT>; + using SetContradictionCallback = Builder, RandomSelectorT, WeightsMapT>; template - using SetBranchCallback = Builder, RandomSelectorT>; + using SetBranchCallback = Builder, RandomSelectorT, WeightsMapT>; + template - requires RandomSelectorFunction - using SetRandomSelector = Builder; - - template - using Weights = Builder>>; + using SetRandomSelector = Builder; - using Build = WFC; + template + using SetWeights = Builder(Precision)>>, SelectedValueT>; + + + using Build = WFC; }; } \ No newline at end of file diff --git a/include/nd-wfc/wfc_constrainer.hpp b/include/nd-wfc/wfc_constrainer.hpp index e05f4c0..8c9a737 100644 --- a/include/nd-wfc/wfc_constrainer.hpp +++ b/include/nd-wfc/wfc_constrainer.hpp @@ -5,6 +5,12 @@ namespace WFC { +template +struct EmptyConstrainerFunction +{ + void operator()(WorldT&, WorldSizeT, WorldValue, ConstainerType&) const {} +}; + template struct ConstrainerFunctionMap { public: @@ -56,7 +62,7 @@ template using MergedConstrainerFunctionMap = decltype( - MakeMergedConstrainerIDMap(std::make_index_sequence{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr) + MakeMergedConstrainerIDMap(std::make_index_sequence{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr) ); /** diff --git a/include/nd-wfc/wfc_queue.hpp b/include/nd-wfc/wfc_queue.hpp index b442ddf..d6cb283 100644 --- a/include/nd-wfc/wfc_queue.hpp +++ b/include/nd-wfc/wfc_queue.hpp @@ -7,7 +7,7 @@ #include #include -#include "nd-wfc/wfc_utils.hpp" +#include "wfc_utils.hpp" namespace WFC { diff --git a/include/nd-wfc/wfc_random.hpp b/include/nd-wfc/wfc_random.hpp index 2319c19..5876580 100644 --- a/include/nd-wfc/wfc_random.hpp +++ b/include/nd-wfc/wfc_random.hpp @@ -14,13 +14,6 @@ private: public: constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {} - constexpr size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - - // Simple linear congruential generator for constexpr compatibility - return static_cast(rng(possibleValues.size())); - } - constexpr uint32_t rng(uint32_t max) const { m_seed = m_seed * 1103515245 + 12345; return m_seed % max; @@ -39,140 +32,10 @@ private: public: explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {} - size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - - return rng(possibleValues.size()); - } - uint32_t rng(uint32_t max) const { std::uniform_int_distribution dist(0, max); return dist(m_rng); } }; -/** -* @brief Weight specification for a specific value -* @tparam Value The variable value -* @tparam Weight The 16-bit weight for this value -*/ -template -struct Weight { - static constexpr VarT value = Value; - static constexpr uint16_t weight = WeightValue; -}; - -/** -* @brief Compile-time weights storage for weighted random selection -* @tparam VarT The variable type -* @tparam VariableIDMapT The variable ID map type -* @tparam DefaultWeight The default weight for values not explicitly specified -* @tparam WeightSpecs Variadic template parameters of Weight specifications -*/ -template -class WeightsMap { -private: - static constexpr size_t NumWeights = sizeof...(WeightSpecs); - - // Helper to get weight for a specific value - static consteval uint16_t GetWeightForValue(VarT targetValue) { - // Check each weight spec to find the target value - uint16_t weight = DefaultWeight; - ((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...); - return weight; - } - -public: - /** - * @brief Get the weight for a specific value at compile time - * @tparam TargetValue The value to get weight for - * @return The weight for the value - */ - template - static consteval uint16_t GetWeight() { - return GetWeightForValue(TargetValue); - } - - /** - * @brief Get weights array for all registered values - * @return Array of weights corresponding to all registered values - */ - static consteval std::array GetWeightsArray() { - std::array weights{}; - - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - weights[i] = GetWeightForValue(VariableIDMapT::GetValue(i)); - } - - return weights; - } - - static consteval uint32_t GetTotalWeight() { - uint32_t totalWeight = 0; - auto weights = GetWeightsArray(); - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - totalWeight += weights[i]; - } - return totalWeight; - } - - static consteval std::array GetCumulativeWeightsArray() { - auto weights = GetWeightsArray(); - uint32_t totalWeight = 0; - std::array cumulativeWeights{}; - for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) { - totalWeight += weights[i]; - cumulativeWeights[i] = totalWeight; - } - return cumulativeWeights; - } -}; - -/** -* @brief Weighted random selector that uses another random selector as backend -* @tparam VarT The variable type -* @tparam VariableIDMapT The variable ID map type -* @tparam BackendSelectorT The backend random selector type -* @tparam WeightsMapT The weights map type containing weight specifications -*/ -template -class WeightedSelector { -private: - BackendSelectorT m_backendSelector; - const std::array m_weights; - const std::array m_cumulativeWeights; - -public: - explicit WeightedSelector(BackendSelectorT backendSelector) - : m_backendSelector(backendSelector) - , m_weights(WeightsMapT::GetWeightsArray()) - , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) - {} - - explicit WeightedSelector(uint32_t seed) - requires std::is_same_v> - : m_backendSelector(seed) - , m_weights(WeightsMapT::GetWeightsArray()) - , m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray()) - {} - - size_t operator()(std::span possibleValues) const { - constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty"); - constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value"); - - // Use backend selector to pick a random number in range [0, totalWeight) - uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back()); - - // Find which value this random value corresponds to - for (size_t i = 0; i < possibleValues.size(); ++i) { - if (randomValue <= m_cumulativeWeights[i]) { - return i; - } - } - - // Fallback (should not reach here) - return possibleValues.size() - 1; - } -}; - } \ No newline at end of file diff --git a/include/nd-wfc/wfc_utils.hpp b/include/nd-wfc/wfc_utils.hpp index 26a955d..e9cd6ce 100644 --- a/include/nd-wfc/wfc_utils.hpp +++ b/include/nd-wfc/wfc_utils.hpp @@ -6,10 +6,32 @@ namespace WFC - -inline constexpr void constexpr_assert(bool condition, const char* message = "") { +# ifdef _DEBUG +inline constexpr void constexpr_assert(bool condition, const char* message = "") +{ if (!condition) throw message; } +#else +inline constexpr void constexpr_assert(bool condition, const char* message = "") +{ + (void)condition; + (void)message; +} +# endif + +template +using MinimumIntegerType = std::conditional_t::max(), uint8_t, + std::conditional_t::max(), uint16_t, + std::conditional_t::max(), uint32_t, + uint64_t>>>; + +template +using MinimumBitsType = std::conditional_t>>>; + inline int FindNthSetBit(size_t num, int n) { constexpr_assert(n < std::popcount(num), "index is out of range"); diff --git a/include/nd-wfc/wfc_variable_map.hpp b/include/nd-wfc/wfc_variable_map.hpp index 367fa5f..cb333fd 100644 --- a/include/nd-wfc/wfc_variable_map.hpp +++ b/include/nd-wfc/wfc_variable_map.hpp @@ -12,23 +12,25 @@ namespace WFC { * It is a compile-time map of variable values to indices. */ +template +using VariableIDType = std::conditional_t::max(), uint8_t, uint16_t>; + + template class VariableIDMap { public: - using Type = VarT; - static constexpr size_t ValuesRegisteredAmount = sizeof...(Values); - template using Merge = VariableIDMap; + using VariableIDT = VariableIDType; + template static consteval bool HasValue() { constexpr VarT arr[] = {Values...}; - constexpr size_t size = sizeof...(Values); - for (size_t i = 0; i < size; ++i) + for (size_t i = 0; i < size(); ++i) if (arr[i] == Value) return true; return false; @@ -39,9 +41,8 @@ public: { static_assert(HasValue(), "Value was not defined"); constexpr VarT arr[] = {Values...}; - constexpr size_t size = ValuesRegisteredAmount; - for (size_t i = 0; i < size; ++i) + for (size_t i = 0; i < size(); ++i) if (arr[i] == Value) return i; @@ -54,15 +55,15 @@ public: { Values... }; - return std::span{ allValues, ValuesRegisteredAmount }; + return std::span{ allValues, size() }; } static constexpr VarT GetValue(size_t index) { - constexpr_assert(index < ValuesRegisteredAmount); + constexpr_assert(index < size()); return GetAllValues()[index]; } - static consteval size_t size() { return ValuesRegisteredAmount; } + static consteval size_t size() { return sizeof...(Values); } template static constexpr auto ValuesToIndices() -> std::array { @@ -76,12 +77,13 @@ class VariableIDRange { public: using Type = VarT; + using VariableIDT = VariableIDType; static_assert(Start < End, "Start must be less than End"); static_assert(std::numeric_limits::min() <= Start, "VarT must be able to represent all values in the range"); static_assert(std::numeric_limits::max() >= End, "VarT must be able to represent all values in the range"); - static constexpr size_t ValuesRegisteredAmount = End - Start; + static constexpr size_t size() { return End - Start; } template static consteval bool HasValue() @@ -100,8 +102,6 @@ public: return Start + index; } - static consteval size_t size() { return End - Start; } - template static constexpr auto ValuesToIndices() -> std::array { std::array indices = {GetIndex()...}; diff --git a/include/nd-wfc/wfc_wave.hpp b/include/nd-wfc/wfc_wave.hpp index 7ace7b2..1816df9 100644 --- a/include/nd-wfc/wfc_wave.hpp +++ b/include/nd-wfc/wfc_wave.hpp @@ -6,12 +6,16 @@ namespace WFC { -template +template class Wave { public: - using BitContainerT = BitContainer; + using BitContainerT = BitContainer; using ElementT = typename BitContainerT::StorageType; using IDMapT = VariableIDMapT; + using WeightContainersT = typename WeightsMapT::template WeightContainersT; + using VariableIDT = typename VariableIDMapT::VariableIDT; + using WeightT = typename WeightsMapT::WeightT; + static constexpr size_t ElementsAmount = Size; public: @@ -34,8 +38,15 @@ public: uint16_t GetVariableID(size_t index) const { return static_cast(std::countr_zero(m_data[index])); } ElementT GetMask(size_t index) const { return m_data[index]; } + void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); } + + template + WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue(containerIndex, elementIndex); } + + private: BitContainerT m_data; + WeightContainersT m_weights; }; } \ No newline at end of file diff --git a/include/nd-wfc/wfc_weights.hpp b/include/nd-wfc/wfc_weights.hpp new file mode 100644 index 0000000..d32f7c1 --- /dev/null +++ b/include/nd-wfc/wfc_weights.hpp @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include + +#include "wfc_bit_container.hpp" +#include "wfc_utils.hpp" + +namespace WFC { + +template +struct PrecisionEntry +{ + + constexpr static uint8_t PrecisionValue = Precision; + + template + constexpr static bool UpdatePrecisions(std::span precisions) + { + constexpr auto SelectedEntries = VariableMap::GetAllValues(); + for (auto entry : SelectedEntries) + { + precisions[MainVariableMap::GetIndex(entry)] = Precision; + } + return true; + } +}; + +enum class EPrecision : uint8_t +{ + Precision_0 = 0, + Precision_2 = 2, + Precision_4 = 4, + Precision_8 = 8, + Precision_16 = 16, + Precision_32 = 32, + Precision_64 = 64, +}; + +template +class WeightContainers +{ +private: + template + using BitContainerT = BitContainer(Precision), Size, AllocatorT>; + + using TupleT = std::tuple...>; + TupleT m_WeightContainers; + + static_assert(((static_cast(Precisions) <= static_cast(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)"); + +public: + WeightContainers() = default; + WeightContainers(size_t size) + : m_WeightContainers{ BitContainerT(size, AllocatorT()) ... } + {} + WeightContainers(size_t size, AllocatorT& allocator) + : m_WeightContainers{ BitContainerT(size, allocator) ... } + {} + +public: + static constexpr size_t size() + { + return sizeof...(Precisions); + } + + /* + template + void SetValue(size_t containerIndex, size_t index, ValueT value) + { + SetValueFunctions()[containerIndex](*this, index, value); + } + */ + void SetValueFloat(size_t containerIndex, size_t index, double value) + { + SetFloatValueFunctions()[containerIndex](*this, index, value); + } + + template + uint64_t GetValue(size_t containerIndex, size_t index) + { + return GetValueFunctions()[containerIndex](*this, index); + } + +private: +/* + template + static constexpr auto& SetValueFunctions() + { + return SetValueFunctions(std::make_index_sequence()); + } + + template + static constexpr auto& SetValueFunctions(std::index_sequence) + { + static constexpr std::array setValueFunctions = + { + [] (WeightContainers& weightContainers, size_t index, ValueT value) { + std::get(weightContainers.m_WeightContainers)[index] = value; + }, + ... + }; + return setValueFunctions; + } +*/ + static constexpr auto& SetFloatValueFunctions() + { + return SetFloatValueFunctions(std::make_index_sequence()); + } + + template + static constexpr auto& SetFloatValueFunctions(std::index_sequence) + { + using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value); + constexpr std::array setFloatValueFunctions + { + [](WeightContainers& weightContainers, size_t index, double value) -> FunctionT { + + using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element::type; + if constexpr (!std::is_same_v) + { + constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0"); + std::get(weightContainers.m_WeightContainers)[index] = static_cast(value * BitContainerEntryT::MaxValue); + } + } + ... + }; + return setFloatValueFunctions; + } + + template + static constexpr auto& GetValueFunctions() + { + return GetValueFunctions(std::make_index_sequence()); + } + + template + static constexpr auto& GetValueFunctions(std::index_sequence) + { + using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index); + constexpr std::array getValueFunctions = + { + [] (WeightContainers& weightContainers, size_t index) -> FunctionT { + using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element::type; + + if constexpr (std::is_same_v) + { + return MaxWeight / 2; + } + else + { + constexpr size_t maxValue = BitContainerEntryT::MaxValue; + if constexpr (maxValue <= MaxWeight) + { + return std::get(weightContainers.m_WeightContainers)[index]; + } + else + { + return static_cast(std::get(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue; + } + } + } + ... + }; + return getValueFunctions; + } +}; + +/** +* @brief Compile-time weights storage for weighted random selection +* @tparam VarT The variable type +* @tparam VariableIDMapT The variable ID map type +* @tparam DefaultWeight The default weight for values not explicitly specified +* @tparam WeightSpecs Variadic template parameters of Weight specifications +*/ +template +class WeightsMap { +public: + static constexpr std::array GeneratePrecisionArray() + { + std::array precisionArray{}; + + (PrecisionEntries::template UpdatePrecisions(precisionArray) && ...); + + return precisionArray; + } + + static constexpr std::array GetPrecisionArray() + { + constexpr std::array precisionArray = GeneratePrecisionArray(); + return precisionArray; + } + + static constexpr size_t GetPrecision(size_t index) + { + return GetPrecisionArray()[index]; + } + + static constexpr uint8_t GetMaxPrecision() + { + return std::max({PrecisionEntries::PrecisionValue ...}); + } + + static constexpr uint8_t GetMaxValue() + { + return (1 << GetMaxPrecision()) - 1; + } + + static constexpr bool HasWeights() + { + return sizeof...(PrecisionEntries) > 0; + } + +public: + + using VariablesT = VariableIDMapT; + + template + auto MakeWeightContainersT(AllocatorT*, std::index_sequence) + -> WeightContainers; + + template + using WeightContainersT = decltype( + MakeWeightContainersT(static_cast(nullptr), std::make_index_sequence{}) + ); + + template + using Merge = WeightsMap; + + using WeightT = typename BitContainer>::StorageType; +}; + +} \ No newline at end of file