Exemple #1
0
def train(network, dataset, sample_length, dt, input_shape, polarity, indices,
          test_indices, lr, n_classes, r, beta, gamma, kappa, start_idx,
          test_accs, save_path):
    """"
    Train an SNN.
    """

    eligibility_trace_output, eligibility_trace_hidden, \
        learning_signal, baseline_num, baseline_den = init_training(network)

    train_data = dataset.root.train
    test_data = dataset.root.test
    T = int(sample_length * 1000 / dt)

    for j, idx in enumerate(indices[start_idx:]):
        j += start_idx
        if (j + 1) % (5 * (dataset.root.stats.train_data[0])) == 0:
            lr /= 2

        # Regularly test the accuracy
        if test_accs:
            if (j + 1) in test_accs:
                acc, loss = get_acc_and_loss(network, test_data, test_indices,
                                             T, n_classes, input_shape, dt,
                                             dataset.root.stats.train_data[1],
                                             polarity)
                test_accs[int(j + 1)].append(acc)
                print('test accuracy at ite %d: %f' % (int(j + 1), acc))

                if save_path is not None:
                    with open(save_path + '/test_accs.pkl', 'wb') as f:
                        pickle.dump(test_accs, f, pickle.HIGHEST_PROTOCOL)
                    network.save(save_path + '/network_weights.hdf5')

                network.train()
                network.reset_internal_state()

        refractory_period(network)

        inputs, label = get_example(train_data, idx, T, n_classes, input_shape,
                                    dt, dataset.root.stats.train_data[1],
                                    polarity)
        example = torch.cat((inputs, label), dim=0).to(network.device)

        log_proba, eligibility_trace_hidden, eligibility_trace_output, learning_signal, baseline_num, baseline_den = \
            train_on_example(network, T, example, gamma, r, eligibility_trace_hidden, eligibility_trace_output, learning_signal, baseline_num, baseline_den, lr, beta, kappa)

        if j % max(1, int(len(indices) / 5)) == 0:
            print('Step %d out of %d' % (j, len(indices)))

    # At the end of training, save final weights if none exist or if this ite was better than all the others
    if not os.path.exists(save_path + '/network_weights_final.hdf5'):
        network.save(save_path + '/network_weights_final.hdf5')
    else:
        if test_accs[list(test_accs.keys())[-1]][-1] >= max(test_accs[list(
                test_accs.keys())[-1]][:-1]):
            network.save(save_path + '/network_weights_final.hdf5')

    return test_accs
    for ite in range(params['n_examples_train']):
        if (ite+1) % params['test_period'] == 0:
            print('Ite %d: ' % (ite+1))
            acc_layered = get_acc_layered(network, test_dl, len(iter(test_dl)), T)
            print('Acc: %f' % acc_layered)

            test_accs[int(ite + 1)].append(acc_layered)

            with open(args.save_path + '/test_accs.pkl', 'wb') as f:
                pickle.dump(test_accs, f, pickle.HIGHEST_PROTOCOL)

            torch.save(network.state_dict(), args.save_path + '/encoding_network_trial_%d.pt' % trial)

        network.train(args.save_path)

        refractory_period(network)

        try:
            inputs, targets = next(train_iterator)
        except StopIteration:
            train_iterator = iter(train_dl)
            inputs, targets = next(train_iterator)

        inputs = inputs[0].to(network.device)
        targets = targets[0].to(network.device)

        for t in range(T):
            ### LayeredSNN
            net_probas, net_outputs, probas_hidden, outputs_hidden = network(inputs[:t].T, targets[:, t], n_samples=params['n_samples'])

            # Generate gradients and KL regularization for hidden neurons
Exemple #3
0
def train_fixed_rate(rank, num_nodes, args):
    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2],
                               timeout=datetime.timedelta(0, 360000))

    # Setup training parameters
    args.dataset = tables.open_file(args.dataset)

    train_data = args.dataset.root.train
    test_data = args.dataset.root.test

    args.S_prime = int(args.sample_length * 1000 / args.dt)
    S = args.num_samples_train * args.S_prime

    args, test_indices, save_dict, save_path = init_test(rank, args)

    for tau in args.tau_list:
        n_weights_to_send = int(tau * args.rate)

        for _ in range(args.num_ite):
            # Initialize main parameters for training
            network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training(
                rank, num_nodes, all_nodes, args)

            # Gradients accumulator
            gradients_accum = torch.zeros(network.feedforward_weights.shape,
                                          dtype=torch.float)
            dist.barrier(all_nodes)

            for s in range(S):
                if rank != 0:
                    if s % args.S_prime == 0:  # Reset internal state for each example
                        refractory_period(network)
                        inputs, label = get_example(
                            train_data, s // args.S_prime, args.S_prime,
                            args.n_classes, args.input_shape, args.dt,
                            args.dataset.root.stats.train_data[1],
                            args.polarity)
                        inputs = inputs.to(network.device)
                        label = label.to(network.device)

                    # lr decay
                    # if (s + 1) % int(S / 4) == 0:
                    #     args.lr /= 2

                    # Feedforward sampling
                    log_proba, ls_temp, et_temp, gradients_accum = feedforward_sampling(
                        network, inputs[:, s % args.S_prime],
                        label[:, s % args.S_prime], ls_temp, et_temp, args,
                        gradients_accum)

                    # Local feedback and update
                    eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update(
                        network, eligibility_trace, et_temp, learning_signal,
                        ls_temp, s, args)

                # Global update
                if (s + 1) % (tau * args.deltas) == 0:
                    dist.barrier(all_nodes)
                    global_update_subset(all_nodes, rank, network,
                                         weights_list, gradients_accum,
                                         n_weights_to_send)
                    gradients_accum = torch.zeros(
                        network.feedforward_weights.shape, dtype=torch.float)
                    dist.barrier(all_nodes)

            if rank == 0:
                global_acc, _ = get_acc_and_loss(
                    network, test_data, test_indices, args.S_prime,
                    args.n_classes, args.input_shape, args.dt,
                    args.dataset.root.stats.train_data[1], args.polarity)
                save_dict[tau].append(global_acc)
                save_results(save_dict, save_path)
                print('Tau: %d, final accuracy: %f' % (tau, global_acc))

    if rank == 0:
        save_results(save_dict, save_path)
        print('Training finished and accuracies saved to ' + save_path)
Exemple #4
0
def train(rank, num_nodes, args):
    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2],
                               timeout=datetime.timedelta(0, 360000))

    # Setup training parameters
    args.dataset = tables.open_file(args.dataset)

    args.n_classes = args.dataset.root.stats.test_label[1]
    train_data = args.dataset.root.train
    test_data = args.dataset.root.test

    args.S_prime = int(args.sample_length * 1000 / args.dt)
    S = args.num_samples_train * args.S_prime

    args, test_indices, save_dict_loss, save_dict_acc = init_test(args)

    for i in range(args.num_ite):
        # Initialize main parameters for training
        network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training(
            rank, num_nodes, all_nodes, args)

        dist.barrier(all_nodes)

        # Test loss at beginning + selection of training indices
        # if rank != 0:
        #     print(rank, args.local_labels)
        #     acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1],
        #                                              args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity)
        #     save_dict_acc[0].append(acc)
        #     save_dict_loss[0].append(loss)
        #     network.train()
        # else:
        # test_acc, test_loss, spikes = get_acc_loss_and_spikes(network, test_data, test_indices, args.S_prime, args.n_classes, [1],
        #                                          args.input_shape,  args.dt, args.dataset.root.stats.train_data[1], args.polarity)
        # save_dict_acc[0].append(test_acc)
        # save_dict_loss[0].append(test_loss)
        # np.save(args.save_path + r'/spikes_test_s_%d.npy' % 0, spikes.numpy())
        # network.train()

        dist.barrier(all_nodes)

        for s in range(S):
            if rank == 0:
                if (s + 1) % args.test_interval == 0:
                    test_acc, test_loss, spikes = get_acc_loss_and_spikes(
                        network, test_data,
                        find_indices_for_labels(test_data,
                                                args.labels), args.S_prime,
                        args.n_classes, [1], args.input_shape, args.dt,
                        args.dataset.root.stats.train_data[1], args.polarity)
                    save_dict_acc[s + 1].append(test_acc)
                    save_dict_loss[s + 1].append(test_loss)

                    network.train()
                    save_results(save_dict_acc,
                                 args.save_path + r'/test_acc.pkl')
                    save_results(save_dict_loss,
                                 args.save_path + r'/test_loss.pkl')

                    print('Acc at step %d : %f' % (s, test_acc))

            dist.barrier(all_nodes)

            if rank != 0:
                # if (s + 1) % args.test_interval == 0:
                #     acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1],
                #                                              args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity)
                #     save_dict_acc[s + 1].append(acc)
                #     save_dict_loss[s + 1].append(loss)
                #
                #     save_results(save_dict_acc, args.save_path + r'/test_acc.pkl')
                #     save_results(save_dict_loss, args.save_path + r'/test_loss.pkl')
                #
                #     network.train()

                if s % args.S_prime == 0:  # at each example
                    refractory_period(network)
                    inputs, label = get_example(
                        train_data, indices_local[s // args.S_prime],
                        args.S_prime, args.n_classes, [1], args.input_shape,
                        args.dt, args.dataset.root.stats.train_data[1],
                        args.polarity)
                    inputs = inputs.to(network.device)
                    label = label.to(network.device)

                # lr decay
                # if s % S / 4 == 0:
                #     args.lr /= 2

                # Feedforward sampling
                log_proba, ls_temp, et_temp, _ = feedforward_sampling(
                    network, inputs[:, s % args.S_prime],
                    label[:, s % args.S_prime], ls_temp, et_temp, args)

                # Local feedback and update
                eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update(
                    network, eligibility_trace, et_temp, learning_signal,
                    ls_temp, s, args)

            # Global update
            if (s + 1) % (args.tau * args.deltas) == 0:
                dist.barrier(all_nodes)
                global_update(all_nodes, rank, network, weights_list)
                dist.barrier(all_nodes)

        # Final global update
        dist.barrier(all_nodes)
        global_update(all_nodes, rank, network, weights_list)
        dist.barrier(all_nodes)

        if rank == 0:
            test_acc, test_loss, spikes = get_acc_loss_and_spikes(
                network, test_data,
                find_indices_for_labels(test_data, args.labels), args.S_prime,
                args.n_classes, [1], args.input_shape, args.dt,
                args.dataset.root.stats.train_data[1], args.polarity)
            save_dict_acc[S].append(test_acc)
            save_dict_loss[S].append(test_loss)

        network.train()

        save_results(save_dict_acc, args.save_path + r'/acc.pkl')
        save_results(save_dict_loss, args.save_path + r'/loss.pkl')