removed weights
This commit is contained in:
@@ -23,7 +23,6 @@
|
|||||||
#include "wfc_callbacks.hpp"
|
#include "wfc_callbacks.hpp"
|
||||||
#include "wfc_random.hpp"
|
#include "wfc_random.hpp"
|
||||||
#include "wfc_queue.hpp"
|
#include "wfc_queue.hpp"
|
||||||
#include "wfc_weights.hpp"
|
|
||||||
|
|
||||||
namespace WFC {
|
namespace WFC {
|
||||||
|
|
||||||
@@ -50,7 +49,6 @@ template<typename WorldT, typename VarT,
|
|||||||
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
||||||
typename CallbacksT = Callbacks<WorldT>,
|
typename CallbacksT = Callbacks<WorldT>,
|
||||||
typename RandomSelectorT = DefaultRandomSelector<VarT>,
|
typename RandomSelectorT = DefaultRandomSelector<VarT>,
|
||||||
typename WeightsMapT = WeightsMap<VariableIDMapT>
|
|
||||||
>
|
>
|
||||||
class WFC {
|
class WFC {
|
||||||
public:
|
public:
|
||||||
@@ -61,12 +59,11 @@ public:
|
|||||||
// Try getting the world size, which is only available if the world type has a constexpr size() method
|
// Try getting the world size, which is only available if the world type has a constexpr size() method
|
||||||
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
||||||
|
|
||||||
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
|
using WaveType = Wave<VariableIDMapT, WorldSize>;
|
||||||
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
|
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
|
||||||
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
|
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
|
||||||
using MaskType = typename WaveType::ElementT;
|
using MaskType = typename WaveType::ElementT;
|
||||||
using VariableIDT = typename WaveType::VariableIDT;
|
using VariableIDT = typename WaveType::VariableIDT;
|
||||||
using WeightsBufferType = BitContainer<WeightsMapT::GetMaxPrecision(), VariableIDMapT::size()>;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
struct SolverState
|
struct SolverState
|
||||||
@@ -217,86 +214,25 @@ private:
|
|||||||
return minEntropyCell;
|
return minEntropyCell;
|
||||||
}
|
}
|
||||||
|
|
||||||
using WeightsType = MinimumIntegerType<WeightsMapT::GetMaxValue()>;
|
|
||||||
using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1)));
|
using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1)));
|
||||||
|
|
||||||
static WorldSizeT FindMinimumEntropyCellsWeighted(std::span<VariableIDT>& buffer, std::span<WeightsType>& weights, WaveType& wave)
|
|
||||||
{
|
|
||||||
constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits<RandomGeneratorReturnType>::max()) / VariableIDMapT::size();
|
|
||||||
|
|
||||||
auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t
|
|
||||||
{
|
|
||||||
auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; };
|
|
||||||
auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight<ElementsMaxWeight>(index); };
|
|
||||||
|
|
||||||
auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter);
|
|
||||||
return std::accumulate(view.begin(), view.end(), 0);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter));
|
|
||||||
|
|
||||||
VariableIDT availableValues = wave.Entropy(minEntropyCell);
|
|
||||||
MaskType mask = wave.GetMask(minEntropyCell);
|
|
||||||
for (size_t i = 0; i < availableValues; ++i)
|
|
||||||
{
|
|
||||||
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
|
|
||||||
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
|
|
||||||
|
|
||||||
buffer[i] = index;
|
|
||||||
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
|
|
||||||
|
|
||||||
weights[i] = wave.template GetWeight<ElementsMaxWeight>(index);
|
|
||||||
|
|
||||||
mask = mask & (mask - 1); // turn off lowest set bit
|
|
||||||
}
|
|
||||||
|
|
||||||
return minEntropyCell;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool Branch(SolverState& state, WaveType& wave)
|
static bool Branch(SolverState& state, WaveType& wave)
|
||||||
{
|
{
|
||||||
constexpr_assert(state.m_propagationQueue.empty());
|
constexpr_assert(state.m_propagationQueue.empty());
|
||||||
|
|
||||||
std::array<VariableIDT, VariableIDMapT::size()> Buffer{};
|
std::array<VariableIDT, VariableIDMapT::size()> Buffer{};
|
||||||
std::array<WeightsType, VariableIDMapT::size()> WeightsBuffer{};
|
|
||||||
uint64_t accumulatedWeights = 0;
|
|
||||||
WorldSizeT minEntropyCell{};
|
WorldSizeT minEntropyCell{};
|
||||||
|
|
||||||
if constexpr (WeightsMapT::HasWeights())
|
minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
|
||||||
{
|
|
||||||
minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave);
|
|
||||||
accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
|
|
||||||
}
|
|
||||||
|
|
||||||
// randomly select a value from possible values
|
// randomly select a value from possible values
|
||||||
while (Buffer.size())
|
while (Buffer.size())
|
||||||
{
|
{
|
||||||
size_t randomIndex;
|
size_t randomIndex;
|
||||||
VariableIDT selectedValue;
|
VariableIDT selectedValue;
|
||||||
if constexpr (WeightsMapT::HasWeights())
|
|
||||||
{
|
randomIndex = state.m_randomSelector.rng(Buffer.size());
|
||||||
auto randomWeight = state.m_randomSelector.rng(accumulatedWeights);
|
selectedValue = Buffer[randomIndex];
|
||||||
for (size_t i = 0; i < WeightsBuffer.size(); ++i)
|
|
||||||
{
|
|
||||||
if (randomWeight < WeightsBuffer[i])
|
|
||||||
{
|
|
||||||
randomIndex = i;
|
|
||||||
selectedValue = Buffer[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
randomWeight -= WeightsBuffer[i];
|
|
||||||
}
|
|
||||||
accumulatedWeights -= WeightsBuffer[randomIndex];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
randomIndex = state.m_randomSelector.rng(Buffer.size());
|
|
||||||
selectedValue = Buffer[randomIndex];
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// copy the state and branch out
|
// copy the state and branch out
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ namespace WFC {
|
|||||||
#include "wfc_constrainer.hpp"
|
#include "wfc_constrainer.hpp"
|
||||||
#include "wfc_callbacks.hpp"
|
#include "wfc_callbacks.hpp"
|
||||||
#include "wfc_random.hpp"
|
#include "wfc_random.hpp"
|
||||||
#include "wfc_weights.hpp"
|
|
||||||
#include "wfc.hpp"
|
#include "wfc.hpp"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -20,33 +19,32 @@ template<
|
|||||||
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
||||||
typename CallbacksT = Callbacks<WorldT>,
|
typename CallbacksT = Callbacks<WorldT>,
|
||||||
typename RandomSelectorT = DefaultRandomSelector<VarT>,
|
typename RandomSelectorT = DefaultRandomSelector<VarT>,
|
||||||
typename WeightsMapT = WeightsMap<VariableIDMapT>,
|
|
||||||
typename SelectedValueT = void>
|
typename SelectedValueT = void>
|
||||||
class Builder {
|
class Builder {
|
||||||
public:
|
public:
|
||||||
using WorldSizeT = decltype(WorldT{}.size());
|
using WorldSizeT = decltype(WorldT{}.size());
|
||||||
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
||||||
|
|
||||||
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
|
using WaveType = Wave<VariableIDMapT, WorldSize>;
|
||||||
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
|
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
|
||||||
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
|
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
|
||||||
|
|
||||||
|
|
||||||
template <VarT ... Values>
|
template <VarT ... Values>
|
||||||
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
|
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
|
||||||
|
|
||||||
template <size_t RangeStart, size_t RangeEnd>
|
template <size_t RangeStart, size_t RangeEnd>
|
||||||
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
||||||
|
|
||||||
template <size_t RangeEnd>
|
template <size_t RangeEnd>
|
||||||
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, 0, RangeEnd>>;
|
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, 0, RangeEnd>>;
|
||||||
|
|
||||||
|
|
||||||
template <VarT ... Values>
|
template <VarT ... Values>
|
||||||
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
|
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
|
||||||
|
|
||||||
template <size_t RangeStart, size_t RangeEnd>
|
template <size_t RangeStart, size_t RangeEnd>
|
||||||
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
||||||
|
|
||||||
|
|
||||||
using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>;
|
using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>;
|
||||||
@@ -60,7 +58,7 @@ public:
|
|||||||
ConstrainerFunctionT,
|
ConstrainerFunctionT,
|
||||||
SelectedValueT,
|
SelectedValueT,
|
||||||
EmptyConstrainerFunctionT
|
EmptyConstrainerFunctionT
|
||||||
>, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT
|
>, CallbacksT, RandomSelectorT, SelectedValueT
|
||||||
>;
|
>;
|
||||||
|
|
||||||
template <typename ConstrainerFunctionT>
|
template <typename ConstrainerFunctionT>
|
||||||
@@ -72,27 +70,22 @@ public:
|
|||||||
ConstrainerFunctionT,
|
ConstrainerFunctionT,
|
||||||
VariableIDMapT,
|
VariableIDMapT,
|
||||||
EmptyConstrainerFunctionT
|
EmptyConstrainerFunctionT
|
||||||
>, CallbacksT, RandomSelectorT, WeightsMapT
|
>, CallbacksT, RandomSelectorT
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|
||||||
template <typename NewCellCollapsedCallbackT>
|
template <typename NewCellCollapsedCallbackT>
|
||||||
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT, WeightsMapT>;
|
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
|
||||||
template <typename NewContradictionCallbackT>
|
template <typename NewContradictionCallbackT>
|
||||||
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT, WeightsMapT>;
|
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
|
||||||
template <typename NewBranchCallbackT>
|
template <typename NewBranchCallbackT>
|
||||||
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT, WeightsMapT>;
|
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
|
||||||
|
|
||||||
|
|
||||||
template <typename NewRandomSelectorT>
|
template <typename NewRandomSelectorT>
|
||||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT, WeightsMapT>;
|
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
||||||
|
|
||||||
|
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||||
template <EPrecision Precision>
|
|
||||||
using SetWeights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, typename WeightsMapT::template Merge<PrecisionEntry<SelectedValueT, static_cast<uint8_t>(Precision)>>, SelectedValueT>;
|
|
||||||
|
|
||||||
|
|
||||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT>;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -6,15 +6,13 @@
|
|||||||
|
|
||||||
namespace WFC {
|
namespace WFC {
|
||||||
|
|
||||||
template <typename VariableIDMapT, typename WeightsMapT, size_t Size = 0>
|
template <typename VariableIDMapT, size_t Size = 0>
|
||||||
class Wave {
|
class Wave {
|
||||||
public:
|
public:
|
||||||
using BitContainerT = BitContainer<VariableIDMapT::size(), Size>;
|
using BitContainerT = BitContainer<VariableIDMapT::size(), Size>;
|
||||||
using ElementT = typename BitContainerT::StorageType;
|
using ElementT = typename BitContainerT::StorageType;
|
||||||
using IDMapT = VariableIDMapT;
|
using IDMapT = VariableIDMapT;
|
||||||
using WeightContainersT = typename WeightsMapT::template WeightContainersT<Size, WFCStackAllocator>;
|
|
||||||
using VariableIDT = typename VariableIDMapT::VariableIDT;
|
using VariableIDT = typename VariableIDMapT::VariableIDT;
|
||||||
using WeightT = typename WeightsMapT::WeightT;
|
|
||||||
|
|
||||||
static constexpr size_t ElementsAmount = Size;
|
static constexpr size_t ElementsAmount = Size;
|
||||||
|
|
||||||
@@ -38,15 +36,8 @@ public:
|
|||||||
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
|
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
|
||||||
ElementT GetMask(size_t index) const { return m_data[index]; }
|
ElementT GetMask(size_t index) const { return m_data[index]; }
|
||||||
|
|
||||||
void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); }
|
|
||||||
|
|
||||||
template <size_t MaxWeight>
|
|
||||||
WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue<MaxWeight>(containerIndex, elementIndex); }
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BitContainerT m_data;
|
BitContainerT m_data;
|
||||||
WeightContainersT m_weights;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,234 +1,234 @@
|
|||||||
#pragma once
|
// #pragma once
|
||||||
|
|
||||||
#include <array>
|
// #include <array>
|
||||||
#include <span>
|
// #include <span>
|
||||||
#include <tuple>
|
// #include <tuple>
|
||||||
|
|
||||||
#include "wfc_bit_container.hpp"
|
// #include "wfc_bit_container.hpp"
|
||||||
#include "wfc_utils.hpp"
|
// #include "wfc_utils.hpp"
|
||||||
|
|
||||||
namespace WFC {
|
// namespace WFC {
|
||||||
|
|
||||||
template <typename VariableMap, uint8_t Precision>
|
// template <typename VariableMap, uint8_t Precision>
|
||||||
struct PrecisionEntry
|
// struct PrecisionEntry
|
||||||
{
|
// {
|
||||||
|
|
||||||
constexpr static uint8_t PrecisionValue = Precision;
|
// constexpr static uint8_t PrecisionValue = Precision;
|
||||||
|
|
||||||
template <typename MainVariableMap>
|
// template <typename MainVariableMap>
|
||||||
constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
|
// constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
|
||||||
{
|
// {
|
||||||
constexpr auto SelectedEntries = VariableMap::GetAllValues();
|
// constexpr auto SelectedEntries = VariableMap::GetAllValues();
|
||||||
for (auto entry : SelectedEntries)
|
// for (auto entry : SelectedEntries)
|
||||||
{
|
// {
|
||||||
precisions[MainVariableMap::GetIndex(entry)] = Precision;
|
// precisions[MainVariableMap::GetIndex(entry)] = Precision;
|
||||||
}
|
// }
|
||||||
return true;
|
// return true;
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
enum class EPrecision : uint8_t
|
// enum class EPrecision : uint8_t
|
||||||
{
|
// {
|
||||||
Precision_0 = 0,
|
// Precision_0 = 0,
|
||||||
Precision_2 = 2,
|
// Precision_2 = 2,
|
||||||
Precision_4 = 4,
|
// Precision_4 = 4,
|
||||||
Precision_8 = 8,
|
// Precision_8 = 8,
|
||||||
Precision_16 = 16,
|
// Precision_16 = 16,
|
||||||
Precision_32 = 32,
|
// Precision_32 = 32,
|
||||||
Precision_64 = 64,
|
// Precision_64 = 64,
|
||||||
};
|
// };
|
||||||
|
|
||||||
template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
|
// template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
|
||||||
class WeightContainers
|
// class WeightContainers
|
||||||
{
|
// {
|
||||||
private:
|
// private:
|
||||||
template <EPrecision Precision>
|
// template <EPrecision Precision>
|
||||||
using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
|
// using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
|
||||||
|
|
||||||
using TupleT = std::tuple<BitContainerT<Precisions>...>;
|
// using TupleT = std::tuple<BitContainerT<Precisions>...>;
|
||||||
TupleT m_WeightContainers;
|
// TupleT m_WeightContainers;
|
||||||
|
|
||||||
static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
|
// static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
|
||||||
|
|
||||||
public:
|
// public:
|
||||||
WeightContainers() = default;
|
// WeightContainers() = default;
|
||||||
WeightContainers(size_t size)
|
// WeightContainers(size_t size)
|
||||||
: m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
|
// : m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
|
||||||
{}
|
// {}
|
||||||
WeightContainers(size_t size, AllocatorT& allocator)
|
// WeightContainers(size_t size, AllocatorT& allocator)
|
||||||
: m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
|
// : m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
|
||||||
{}
|
// {}
|
||||||
|
|
||||||
public:
|
// public:
|
||||||
static constexpr size_t size()
|
// static constexpr size_t size()
|
||||||
{
|
// {
|
||||||
return sizeof...(Precisions);
|
// return sizeof...(Precisions);
|
||||||
}
|
// }
|
||||||
|
|
||||||
/*
|
// /*
|
||||||
template <typename ValueT>
|
// template <typename ValueT>
|
||||||
void SetValue(size_t containerIndex, size_t index, ValueT value)
|
// void SetValue(size_t containerIndex, size_t index, ValueT value)
|
||||||
{
|
// {
|
||||||
SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
|
// SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
|
||||||
}
|
// }
|
||||||
*/
|
// */
|
||||||
void SetValueFloat(size_t containerIndex, size_t index, double value)
|
// void SetValueFloat(size_t containerIndex, size_t index, double value)
|
||||||
{
|
// {
|
||||||
SetFloatValueFunctions()[containerIndex](*this, index, value);
|
// SetFloatValueFunctions()[containerIndex](*this, index, value);
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <size_t MaxWeight>
|
// template <size_t MaxWeight>
|
||||||
uint64_t GetValue(size_t containerIndex, size_t index)
|
// uint64_t GetValue(size_t containerIndex, size_t index)
|
||||||
{
|
// {
|
||||||
return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
|
// return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
|
||||||
}
|
// }
|
||||||
|
|
||||||
private:
|
// private:
|
||||||
/*
|
// /*
|
||||||
template <typename ValueT>
|
// template <typename ValueT>
|
||||||
static constexpr auto& SetValueFunctions()
|
// static constexpr auto& SetValueFunctions()
|
||||||
{
|
// {
|
||||||
return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
|
// return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename ValueT, size_t ... Is>
|
// template <typename ValueT, size_t ... Is>
|
||||||
static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
|
// static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
|
||||||
{
|
// {
|
||||||
static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
|
// static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
|
||||||
{
|
// {
|
||||||
[] (WeightContainers& weightContainers, size_t index, ValueT value) {
|
// [] (WeightContainers& weightContainers, size_t index, ValueT value) {
|
||||||
std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
|
// std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
|
||||||
},
|
// },
|
||||||
...
|
// ...
|
||||||
};
|
// };
|
||||||
return setValueFunctions;
|
// return setValueFunctions;
|
||||||
}
|
// }
|
||||||
*/
|
// */
|
||||||
static constexpr auto& SetFloatValueFunctions()
|
// static constexpr auto& SetFloatValueFunctions()
|
||||||
{
|
// {
|
||||||
return SetFloatValueFunctions(std::make_index_sequence<size()>());
|
// return SetFloatValueFunctions(std::make_index_sequence<size()>());
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <size_t ... Is>
|
// template <size_t ... Is>
|
||||||
static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
|
// static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
|
||||||
{
|
// {
|
||||||
using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
|
// using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
|
||||||
constexpr std::array<FunctionT, size()> setFloatValueFunctions
|
// constexpr std::array<FunctionT, size()> setFloatValueFunctions
|
||||||
{
|
// {
|
||||||
[](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
|
// [](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
|
||||||
|
|
||||||
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
|
// using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
|
||||||
if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
|
// if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
|
||||||
{
|
// {
|
||||||
constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
|
// constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
|
||||||
std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
|
// std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
...
|
// ...
|
||||||
};
|
// };
|
||||||
return setFloatValueFunctions;
|
// return setFloatValueFunctions;
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <size_t MaxWeight>
|
// template <size_t MaxWeight>
|
||||||
static constexpr auto& GetValueFunctions()
|
// static constexpr auto& GetValueFunctions()
|
||||||
{
|
// {
|
||||||
return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
|
// return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <size_t MaxWeight, size_t ... Is>
|
// template <size_t MaxWeight, size_t ... Is>
|
||||||
static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
|
// static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
|
||||||
{
|
// {
|
||||||
using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
|
// using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
|
||||||
constexpr std::array<FunctionT, size()> getValueFunctions =
|
// constexpr std::array<FunctionT, size()> getValueFunctions =
|
||||||
{
|
// {
|
||||||
[] (WeightContainers& weightContainers, size_t index) -> FunctionT {
|
// [] (WeightContainers& weightContainers, size_t index) -> FunctionT {
|
||||||
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
|
// using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
|
||||||
|
|
||||||
if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
|
// if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
|
||||||
{
|
// {
|
||||||
return MaxWeight / 2;
|
// return MaxWeight / 2;
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
{
|
// {
|
||||||
constexpr size_t maxValue = BitContainerEntryT::MaxValue;
|
// constexpr size_t maxValue = BitContainerEntryT::MaxValue;
|
||||||
if constexpr (maxValue <= MaxWeight)
|
// if constexpr (maxValue <= MaxWeight)
|
||||||
{
|
// {
|
||||||
return std::get<Is>(weightContainers.m_WeightContainers)[index];
|
// return std::get<Is>(weightContainers.m_WeightContainers)[index];
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
{
|
// {
|
||||||
return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
|
// return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
...
|
// ...
|
||||||
};
|
// };
|
||||||
return getValueFunctions;
|
// return getValueFunctions;
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
/**
|
// /**
|
||||||
* @brief Compile-time weights storage for weighted random selection
|
// * @brief Compile-time weights storage for weighted random selection
|
||||||
* @tparam VarT The variable type
|
// * @tparam VarT The variable type
|
||||||
* @tparam VariableIDMapT The variable ID map type
|
// * @tparam VariableIDMapT The variable ID map type
|
||||||
* @tparam DefaultWeight The default weight for values not explicitly specified
|
// * @tparam DefaultWeight The default weight for values not explicitly specified
|
||||||
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
// * @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||||
*/
|
// */
|
||||||
template <typename VariableIDMapT, typename ... PrecisionEntries>
|
// template <typename VariableIDMapT, typename ... PrecisionEntries>
|
||||||
class WeightsMap {
|
// class WeightsMap {
|
||||||
public:
|
// public:
|
||||||
static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
|
// static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
|
||||||
{
|
// {
|
||||||
std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
|
// std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
|
||||||
|
|
||||||
(PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
|
// (PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
|
||||||
|
|
||||||
return precisionArray;
|
// return precisionArray;
|
||||||
}
|
// }
|
||||||
|
|
||||||
static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
|
// static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
|
||||||
{
|
// {
|
||||||
constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
|
// constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
|
||||||
return precisionArray;
|
// return precisionArray;
|
||||||
}
|
// }
|
||||||
|
|
||||||
static constexpr size_t GetPrecision(size_t index)
|
// static constexpr size_t GetPrecision(size_t index)
|
||||||
{
|
// {
|
||||||
return GetPrecisionArray()[index];
|
// return GetPrecisionArray()[index];
|
||||||
}
|
// }
|
||||||
|
|
||||||
static constexpr uint8_t GetMaxPrecision()
|
// static constexpr uint8_t GetMaxPrecision()
|
||||||
{
|
// {
|
||||||
return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
|
// return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
|
||||||
}
|
// }
|
||||||
|
|
||||||
static constexpr uint8_t GetMaxValue()
|
// static constexpr uint8_t GetMaxValue()
|
||||||
{
|
// {
|
||||||
return (1 << GetMaxPrecision()) - 1;
|
// return (1 << GetMaxPrecision()) - 1;
|
||||||
}
|
// }
|
||||||
|
|
||||||
static constexpr bool HasWeights()
|
// static constexpr bool HasWeights()
|
||||||
{
|
// {
|
||||||
return sizeof...(PrecisionEntries) > 0;
|
// return sizeof...(PrecisionEntries) > 0;
|
||||||
}
|
// }
|
||||||
|
|
||||||
public:
|
// public:
|
||||||
|
|
||||||
using VariablesT = VariableIDMapT;
|
// using VariablesT = VariableIDMapT;
|
||||||
|
|
||||||
template<size_t Size, typename AllocatorT, size_t... Is>
|
// template<size_t Size, typename AllocatorT, size_t... Is>
|
||||||
auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
|
// auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
|
||||||
-> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
|
// -> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
|
||||||
|
|
||||||
template <size_t Size, typename AllocatorT>
|
// template <size_t Size, typename AllocatorT>
|
||||||
using WeightContainersT = decltype(
|
// using WeightContainersT = decltype(
|
||||||
MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
|
// MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
|
||||||
);
|
// );
|
||||||
|
|
||||||
template <typename PrecisionEntryT>
|
// template <typename PrecisionEntryT>
|
||||||
using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
|
// using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
|
||||||
|
|
||||||
using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
|
// using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
|
||||||
};
|
// };
|
||||||
|
|
||||||
}
|
// }
|
||||||
Reference in New Issue
Block a user