summaryrefslogtreecommitdiff
path: root/code/trainer/mem1.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'code/trainer/mem1.cpp')
-rw-r--r--code/trainer/mem1.cpp412
1 files changed, 412 insertions, 0 deletions
diff --git a/code/trainer/mem1.cpp b/code/trainer/mem1.cpp
new file mode 100644
index 0000000..3b522b4
--- /dev/null
+++ b/code/trainer/mem1.cpp
@@ -0,0 +1,412 @@
+#include <stdlib.h>
+#include "fileutils.h"
+#include "math.h"
+
+#include "reinforce_synapse.h"
+#include "fileutils.cpp"
+#include "model_switch.h"
+
+using namespace std;
+
+int main(int argc, char **argv) {
+ // check cmd line sanity
+ if (argc != 7) {
+ fprintf(stderr, "Wrong argument count\n\n"
+ "Call format:\n"
+ "%s\n\t"
+ "performance out\n\t"
+ "trace cmd out\n\t"
+ "global out\n\t"
+ "global in\n\t"
+ "spike out\n\t"
+ "spike in\n\t"
+ "\n"
+ "Special names allowed:\n\t- (standart input)\n\t0 (/dev/null)\n", argv[0]);
+ return -1;
+ }
+
+ Trainer *t = new Trainer(argc, argv);
+ t->run();
+}
+
+//===== Initialisation =====================================
+
+Trainer::Trainer(int argc, char** argv) {
+ initConfig();
+ initState();
+ initGroups(); // determine input and output neurons
+
+ initFiles();
+ initThreads();
+}
+
+void Trainer::initConfig() {
+ neurons = 1000;
+ neuronsPerSymbol = 200;
+ noiseFreq = 10.0; // [Hz]
+ noiseVoltage = 0.03; // [V]
+ reward = 0.1;
+
+ epochDuration = 1.0; // [s]
+ numTrials = 1000;
+ numSymbols = 2;
+ //readoutDelay = 1;
+ refractoryPeriods = 3;
+}
+
+void Trainer::initState() {
+ dopamin_level = 0.0;
+ currentTrial = 0;
+ state = 0;
+
+ msg_init(msg);
+ msg.dopamin_level = dopamin_level;
+
+ groupFreq.resize(numSymbols);
+ for (int i=0; i<numSymbols; i++) {
+ groupFreq[i] = 0;
+ }
+}
+
+void Trainer::initGroups() {
+ ioNeurons[i] = new set<int>();
+ for (int j=0; j<neuronsPerSymbol;) {
+ int n = rand() % numNeurons;
+ if (!isNeurons[i]->count(n)) {
+ ioNeurons[i]->insert(n);
+ j++;
+ }
+ }
+}
+
+void Trainer::initFiles() {
+ // open all file descriptors in an order complementary to the simulators one
+ // to avoid deadlocks
+ fd_spike_in = fd_magic(argv[6], false);
+ fd_global_in = fd_magic(argv[4], false);
+ fd_spike_out = fd_magic(argv[5], true);
+ fd_global_out = fd_magic(argv[3], true);
+ fd_performance_out = fd_magic(argv[1], true);
+ fd_trace_out = fd_magic(argv[2], true);
+}
+
+void Trainer::initThreads() {
+ // init locks
+ pthread_mutex_init(&incomingSpikeLock, NULL);
+ pthread_mutex_init(&writerLock, NULL);
+
+ // create read and write threads
+ pthread_create(&thread_read, NULL, (void* (*)(void*)) &read_spikes, this);
+ pthread_create(&thread_write, NULL, (void* (*)(void*)) &write_spikes, this);
+}
+
+//===== Core trainer ====================================
+
+void Trainer::pushGlobal(double time) {
+ fprintf(fd_global_out, "%f, ", time);
+ msg_print(msg, fd_global_out);
+ fprintf(fd_global_out, "\n");
+ fflush(fd_global_out);
+}
+
+// hint time is delta time!
+void Trainer::pushTrace(double time) {
+ const char *str_trace = "%f; spikes (0; 1); global; neuron (0; 1); synapse (0; 1)\n";
+ fprintf(fd_trace_out, str_trace, time);
+ fflush(fd_trace_out);
+}
+
+bool Trainer::readGlobal() {
+ double _foo_dbl;
+ char str_raw[128],
+ str_msg[128];
+ str_raw[0] = 0;
+
+ // read a single line
+ if (fgets((char*) str_raw, 128, fd_global_in) == NULL) {
+ fprintf(stderr, "ERROR: global status file descriptor from simulator closed unexpectedly\n");
+ return false;
+ }
+
+ // parse it
+ if ((sscanf((char*) str_raw, "%lf, %[^\n]\n", &_foo_dbl, (char*) str_msg) != 2)
+ || (!msg_parse(msg, (char*) str_msg))) {
+ fprintf(stderr, "ERROR: reading global status from simulator failed\n\t\"%s\"\n", (char*) str_raw);
+ return false;
+ }
+
+ return true;
+}
+
+void binIncomingSpikes() {
+ // reset bins
+ for (int i=0; i<groupFreq.size(); i++)
+ groupFreq[i] = 0;
+
+ // lock spike queue
+ pthread_yield(); // give the spike reading thread chance to finish ... this is not more than ugly semifix wrong par!
+ pthread_mutex_lock(&incomingSpikeLock);
+
+ // read all spikes in the correct time window
+ while ((!incomingSpikes.empty()) && (incomingSpikes.front().get<0>() <= currentEpoch * epochDuration)) {
+ // drop event out of queue
+ SpikeEvent se = incomingSpikes.front();
+ double time = se.get<0>();
+ int neuron = se.get<1>();
+ incomingSpikes.pop();
+
+ // check if it belongs to the previous bin (and ignore it if this is the case)
+ if (time < (currentEpoch - 1) * epochDuration) {
+ fprintf(stderr, "WARN: spike reading thread to slow; unprocessed spike of the past discovered\n%f\t%f\t%d\t%f\n",
+ time, (double) (currentEpoch - 1) * epochDuration, currentEpoch, epochDuration);
+ continue;
+ }
+
+ // check membership in each group and increase group frequency
+ for (int i=0; i < ioNeurons.size(); i++)
+ if (ioNeuros[i]->count(neuron))
+ groupFreq[i]++;
+ }
+
+ pthread_mutex_unlock(&incomingSpikeLock);
+}
+
+void Trainer::addBaselineSpikes() {
+}
+
+void Trainer::addSymbolSpikes() {
+
+}
+
+
+double Trainer::calcSignalStrength() {
+ if (symbolHist.empty()) {
+ fprintf(stderr, "Writer thread is too slow; missed the current symbol\n");
+ exit(-1);
+ }
+
+ int fs, fn = 0; // freq signal, freq noise
+ for (int i=0; i<numSymbols; i++) {
+ if (i == symbolHist.front()) {
+ fs = groupFreq[i];
+ }else{
+ fm = fmax(fm, groupFreq[i]);
+ }
+ }
+
+ if (fn == 0) {
+ return fs * INFINITY;
+ }else{
+ return ((double) fs) / fn;
+ }
+}
+
+void Trainer::run() {
+ // rough description of this function
+ // . start an epoch
+ // . wait for it's end
+ // . process incomig spikes (binning)
+ // . select if a reward takes place
+ // . print reward value
+ // . send out the reward signal
+
+ // send out the full trace command once (later it will be repeated by sending newline)
+ pushTrace(epochDuration);
+
+ // send the first two global states (at t=0 and t=1.5 [bintime] to allow the simulation to
+ // be initialized (before the causality of the loop below is met)
+ pushGlobal(0.0);
+ msg_process(msg, 1.5 * epochDuration);
+ dopamin_level = msg.dopamin_level;
+ pushGlobal(1.5 * epochDuration);
+
+ // loop until the experiment is done
+ for (; currentEpoch * epochDuration < entireDuration; currentEpoch++) {
+
+ // send a new trace command (do it as early as possible although it is
+ // only executed after the new global is send out at the bottom of this loop)
+ if ((currentEpoch + 2) * epochDuration < entireDuration) {
+ // repeat the previous trace command
+ fprintf(fd_trace_out, "\n");
+ fflush(fd_trace_out);
+ }else{
+ pushTrace(entireDuration - (currentEpoch + 1) * epochDuration);
+ }
+
+ // send new spikes
+ pthread_mutex_lock(&outgoingSpikeLock);
+ addBaselineSpikes();
+ if (state == 0) addSymbolSpikes();
+ pthread_cond_signal(&outgoingSpikeCond);
+ pthread_mutex_unlock(&outgoingSpikeLock);
+
+ // wait for the end of the epoch (by reading the global state resulting from it)
+ if (!readGlobal())
+ break;
+
+ // process incomig spikes (binning) of the previous epoch
+ if (currentEpoch > 0)
+ binIncomingSpikes();
+
+ // proceed the global state to keep it in sync with the simulator's global state
+ // the local dopamin level is kept seperately and aged only one epochDuration to
+ // avoid oscillation effects in dopamin level
+ msg_process(msg, 1.5 * epochDuration);
+ dopamin_level *= exp( - epochDuration / msg.dopamin_tau );
+
+ // do various actions depeding on state (thus lock mutex of the writer thread)
+
+
+ switch (state) {
+ case 0: // a signal is sent
+ state++;
+
+ case 1: // we are waiting for the signal to be reproduce
+ // get fraction of the current symbol's freq compared to the strongest wrong symbol
+ double ss = calcSignalStrength();
+
+ // check if the reward condition is met
+ if (ss > 1) {
+ dopmain_level += reward;
+ }else{
+ state++; // lost signal -> next state (and finally a new trial)
+ currentSymbol = rand() % numSymbols; // determine new symbol to display
+ }
+
+ break;
+
+ default: // the signal has been lost (in the last round); refractory time
+ ++state %= refractoryPeriods;
+ }
+
+ /*if ((currentEpoch > 1) && ((*neuronFreq[0])[0] > 0) && ((*neuronFreq[1])[1] > 0)) {
+ dopamin_level += da_single_reward;
+ fprintf(fd_performance_out, "+");
+ }else{
+ fprintf(fd_performance_out, "-");
+ }*/
+
+ // performance and "debug" output
+ if (currentEpoch > 1) {
+ //fprintf(fd_performance_out, "\n");
+ fprintf(fd_performance_out, "\t%f\t%d\t%d\n", dopamin_level, (*neuronFreq[0])[0], (*neuronFreq[1])[1]);
+ }else{
+ // fake output as acutal data is not available, yet
+ fprintf(fd_performance_out, "\t%f\t%d\t%d\n", dopamin_level, (int) 0, (int) 0);
+ }
+
+ // set the new DA level
+ msg.dopamin_level = dopamin_level;
+
+ // print new global state
+ // (do this even if there has been no evaluation of the performance yet,
+ // because it is neccessary for the simulator to proceed)
+ pushGlobal(((double) currentEpoch + 2.5) * epochDuration);
+ }
+
+ fclose(fd_trace_out);
+
+ // terminate child threads
+ pthread_cancel(thread_read);
+ pthread_cancel(thread_write);
+}
+
+void *read_spikes(Trainer *t) {
+ double lastSpike = -INFINITY; // used to check if the spikes are coming in order
+
+ // read spikes until eternity
+ while (!feof(t->fd_spike_in)) {
+ // read one line from stdin (blocking)
+ char buf[128];
+ if (fgets((char*) buf, 128, t->fd_spike_in) == NULL) continue; // this should stop the loop because of EOF
+
+ // parse the input
+ double time, current;
+ int neuron;
+ switch (sscanf((char*) buf, "%lf, %d, %lf\n", &time, &neuron, &current)) {
+ case 3:
+ // format is ok, continue
+ break;
+ default:
+ // format is wrong, stop
+ fprintf(stderr, "ERROR: malformatted incoming spike:\n\t%s\n", &buf);
+ return NULL;
+ }
+
+ if (lastSpike > time) {
+ fprintf(stderr, "WARN: out of order spike detected (coming from simulator)\n\t%f\t%d\n", time, neuron);
+ continue;
+ }
+
+ lastSpike = time;
+
+ // add the spike to the queue of spikes
+ pthread_mutex_lock(&(t->incomingSpikeLock));
+ t->incomingSpikes.push(boost::make_tuple(time, neuron, current));
+ pthread_mutex_unlock(&(t->incomingSpikeLock));
+ }
+
+ // we shouldn't reach this point in a non-error case
+ fprintf(stderr, "ERROR: EOF in incoming spike stream\n");
+ // TODO: kill entire programm
+ return NULL;
+}
+
+void *write_spikes(Trainer *t) {
+ // at the moment: generate noise until the file descriptor blocks
+ double time = 0.0;
+
+ // PAR HINT:
+ // loop until exactly one spike after the entire duration is send out
+ // this will block on full buffer on the file descriptor and thus keep
+ // the thread busy early enough
+
+
+ /* // ---- send 100% dependent spike train ---
+ time = 0.005;
+ while (time <= t->entireDuration) {
+ fprintf(t->fd_spike_out, "%f, %d, %f\n", time, 0, 1.0);
+ time += 0.012;
+ fprintf(t->fd_spike_out, "%f, %d, %f\n", time, 1, 1.0);
+ time += 1.0;
+ }*/
+
+
+ /* // ---- send indepenent poisson noise ----
+ while (time <= t->entireDuration) {
+ // calc timing, intensity and destination of the spike
+ // HINT:
+ // * log(...) is negative
+ // * drand48() returns something in [0,1), to avoid log(0) we transform it to (0,1]
+ time -= log(1.0 - drand48()) / (t->freq * t->neurons);
+ int dst = rand() % t->neurons;
+ double current = t->voltage;
+
+ // send it to the simulator
+ fprintf(t->fd_spike_out, "%f, %d, %f\n", time, dst, current);
+ }*/
+
+ // ---- send indepenent poisson noise w7 increasing fequency----
+ double blafoo = 0;
+ t->freq = 1.0;
+ while (time <= t->entireDuration) {
+ if (time - blafoo > 100.0) {
+ blafoo += 200.0;
+ t->freq += 1.0;
+ time += 100.0; // time jump to let ET recover to zero
+ }
+ // calc timing, intensity and destination of the spike
+ // HINT:
+ // * log(...) is negative
+ // * drand48() returns something in [0,1), to avoid log(0) we transform it to (0,1]
+ time -= log(1.0 - drand48()) / (t->freq * t->neurons);
+ int dst = rand() % t->neurons;
+ double current = t->voltage;
+
+ // send it to the simulator
+ fprintf(t->fd_spike_out, "%f, %d, %f\n", time, dst, current);
+ }
+
+ // close fd because fscanf sucks
+ fclose(t->fd_spike_out);
+}
contact: Jan Huwald // Impressum