diff options
Diffstat (limited to 'core/trainer_impl.hpp')
-rw-r--r-- | core/trainer_impl.hpp | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/core/trainer_impl.hpp b/core/trainer_impl.hpp new file mode 100644 index 0000000..d685d66 --- /dev/null +++ b/core/trainer_impl.hpp @@ -0,0 +1,162 @@ +#ifndef pAHg29dA1QrONep0XYIUhA0GC08 +#define pAHg29dA1QrONep0XYIUhA0GC08 + +#include <iostream> + +#include "trainer.hpp" +#include "model.hpp" + +#include "index.hpp" +#include "index_global.hpp" +#include "index_randomspike.hpp" +#include "index_spike.hpp" +#include "index_spike_arrival.hpp" + +namespace TrainerImpl { + +namespace MC = ModelConsts; + +const Time::type delays[6] = + { Time::epsilon()(), + MC::TrainerInputWindow + MC::TrainerReadoutDelay, + MC::TrainerReadoutWindow, + Time::epsilon()(), + Time::epsilon()(), + MC::TrainerInterTrialDelay }; + +const Time::type randDelays[6] = + { 0, + MC::TrainerReadoutRandDelay, + 0, 0, 0, + MC::TrainerInterTrialRandDelay }; + +TrainerT::TrainerT(RNG::seed_t seed) : + rng(seed), + delay(delays[0]), + reward(0), + performance(1.0 / MC::TrainerNumSymbols), + generation(0), + input(0), output(2), + state(0), + resetCounter(1) {} + +template <typename PC, typename MI, typename MQ> +struct Update { + static TrainerT eval(const TrainerT &old, PC &pc, + MI &indices, MQ &queues, Time t) { + TrainerT res = old; + res.generation++; + res.state++; + res.reward = 0; + + auto sendSpikes = [&](int mult, double window, int fanIn, double freq) -> void { + Time ct = t + Time::epsilon(); + while (ct < t + window) { + // determine dst + Ptr<Neuron> dst{uint16_t(RNG::integer(res.rng, fanIn) + * mult + % MC::NumExcitatory)}; + Ptr<Neuron> src(dst() + ((Ptr<Neuron>::ptr_t) maxNeurons/2)); + Ptr<RandomSpike> ptrNE(indices.template get<Index<RandomSpike>>().add(t, ct, src)); + queues.insert(t, ct, Event<RandomSpike>(ct, src, ptrNE)); + res.rng = RNG::next(res.rng); + + // determine next spike + ct += RNG::expo(res.rng, 1.0 / freq / fanIn); + res.rng = RNG::next(res.rng); + } + }; + + switch (res.state) { + case 1: // pre: let the network settle + case 7: { // pre: give last reward + // select and give input + res.input = RNG::integer(res.rng, MC::TrainerNumSymbols); + res.rng = RNG::next(res.rng); + res.state = 1; + const int mult_in[10] = // some arbitrarily selected prime number except for + {1, 997, 1013, 1021, 1033, 1049, 1061, 1069, 1091, 1097}; // symbol one + BOOST_STATIC_ASSERT((MC::TrainerNumSymbols <= 10)); + std::cerr << "INPUT " << t << std::endl; + sendSpikes(mult_in[res.input], + MC::TrainerInputWindow, MC::FanIn, MC::TrainerInputFreq); + } break; + + case 2: + res.resetCounter = 0; + break; + + case 4: case 6: // wait + break; + + case 3: { + // send readout signal + std::cerr << "READOUT SIGNAL " << t << std::endl; + if (MC::TrainerReadoutFreq > 0) + sendSpikes(967, // arbitrary prime + MC::TrainerReadoutWindow, MC::FanIn, MC::TrainerReadoutFreq); + } break; + + + case 5: { + // evaluate which symbol won, reward (TODO), wait for a longer time + const int mult_out[10] = // some arbitrarily selected prime number except for + {1, 937, 919, 907, 883, 877, 859, 853, 829, 823}; // symbol one + BOOST_STATIC_ASSERT((MC::TrainerNumSymbols <= 10)); + uint32_t freq[MC::TrainerNumSymbols + 1] = {}; + uint8_t maxIdx = MC::TrainerNumSymbols, + sndIdx = MC::TrainerNumSymbols; + uint16_t overlap = uint16_t(MC::FanIn / MC::NumExcitatory * MC::FanOut); + std::cerr << "READOUT"; + for (int o=0; o<MC::TrainerNumSymbols; o++) { + for (Ptr<Neuron>::ptr_t i = MC::FanIn - overlap; + i < MC::FanIn + MC::FanIn; + i++) { + Ptr<Neuron> src(uint16_t(mult_out[o] * i % MC::NumExcitatory)); + PLA_Get<SpikeCounter, ContinuousContext<Neuron> > + pla_get(t, ContinuousContext<Neuron>(src)); + freq[o] += pc.call(pla_get); + } + if (freq[o] > freq[maxIdx]) { + sndIdx = maxIdx; + maxIdx = o; + }else if (freq[o] > freq[sndIdx]) { + sndIdx = o; + } + std::cerr << " " << freq[o]; + } + bool correct = maxIdx == res.input; + res.output = maxIdx; + res.reward = (correct ? MC::TrainerReward : (-MC::TrainerPunish)) + * ((1.0 + MC::TrainerWinAdv) * freq[sndIdx] < freq[maxIdx]); + res.performance *= 0.9; + res.performance += correct * 0.1; + res.resetCounter = 1; + std::cerr << "\nWINNER @ " << t << " = " << uint16_t(maxIdx) + << " correct = " << uint16_t(res.input) + << " reward = " << res.reward + << "\n"; + } break; + + default: + assert(false); + } + + res.delay = delays[res.state] + + randDelays[res.state] * RNG::equi(res.rng); + res.rng = RNG::next(res.rng); + + return res; + } +}; + +/// dummy for sim_replay, must not be called +template <typename PC> +struct Update<PC, Void, Void> { + static TrainerT eval(const TrainerT&, PC &, Void&, Void&, Time) DO_NOT_CALL; +}; + +} // NS + + +#endif // pAHg29dA1QrONep0XYIUhA0GC08 |