198 lines
5.0 KiB
C++
198 lines
5.0 KiB
C++
#pragma once
|
|
|
|
#include <memory>
|
|
#include <type_traits>
|
|
|
|
template <typename T, typename WeightType>
|
|
class RandomPickerT
|
|
{
|
|
public:
|
|
struct Entry
|
|
{
|
|
Entry() = default;
|
|
Entry(const T& val, WeightType weight) : Val{ val }, Weight{ weight } {}
|
|
|
|
T Val{};
|
|
const WeightType Weight{};
|
|
WeightType AccumulatedWeight{};
|
|
WeightType GetWeightSum() { return AccumulatedWeight + Weight; }
|
|
};
|
|
|
|
public:
|
|
std::unique_ptr<Entry[]> Entries{};
|
|
WeightType TotalWeight{};
|
|
uint32_t TotalEntries{};
|
|
const uint32_t DataSize{};
|
|
|
|
public:
|
|
RandomPickerT() = default;
|
|
|
|
template <typename Container, typename WeightGetter>
|
|
RandomPickerT(const Container& container, WeightGetter getter)
|
|
: Entries{ std::make_unique<Entry[]>(container.size()) }
|
|
, TotalEntries{ container.size() }
|
|
, DataSize{ container.size() }
|
|
{
|
|
int counter{};
|
|
for (const auto& entry : container)
|
|
{
|
|
auto weight = getter(entry);
|
|
Entries[counter++] = Entry{entry, weight};
|
|
}
|
|
|
|
RecalculateWeights();
|
|
}
|
|
|
|
template <typename EntriesT>
|
|
RandomPickerT(const EntriesT& entries)
|
|
: Entries{ std::make_unique<Entry[]>(entries.size()) }
|
|
, TotalEntries{ entries.size() }
|
|
, DataSize{ entries.size() }
|
|
{
|
|
int counter{};
|
|
for (const auto& entry : entries)
|
|
{
|
|
Entries[counter++] = entry;
|
|
}
|
|
|
|
RecalculateWeights();
|
|
}
|
|
|
|
public:
|
|
T GetRandom(int& index, WeightType randomVal) const
|
|
{
|
|
WeightType weight = GetRandomWeight(randomVal);
|
|
index = GetIndex(weight);
|
|
|
|
return Entries[index].Val;
|
|
}
|
|
|
|
int GetIndex(WeightType weight) const
|
|
{
|
|
if (weight < 0 || weight > TotalWeight || TotalEntries == 0)
|
|
return -1;
|
|
if (TotalEntries == 1)
|
|
return 0;
|
|
return BinarySearchRecursive(weight, 0, TotalEntries - 1);
|
|
}
|
|
|
|
void RemoveEntry(int index)
|
|
{
|
|
if (ValidateIndex(index)) RemoveEntryInternal(index);
|
|
}
|
|
|
|
T GetAndRemoveRandom(int& index, WeightType randomVal)
|
|
{
|
|
WeightType weight = GetRandomWeight(randomVal);
|
|
int index = GetIndex(weight);
|
|
|
|
auto returnVal = Entries[index].Val;
|
|
RemoveEntryInternal(index);
|
|
|
|
return returnVal;
|
|
}
|
|
|
|
T GetAndRemoveRandom(WeightType randomVal)
|
|
{
|
|
int index{};
|
|
return GetAndRemoveRandom(index, randomVal);
|
|
}
|
|
|
|
T Peek(int index) const
|
|
{
|
|
_ASSERT(ValidateIndex(index));
|
|
return Entries[index];
|
|
}
|
|
|
|
void Reset()
|
|
{
|
|
TotalEntries = DataSize;
|
|
RecalculateWeights();
|
|
}
|
|
|
|
template <typename WeightGetter>
|
|
void SetObjectAtIndex(int index, const T& val, WeightGetter getter)
|
|
{
|
|
if (!ValidateIndex(index)) return;
|
|
|
|
auto entry = Entries[index];
|
|
|
|
auto newWeight = getter(val);
|
|
bool recalculateWeights = newWeight != entry.Weight;
|
|
|
|
Entries[index] = Entry{ val, newWeight };
|
|
|
|
if (recalculateWeights) RecalculateWeights();
|
|
}
|
|
|
|
void SetEntryAtIndex(int index, const Entry& entry)
|
|
{
|
|
if (!ValidateIndex(index)) return;
|
|
|
|
bool recalculateWeights = entry.Weight != Entries[index].Weight;
|
|
Entries[index] = entry;
|
|
|
|
if (recalculateWeights) RecalculateWeights();
|
|
}
|
|
|
|
private:
|
|
void RecalculateWeights()
|
|
{
|
|
WeightType accumulatedWeight = 0;
|
|
for (uint32_t i{}; i < TotalEntries; ++i)
|
|
{
|
|
Entries[i].AccumulatedWeight = accumulatedWeight;
|
|
accumulatedWeight += Entries[i].Weight;
|
|
}
|
|
TotalWeight = accumulatedWeight;
|
|
}
|
|
|
|
WeightType GetRandomWeight(WeightType randomVal) const
|
|
{
|
|
if constexpr (std::is_integral_v<WeightType>)
|
|
{
|
|
return randomVal % TotalWeight;
|
|
}
|
|
if constexpr (std::is_floating_point_v<WeightType>)
|
|
{
|
|
return std::fmod(randomVal, TotalWeight);
|
|
}
|
|
}
|
|
|
|
int BinarySearchRecursive(WeightType weight, int min, int max) const
|
|
{
|
|
int middle = (min + max) >> 1;
|
|
|
|
auto entry = Entries[middle];
|
|
if (weight >= entry.AccumulatedWeight && weight < entry.GetWeightSum())
|
|
return middle;
|
|
|
|
return (weight < entry.AccumulatedWeight) ? BinarySearchRecursive(weight, min, middle - 1) : BinarySearchRecursive(weight, middle + 1, max);
|
|
}
|
|
|
|
void RemoveEntryInternal(int index)
|
|
{
|
|
if (!ValidateIndex(index))
|
|
return;
|
|
|
|
std::swap(Entries[index], Entries.back());
|
|
--TotalEntries;
|
|
|
|
if (TotalEntries != 0) RecalculateWeights();
|
|
else TotalWeight = 0;
|
|
}
|
|
|
|
bool ValidateIndex(int index) const
|
|
{
|
|
return index >= 0 && index < TotalEntries;
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
using RandomPicker32 = RandomPickerT<T, uint32_t>;
|
|
template <typename T>
|
|
using RandomPicker64 = RandomPickerT<T, uint64_t>;
|
|
template <typename T>
|
|
using RandomPickerF = RandomPickerT<T, float>;
|
|
template <typename T>
|
|
using RandomPickerD = RandomPickerT<T, double>; |