From ded7ebc28505d28c99c06c2db33c298a177182c4 Mon Sep 17 00:00:00 2001 From: Connor Date: Fri, 6 Feb 2026 12:07:55 +0900 Subject: [PATCH] removed weights --- include/nd-wfc/wfc.hpp | 74 +----- include/nd-wfc/wfc_builder.hpp | 33 ++- include/nd-wfc/wfc_wave.hpp | 11 +- include/nd-wfc/wfc_weights.hpp | 396 ++++++++++++++++----------------- 4 files changed, 217 insertions(+), 297 deletions(-) diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index 54f477b..021cc0c 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -23,7 +23,6 @@ #include "wfc_callbacks.hpp" #include "wfc_random.hpp" #include "wfc_queue.hpp" -#include "wfc_weights.hpp" namespace WFC { @@ -50,7 +49,6 @@ template, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector, - typename WeightsMapT = WeightsMap > class WFC { public: @@ -61,12 +59,11 @@ public: // Try getting the world size, which is only available if the world type has a constexpr size() method constexpr static WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; - using WaveType = Wave; + using WaveType = Wave; using PropagationQueueType = WFCQueue; using ConstrainerType = Constrainer; using MaskType = typename WaveType::ElementT; using VariableIDT = typename WaveType::VariableIDT; - using WeightsBufferType = BitContainer; public: struct SolverState @@ -217,86 +214,25 @@ private: return minEntropyCell; } - using WeightsType = MinimumIntegerType; using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast(1))); - static WorldSizeT FindMinimumEntropyCellsWeighted(std::span& buffer, std::span& weights, WaveType& wave) - { - constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits::max()) / VariableIDMapT::size(); - - auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t - { - auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; }; - auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight(index); }; - - auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter); - return std::accumulate(view.begin(), view.end(), 0); - }; - - auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter)); - - VariableIDT availableValues = wave.Entropy(minEntropyCell); - MaskType mask = wave.GetMask(minEntropyCell); - for (size_t i = 0; i < availableValues; ++i) - { - VariableIDT index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit - constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds"); - - buffer[i] = index; - constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); - - weights[i] = wave.template GetWeight(index); - - mask = mask & (mask - 1); // turn off lowest set bit - } - - return minEntropyCell; - } - static bool Branch(SolverState& state, WaveType& wave) { constexpr_assert(state.m_propagationQueue.empty()); std::array Buffer{}; - std::array WeightsBuffer{}; - uint64_t accumulatedWeights = 0; WorldSizeT minEntropyCell{}; - if constexpr (WeightsMapT::HasWeights()) - { - minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave); - accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0); - } - else - { - minEntropyCell = FindMinimumEntropyCells(Buffer, wave); - } + minEntropyCell = FindMinimumEntropyCells(Buffer, wave); // randomly select a value from possible values while (Buffer.size()) { size_t randomIndex; VariableIDT selectedValue; - if constexpr (WeightsMapT::HasWeights()) - { - auto randomWeight = state.m_randomSelector.rng(accumulatedWeights); - for (size_t i = 0; i < WeightsBuffer.size(); ++i) - { - if (randomWeight < WeightsBuffer[i]) - { - randomIndex = i; - selectedValue = Buffer[i]; - break; - } - randomWeight -= WeightsBuffer[i]; - } - accumulatedWeights -= WeightsBuffer[randomIndex]; - } - else - { - randomIndex = state.m_randomSelector.rng(Buffer.size()); - selectedValue = Buffer[randomIndex]; - } + + randomIndex = state.m_randomSelector.rng(Buffer.size()); + selectedValue = Buffer[randomIndex]; { // copy the state and branch out diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp index 34424dd..5908641 100644 --- a/include/nd-wfc/wfc_builder.hpp +++ b/include/nd-wfc/wfc_builder.hpp @@ -7,7 +7,6 @@ namespace WFC { #include "wfc_constrainer.hpp" #include "wfc_callbacks.hpp" #include "wfc_random.hpp" -#include "wfc_weights.hpp" #include "wfc.hpp" /** @@ -20,33 +19,32 @@ template< typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector, - typename WeightsMapT = WeightsMap, typename SelectedValueT = void> class Builder { public: using WorldSizeT = decltype(WorldT{}.size()); constexpr static WorldSizeT WorldSize = HasConstexprSize ? WorldT{}.size() : 0; - using WaveType = Wave; + using WaveType = Wave; using PropagationQueueType = WFCQueue; using ConstrainerType = Constrainer; template - using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap>; + using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap>; template - using DefineRange = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange>; + using DefineRange = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange>; template - using DefineRange0 = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange>; + using DefineRange0 = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange>; template - using Variable = Builder>; + using Variable = Builder>; template - using VariableRange = Builder>; + using VariableRange = Builder>; using EmptyConstrainerFunctionT = EmptyConstrainerFunction; @@ -60,7 +58,7 @@ public: ConstrainerFunctionT, SelectedValueT, EmptyConstrainerFunctionT - >, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT + >, CallbacksT, RandomSelectorT, SelectedValueT >; template @@ -72,27 +70,22 @@ public: ConstrainerFunctionT, VariableIDMapT, EmptyConstrainerFunctionT - >, CallbacksT, RandomSelectorT, WeightsMapT + >, CallbacksT, RandomSelectorT >; template - using SetCellCollapsedCallback = Builder, RandomSelectorT, WeightsMapT>; + using SetCellCollapsedCallback = Builder, RandomSelectorT>; template - using SetContradictionCallback = Builder, RandomSelectorT, WeightsMapT>; + using SetContradictionCallback = Builder, RandomSelectorT>; template - using SetBranchCallback = Builder, RandomSelectorT, WeightsMapT>; + using SetBranchCallback = Builder, RandomSelectorT>; template - using SetRandomSelector = Builder; + using SetRandomSelector = Builder; - - template - using SetWeights = Builder(Precision)>>, SelectedValueT>; - - - using Build = WFC; + using Build = WFC; }; } \ No newline at end of file diff --git a/include/nd-wfc/wfc_wave.hpp b/include/nd-wfc/wfc_wave.hpp index 1816df9..5e146eb 100644 --- a/include/nd-wfc/wfc_wave.hpp +++ b/include/nd-wfc/wfc_wave.hpp @@ -6,15 +6,13 @@ namespace WFC { -template +template class Wave { public: using BitContainerT = BitContainer; using ElementT = typename BitContainerT::StorageType; using IDMapT = VariableIDMapT; - using WeightContainersT = typename WeightsMapT::template WeightContainersT; using VariableIDT = typename VariableIDMapT::VariableIDT; - using WeightT = typename WeightsMapT::WeightT; static constexpr size_t ElementsAmount = Size; @@ -38,15 +36,8 @@ public: uint16_t GetVariableID(size_t index) const { return static_cast(std::countr_zero(m_data[index])); } ElementT GetMask(size_t index) const { return m_data[index]; } - void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); } - - template - WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue(containerIndex, elementIndex); } - - private: BitContainerT m_data; - WeightContainersT m_weights; }; } \ No newline at end of file diff --git a/include/nd-wfc/wfc_weights.hpp b/include/nd-wfc/wfc_weights.hpp index d32f7c1..93fca8c 100644 --- a/include/nd-wfc/wfc_weights.hpp +++ b/include/nd-wfc/wfc_weights.hpp @@ -1,234 +1,234 @@ -#pragma once +// #pragma once -#include -#include -#include +// #include +// #include +// #include -#include "wfc_bit_container.hpp" -#include "wfc_utils.hpp" +// #include "wfc_bit_container.hpp" +// #include "wfc_utils.hpp" -namespace WFC { +// namespace WFC { -template -struct PrecisionEntry -{ +// template +// struct PrecisionEntry +// { - constexpr static uint8_t PrecisionValue = Precision; +// constexpr static uint8_t PrecisionValue = Precision; - template - constexpr static bool UpdatePrecisions(std::span precisions) - { - constexpr auto SelectedEntries = VariableMap::GetAllValues(); - for (auto entry : SelectedEntries) - { - precisions[MainVariableMap::GetIndex(entry)] = Precision; - } - return true; - } -}; +// template +// constexpr static bool UpdatePrecisions(std::span precisions) +// { +// constexpr auto SelectedEntries = VariableMap::GetAllValues(); +// for (auto entry : SelectedEntries) +// { +// precisions[MainVariableMap::GetIndex(entry)] = Precision; +// } +// return true; +// } +// }; -enum class EPrecision : uint8_t -{ - Precision_0 = 0, - Precision_2 = 2, - Precision_4 = 4, - Precision_8 = 8, - Precision_16 = 16, - Precision_32 = 32, - Precision_64 = 64, -}; +// enum class EPrecision : uint8_t +// { +// Precision_0 = 0, +// Precision_2 = 2, +// Precision_4 = 4, +// Precision_8 = 8, +// Precision_16 = 16, +// Precision_32 = 32, +// Precision_64 = 64, +// }; -template -class WeightContainers -{ -private: - template - using BitContainerT = BitContainer(Precision), Size, AllocatorT>; +// template +// class WeightContainers +// { +// private: +// template +// using BitContainerT = BitContainer(Precision), Size, AllocatorT>; - using TupleT = std::tuple...>; - TupleT m_WeightContainers; +// using TupleT = std::tuple...>; +// TupleT m_WeightContainers; - static_assert(((static_cast(Precisions) <= static_cast(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)"); +// static_assert(((static_cast(Precisions) <= static_cast(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)"); -public: - WeightContainers() = default; - WeightContainers(size_t size) - : m_WeightContainers{ BitContainerT(size, AllocatorT()) ... } - {} - WeightContainers(size_t size, AllocatorT& allocator) - : m_WeightContainers{ BitContainerT(size, allocator) ... } - {} +// public: +// WeightContainers() = default; +// WeightContainers(size_t size) +// : m_WeightContainers{ BitContainerT(size, AllocatorT()) ... } +// {} +// WeightContainers(size_t size, AllocatorT& allocator) +// : m_WeightContainers{ BitContainerT(size, allocator) ... } +// {} -public: - static constexpr size_t size() - { - return sizeof...(Precisions); - } +// public: +// static constexpr size_t size() +// { +// return sizeof...(Precisions); +// } - /* - template - void SetValue(size_t containerIndex, size_t index, ValueT value) - { - SetValueFunctions()[containerIndex](*this, index, value); - } - */ - void SetValueFloat(size_t containerIndex, size_t index, double value) - { - SetFloatValueFunctions()[containerIndex](*this, index, value); - } +// /* +// template +// void SetValue(size_t containerIndex, size_t index, ValueT value) +// { +// SetValueFunctions()[containerIndex](*this, index, value); +// } +// */ +// void SetValueFloat(size_t containerIndex, size_t index, double value) +// { +// SetFloatValueFunctions()[containerIndex](*this, index, value); +// } - template - uint64_t GetValue(size_t containerIndex, size_t index) - { - return GetValueFunctions()[containerIndex](*this, index); - } +// template +// uint64_t GetValue(size_t containerIndex, size_t index) +// { +// return GetValueFunctions()[containerIndex](*this, index); +// } -private: -/* - template - static constexpr auto& SetValueFunctions() - { - return SetValueFunctions(std::make_index_sequence()); - } +// private: +// /* +// template +// static constexpr auto& SetValueFunctions() +// { +// return SetValueFunctions(std::make_index_sequence()); +// } - template - static constexpr auto& SetValueFunctions(std::index_sequence) - { - static constexpr std::array setValueFunctions = - { - [] (WeightContainers& weightContainers, size_t index, ValueT value) { - std::get(weightContainers.m_WeightContainers)[index] = value; - }, - ... - }; - return setValueFunctions; - } -*/ - static constexpr auto& SetFloatValueFunctions() - { - return SetFloatValueFunctions(std::make_index_sequence()); - } +// template +// static constexpr auto& SetValueFunctions(std::index_sequence) +// { +// static constexpr std::array setValueFunctions = +// { +// [] (WeightContainers& weightContainers, size_t index, ValueT value) { +// std::get(weightContainers.m_WeightContainers)[index] = value; +// }, +// ... +// }; +// return setValueFunctions; +// } +// */ +// static constexpr auto& SetFloatValueFunctions() +// { +// return SetFloatValueFunctions(std::make_index_sequence()); +// } - template - static constexpr auto& SetFloatValueFunctions(std::index_sequence) - { - using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value); - constexpr std::array setFloatValueFunctions - { - [](WeightContainers& weightContainers, size_t index, double value) -> FunctionT { +// template +// static constexpr auto& SetFloatValueFunctions(std::index_sequence) +// { +// using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value); +// constexpr std::array setFloatValueFunctions +// { +// [](WeightContainers& weightContainers, size_t index, double value) -> FunctionT { - using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element::type; - if constexpr (!std::is_same_v) - { - constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0"); - std::get(weightContainers.m_WeightContainers)[index] = static_cast(value * BitContainerEntryT::MaxValue); - } - } - ... - }; - return setFloatValueFunctions; - } +// using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element::type; +// if constexpr (!std::is_same_v) +// { +// constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0"); +// std::get(weightContainers.m_WeightContainers)[index] = static_cast(value * BitContainerEntryT::MaxValue); +// } +// } +// ... +// }; +// return setFloatValueFunctions; +// } - template - static constexpr auto& GetValueFunctions() - { - return GetValueFunctions(std::make_index_sequence()); - } +// template +// static constexpr auto& GetValueFunctions() +// { +// return GetValueFunctions(std::make_index_sequence()); +// } - template - static constexpr auto& GetValueFunctions(std::index_sequence) - { - using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index); - constexpr std::array getValueFunctions = - { - [] (WeightContainers& weightContainers, size_t index) -> FunctionT { - using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element::type; +// template +// static constexpr auto& GetValueFunctions(std::index_sequence) +// { +// using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index); +// constexpr std::array getValueFunctions = +// { +// [] (WeightContainers& weightContainers, size_t index) -> FunctionT { +// using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element::type; - if constexpr (std::is_same_v) - { - return MaxWeight / 2; - } - else - { - constexpr size_t maxValue = BitContainerEntryT::MaxValue; - if constexpr (maxValue <= MaxWeight) - { - return std::get(weightContainers.m_WeightContainers)[index]; - } - else - { - return static_cast(std::get(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue; - } - } - } - ... - }; - return getValueFunctions; - } -}; +// if constexpr (std::is_same_v) +// { +// return MaxWeight / 2; +// } +// else +// { +// constexpr size_t maxValue = BitContainerEntryT::MaxValue; +// if constexpr (maxValue <= MaxWeight) +// { +// return std::get(weightContainers.m_WeightContainers)[index]; +// } +// else +// { +// return static_cast(std::get(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue; +// } +// } +// } +// ... +// }; +// return getValueFunctions; +// } +// }; -/** -* @brief Compile-time weights storage for weighted random selection -* @tparam VarT The variable type -* @tparam VariableIDMapT The variable ID map type -* @tparam DefaultWeight The default weight for values not explicitly specified -* @tparam WeightSpecs Variadic template parameters of Weight specifications -*/ -template -class WeightsMap { -public: - static constexpr std::array GeneratePrecisionArray() - { - std::array precisionArray{}; +// /** +// * @brief Compile-time weights storage for weighted random selection +// * @tparam VarT The variable type +// * @tparam VariableIDMapT The variable ID map type +// * @tparam DefaultWeight The default weight for values not explicitly specified +// * @tparam WeightSpecs Variadic template parameters of Weight specifications +// */ +// template +// class WeightsMap { +// public: +// static constexpr std::array GeneratePrecisionArray() +// { +// std::array precisionArray{}; - (PrecisionEntries::template UpdatePrecisions(precisionArray) && ...); +// (PrecisionEntries::template UpdatePrecisions(precisionArray) && ...); - return precisionArray; - } +// return precisionArray; +// } - static constexpr std::array GetPrecisionArray() - { - constexpr std::array precisionArray = GeneratePrecisionArray(); - return precisionArray; - } +// static constexpr std::array GetPrecisionArray() +// { +// constexpr std::array precisionArray = GeneratePrecisionArray(); +// return precisionArray; +// } - static constexpr size_t GetPrecision(size_t index) - { - return GetPrecisionArray()[index]; - } +// static constexpr size_t GetPrecision(size_t index) +// { +// return GetPrecisionArray()[index]; +// } - static constexpr uint8_t GetMaxPrecision() - { - return std::max({PrecisionEntries::PrecisionValue ...}); - } +// static constexpr uint8_t GetMaxPrecision() +// { +// return std::max({PrecisionEntries::PrecisionValue ...}); +// } - static constexpr uint8_t GetMaxValue() - { - return (1 << GetMaxPrecision()) - 1; - } +// static constexpr uint8_t GetMaxValue() +// { +// return (1 << GetMaxPrecision()) - 1; +// } - static constexpr bool HasWeights() - { - return sizeof...(PrecisionEntries) > 0; - } +// static constexpr bool HasWeights() +// { +// return sizeof...(PrecisionEntries) > 0; +// } -public: +// public: - using VariablesT = VariableIDMapT; +// using VariablesT = VariableIDMapT; - template - auto MakeWeightContainersT(AllocatorT*, std::index_sequence) - -> WeightContainers; +// template +// auto MakeWeightContainersT(AllocatorT*, std::index_sequence) +// -> WeightContainers; - template - using WeightContainersT = decltype( - MakeWeightContainersT(static_cast(nullptr), std::make_index_sequence{}) - ); +// template +// using WeightContainersT = decltype( +// MakeWeightContainersT(static_cast(nullptr), std::make_index_sequence{}) +// ); - template - using Merge = WeightsMap; +// template +// using Merge = WeightsMap; - using WeightT = typename BitContainer>::StorageType; -}; +// using WeightT = typename BitContainer>::StorageType; +// }; -} \ No newline at end of file +// } \ No newline at end of file