summaryrefslogtreecommitdiff
path: root/core/trainer_impl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'core/trainer_impl.hpp')
-rw-r--r--core/trainer_impl.hpp162
1 files changed, 162 insertions, 0 deletions
diff --git a/core/trainer_impl.hpp b/core/trainer_impl.hpp
new file mode 100644
index 0000000..d685d66
--- /dev/null
+++ b/core/trainer_impl.hpp
@@ -0,0 +1,162 @@
+#ifndef pAHg29dA1QrONep0XYIUhA0GC08
+#define pAHg29dA1QrONep0XYIUhA0GC08
+
+#include <iostream>
+
+#include "trainer.hpp"
+#include "model.hpp"
+
+#include "index.hpp"
+#include "index_global.hpp"
+#include "index_randomspike.hpp"
+#include "index_spike.hpp"
+#include "index_spike_arrival.hpp"
+
+namespace TrainerImpl {
+
+namespace MC = ModelConsts;
+
+const Time::type delays[6] =
+ { Time::epsilon()(),
+ MC::TrainerInputWindow + MC::TrainerReadoutDelay,
+ MC::TrainerReadoutWindow,
+ Time::epsilon()(),
+ Time::epsilon()(),
+ MC::TrainerInterTrialDelay };
+
+const Time::type randDelays[6] =
+ { 0,
+ MC::TrainerReadoutRandDelay,
+ 0, 0, 0,
+ MC::TrainerInterTrialRandDelay };
+
+TrainerT::TrainerT(RNG::seed_t seed) :
+ rng(seed),
+ delay(delays[0]),
+ reward(0),
+ performance(1.0 / MC::TrainerNumSymbols),
+ generation(0),
+ input(0), output(2),
+ state(0),
+ resetCounter(1) {}
+
+template <typename PC, typename MI, typename MQ>
+struct Update {
+ static TrainerT eval(const TrainerT &old, PC &pc,
+ MI &indices, MQ &queues, Time t) {
+ TrainerT res = old;
+ res.generation++;
+ res.state++;
+ res.reward = 0;
+
+ auto sendSpikes = [&](int mult, double window, int fanIn, double freq) -> void {
+ Time ct = t + Time::epsilon();
+ while (ct < t + window) {
+ // determine dst
+ Ptr<Neuron> dst{uint16_t(RNG::integer(res.rng, fanIn)
+ * mult
+ % MC::NumExcitatory)};
+ Ptr<Neuron> src(dst() + ((Ptr<Neuron>::ptr_t) maxNeurons/2));
+ Ptr<RandomSpike> ptrNE(indices.template get<Index<RandomSpike>>().add(t, ct, src));
+ queues.insert(t, ct, Event<RandomSpike>(ct, src, ptrNE));
+ res.rng = RNG::next(res.rng);
+
+ // determine next spike
+ ct += RNG::expo(res.rng, 1.0 / freq / fanIn);
+ res.rng = RNG::next(res.rng);
+ }
+ };
+
+ switch (res.state) {
+ case 1: // pre: let the network settle
+ case 7: { // pre: give last reward
+ // select and give input
+ res.input = RNG::integer(res.rng, MC::TrainerNumSymbols);
+ res.rng = RNG::next(res.rng);
+ res.state = 1;
+ const int mult_in[10] = // some arbitrarily selected prime number except for
+ {1, 997, 1013, 1021, 1033, 1049, 1061, 1069, 1091, 1097}; // symbol one
+ BOOST_STATIC_ASSERT((MC::TrainerNumSymbols <= 10));
+ std::cerr << "INPUT " << t << std::endl;
+ sendSpikes(mult_in[res.input],
+ MC::TrainerInputWindow, MC::FanIn, MC::TrainerInputFreq);
+ } break;
+
+ case 2:
+ res.resetCounter = 0;
+ break;
+
+ case 4: case 6: // wait
+ break;
+
+ case 3: {
+ // send readout signal
+ std::cerr << "READOUT SIGNAL " << t << std::endl;
+ if (MC::TrainerReadoutFreq > 0)
+ sendSpikes(967, // arbitrary prime
+ MC::TrainerReadoutWindow, MC::FanIn, MC::TrainerReadoutFreq);
+ } break;
+
+
+ case 5: {
+ // evaluate which symbol won, reward (TODO), wait for a longer time
+ const int mult_out[10] = // some arbitrarily selected prime number except for
+ {1, 937, 919, 907, 883, 877, 859, 853, 829, 823}; // symbol one
+ BOOST_STATIC_ASSERT((MC::TrainerNumSymbols <= 10));
+ uint32_t freq[MC::TrainerNumSymbols + 1] = {};
+ uint8_t maxIdx = MC::TrainerNumSymbols,
+ sndIdx = MC::TrainerNumSymbols;
+ uint16_t overlap = uint16_t(MC::FanIn / MC::NumExcitatory * MC::FanOut);
+ std::cerr << "READOUT";
+ for (int o=0; o<MC::TrainerNumSymbols; o++) {
+ for (Ptr<Neuron>::ptr_t i = MC::FanIn - overlap;
+ i < MC::FanIn + MC::FanIn;
+ i++) {
+ Ptr<Neuron> src(uint16_t(mult_out[o] * i % MC::NumExcitatory));
+ PLA_Get<SpikeCounter, ContinuousContext<Neuron> >
+ pla_get(t, ContinuousContext<Neuron>(src));
+ freq[o] += pc.call(pla_get);
+ }
+ if (freq[o] > freq[maxIdx]) {
+ sndIdx = maxIdx;
+ maxIdx = o;
+ }else if (freq[o] > freq[sndIdx]) {
+ sndIdx = o;
+ }
+ std::cerr << " " << freq[o];
+ }
+ bool correct = maxIdx == res.input;
+ res.output = maxIdx;
+ res.reward = (correct ? MC::TrainerReward : (-MC::TrainerPunish))
+ * ((1.0 + MC::TrainerWinAdv) * freq[sndIdx] < freq[maxIdx]);
+ res.performance *= 0.9;
+ res.performance += correct * 0.1;
+ res.resetCounter = 1;
+ std::cerr << "\nWINNER @ " << t << " = " << uint16_t(maxIdx)
+ << " correct = " << uint16_t(res.input)
+ << " reward = " << res.reward
+ << "\n";
+ } break;
+
+ default:
+ assert(false);
+ }
+
+ res.delay = delays[res.state]
+ + randDelays[res.state] * RNG::equi(res.rng);
+ res.rng = RNG::next(res.rng);
+
+ return res;
+ }
+};
+
+/// dummy for sim_replay, must not be called
+template <typename PC>
+struct Update<PC, Void, Void> {
+ static TrainerT eval(const TrainerT&, PC &, Void&, Void&, Time) DO_NOT_CALL;
+};
+
+} // NS
+
+
+#endif // pAHg29dA1QrONep0XYIUhA0GC08
contact: Jan Huwald // Impressum