removed weights

This commit is contained in:
Connor
2026-02-06 12:07:55 +09:00
parent 94cd003e96
commit ded7ebc285
4 changed files with 217 additions and 297 deletions

View File

@@ -23,7 +23,6 @@
#include "wfc_callbacks.hpp"
#include "wfc_random.hpp"
#include "wfc_queue.hpp"
#include "wfc_weights.hpp"
namespace WFC {
@@ -50,7 +49,6 @@ template<typename WorldT, typename VarT,
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>,
typename WeightsMapT = WeightsMap<VariableIDMapT>
>
class WFC {
public:
@@ -61,12 +59,11 @@ public:
// Try getting the world size, which is only available if the world type has a constexpr size() method
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
using WaveType = Wave<VariableIDMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
using MaskType = typename WaveType::ElementT;
using VariableIDT = typename WaveType::VariableIDT;
using WeightsBufferType = BitContainer<WeightsMapT::GetMaxPrecision(), VariableIDMapT::size()>;
public:
struct SolverState
@@ -217,86 +214,25 @@ private:
return minEntropyCell;
}
using WeightsType = MinimumIntegerType<WeightsMapT::GetMaxValue()>;
using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1)));
static WorldSizeT FindMinimumEntropyCellsWeighted(std::span<VariableIDT>& buffer, std::span<WeightsType>& weights, WaveType& wave)
{
constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits<RandomGeneratorReturnType>::max()) / VariableIDMapT::size();
auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t
{
auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; };
auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight<ElementsMaxWeight>(index); };
auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter);
return std::accumulate(view.begin(), view.end(), 0);
};
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter));
VariableIDT availableValues = wave.Entropy(minEntropyCell);
MaskType mask = wave.GetMask(minEntropyCell);
for (size_t i = 0; i < availableValues; ++i)
{
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
buffer[i] = index;
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
weights[i] = wave.template GetWeight<ElementsMaxWeight>(index);
mask = mask & (mask - 1); // turn off lowest set bit
}
return minEntropyCell;
}
static bool Branch(SolverState& state, WaveType& wave)
{
constexpr_assert(state.m_propagationQueue.empty());
std::array<VariableIDT, VariableIDMapT::size()> Buffer{};
std::array<WeightsType, VariableIDMapT::size()> WeightsBuffer{};
uint64_t accumulatedWeights = 0;
WorldSizeT minEntropyCell{};
if constexpr (WeightsMapT::HasWeights())
{
minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave);
accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0);
}
else
{
minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
}
// randomly select a value from possible values
while (Buffer.size())
{
size_t randomIndex;
VariableIDT selectedValue;
if constexpr (WeightsMapT::HasWeights())
{
auto randomWeight = state.m_randomSelector.rng(accumulatedWeights);
for (size_t i = 0; i < WeightsBuffer.size(); ++i)
{
if (randomWeight < WeightsBuffer[i])
{
randomIndex = i;
selectedValue = Buffer[i];
break;
}
randomWeight -= WeightsBuffer[i];
}
accumulatedWeights -= WeightsBuffer[randomIndex];
}
else
{
randomIndex = state.m_randomSelector.rng(Buffer.size());
selectedValue = Buffer[randomIndex];
}
{
// copy the state and branch out

View File

@@ -7,7 +7,6 @@ namespace WFC {
#include "wfc_constrainer.hpp"
#include "wfc_callbacks.hpp"
#include "wfc_random.hpp"
#include "wfc_weights.hpp"
#include "wfc.hpp"
/**
@@ -20,33 +19,32 @@ template<
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>,
typename WeightsMapT = WeightsMap<VariableIDMapT>,
typename SelectedValueT = void>
class Builder {
public:
using WorldSizeT = decltype(WorldT{}.size());
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
using WaveType = Wave<VariableIDMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
template <VarT ... Values>
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
template <size_t RangeStart, size_t RangeEnd>
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
template <size_t RangeEnd>
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, 0, RangeEnd>>;
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, 0, RangeEnd>>;
template <VarT ... Values>
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
template <size_t RangeStart, size_t RangeEnd>
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>;
@@ -60,7 +58,7 @@ public:
ConstrainerFunctionT,
SelectedValueT,
EmptyConstrainerFunctionT
>, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT
>, CallbacksT, RandomSelectorT, SelectedValueT
>;
template <typename ConstrainerFunctionT>
@@ -72,27 +70,22 @@ public:
ConstrainerFunctionT,
VariableIDMapT,
EmptyConstrainerFunctionT
>, CallbacksT, RandomSelectorT, WeightsMapT
>, CallbacksT, RandomSelectorT
>;
template <typename NewCellCollapsedCallbackT>
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT, WeightsMapT>;
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
template <typename NewContradictionCallbackT>
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT, WeightsMapT>;
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
template <typename NewBranchCallbackT>
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT, WeightsMapT>;
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
template <typename NewRandomSelectorT>
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT, WeightsMapT>;
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
template <EPrecision Precision>
using SetWeights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, typename WeightsMapT::template Merge<PrecisionEntry<SelectedValueT, static_cast<uint8_t>(Precision)>>, SelectedValueT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
};
}

View File

@@ -6,15 +6,13 @@
namespace WFC {
template <typename VariableIDMapT, typename WeightsMapT, size_t Size = 0>
template <typename VariableIDMapT, size_t Size = 0>
class Wave {
public:
using BitContainerT = BitContainer<VariableIDMapT::size(), Size>;
using ElementT = typename BitContainerT::StorageType;
using IDMapT = VariableIDMapT;
using WeightContainersT = typename WeightsMapT::template WeightContainersT<Size, WFCStackAllocator>;
using VariableIDT = typename VariableIDMapT::VariableIDT;
using WeightT = typename WeightsMapT::WeightT;
static constexpr size_t ElementsAmount = Size;
@@ -38,15 +36,8 @@ public:
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
ElementT GetMask(size_t index) const { return m_data[index]; }
void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); }
template <size_t MaxWeight>
WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue<MaxWeight>(containerIndex, elementIndex); }
private:
BitContainerT m_data;
WeightContainersT m_weights;
};
}

View File

@@ -1,234 +1,234 @@
#pragma once
// #pragma once
#include <array>
#include <span>
#include <tuple>
// #include <array>
// #include <span>
// #include <tuple>
#include "wfc_bit_container.hpp"
#include "wfc_utils.hpp"
// #include "wfc_bit_container.hpp"
// #include "wfc_utils.hpp"
namespace WFC {
// namespace WFC {
template <typename VariableMap, uint8_t Precision>
struct PrecisionEntry
{
// template <typename VariableMap, uint8_t Precision>
// struct PrecisionEntry
// {
constexpr static uint8_t PrecisionValue = Precision;
// constexpr static uint8_t PrecisionValue = Precision;
template <typename MainVariableMap>
constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
{
constexpr auto SelectedEntries = VariableMap::GetAllValues();
for (auto entry : SelectedEntries)
{
precisions[MainVariableMap::GetIndex(entry)] = Precision;
}
return true;
}
};
// template <typename MainVariableMap>
// constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
// {
// constexpr auto SelectedEntries = VariableMap::GetAllValues();
// for (auto entry : SelectedEntries)
// {
// precisions[MainVariableMap::GetIndex(entry)] = Precision;
// }
// return true;
// }
// };
enum class EPrecision : uint8_t
{
Precision_0 = 0,
Precision_2 = 2,
Precision_4 = 4,
Precision_8 = 8,
Precision_16 = 16,
Precision_32 = 32,
Precision_64 = 64,
};
// enum class EPrecision : uint8_t
// {
// Precision_0 = 0,
// Precision_2 = 2,
// Precision_4 = 4,
// Precision_8 = 8,
// Precision_16 = 16,
// Precision_32 = 32,
// Precision_64 = 64,
// };
template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
class WeightContainers
{
private:
template <EPrecision Precision>
using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
// template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
// class WeightContainers
// {
// private:
// template <EPrecision Precision>
// using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
using TupleT = std::tuple<BitContainerT<Precisions>...>;
TupleT m_WeightContainers;
// using TupleT = std::tuple<BitContainerT<Precisions>...>;
// TupleT m_WeightContainers;
static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
// static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
public:
WeightContainers() = default;
WeightContainers(size_t size)
: m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
{}
WeightContainers(size_t size, AllocatorT& allocator)
: m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
{}
// public:
// WeightContainers() = default;
// WeightContainers(size_t size)
// : m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
// {}
// WeightContainers(size_t size, AllocatorT& allocator)
// : m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
// {}
public:
static constexpr size_t size()
{
return sizeof...(Precisions);
}
// public:
// static constexpr size_t size()
// {
// return sizeof...(Precisions);
// }
/*
template <typename ValueT>
void SetValue(size_t containerIndex, size_t index, ValueT value)
{
SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
}
*/
void SetValueFloat(size_t containerIndex, size_t index, double value)
{
SetFloatValueFunctions()[containerIndex](*this, index, value);
}
// /*
// template <typename ValueT>
// void SetValue(size_t containerIndex, size_t index, ValueT value)
// {
// SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
// }
// */
// void SetValueFloat(size_t containerIndex, size_t index, double value)
// {
// SetFloatValueFunctions()[containerIndex](*this, index, value);
// }
template <size_t MaxWeight>
uint64_t GetValue(size_t containerIndex, size_t index)
{
return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
}
// template <size_t MaxWeight>
// uint64_t GetValue(size_t containerIndex, size_t index)
// {
// return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
// }
private:
/*
template <typename ValueT>
static constexpr auto& SetValueFunctions()
{
return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
}
// private:
// /*
// template <typename ValueT>
// static constexpr auto& SetValueFunctions()
// {
// return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
// }
template <typename ValueT, size_t ... Is>
static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
{
static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
{
[] (WeightContainers& weightContainers, size_t index, ValueT value) {
std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
},
...
};
return setValueFunctions;
}
*/
static constexpr auto& SetFloatValueFunctions()
{
return SetFloatValueFunctions(std::make_index_sequence<size()>());
}
// template <typename ValueT, size_t ... Is>
// static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
// {
// static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
// {
// [] (WeightContainers& weightContainers, size_t index, ValueT value) {
// std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
// },
// ...
// };
// return setValueFunctions;
// }
// */
// static constexpr auto& SetFloatValueFunctions()
// {
// return SetFloatValueFunctions(std::make_index_sequence<size()>());
// }
template <size_t ... Is>
static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
{
using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
constexpr std::array<FunctionT, size()> setFloatValueFunctions
{
[](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
// template <size_t ... Is>
// static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
// {
// using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
// constexpr std::array<FunctionT, size()> setFloatValueFunctions
// {
// [](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
{
constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
}
}
...
};
return setFloatValueFunctions;
}
// using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
// if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
// {
// constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
// std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
// }
// }
// ...
// };
// return setFloatValueFunctions;
// }
template <size_t MaxWeight>
static constexpr auto& GetValueFunctions()
{
return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
}
// template <size_t MaxWeight>
// static constexpr auto& GetValueFunctions()
// {
// return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
// }
template <size_t MaxWeight, size_t ... Is>
static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
{
using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
constexpr std::array<FunctionT, size()> getValueFunctions =
{
[] (WeightContainers& weightContainers, size_t index) -> FunctionT {
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
// template <size_t MaxWeight, size_t ... Is>
// static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
// {
// using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
// constexpr std::array<FunctionT, size()> getValueFunctions =
// {
// [] (WeightContainers& weightContainers, size_t index) -> FunctionT {
// using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
{
return MaxWeight / 2;
}
else
{
constexpr size_t maxValue = BitContainerEntryT::MaxValue;
if constexpr (maxValue <= MaxWeight)
{
return std::get<Is>(weightContainers.m_WeightContainers)[index];
}
else
{
return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
}
}
}
...
};
return getValueFunctions;
}
};
// if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
// {
// return MaxWeight / 2;
// }
// else
// {
// constexpr size_t maxValue = BitContainerEntryT::MaxValue;
// if constexpr (maxValue <= MaxWeight)
// {
// return std::get<Is>(weightContainers.m_WeightContainers)[index];
// }
// else
// {
// return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
// }
// }
// }
// ...
// };
// return getValueFunctions;
// }
// };
/**
* @brief Compile-time weights storage for weighted random selection
* @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type
* @tparam DefaultWeight The default weight for values not explicitly specified
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
*/
template <typename VariableIDMapT, typename ... PrecisionEntries>
class WeightsMap {
public:
static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
{
std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
// /**
// * @brief Compile-time weights storage for weighted random selection
// * @tparam VarT The variable type
// * @tparam VariableIDMapT The variable ID map type
// * @tparam DefaultWeight The default weight for values not explicitly specified
// * @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
// */
// template <typename VariableIDMapT, typename ... PrecisionEntries>
// class WeightsMap {
// public:
// static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
// {
// std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
(PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
// (PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
return precisionArray;
}
// return precisionArray;
// }
static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
{
constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
return precisionArray;
}
// static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
// {
// constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
// return precisionArray;
// }
static constexpr size_t GetPrecision(size_t index)
{
return GetPrecisionArray()[index];
}
// static constexpr size_t GetPrecision(size_t index)
// {
// return GetPrecisionArray()[index];
// }
static constexpr uint8_t GetMaxPrecision()
{
return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
}
// static constexpr uint8_t GetMaxPrecision()
// {
// return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
// }
static constexpr uint8_t GetMaxValue()
{
return (1 << GetMaxPrecision()) - 1;
}
// static constexpr uint8_t GetMaxValue()
// {
// return (1 << GetMaxPrecision()) - 1;
// }
static constexpr bool HasWeights()
{
return sizeof...(PrecisionEntries) > 0;
}
// static constexpr bool HasWeights()
// {
// return sizeof...(PrecisionEntries) > 0;
// }
public:
// public:
using VariablesT = VariableIDMapT;
// using VariablesT = VariableIDMapT;
template<size_t Size, typename AllocatorT, size_t... Is>
auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
-> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
// template<size_t Size, typename AllocatorT, size_t... Is>
// auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
// -> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
template <size_t Size, typename AllocatorT>
using WeightContainersT = decltype(
MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
);
// template <size_t Size, typename AllocatorT>
// using WeightContainersT = decltype(
// MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
// );
template <typename PrecisionEntryT>
using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
// template <typename PrecisionEntryT>
// using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
};
// using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
// };
}
// }