Ejemplo n.º 1
0
def train(rank, num_nodes, args):
    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2],
                               timeout=datetime.timedelta(0, 360000))

    # Setup training parameters
    args.dataset = tables.open_file(args.dataset)

    args.n_classes = args.dataset.root.stats.test_label[1]
    train_data = args.dataset.root.train
    test_data = args.dataset.root.test

    args.S_prime = int(args.sample_length * 1000 / args.dt)
    S = args.num_samples_train * args.S_prime

    args, test_indices, save_dict_loss, save_dict_acc = init_test(args)

    for i in range(args.num_ite):
        # Initialize main parameters for training
        network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training(
            rank, num_nodes, all_nodes, args)

        dist.barrier(all_nodes)

        # Test loss at beginning + selection of training indices
        # if rank != 0:
        #     print(rank, args.local_labels)
        #     acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1],
        #                                              args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity)
        #     save_dict_acc[0].append(acc)
        #     save_dict_loss[0].append(loss)
        #     network.train()
        # else:
        # test_acc, test_loss, spikes = get_acc_loss_and_spikes(network, test_data, test_indices, args.S_prime, args.n_classes, [1],
        #                                          args.input_shape,  args.dt, args.dataset.root.stats.train_data[1], args.polarity)
        # save_dict_acc[0].append(test_acc)
        # save_dict_loss[0].append(test_loss)
        # np.save(args.save_path + r'/spikes_test_s_%d.npy' % 0, spikes.numpy())
        # network.train()

        dist.barrier(all_nodes)

        for s in range(S):
            if rank == 0:
                if (s + 1) % args.test_interval == 0:
                    test_acc, test_loss, spikes = get_acc_loss_and_spikes(
                        network, test_data,
                        find_indices_for_labels(test_data,
                                                args.labels), args.S_prime,
                        args.n_classes, [1], args.input_shape, args.dt,
                        args.dataset.root.stats.train_data[1], args.polarity)
                    save_dict_acc[s + 1].append(test_acc)
                    save_dict_loss[s + 1].append(test_loss)

                    network.train()
                    save_results(save_dict_acc,
                                 args.save_path + r'/test_acc.pkl')
                    save_results(save_dict_loss,
                                 args.save_path + r'/test_loss.pkl')

                    print('Acc at step %d : %f' % (s, test_acc))

            dist.barrier(all_nodes)

            if rank != 0:
                # if (s + 1) % args.test_interval == 0:
                #     acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1],
                #                                              args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity)
                #     save_dict_acc[s + 1].append(acc)
                #     save_dict_loss[s + 1].append(loss)
                #
                #     save_results(save_dict_acc, args.save_path + r'/test_acc.pkl')
                #     save_results(save_dict_loss, args.save_path + r'/test_loss.pkl')
                #
                #     network.train()

                if s % args.S_prime == 0:  # at each example
                    refractory_period(network)
                    inputs, label = get_example(
                        train_data, indices_local[s // args.S_prime],
                        args.S_prime, args.n_classes, [1], args.input_shape,
                        args.dt, args.dataset.root.stats.train_data[1],
                        args.polarity)
                    inputs = inputs.to(network.device)
                    label = label.to(network.device)

                # lr decay
                # if s % S / 4 == 0:
                #     args.lr /= 2

                # Feedforward sampling
                log_proba, ls_temp, et_temp, _ = feedforward_sampling(
                    network, inputs[:, s % args.S_prime],
                    label[:, s % args.S_prime], ls_temp, et_temp, args)

                # Local feedback and update
                eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update(
                    network, eligibility_trace, et_temp, learning_signal,
                    ls_temp, s, args)

            # Global update
            if (s + 1) % (args.tau * args.deltas) == 0:
                dist.barrier(all_nodes)
                global_update(all_nodes, rank, network, weights_list)
                dist.barrier(all_nodes)

        # Final global update
        dist.barrier(all_nodes)
        global_update(all_nodes, rank, network, weights_list)
        dist.barrier(all_nodes)

        if rank == 0:
            test_acc, test_loss, spikes = get_acc_loss_and_spikes(
                network, test_data,
                find_indices_for_labels(test_data, args.labels), args.S_prime,
                args.n_classes, [1], args.input_shape, args.dt,
                args.dataset.root.stats.train_data[1], args.polarity)
            save_dict_acc[S].append(test_acc)
            save_dict_loss[S].append(test_loss)

        network.train()

        save_results(save_dict_acc, args.save_path + r'/acc.pkl')
        save_results(save_dict_loss, args.save_path + r'/loss.pkl')
Ejemplo n.º 2
0
def run(rank, size):
    local_train_length = 300
    local_test_length = 100
    train_indices = torch.zeros([3, local_train_length], dtype=torch.long)
    test_indices = torch.zeros([3, local_test_length], dtype=torch.long)

    local_data_path = '/home/cream/Desktop/arafin_experiments/SOCC/FL-SNN/data/'
    save_path = os.getcwd() + r'/results'

    datasets = {'mnist_dvs_10': r'mnist_dvs_25ms_26pxl_10_digits.hdf5'}
    dataset = local_data_path + datasets['mnist_dvs_10']

    input_train = torch.FloatTensor(
        tables.open_file(dataset).root.train.data[:])
    output_train = torch.FloatTensor(
        tables.open_file(dataset).root.train.label[:])

    input_test = torch.FloatTensor(tables.open_file(dataset).root.test.data[:])
    output_test = torch.FloatTensor(
        tables.open_file(dataset).root.test.label[:])
    ### Network parameters
    n_input_neurons = input_train.shape[1]
    n_output_neurons = output_train.shape[1]
    n_hidden_neurons = 16
    epochs = local_train_length
    epochs_test = local_test_length

    learning_rate = 0.005 / n_hidden_neurons
    kappa = 0.2
    alpha = 1
    deltas = 1
    num_ite = 1
    r = 0.3
    weights_magnitude = 0.05
    task = 'supervised'
    mode = 'train',
    tau_ff = 10
    tau_fb = 10
    tau = 10
    mu = 1.5,
    n_basis_feedforward = 8
    feedforward_filter = filters.raised_cosine_pillow_08
    feedback_filter = filters.raised_cosine_pillow_08
    n_basis_feedback = 1
    topology = torch.ones([
        n_hidden_neurons + n_output_neurons,
        n_input_neurons + n_hidden_neurons + n_output_neurons
    ],
                          dtype=torch.float)
    topology[[i for i in range(n_output_neurons + n_hidden_neurons)], [
        i + n_input_neurons for i in range(n_output_neurons + n_hidden_neurons)
    ]] = 0
    assert torch.sum(topology[:, :n_input_neurons]) == (
        n_input_neurons * (n_hidden_neurons + n_output_neurons))
    print(topology[:, n_input_neurons:])
    # Create the network
    network = SNNetwork(**utils.training_utils.make_network_parameters(
        n_input_neurons,
        n_output_neurons,
        n_hidden_neurons,
        topology_type='fully_connected'))

    # At the beginning, the master node:
    # - transmits its weights to the workers
    # - distributes the samples among workers
    if rank == 0:
        # Initializing an aggregation list for future weights collection
        weights_list = [
            [
                torch.zeros(network.feedforward_weights.shape,
                            dtype=torch.float) for _ in range(size)
            ],
            [
                torch.zeros(network.feedback_weights.shape, dtype=torch.float)
                for _ in range(size)
            ],
            [
                torch.zeros(network.bias.shape, dtype=torch.float)
                for _ in range(size)
            ], [torch.zeros(1, dtype=torch.float) for _ in range(size)]
        ]
    else:
        weights_list = []

    if rank == 0:
        train_indicess = torch.tensor(np.random.choice(np.arange(
            input_train.shape[0]), [3, local_train_length],
                                                       replace=False),
                                      dtype=torch.long)
        test_indicess = torch.tensor(np.random.choice(np.arange(
            input_test.shape[0]), [3, local_test_length],
                                                      replace=False),
                                     dtype=torch.long)
        dist.send(tensor=train_indicess, dst=1)
        dist.send(tensor=train_indicess, dst=2)
        dist.send(tensor=train_indicess, dst=3)
    else:
        dist.recv(tensor=train_indices, src=0)
    dist.barrier()

    if rank == 0:
        dist.send(tensor=test_indicess, dst=1)
        dist.send(tensor=test_indicess, dst=2)
        dist.send(tensor=test_indicess, dst=3)
    else:
        dist.recv(tensor=test_indices, src=0)
    dist.barrier()
    if rank != 0:
        training_data = input_train[train_indices[rank - 1, :]]
        training_label = output_train[train_indices[rank - 1, :]]
        test_data = input_test[test_indices[rank - 1, :]]
        test_label = output_test[test_indices[rank - 1, :]]

        indices = np.random.choice(np.arange(training_data.shape[0]),
                                   [training_data.shape[0]],
                                   replace=True)
        S_prime = training_data.shape[-1]
        S = epochs * S_prime
        print("S is", S)
    dist.barrier()

    group = dist.group.WORLD
    # Master node sends its weights
    for parameter in network.get_parameters():
        dist.broadcast(network.get_parameters()[parameter], 0)
    if rank == 0:
        print(
            'Node 0 has shared its model and training data is partitioned among workers'
        )
    # The nodes initialize their eligibility trace and learning signal
    eligibility_trace = {'ff_weights': 0, 'fb_weights': 0, 'bias': 0}
    et_temp = {'ff_weights': 0, 'fb_weights': 0, 'bias': 0}

    learning_signal = 0
    ls_temp = 0
    dist.barrier()
    num_ite = 1

    test_accs = []
    if rank != 0:
        test_indx = np.random.choice(np.arange(test_data.shape[0]),
                                     [test_data.shape[0]],
                                     replace=False)
        np.random.shuffle(test_indx)

        _, loss = get_acc_and_loss(network, test_data[test_indx],
                                   test_label[test_indx])

        network.set_mode('train')
        local_training_sequence = torch.cat((training_data, training_label),
                                            dim=1)
    dist.barrier()
    ### First local step
    for i in range(num_ite):
        for s in range(deltas):
            if rank != 0:
                # Feedforward sampling step
                log_proba, learning_signal, eligibility_trace \
                    = feedforward_sampling(network, local_training_sequence[indices[0]], eligibility_trace, learning_signal, s, S_prime, alpha, r)

        if rank != 0:
            # First local update
            for parameter in eligibility_trace:
                eligibility_trace[parameter][
                    network.hidden_neurons -
                    network.n_non_learnable_neurons] *= learning_signal
                network.get_parameters(
                )[parameter] += eligibility_trace[parameter] * learning_rate

        # First global update
        if (s + 1) % (tau * deltas) == 0:
            dist.barrier()
            global_update(group, rank, network, weights_list)
            dist.barrier()

        S = input_train.shape[-1] * local_train_length
        ### Remainder of the steps
        for s in range(deltas, S):
            print(s)
            if rank != 0:
                if s % S_prime == 0:  # Reset internal state for each example
                    network.reset_internal_state()

                # lr decay
                if (s % S / 5 == 0) & (learning_rate > 0.005):
                    learning_rate /= 2

                # Feedforward sampling
                log_proba, ls_temp, et_temp \
                    = feedforward_sampling(network, local_training_sequence[indices[0]], et_temp, ls_temp, s, S_prime, alpha, r)

                # Local feedback and global update
                learning_signal, ls_temp, eligibility_trace, et_temp \
                    = local_feedback_and_update(network, eligibility_trace, learning_signal, et_temp, ls_temp, learning_rate, kappa, s, deltas)

                ## Every few timesteps, record test losses
                if (s + 1) % 40 == 0:
                    _, loss = get_acc_and_loss(network, test_data[test_indx],
                                               test_label[test_indx])

                    network.set_mode('train')

            # Global update
            if (s + 1) % (tau * deltas) == 0:
                dist.barrier()
                global_update(group, rank, network, weights_list)
                dist.barrier()

        if rank == 0:
            global_test_indices = np.random.choice(np.arange(
                input_test.shape[0]), [epochs_test],
                                                   replace=False)
            np.random.shuffle(global_test_indices)
            print(global_test_indices)
            global_acc, _ = get_acc_and_loss(network,
                                             input_test[global_test_indices],
                                             output_test[global_test_indices])
            print('Final global test accuracy: %f' % global_acc)
def train(rank, num_nodes, net_params, train_params):
    # Setup training parameters
    dataset = train_params['dataset']
    epochs = train_params['epochs']
    epochs_test = train_params['epochs_test']
    deltas = train_params['deltas']
    num_ite = train_params['num_ite']
    save_path = net_params['save_path']
    tau = train_params['tau']

    learning_rate = train_params['learning_rate']
    alpha = train_params['alpha']
    eta = train_params['eta']
    kappa = train_params['kappa']
    r = train_params['r']

    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2, 3],
                               timeout=datetime.timedelta(0, 360000))

    test_accuracies = []  # used to store test accuracies
    test_loss = [[] for _ in range(num_ite)]
    test_indices = np.hstack((np.arange(900, 1000)[:epochs_test]))
    if (rank == 1):
        print(test_indices)
    print('training at node', rank)
    for i in range(num_ite):
        # Initialize main parameters for training
        network, local_training_sequence, weights_list, S_prime, S, eligibility_trace, et_temp, learning_signal, ls_temp \
            = init_training(rank, num_nodes, all_nodes, dataset, eta, epochs, net_params)

        dist.barrier(all_nodes)

        if rank != 0:
            _, loss = get_acc_and_loss(network, dataset, test_indices)
            test_loss[i].append((0, loss))
            network.set_mode('train')

        dist.barrier(all_nodes)

        ### First local step
        for s in range(deltas):
            if rank != 0:
                print('local trainig sequence', local_training_sequence)
                # Feedforward sampling step
                log_proba, learning_signal, eligibility_trace \
                    = feedforward_sampling(network, local_training_sequence, eligibility_trace, learning_signal, s, S_prime, alpha, r)

        if rank != 0:
            # First local update
            for parameter in eligibility_trace:
                eligibility_trace[parameter][
                    network.hidden_neurons -
                    network.n_non_learnable_neurons] *= learning_signal
                network.get_parameters(
                )[parameter] += eligibility_trace[parameter] * learning_rate

        # First global update
        if (s + 1) % (tau * deltas) == 0:
            dist.barrier(all_nodes)
            global_update(all_nodes, rank, network, weights_list)
            dist.barrier(all_nodes)

        ### Remainder of the steps
        for s in range(deltas, S):
            if rank != 0:
                if s % S_prime == 0:  # Reset internal state for each example
                    network.reset_internal_state()

                # lr decay
                if (s % S / 5 == 0) & (learning_rate > 0.005):
                    learning_rate /= 2

                # Feedforward sampling
                log_proba, ls_temp, et_temp \
                    = feedforward_sampling(network, local_training_sequence, et_temp, ls_temp, s, S_prime, alpha, r)

                # Local feedback and global update
                learning_signal, ls_temp, eligibility_trace, et_temp \
                    = local_feedback_and_update(network, eligibility_trace, learning_signal, et_temp, ls_temp, learning_rate, kappa, s, deltas)

                ## Every few timesteps, record test losses
                if (s + 1) % 40 == 0:
                    _, loss = get_acc_and_loss(network, dataset, test_indices)
                    test_loss[i].append((s, loss))
                    network.set_mode('train')

            # Global update
            if (s + 1) % (tau * deltas) == 0:
                dist.barrier(all_nodes)
                global_update(all_nodes, rank, network, weights_list)
                dist.barrier(all_nodes)

        if rank == 0:
            global_acc, _ = get_acc_and_loss(network, dataset, test_indices)
            test_accuracies.append(global_acc)
            print('Iteration: %d, final accuracy: %f' % (i, global_acc))

    if rank == 0:
        if save_path is None:
            save_path = os.getcwd()
        np.save(save_path + r'/test_accuracies.npy',
                arr=np.array(test_accuracies))
        print('Training finished and accuracies saved to ' + save_path +
              r'/test_accuracies.npy')

    else:
        if save_path is None:
            save_path = os.getcwd()

        np.save(save_path + r'/test_loss_w%d.npy' % rank,
                arr=np.array(test_loss))
        print('Training finished and accuracies saved to ' + save_path +
              r'//test_loss_w%d.npy' % rank)