Files
factory-hole-core/include/Util/RandomPicker.h
2026-02-09 00:53:38 +09:00

198 lines
5.0 KiB
C++

#pragma once
#include <memory>
#include <type_traits>
template <typename T, typename WeightType>
class RandomPickerT
{
public:
struct Entry
{
Entry() = default;
Entry(const T& val, WeightType weight) : Val{ val }, Weight{ weight } {}
T Val{};
const WeightType Weight{};
WeightType AccumulatedWeight{};
WeightType GetWeightSum() { return AccumulatedWeight + Weight; }
};
public:
std::unique_ptr<Entry[]> Entries{};
WeightType TotalWeight{};
uint32_t TotalEntries{};
const uint32_t DataSize{};
public:
RandomPickerT() = default;
template <typename Container, typename WeightGetter>
RandomPickerT(const Container& container, WeightGetter getter)
: Entries{ std::make_unique<Entry[]>(container.size()) }
, TotalEntries{ container.size() }
, DataSize{ container.size() }
{
int counter{};
for (const auto& entry : container)
{
auto weight = getter(entry);
Entries[counter++] = Entry{entry, weight};
}
RecalculateWeights();
}
template <typename EntriesT>
RandomPickerT(const EntriesT& entries)
: Entries{ std::make_unique<Entry[]>(entries.size()) }
, TotalEntries{ entries.size() }
, DataSize{ entries.size() }
{
int counter{};
for (const auto& entry : entries)
{
Entries[counter++] = entry;
}
RecalculateWeights();
}
public:
T GetRandom(int& index, WeightType randomVal) const
{
WeightType weight = GetRandomWeight(randomVal);
index = GetIndex(weight);
return Entries[index].Val;
}
int GetIndex(WeightType weight) const
{
if (weight < 0 || weight > TotalWeight || TotalEntries == 0)
return -1;
if (TotalEntries == 1)
return 0;
return BinarySearchRecursive(weight, 0, TotalEntries - 1);
}
void RemoveEntry(int index)
{
if (ValidateIndex(index)) RemoveEntryInternal(index);
}
T GetAndRemoveRandom(int& index, WeightType randomVal)
{
WeightType weight = GetRandomWeight(randomVal);
int index = GetIndex(weight);
auto returnVal = Entries[index].Val;
RemoveEntryInternal(index);
return returnVal;
}
T GetAndRemoveRandom(WeightType randomVal)
{
int index{};
return GetAndRemoveRandom(index, randomVal);
}
T Peek(int index) const
{
_ASSERT(ValidateIndex(index));
return Entries[index];
}
void Reset()
{
TotalEntries = DataSize;
RecalculateWeights();
}
template <typename WeightGetter>
void SetObjectAtIndex(int index, const T& val, WeightGetter getter)
{
if (!ValidateIndex(index)) return;
auto entry = Entries[index];
auto newWeight = getter(val);
bool recalculateWeights = newWeight != entry.Weight;
Entries[index] = Entry{ val, newWeight };
if (recalculateWeights) RecalculateWeights();
}
void SetEntryAtIndex(int index, const Entry& entry)
{
if (!ValidateIndex(index)) return;
bool recalculateWeights = entry.Weight != Entries[index].Weight;
Entries[index] = entry;
if (recalculateWeights) RecalculateWeights();
}
private:
void RecalculateWeights()
{
WeightType accumulatedWeight = 0;
for (uint32_t i{}; i < TotalEntries; ++i)
{
Entries[i].AccumulatedWeight = accumulatedWeight;
accumulatedWeight += Entries[i].Weight;
}
TotalWeight = accumulatedWeight;
}
WeightType GetRandomWeight(WeightType randomVal) const
{
if constexpr (std::is_integral_v<WeightType>)
{
return randomVal % TotalWeight;
}
if constexpr (std::is_floating_point_v<WeightType>)
{
return std::fmod(randomVal, TotalWeight);
}
}
int BinarySearchRecursive(WeightType weight, int min, int max) const
{
int middle = (min + max) >> 1;
auto entry = Entries[middle];
if (weight >= entry.AccumulatedWeight && weight < entry.GetWeightSum())
return middle;
return (weight < entry.AccumulatedWeight) ? BinarySearchRecursive(weight, min, middle - 1) : BinarySearchRecursive(weight, middle + 1, max);
}
void RemoveEntryInternal(int index)
{
if (!ValidateIndex(index))
return;
std::swap(Entries[index], Entries.back());
--TotalEntries;
if (TotalEntries != 0) RecalculateWeights();
else TotalWeight = 0;
}
bool ValidateIndex(int index) const
{
return index >= 0 && index < TotalEntries;
}
};
template <typename T>
using RandomPicker32 = RandomPickerT<T, uint32_t>;
template <typename T>
using RandomPicker64 = RandomPickerT<T, uint64_t>;
template <typename T>
using RandomPickerF = RandomPickerT<T, float>;
template <typename T>
using RandomPickerD = RandomPickerT<T, double>;