def train(network, dataset, sample_length, dt, input_shape, polarity, indices, test_indices, lr, n_classes, r, beta, gamma, kappa, start_idx, test_accs, save_path): """" Train an SNN. """ eligibility_trace_output, eligibility_trace_hidden, \ learning_signal, baseline_num, baseline_den = init_training(network) train_data = dataset.root.train test_data = dataset.root.test T = int(sample_length * 1000 / dt) for j, idx in enumerate(indices[start_idx:]): j += start_idx if (j + 1) % (5 * (dataset.root.stats.train_data[0])) == 0: lr /= 2 # Regularly test the accuracy if test_accs: if (j + 1) in test_accs: acc, loss = get_acc_and_loss(network, test_data, test_indices, T, n_classes, input_shape, dt, dataset.root.stats.train_data[1], polarity) test_accs[int(j + 1)].append(acc) print('test accuracy at ite %d: %f' % (int(j + 1), acc)) if save_path is not None: with open(save_path + '/test_accs.pkl', 'wb') as f: pickle.dump(test_accs, f, pickle.HIGHEST_PROTOCOL) network.save(save_path + '/network_weights.hdf5') network.train() network.reset_internal_state() refractory_period(network) inputs, label = get_example(train_data, idx, T, n_classes, input_shape, dt, dataset.root.stats.train_data[1], polarity) example = torch.cat((inputs, label), dim=0).to(network.device) log_proba, eligibility_trace_hidden, eligibility_trace_output, learning_signal, baseline_num, baseline_den = \ train_on_example(network, T, example, gamma, r, eligibility_trace_hidden, eligibility_trace_output, learning_signal, baseline_num, baseline_den, lr, beta, kappa) if j % max(1, int(len(indices) / 5)) == 0: print('Step %d out of %d' % (j, len(indices))) # At the end of training, save final weights if none exist or if this ite was better than all the others if not os.path.exists(save_path + '/network_weights_final.hdf5'): network.save(save_path + '/network_weights_final.hdf5') else: if test_accs[list(test_accs.keys())[-1]][-1] >= max(test_accs[list( test_accs.keys())[-1]][:-1]): network.save(save_path + '/network_weights_final.hdf5') return test_accs
for ite in range(params['n_examples_train']): if (ite+1) % params['test_period'] == 0: print('Ite %d: ' % (ite+1)) acc_layered = get_acc_layered(network, test_dl, len(iter(test_dl)), T) print('Acc: %f' % acc_layered) test_accs[int(ite + 1)].append(acc_layered) with open(args.save_path + '/test_accs.pkl', 'wb') as f: pickle.dump(test_accs, f, pickle.HIGHEST_PROTOCOL) torch.save(network.state_dict(), args.save_path + '/encoding_network_trial_%d.pt' % trial) network.train(args.save_path) refractory_period(network) try: inputs, targets = next(train_iterator) except StopIteration: train_iterator = iter(train_dl) inputs, targets = next(train_iterator) inputs = inputs[0].to(network.device) targets = targets[0].to(network.device) for t in range(T): ### LayeredSNN net_probas, net_outputs, probas_hidden, outputs_hidden = network(inputs[:t].T, targets[:, t], n_samples=params['n_samples']) # Generate gradients and KL regularization for hidden neurons
def train_fixed_rate(rank, num_nodes, args): # Create network groups for communication all_nodes = dist.new_group([0, 1, 2], timeout=datetime.timedelta(0, 360000)) # Setup training parameters args.dataset = tables.open_file(args.dataset) train_data = args.dataset.root.train test_data = args.dataset.root.test args.S_prime = int(args.sample_length * 1000 / args.dt) S = args.num_samples_train * args.S_prime args, test_indices, save_dict, save_path = init_test(rank, args) for tau in args.tau_list: n_weights_to_send = int(tau * args.rate) for _ in range(args.num_ite): # Initialize main parameters for training network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training( rank, num_nodes, all_nodes, args) # Gradients accumulator gradients_accum = torch.zeros(network.feedforward_weights.shape, dtype=torch.float) dist.barrier(all_nodes) for s in range(S): if rank != 0: if s % args.S_prime == 0: # Reset internal state for each example refractory_period(network) inputs, label = get_example( train_data, s // args.S_prime, args.S_prime, args.n_classes, args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) inputs = inputs.to(network.device) label = label.to(network.device) # lr decay # if (s + 1) % int(S / 4) == 0: # args.lr /= 2 # Feedforward sampling log_proba, ls_temp, et_temp, gradients_accum = feedforward_sampling( network, inputs[:, s % args.S_prime], label[:, s % args.S_prime], ls_temp, et_temp, args, gradients_accum) # Local feedback and update eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update( network, eligibility_trace, et_temp, learning_signal, ls_temp, s, args) # Global update if (s + 1) % (tau * args.deltas) == 0: dist.barrier(all_nodes) global_update_subset(all_nodes, rank, network, weights_list, gradients_accum, n_weights_to_send) gradients_accum = torch.zeros( network.feedforward_weights.shape, dtype=torch.float) dist.barrier(all_nodes) if rank == 0: global_acc, _ = get_acc_and_loss( network, test_data, test_indices, args.S_prime, args.n_classes, args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) save_dict[tau].append(global_acc) save_results(save_dict, save_path) print('Tau: %d, final accuracy: %f' % (tau, global_acc)) if rank == 0: save_results(save_dict, save_path) print('Training finished and accuracies saved to ' + save_path)
def train(rank, num_nodes, args): # Create network groups for communication all_nodes = dist.new_group([0, 1, 2], timeout=datetime.timedelta(0, 360000)) # Setup training parameters args.dataset = tables.open_file(args.dataset) args.n_classes = args.dataset.root.stats.test_label[1] train_data = args.dataset.root.train test_data = args.dataset.root.test args.S_prime = int(args.sample_length * 1000 / args.dt) S = args.num_samples_train * args.S_prime args, test_indices, save_dict_loss, save_dict_acc = init_test(args) for i in range(args.num_ite): # Initialize main parameters for training network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training( rank, num_nodes, all_nodes, args) dist.barrier(all_nodes) # Test loss at beginning + selection of training indices # if rank != 0: # print(rank, args.local_labels) # acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1], # args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) # save_dict_acc[0].append(acc) # save_dict_loss[0].append(loss) # network.train() # else: # test_acc, test_loss, spikes = get_acc_loss_and_spikes(network, test_data, test_indices, args.S_prime, args.n_classes, [1], # args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) # save_dict_acc[0].append(test_acc) # save_dict_loss[0].append(test_loss) # np.save(args.save_path + r'/spikes_test_s_%d.npy' % 0, spikes.numpy()) # network.train() dist.barrier(all_nodes) for s in range(S): if rank == 0: if (s + 1) % args.test_interval == 0: test_acc, test_loss, spikes = get_acc_loss_and_spikes( network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1], args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) save_dict_acc[s + 1].append(test_acc) save_dict_loss[s + 1].append(test_loss) network.train() save_results(save_dict_acc, args.save_path + r'/test_acc.pkl') save_results(save_dict_loss, args.save_path + r'/test_loss.pkl') print('Acc at step %d : %f' % (s, test_acc)) dist.barrier(all_nodes) if rank != 0: # if (s + 1) % args.test_interval == 0: # acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1], # args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) # save_dict_acc[s + 1].append(acc) # save_dict_loss[s + 1].append(loss) # # save_results(save_dict_acc, args.save_path + r'/test_acc.pkl') # save_results(save_dict_loss, args.save_path + r'/test_loss.pkl') # # network.train() if s % args.S_prime == 0: # at each example refractory_period(network) inputs, label = get_example( train_data, indices_local[s // args.S_prime], args.S_prime, args.n_classes, [1], args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) inputs = inputs.to(network.device) label = label.to(network.device) # lr decay # if s % S / 4 == 0: # args.lr /= 2 # Feedforward sampling log_proba, ls_temp, et_temp, _ = feedforward_sampling( network, inputs[:, s % args.S_prime], label[:, s % args.S_prime], ls_temp, et_temp, args) # Local feedback and update eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update( network, eligibility_trace, et_temp, learning_signal, ls_temp, s, args) # Global update if (s + 1) % (args.tau * args.deltas) == 0: dist.barrier(all_nodes) global_update(all_nodes, rank, network, weights_list) dist.barrier(all_nodes) # Final global update dist.barrier(all_nodes) global_update(all_nodes, rank, network, weights_list) dist.barrier(all_nodes) if rank == 0: test_acc, test_loss, spikes = get_acc_loss_and_spikes( network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1], args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity) save_dict_acc[S].append(test_acc) save_dict_loss[S].append(test_loss) network.train() save_results(save_dict_acc, args.save_path + r'/acc.pkl') save_results(save_dict_loss, args.save_path + r'/loss.pkl')