def app(opt): import numpy as np np.set_printoptions(threshold=np.inf) print(opt) data_loader = n3ml.data.IRISDataLoader() data = data_loader.run() summary = data_loader.summarize() data_encoder = n3ml.encoder.Population(neurons=12, minimum=summary['min'], maximum=summary['max'], max_firing_time=opt.max_firing_time, not_to_fire=opt.not_to_fire, dt=opt.dt) label_encoder = LabelEncoder(opt.num_classes) model = n3ml.model.Bohte2002() model.initialize() optimizer = n3ml.optimizer.Bohte() for epoch in range(opt.num_epochs): acc = train(data, model, data_encoder, label_encoder, optimizer, rmse, epoch, opt) print("epoch: {} - tr. acc: {}".format(epoch, acc)) acc = validate(data, model, data_encoder, label_encoder, opt) print("epoch: {} - val. acc: {}".format(epoch, acc))
def train(data, model, data_encoder, label_encoder, optimizer, loss, epoch, opt): total_data = 0 corrects = 0 for i in range(data['train.data'].size(0)): model.initialize(delay=False) input = data['train.data'][i] label = data['train.target'][i] spiked_input = data_encoder.run(input) spiked_input = torch.cat((spiked_input.view(-1), torch.zeros(2))) spiked_label = label_encoder.run(label) for t in range(opt.num_steps): model(torch.tensor(t).float(), spiked_input) o = model.fc2.s print(model.fc1.s) print(model.fc2.s) print("pred: {} - target: {}".format(o, spiked_label)) # l = loss(o, spiked_label) # print("loss: {}".format(l)) optimizer.step(model, spiked_input, spiked_label, epoch) total_data += 1 corrects += do_correct(o, spiked_label) return corrects.float() / total_data
def validate(data, model, data_encoder, label_encoder, loss, opt): total_data = 0 corrects = 0 total_loss = 0 for i in range(data['test.data'].size(0)): model.initialize(delay=False) input = data['test.data'][i] label = data['test.target'][i] spiked_input = data_encoder.run(input) spiked_input = torch.cat((spiked_input.view(-1), torch.zeros(2))) spiked_label = label_encoder.run(label) for t in range(opt.num_steps): model(torch.tensor(t).float(), spiked_input) o = model.fc2.s total_data += 1 corrects += do_correct(o, spiked_label) total_loss += loss(o, spiked_label) avg_acc = corrects.float() / total_data avg_loss = total_loss / total_data return avg_loss, avg_acc
def app(opt): np.set_printoptions(threshold=np.inf) print(opt) data_loader = n3ml.data.IRISDataLoader(ratio=0.8) data = data_loader.run() summary = data_loader.summarize() data_encoder = n3ml.encoder.Population(neurons=12, minimum=summary['min'], maximum=summary['max'], max_firing_time=opt.max_firing_time, not_to_fire=opt.not_to_fire, dt=opt.dt) label_encoder = LabelEncoder(opt.num_classes) model = n3ml.model.Bohte2002() model.initialize() optimizer = n3ml.optimizer.Bohte() # for plot plotter = Plot() meter = {'total_losses': 0.0, 'num_corrects': 0, 'num_images': 0} acc_buffer = [] loss_buffer = [] for epoch in range(opt.num_epochs): train(data, model, data_encoder, label_encoder, optimizer, rmse, epoch, meter, acc_buffer, loss_buffer, plotter, opt) # print("epoch: {} - tr. loss: {} - tr. acc: {}".format(epoch, loss, acc)) loss, acc = validate(data, model, data_encoder, label_encoder, rmse, opt) print("epoch: {} - val. loss: {} - val. acc: {}".format( epoch, loss, acc)) data = data_loader.run()
def train(data, model, data_encoder, label_encoder, optimizer, loss, epoch, meter, acc_buffer, loss_buffer, plotter, opt): for i in range(data['train.data'].size(0)): model.initialize(delay=False) input = data['train.data'][i] label = data['train.target'][i] spiked_input = data_encoder.run(input) spiked_input = torch.cat((spiked_input.view(-1), torch.zeros(2))) spiked_label = label_encoder.run(label) for t in range(opt.num_steps): model(torch.tensor(t).float(), spiked_input) o = model.fc2.s # print(model.fc1.s) # print(model.fc2.s) # print("pred: {} - target: {}".format(o, spiked_label)) # l = loss(o, spiked_label) # print("loss: {}".format(l)) optimizer.step(model, spiked_input, spiked_label, epoch) meter['num_images'] += 1 meter['num_corrects'] += do_correct(o, spiked_label) meter['total_losses'] += loss(o, spiked_label) if (i + 1) % 30 == 0: print("label: {} - target: {} - pred: {} - result: {}".format( label, spiked_label, o, do_correct(o, spiked_label))) acc_buffer.append(1.0 * meter['num_corrects'] / meter['num_images']) loss_buffer.append(meter['total_losses'] / meter['num_images']) plotter.update(y1=np.array(acc_buffer), y2=np.array(loss_buffer))