diff options
Diffstat (limited to 'code/trainer/mem1.cpp')
-rw-r--r-- | code/trainer/mem1.cpp | 412 |
1 files changed, 412 insertions, 0 deletions
diff --git a/code/trainer/mem1.cpp b/code/trainer/mem1.cpp new file mode 100644 index 0000000..3b522b4 --- /dev/null +++ b/code/trainer/mem1.cpp @@ -0,0 +1,412 @@ +#include <stdlib.h> +#include "fileutils.h" +#include "math.h" + +#include "reinforce_synapse.h" +#include "fileutils.cpp" +#include "model_switch.h" + +using namespace std; + +int main(int argc, char **argv) { + // check cmd line sanity + if (argc != 7) { + fprintf(stderr, "Wrong argument count\n\n" + "Call format:\n" + "%s\n\t" + "performance out\n\t" + "trace cmd out\n\t" + "global out\n\t" + "global in\n\t" + "spike out\n\t" + "spike in\n\t" + "\n" + "Special names allowed:\n\t- (standart input)\n\t0 (/dev/null)\n", argv[0]); + return -1; + } + + Trainer *t = new Trainer(argc, argv); + t->run(); +} + +//===== Initialisation ===================================== + +Trainer::Trainer(int argc, char** argv) { + initConfig(); + initState(); + initGroups(); // determine input and output neurons + + initFiles(); + initThreads(); +} + +void Trainer::initConfig() { + neurons = 1000; + neuronsPerSymbol = 200; + noiseFreq = 10.0; // [Hz] + noiseVoltage = 0.03; // [V] + reward = 0.1; + + epochDuration = 1.0; // [s] + numTrials = 1000; + numSymbols = 2; + //readoutDelay = 1; + refractoryPeriods = 3; +} + +void Trainer::initState() { + dopamin_level = 0.0; + currentTrial = 0; + state = 0; + + msg_init(msg); + msg.dopamin_level = dopamin_level; + + groupFreq.resize(numSymbols); + for (int i=0; i<numSymbols; i++) { + groupFreq[i] = 0; + } +} + +void Trainer::initGroups() { + ioNeurons[i] = new set<int>(); + for (int j=0; j<neuronsPerSymbol;) { + int n = rand() % numNeurons; + if (!isNeurons[i]->count(n)) { + ioNeurons[i]->insert(n); + j++; + } + } +} + +void Trainer::initFiles() { + // open all file descriptors in an order complementary to the simulators one + // to avoid deadlocks + fd_spike_in = fd_magic(argv[6], false); + fd_global_in = fd_magic(argv[4], false); + fd_spike_out = fd_magic(argv[5], true); + fd_global_out = fd_magic(argv[3], true); + fd_performance_out = fd_magic(argv[1], true); + fd_trace_out = fd_magic(argv[2], true); +} + +void Trainer::initThreads() { + // init locks + pthread_mutex_init(&incomingSpikeLock, NULL); + pthread_mutex_init(&writerLock, NULL); + + // create read and write threads + pthread_create(&thread_read, NULL, (void* (*)(void*)) &read_spikes, this); + pthread_create(&thread_write, NULL, (void* (*)(void*)) &write_spikes, this); +} + +//===== Core trainer ==================================== + +void Trainer::pushGlobal(double time) { + fprintf(fd_global_out, "%f, ", time); + msg_print(msg, fd_global_out); + fprintf(fd_global_out, "\n"); + fflush(fd_global_out); +} + +// hint time is delta time! +void Trainer::pushTrace(double time) { + const char *str_trace = "%f; spikes (0; 1); global; neuron (0; 1); synapse (0; 1)\n"; + fprintf(fd_trace_out, str_trace, time); + fflush(fd_trace_out); +} + +bool Trainer::readGlobal() { + double _foo_dbl; + char str_raw[128], + str_msg[128]; + str_raw[0] = 0; + + // read a single line + if (fgets((char*) str_raw, 128, fd_global_in) == NULL) { + fprintf(stderr, "ERROR: global status file descriptor from simulator closed unexpectedly\n"); + return false; + } + + // parse it + if ((sscanf((char*) str_raw, "%lf, %[^\n]\n", &_foo_dbl, (char*) str_msg) != 2) + || (!msg_parse(msg, (char*) str_msg))) { + fprintf(stderr, "ERROR: reading global status from simulator failed\n\t\"%s\"\n", (char*) str_raw); + return false; + } + + return true; +} + +void binIncomingSpikes() { + // reset bins + for (int i=0; i<groupFreq.size(); i++) + groupFreq[i] = 0; + + // lock spike queue + pthread_yield(); // give the spike reading thread chance to finish ... this is not more than ugly semifix wrong par! + pthread_mutex_lock(&incomingSpikeLock); + + // read all spikes in the correct time window + while ((!incomingSpikes.empty()) && (incomingSpikes.front().get<0>() <= currentEpoch * epochDuration)) { + // drop event out of queue + SpikeEvent se = incomingSpikes.front(); + double time = se.get<0>(); + int neuron = se.get<1>(); + incomingSpikes.pop(); + + // check if it belongs to the previous bin (and ignore it if this is the case) + if (time < (currentEpoch - 1) * epochDuration) { + fprintf(stderr, "WARN: spike reading thread to slow; unprocessed spike of the past discovered\n%f\t%f\t%d\t%f\n", + time, (double) (currentEpoch - 1) * epochDuration, currentEpoch, epochDuration); + continue; + } + + // check membership in each group and increase group frequency + for (int i=0; i < ioNeurons.size(); i++) + if (ioNeuros[i]->count(neuron)) + groupFreq[i]++; + } + + pthread_mutex_unlock(&incomingSpikeLock); +} + +void Trainer::addBaselineSpikes() { +} + +void Trainer::addSymbolSpikes() { + +} + + +double Trainer::calcSignalStrength() { + if (symbolHist.empty()) { + fprintf(stderr, "Writer thread is too slow; missed the current symbol\n"); + exit(-1); + } + + int fs, fn = 0; // freq signal, freq noise + for (int i=0; i<numSymbols; i++) { + if (i == symbolHist.front()) { + fs = groupFreq[i]; + }else{ + fm = fmax(fm, groupFreq[i]); + } + } + + if (fn == 0) { + return fs * INFINITY; + }else{ + return ((double) fs) / fn; + } +} + +void Trainer::run() { + // rough description of this function + // . start an epoch + // . wait for it's end + // . process incomig spikes (binning) + // . select if a reward takes place + // . print reward value + // . send out the reward signal + + // send out the full trace command once (later it will be repeated by sending newline) + pushTrace(epochDuration); + + // send the first two global states (at t=0 and t=1.5 [bintime] to allow the simulation to + // be initialized (before the causality of the loop below is met) + pushGlobal(0.0); + msg_process(msg, 1.5 * epochDuration); + dopamin_level = msg.dopamin_level; + pushGlobal(1.5 * epochDuration); + + // loop until the experiment is done + for (; currentEpoch * epochDuration < entireDuration; currentEpoch++) { + + // send a new trace command (do it as early as possible although it is + // only executed after the new global is send out at the bottom of this loop) + if ((currentEpoch + 2) * epochDuration < entireDuration) { + // repeat the previous trace command + fprintf(fd_trace_out, "\n"); + fflush(fd_trace_out); + }else{ + pushTrace(entireDuration - (currentEpoch + 1) * epochDuration); + } + + // send new spikes + pthread_mutex_lock(&outgoingSpikeLock); + addBaselineSpikes(); + if (state == 0) addSymbolSpikes(); + pthread_cond_signal(&outgoingSpikeCond); + pthread_mutex_unlock(&outgoingSpikeLock); + + // wait for the end of the epoch (by reading the global state resulting from it) + if (!readGlobal()) + break; + + // process incomig spikes (binning) of the previous epoch + if (currentEpoch > 0) + binIncomingSpikes(); + + // proceed the global state to keep it in sync with the simulator's global state + // the local dopamin level is kept seperately and aged only one epochDuration to + // avoid oscillation effects in dopamin level + msg_process(msg, 1.5 * epochDuration); + dopamin_level *= exp( - epochDuration / msg.dopamin_tau ); + + // do various actions depeding on state (thus lock mutex of the writer thread) + + + switch (state) { + case 0: // a signal is sent + state++; + + case 1: // we are waiting for the signal to be reproduce + // get fraction of the current symbol's freq compared to the strongest wrong symbol + double ss = calcSignalStrength(); + + // check if the reward condition is met + if (ss > 1) { + dopmain_level += reward; + }else{ + state++; // lost signal -> next state (and finally a new trial) + currentSymbol = rand() % numSymbols; // determine new symbol to display + } + + break; + + default: // the signal has been lost (in the last round); refractory time + ++state %= refractoryPeriods; + } + + /*if ((currentEpoch > 1) && ((*neuronFreq[0])[0] > 0) && ((*neuronFreq[1])[1] > 0)) { + dopamin_level += da_single_reward; + fprintf(fd_performance_out, "+"); + }else{ + fprintf(fd_performance_out, "-"); + }*/ + + // performance and "debug" output + if (currentEpoch > 1) { + //fprintf(fd_performance_out, "\n"); + fprintf(fd_performance_out, "\t%f\t%d\t%d\n", dopamin_level, (*neuronFreq[0])[0], (*neuronFreq[1])[1]); + }else{ + // fake output as acutal data is not available, yet + fprintf(fd_performance_out, "\t%f\t%d\t%d\n", dopamin_level, (int) 0, (int) 0); + } + + // set the new DA level + msg.dopamin_level = dopamin_level; + + // print new global state + // (do this even if there has been no evaluation of the performance yet, + // because it is neccessary for the simulator to proceed) + pushGlobal(((double) currentEpoch + 2.5) * epochDuration); + } + + fclose(fd_trace_out); + + // terminate child threads + pthread_cancel(thread_read); + pthread_cancel(thread_write); +} + +void *read_spikes(Trainer *t) { + double lastSpike = -INFINITY; // used to check if the spikes are coming in order + + // read spikes until eternity + while (!feof(t->fd_spike_in)) { + // read one line from stdin (blocking) + char buf[128]; + if (fgets((char*) buf, 128, t->fd_spike_in) == NULL) continue; // this should stop the loop because of EOF + + // parse the input + double time, current; + int neuron; + switch (sscanf((char*) buf, "%lf, %d, %lf\n", &time, &neuron, ¤t)) { + case 3: + // format is ok, continue + break; + default: + // format is wrong, stop + fprintf(stderr, "ERROR: malformatted incoming spike:\n\t%s\n", &buf); + return NULL; + } + + if (lastSpike > time) { + fprintf(stderr, "WARN: out of order spike detected (coming from simulator)\n\t%f\t%d\n", time, neuron); + continue; + } + + lastSpike = time; + + // add the spike to the queue of spikes + pthread_mutex_lock(&(t->incomingSpikeLock)); + t->incomingSpikes.push(boost::make_tuple(time, neuron, current)); + pthread_mutex_unlock(&(t->incomingSpikeLock)); + } + + // we shouldn't reach this point in a non-error case + fprintf(stderr, "ERROR: EOF in incoming spike stream\n"); + // TODO: kill entire programm + return NULL; +} + +void *write_spikes(Trainer *t) { + // at the moment: generate noise until the file descriptor blocks + double time = 0.0; + + // PAR HINT: + // loop until exactly one spike after the entire duration is send out + // this will block on full buffer on the file descriptor and thus keep + // the thread busy early enough + + + /* // ---- send 100% dependent spike train --- + time = 0.005; + while (time <= t->entireDuration) { + fprintf(t->fd_spike_out, "%f, %d, %f\n", time, 0, 1.0); + time += 0.012; + fprintf(t->fd_spike_out, "%f, %d, %f\n", time, 1, 1.0); + time += 1.0; + }*/ + + + /* // ---- send indepenent poisson noise ---- + while (time <= t->entireDuration) { + // calc timing, intensity and destination of the spike + // HINT: + // * log(...) is negative + // * drand48() returns something in [0,1), to avoid log(0) we transform it to (0,1] + time -= log(1.0 - drand48()) / (t->freq * t->neurons); + int dst = rand() % t->neurons; + double current = t->voltage; + + // send it to the simulator + fprintf(t->fd_spike_out, "%f, %d, %f\n", time, dst, current); + }*/ + + // ---- send indepenent poisson noise w7 increasing fequency---- + double blafoo = 0; + t->freq = 1.0; + while (time <= t->entireDuration) { + if (time - blafoo > 100.0) { + blafoo += 200.0; + t->freq += 1.0; + time += 100.0; // time jump to let ET recover to zero + } + // calc timing, intensity and destination of the spike + // HINT: + // * log(...) is negative + // * drand48() returns something in [0,1), to avoid log(0) we transform it to (0,1] + time -= log(1.0 - drand48()) / (t->freq * t->neurons); + int dst = rand() % t->neurons; + double current = t->voltage; + + // send it to the simulator + fprintf(t->fd_spike_out, "%f, %d, %f\n", time, dst, current); + } + + // close fd because fscanf sucks + fclose(t->fd_spike_out); +} |