removed weights

This commit is contained in:
Connor
2026-02-06 12:07:55 +09:00
parent 94cd003e96
commit ded7ebc285
4 changed files with 217 additions and 297 deletions

View File

@@ -23,7 +23,6 @@
#include "wfc_callbacks.hpp" #include "wfc_callbacks.hpp"
#include "wfc_random.hpp" #include "wfc_random.hpp"
#include "wfc_queue.hpp" #include "wfc_queue.hpp"
#include "wfc_weights.hpp"
namespace WFC { namespace WFC {
@@ -50,7 +49,6 @@ template<typename WorldT, typename VarT,
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>, typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>, typename RandomSelectorT = DefaultRandomSelector<VarT>,
typename WeightsMapT = WeightsMap<VariableIDMapT>
> >
class WFC { class WFC {
public: public:
@@ -61,12 +59,11 @@ public:
// Try getting the world size, which is only available if the world type has a constexpr size() method // Try getting the world size, which is only available if the world type has a constexpr size() method
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0; constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>; using WaveType = Wave<VariableIDMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>; using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>; using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
using MaskType = typename WaveType::ElementT; using MaskType = typename WaveType::ElementT;
using VariableIDT = typename WaveType::VariableIDT; using VariableIDT = typename WaveType::VariableIDT;
using WeightsBufferType = BitContainer<WeightsMapT::GetMaxPrecision(), VariableIDMapT::size()>;
public: public:
struct SolverState struct SolverState
@@ -217,86 +214,25 @@ private:
return minEntropyCell; return minEntropyCell;
} }
using WeightsType = MinimumIntegerType<WeightsMapT::GetMaxValue()>;
using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1))); using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1)));
static WorldSizeT FindMinimumEntropyCellsWeighted(std::span<VariableIDT>& buffer, std::span<WeightsType>& weights, WaveType& wave)
{
constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits<RandomGeneratorReturnType>::max()) / VariableIDMapT::size();
auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t
{
auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; };
auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight<ElementsMaxWeight>(index); };
auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter);
return std::accumulate(view.begin(), view.end(), 0);
};
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter));
VariableIDT availableValues = wave.Entropy(minEntropyCell);
MaskType mask = wave.GetMask(minEntropyCell);
for (size_t i = 0; i < availableValues; ++i)
{
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
buffer[i] = index;
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
weights[i] = wave.template GetWeight<ElementsMaxWeight>(index);
mask = mask & (mask - 1); // turn off lowest set bit
}
return minEntropyCell;
}
static bool Branch(SolverState& state, WaveType& wave) static bool Branch(SolverState& state, WaveType& wave)
{ {
constexpr_assert(state.m_propagationQueue.empty()); constexpr_assert(state.m_propagationQueue.empty());
std::array<VariableIDT, VariableIDMapT::size()> Buffer{}; std::array<VariableIDT, VariableIDMapT::size()> Buffer{};
std::array<WeightsType, VariableIDMapT::size()> WeightsBuffer{};
uint64_t accumulatedWeights = 0;
WorldSizeT minEntropyCell{}; WorldSizeT minEntropyCell{};
if constexpr (WeightsMapT::HasWeights()) minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
{
minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave);
accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0);
}
else
{
minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
}
// randomly select a value from possible values // randomly select a value from possible values
while (Buffer.size()) while (Buffer.size())
{ {
size_t randomIndex; size_t randomIndex;
VariableIDT selectedValue; VariableIDT selectedValue;
if constexpr (WeightsMapT::HasWeights())
{ randomIndex = state.m_randomSelector.rng(Buffer.size());
auto randomWeight = state.m_randomSelector.rng(accumulatedWeights); selectedValue = Buffer[randomIndex];
for (size_t i = 0; i < WeightsBuffer.size(); ++i)
{
if (randomWeight < WeightsBuffer[i])
{
randomIndex = i;
selectedValue = Buffer[i];
break;
}
randomWeight -= WeightsBuffer[i];
}
accumulatedWeights -= WeightsBuffer[randomIndex];
}
else
{
randomIndex = state.m_randomSelector.rng(Buffer.size());
selectedValue = Buffer[randomIndex];
}
{ {
// copy the state and branch out // copy the state and branch out

View File

@@ -7,7 +7,6 @@ namespace WFC {
#include "wfc_constrainer.hpp" #include "wfc_constrainer.hpp"
#include "wfc_callbacks.hpp" #include "wfc_callbacks.hpp"
#include "wfc_random.hpp" #include "wfc_random.hpp"
#include "wfc_weights.hpp"
#include "wfc.hpp" #include "wfc.hpp"
/** /**
@@ -20,33 +19,32 @@ template<
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>, typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>, typename RandomSelectorT = DefaultRandomSelector<VarT>,
typename WeightsMapT = WeightsMap<VariableIDMapT>,
typename SelectedValueT = void> typename SelectedValueT = void>
class Builder { class Builder {
public: public:
using WorldSizeT = decltype(WorldT{}.size()); using WorldSizeT = decltype(WorldT{}.size());
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0; constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>; using WaveType = Wave<VariableIDMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>; using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>; using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
template <VarT ... Values> template <VarT ... Values>
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>; using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
template <size_t RangeStart, size_t RangeEnd> template <size_t RangeStart, size_t RangeEnd>
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>; using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
template <size_t RangeEnd> template <size_t RangeEnd>
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, 0, RangeEnd>>; using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, 0, RangeEnd>>;
template <VarT ... Values> template <VarT ... Values>
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>; using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
template <size_t RangeStart, size_t RangeEnd> template <size_t RangeStart, size_t RangeEnd>
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>; using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>; using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>;
@@ -60,7 +58,7 @@ public:
ConstrainerFunctionT, ConstrainerFunctionT,
SelectedValueT, SelectedValueT,
EmptyConstrainerFunctionT EmptyConstrainerFunctionT
>, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT >, CallbacksT, RandomSelectorT, SelectedValueT
>; >;
template <typename ConstrainerFunctionT> template <typename ConstrainerFunctionT>
@@ -72,27 +70,22 @@ public:
ConstrainerFunctionT, ConstrainerFunctionT,
VariableIDMapT, VariableIDMapT,
EmptyConstrainerFunctionT EmptyConstrainerFunctionT
>, CallbacksT, RandomSelectorT, WeightsMapT >, CallbacksT, RandomSelectorT
>; >;
template <typename NewCellCollapsedCallbackT> template <typename NewCellCollapsedCallbackT>
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT, WeightsMapT>; using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
template <typename NewContradictionCallbackT> template <typename NewContradictionCallbackT>
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT, WeightsMapT>; using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
template <typename NewBranchCallbackT> template <typename NewBranchCallbackT>
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT, WeightsMapT>; using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
template <typename NewRandomSelectorT> template <typename NewRandomSelectorT>
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT, WeightsMapT>; using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
template <EPrecision Precision>
using SetWeights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, typename WeightsMapT::template Merge<PrecisionEntry<SelectedValueT, static_cast<uint8_t>(Precision)>>, SelectedValueT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT>;
}; };
} }

View File

@@ -6,15 +6,13 @@
namespace WFC { namespace WFC {
template <typename VariableIDMapT, typename WeightsMapT, size_t Size = 0> template <typename VariableIDMapT, size_t Size = 0>
class Wave { class Wave {
public: public:
using BitContainerT = BitContainer<VariableIDMapT::size(), Size>; using BitContainerT = BitContainer<VariableIDMapT::size(), Size>;
using ElementT = typename BitContainerT::StorageType; using ElementT = typename BitContainerT::StorageType;
using IDMapT = VariableIDMapT; using IDMapT = VariableIDMapT;
using WeightContainersT = typename WeightsMapT::template WeightContainersT<Size, WFCStackAllocator>;
using VariableIDT = typename VariableIDMapT::VariableIDT; using VariableIDT = typename VariableIDMapT::VariableIDT;
using WeightT = typename WeightsMapT::WeightT;
static constexpr size_t ElementsAmount = Size; static constexpr size_t ElementsAmount = Size;
@@ -38,15 +36,8 @@ public:
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); } uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
ElementT GetMask(size_t index) const { return m_data[index]; } ElementT GetMask(size_t index) const { return m_data[index]; }
void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); }
template <size_t MaxWeight>
WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue<MaxWeight>(containerIndex, elementIndex); }
private: private:
BitContainerT m_data; BitContainerT m_data;
WeightContainersT m_weights;
}; };
} }

View File

@@ -1,234 +1,234 @@
#pragma once // #pragma once
#include <array> // #include <array>
#include <span> // #include <span>
#include <tuple> // #include <tuple>
#include "wfc_bit_container.hpp" // #include "wfc_bit_container.hpp"
#include "wfc_utils.hpp" // #include "wfc_utils.hpp"
namespace WFC { // namespace WFC {
template <typename VariableMap, uint8_t Precision> // template <typename VariableMap, uint8_t Precision>
struct PrecisionEntry // struct PrecisionEntry
{ // {
constexpr static uint8_t PrecisionValue = Precision; // constexpr static uint8_t PrecisionValue = Precision;
template <typename MainVariableMap> // template <typename MainVariableMap>
constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions) // constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
{ // {
constexpr auto SelectedEntries = VariableMap::GetAllValues(); // constexpr auto SelectedEntries = VariableMap::GetAllValues();
for (auto entry : SelectedEntries) // for (auto entry : SelectedEntries)
{ // {
precisions[MainVariableMap::GetIndex(entry)] = Precision; // precisions[MainVariableMap::GetIndex(entry)] = Precision;
} // }
return true; // return true;
} // }
}; // };
enum class EPrecision : uint8_t // enum class EPrecision : uint8_t
{ // {
Precision_0 = 0, // Precision_0 = 0,
Precision_2 = 2, // Precision_2 = 2,
Precision_4 = 4, // Precision_4 = 4,
Precision_8 = 8, // Precision_8 = 8,
Precision_16 = 16, // Precision_16 = 16,
Precision_32 = 32, // Precision_32 = 32,
Precision_64 = 64, // Precision_64 = 64,
}; // };
template <size_t Size, typename AllocatorT, EPrecision ... Precisions> // template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
class WeightContainers // class WeightContainers
{ // {
private: // private:
template <EPrecision Precision> // template <EPrecision Precision>
using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>; // using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
using TupleT = std::tuple<BitContainerT<Precisions>...>; // using TupleT = std::tuple<BitContainerT<Precisions>...>;
TupleT m_WeightContainers; // TupleT m_WeightContainers;
static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)"); // static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
public: // public:
WeightContainers() = default; // WeightContainers() = default;
WeightContainers(size_t size) // WeightContainers(size_t size)
: m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... } // : m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
{} // {}
WeightContainers(size_t size, AllocatorT& allocator) // WeightContainers(size_t size, AllocatorT& allocator)
: m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... } // : m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
{} // {}
public: // public:
static constexpr size_t size() // static constexpr size_t size()
{ // {
return sizeof...(Precisions); // return sizeof...(Precisions);
} // }
/* // /*
template <typename ValueT> // template <typename ValueT>
void SetValue(size_t containerIndex, size_t index, ValueT value) // void SetValue(size_t containerIndex, size_t index, ValueT value)
{ // {
SetValueFunctions<ValueT>()[containerIndex](*this, index, value); // SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
} // }
*/ // */
void SetValueFloat(size_t containerIndex, size_t index, double value) // void SetValueFloat(size_t containerIndex, size_t index, double value)
{ // {
SetFloatValueFunctions()[containerIndex](*this, index, value); // SetFloatValueFunctions()[containerIndex](*this, index, value);
} // }
template <size_t MaxWeight> // template <size_t MaxWeight>
uint64_t GetValue(size_t containerIndex, size_t index) // uint64_t GetValue(size_t containerIndex, size_t index)
{ // {
return GetValueFunctions<MaxWeight>()[containerIndex](*this, index); // return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
} // }
private: // private:
/* // /*
template <typename ValueT> // template <typename ValueT>
static constexpr auto& SetValueFunctions() // static constexpr auto& SetValueFunctions()
{ // {
return SetValueFunctions<ValueT>(std::make_index_sequence<size()>()); // return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
} // }
template <typename ValueT, size_t ... Is> // template <typename ValueT, size_t ... Is>
static constexpr auto& SetValueFunctions(std::index_sequence<Is...>) // static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
{ // {
static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions = // static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
{ // {
[] (WeightContainers& weightContainers, size_t index, ValueT value) { // [] (WeightContainers& weightContainers, size_t index, ValueT value) {
std::get<Is>(weightContainers.m_WeightContainers)[index] = value; // std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
}, // },
... // ...
}; // };
return setValueFunctions; // return setValueFunctions;
} // }
*/ // */
static constexpr auto& SetFloatValueFunctions() // static constexpr auto& SetFloatValueFunctions()
{ // {
return SetFloatValueFunctions(std::make_index_sequence<size()>()); // return SetFloatValueFunctions(std::make_index_sequence<size()>());
} // }
template <size_t ... Is> // template <size_t ... Is>
static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>) // static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
{ // {
using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value); // using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
constexpr std::array<FunctionT, size()> setFloatValueFunctions // constexpr std::array<FunctionT, size()> setFloatValueFunctions
{ // {
[](WeightContainers& weightContainers, size_t index, double value) -> FunctionT { // [](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type; // using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>) // if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
{ // {
constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0"); // constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue); // std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
} // }
} // }
... // ...
}; // };
return setFloatValueFunctions; // return setFloatValueFunctions;
} // }
template <size_t MaxWeight> // template <size_t MaxWeight>
static constexpr auto& GetValueFunctions() // static constexpr auto& GetValueFunctions()
{ // {
return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>()); // return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
} // }
template <size_t MaxWeight, size_t ... Is> // template <size_t MaxWeight, size_t ... Is>
static constexpr auto& GetValueFunctions(std::index_sequence<Is...>) // static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
{ // {
using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index); // using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
constexpr std::array<FunctionT, size()> getValueFunctions = // constexpr std::array<FunctionT, size()> getValueFunctions =
{ // {
[] (WeightContainers& weightContainers, size_t index) -> FunctionT { // [] (WeightContainers& weightContainers, size_t index) -> FunctionT {
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type; // using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>) // if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
{ // {
return MaxWeight / 2; // return MaxWeight / 2;
} // }
else // else
{ // {
constexpr size_t maxValue = BitContainerEntryT::MaxValue; // constexpr size_t maxValue = BitContainerEntryT::MaxValue;
if constexpr (maxValue <= MaxWeight) // if constexpr (maxValue <= MaxWeight)
{ // {
return std::get<Is>(weightContainers.m_WeightContainers)[index]; // return std::get<Is>(weightContainers.m_WeightContainers)[index];
} // }
else // else
{ // {
return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue; // return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
} // }
} // }
} // }
... // ...
}; // };
return getValueFunctions; // return getValueFunctions;
} // }
}; // };
/** // /**
* @brief Compile-time weights storage for weighted random selection // * @brief Compile-time weights storage for weighted random selection
* @tparam VarT The variable type // * @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type // * @tparam VariableIDMapT The variable ID map type
* @tparam DefaultWeight The default weight for values not explicitly specified // * @tparam DefaultWeight The default weight for values not explicitly specified
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications // * @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
*/ // */
template <typename VariableIDMapT, typename ... PrecisionEntries> // template <typename VariableIDMapT, typename ... PrecisionEntries>
class WeightsMap { // class WeightsMap {
public: // public:
static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray() // static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
{ // {
std::array<uint8_t, VariableIDMapT::size()> precisionArray{}; // std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
(PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...); // (PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
return precisionArray; // return precisionArray;
} // }
static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray() // static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
{ // {
constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray(); // constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
return precisionArray; // return precisionArray;
} // }
static constexpr size_t GetPrecision(size_t index) // static constexpr size_t GetPrecision(size_t index)
{ // {
return GetPrecisionArray()[index]; // return GetPrecisionArray()[index];
} // }
static constexpr uint8_t GetMaxPrecision() // static constexpr uint8_t GetMaxPrecision()
{ // {
return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...}); // return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
} // }
static constexpr uint8_t GetMaxValue() // static constexpr uint8_t GetMaxValue()
{ // {
return (1 << GetMaxPrecision()) - 1; // return (1 << GetMaxPrecision()) - 1;
} // }
static constexpr bool HasWeights() // static constexpr bool HasWeights()
{ // {
return sizeof...(PrecisionEntries) > 0; // return sizeof...(PrecisionEntries) > 0;
} // }
public: // public:
using VariablesT = VariableIDMapT; // using VariablesT = VariableIDMapT;
template<size_t Size, typename AllocatorT, size_t... Is> // template<size_t Size, typename AllocatorT, size_t... Is>
auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>) // auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
-> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>; // -> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
template <size_t Size, typename AllocatorT> // template <size_t Size, typename AllocatorT>
using WeightContainersT = decltype( // using WeightContainersT = decltype(
MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{}) // MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
); // );
template <typename PrecisionEntryT> // template <typename PrecisionEntryT>
using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>; // using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType; // using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
}; // };
} // }