#ifndef pAHg29dA1QrONep0XYIUhA0GC08 #define pAHg29dA1QrONep0XYIUhA0GC08 #include #include "trainer.hpp" #include "model.hpp" #include "index.hpp" #include "index_global.hpp" #include "index_randomspike.hpp" #include "index_spike.hpp" #include "index_spike_arrival.hpp" namespace TrainerImpl { namespace MC = ModelConsts; const Time::type delays[6] = { Time::epsilon()(), MC::TrainerInputWindow + MC::TrainerReadoutDelay, MC::TrainerReadoutWindow, Time::epsilon()(), Time::epsilon()(), MC::TrainerInterTrialDelay }; const Time::type randDelays[6] = { 0, MC::TrainerReadoutRandDelay, 0, 0, 0, MC::TrainerInterTrialRandDelay }; TrainerT::TrainerT(RNG::seed_t seed) : rng(seed), delay(delays[0]), reward(0), performance(1.0 / MC::TrainerNumSymbols), generation(0), input(0), output(2), state(0), resetCounter(1) {} template struct Update { static TrainerT eval(const TrainerT &old, PC &pc, MI &indices, MQ &queues, Time t) { TrainerT res = old; res.generation++; res.state++; res.reward = 0; auto sendSpikes = [&](int mult, double window, int fanIn, double freq) -> void { Time ct = t + Time::epsilon(); while (ct < t + window) { // determine dst Ptr dst{uint16_t(RNG::integer(res.rng, fanIn) * mult % MC::NumExcitatory)}; Ptr src(dst() + ((Ptr::ptr_t) maxNeurons/2)); Ptr ptrNE(indices.template get>().add(t, ct, src)); queues.insert(t, ct, Event(ct, src, ptrNE)); res.rng = RNG::next(res.rng); // determine next spike ct += RNG::expo(res.rng, 1.0 / freq / fanIn); res.rng = RNG::next(res.rng); } }; switch (res.state) { case 1: // pre: let the network settle case 7: { // pre: give last reward // select and give input res.input = RNG::integer(res.rng, MC::TrainerNumSymbols); res.rng = RNG::next(res.rng); res.state = 1; const int mult_in[10] = // some arbitrarily selected prime number except for {1, 997, 1013, 1021, 1033, 1049, 1061, 1069, 1091, 1097}; // symbol one BOOST_STATIC_ASSERT((MC::TrainerNumSymbols <= 10)); std::cerr << "INPUT " << t << std::endl; sendSpikes(mult_in[res.input], MC::TrainerInputWindow, MC::FanIn, MC::TrainerInputFreq); } break; case 2: res.resetCounter = 0; break; case 4: case 6: // wait break; case 3: { // send readout signal std::cerr << "READOUT SIGNAL " << t << std::endl; if (MC::TrainerReadoutFreq > 0) sendSpikes(967, // arbitrary prime MC::TrainerReadoutWindow, MC::FanIn, MC::TrainerReadoutFreq); } break; case 5: { // evaluate which symbol won, reward (TODO), wait for a longer time const int mult_out[10] = // some arbitrarily selected prime number except for {1, 937, 919, 907, 883, 877, 859, 853, 829, 823}; // symbol one BOOST_STATIC_ASSERT((MC::TrainerNumSymbols <= 10)); uint32_t freq[MC::TrainerNumSymbols + 1] = {}; uint8_t maxIdx = MC::TrainerNumSymbols, sndIdx = MC::TrainerNumSymbols; uint16_t overlap = uint16_t(MC::FanIn / MC::NumExcitatory * MC::FanOut); std::cerr << "READOUT"; for (int o=0; o::ptr_t i = MC::FanIn - overlap; i < MC::FanIn + MC::FanIn; i++) { Ptr src(uint16_t(mult_out[o] * i % MC::NumExcitatory)); PLA_Get > pla_get(t, ContinuousContext(src)); freq[o] += pc.call(pla_get); } if (freq[o] > freq[maxIdx]) { sndIdx = maxIdx; maxIdx = o; }else if (freq[o] > freq[sndIdx]) { sndIdx = o; } std::cerr << " " << freq[o]; } bool correct = maxIdx == res.input; res.output = maxIdx; res.reward = (correct ? MC::TrainerReward : (-MC::TrainerPunish)) * ((1.0 + MC::TrainerWinAdv) * freq[sndIdx] < freq[maxIdx]); res.performance *= 0.9; res.performance += correct * 0.1; res.resetCounter = 1; std::cerr << "\nWINNER @ " << t << " = " << uint16_t(maxIdx) << " correct = " << uint16_t(res.input) << " reward = " << res.reward << "\n"; } break; default: assert(false); } res.delay = delays[res.state] + randDelays[res.state] * RNG::equi(res.rng); res.rng = RNG::next(res.rng); return res; } }; /// dummy for sim_replay, must not be called template struct Update { static TrainerT eval(const TrainerT&, PC &, Void&, Void&, Time) DO_NOT_CALL; }; } // NS #endif // pAHg29dA1QrONep0XYIUhA0GC08