#pragma once #include #include template class RandomPickerT { public: struct Entry { Entry() = default; Entry(const T& val, WeightType weight) : Val{ val }, Weight{ weight } {} T Val{}; const WeightType Weight{}; WeightType AccumulatedWeight{}; WeightType GetWeightSum() { return AccumulatedWeight + Weight; } }; public: std::unique_ptr Entries{}; WeightType TotalWeight{}; uint32_t TotalEntries{}; const uint32_t DataSize{}; public: RandomPickerT() = default; template RandomPickerT(const Container& container, WeightGetter getter) : Entries{ std::make_unique(container.size()) } , TotalEntries{ container.size() } , DataSize{ container.size() } { int counter{}; for (const auto& entry : container) { auto weight = getter(entry); Entries[counter++] = Entry{entry, weight}; } RecalculateWeights(); } template RandomPickerT(const EntriesT& entries) : Entries{ std::make_unique(entries.size()) } , TotalEntries{ entries.size() } , DataSize{ entries.size() } { int counter{}; for (const auto& entry : entries) { Entries[counter++] = entry; } RecalculateWeights(); } public: T GetRandom(int& index, WeightType randomVal) const { WeightType weight = GetRandomWeight(randomVal); index = GetIndex(weight); return Entries[index].Val; } int GetIndex(WeightType weight) const { if (weight < 0 || weight > TotalWeight || TotalEntries == 0) return -1; if (TotalEntries == 1) return 0; return BinarySearchRecursive(weight, 0, TotalEntries - 1); } void RemoveEntry(int index) { if (ValidateIndex(index)) RemoveEntryInternal(index); } T GetAndRemoveRandom(int& index, WeightType randomVal) { WeightType weight = GetRandomWeight(randomVal); int index = GetIndex(weight); auto returnVal = Entries[index].Val; RemoveEntryInternal(index); return returnVal; } T GetAndRemoveRandom(WeightType randomVal) { int index{}; return GetAndRemoveRandom(index, randomVal); } T Peek(int index) const { _ASSERT(ValidateIndex(index)); return Entries[index]; } void Reset() { TotalEntries = DataSize; RecalculateWeights(); } template void SetObjectAtIndex(int index, const T& val, WeightGetter getter) { if (!ValidateIndex(index)) return; auto entry = Entries[index]; auto newWeight = getter(val); bool recalculateWeights = newWeight != entry.Weight; Entries[index] = Entry{ val, newWeight }; if (recalculateWeights) RecalculateWeights(); } void SetEntryAtIndex(int index, const Entry& entry) { if (!ValidateIndex(index)) return; bool recalculateWeights = entry.Weight != Entries[index].Weight; Entries[index] = entry; if (recalculateWeights) RecalculateWeights(); } private: void RecalculateWeights() { WeightType accumulatedWeight = 0; for (uint32_t i{}; i < TotalEntries; ++i) { Entries[i].AccumulatedWeight = accumulatedWeight; accumulatedWeight += Entries[i].Weight; } TotalWeight = accumulatedWeight; } WeightType GetRandomWeight(WeightType randomVal) const { if constexpr (std::is_integral_v) { return randomVal % TotalWeight; } if constexpr (std::is_floating_point_v) { return std::fmod(randomVal, TotalWeight); } } int BinarySearchRecursive(WeightType weight, int min, int max) const { int middle = (min + max) >> 1; auto entry = Entries[middle]; if (weight >= entry.AccumulatedWeight && weight < entry.GetWeightSum()) return middle; return (weight < entry.AccumulatedWeight) ? BinarySearchRecursive(weight, min, middle - 1) : BinarySearchRecursive(weight, middle + 1, max); } void RemoveEntryInternal(int index) { if (!ValidateIndex(index)) return; std::swap(Entries[index], Entries.back()); --TotalEntries; if (TotalEntries != 0) RecalculateWeights(); else TotalWeight = 0; } bool ValidateIndex(int index) const { return index >= 0 && index < TotalEntries; } }; template using RandomPicker32 = RandomPickerT; template using RandomPicker64 = RandomPickerT; template using RandomPickerF = RandomPickerT; template using RandomPickerD = RandomPickerT;