Ejemplo n.º 1
0
def train_fixed_rate(rank, num_nodes, args):
    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2],
                               timeout=datetime.timedelta(0, 360000))

    # Setup training parameters
    args.dataset = tables.open_file(args.dataset)

    train_data = args.dataset.root.train
    test_data = args.dataset.root.test

    args.S_prime = int(args.sample_length * 1000 / args.dt)
    S = args.num_samples_train * args.S_prime

    args, test_indices, save_dict, save_path = init_test(rank, args)

    for tau in args.tau_list:
        n_weights_to_send = int(tau * args.rate)

        for _ in range(args.num_ite):
            # Initialize main parameters for training
            network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training(
                rank, num_nodes, all_nodes, args)

            # Gradients accumulator
            gradients_accum = torch.zeros(network.feedforward_weights.shape,
                                          dtype=torch.float)
            dist.barrier(all_nodes)

            for s in range(S):
                if rank != 0:
                    if s % args.S_prime == 0:  # Reset internal state for each example
                        refractory_period(network)
                        inputs, label = get_example(
                            train_data, s // args.S_prime, args.S_prime,
                            args.n_classes, args.input_shape, args.dt,
                            args.dataset.root.stats.train_data[1],
                            args.polarity)
                        inputs = inputs.to(network.device)
                        label = label.to(network.device)

                    # lr decay
                    # if (s + 1) % int(S / 4) == 0:
                    #     args.lr /= 2

                    # Feedforward sampling
                    log_proba, ls_temp, et_temp, gradients_accum = feedforward_sampling(
                        network, inputs[:, s % args.S_prime],
                        label[:, s % args.S_prime], ls_temp, et_temp, args,
                        gradients_accum)

                    # Local feedback and update
                    eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update(
                        network, eligibility_trace, et_temp, learning_signal,
                        ls_temp, s, args)

                # Global update
                if (s + 1) % (tau * args.deltas) == 0:
                    dist.barrier(all_nodes)
                    global_update_subset(all_nodes, rank, network,
                                         weights_list, gradients_accum,
                                         n_weights_to_send)
                    gradients_accum = torch.zeros(
                        network.feedforward_weights.shape, dtype=torch.float)
                    dist.barrier(all_nodes)

            if rank == 0:
                global_acc, _ = get_acc_and_loss(
                    network, test_data, test_indices, args.S_prime,
                    args.n_classes, args.input_shape, args.dt,
                    args.dataset.root.stats.train_data[1], args.polarity)
                save_dict[tau].append(global_acc)
                save_results(save_dict, save_path)
                print('Tau: %d, final accuracy: %f' % (tau, global_acc))

    if rank == 0:
        save_results(save_dict, save_path)
        print('Training finished and accuracies saved to ' + save_path)
Ejemplo n.º 2
0
def train(rank, num_nodes, net_params, train_params):
    # Setup training parameters
    rate = train_params['rate']
    eta = train_params['eta']
    kappa = train_params['kappa']
    deltas = train_params['deltas']
    epochs = train_params['epochs']
    epochs_test = train_params['epochs_test']
    dataset = train_params['dataset']
    learning_rate = train_params['learning_rate']
    alpha = train_params['alpha']
    r = train_params['r']
    num_ite = train_params['num_ite']
    save_path = net_params['save_path']

    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2],
                               timeout=datetime.timedelta(0, 360000))
    print("Training", flush=True)
    tau_list = [int((2**i) / rate)
                for i in range(4)]  # global update periods used for training
    test_accuracies = [[] for _ in range(len(tau_list))
                       ]  # used to store test accuracies
    test_indices = np.hstack(
        (np.arange(900, 1000)[:epochs_test], np.arange(1900,
                                                       2000)[:epochs_test]))

    for i, tau in enumerate(tau_list):
        n_weights_to_send = int(tau * rate)

        for _ in range(num_ite):
            # Initialize main parameters for training
            network, local_training_sequence, weights_list, S_prime, S, eligibility_trace, et_temp, learning_signal, ls_temp \
                = init_training(rank, num_nodes, all_nodes, dataset, eta, epochs, net_params)

            # Gradients accumulator
            gradients_accum = torch.zeros(network.feedforward_weights.shape,
                                          dtype=torch.float)
            dist.barrier(all_nodes)

            ### First local step
            for s in range(deltas):
                if rank != 0:
                    # Feedforward sampling step
                    log_proba, gradients_accum, learning_signal, eligibility_trace \
                        = feedforward_sampling_accum_gradients(network, local_training_sequence, eligibility_trace, learning_signal, gradients_accum, s, S_prime, alpha, r)

            if rank != 0:
                print('rank 0')
                # First local update
                for parameter in eligibility_trace:
                    eligibility_trace[parameter][
                        network.hidden_neurons -
                        network.n_non_learnable_neurons] *= learning_signal
                    network.get_parameters()[parameter] += eligibility_trace[
                        parameter] * learning_rate

            if (s + 1) % (tau * deltas) == 0:  # First global update
                dist.barrier(all_nodes)
                global_update_subset(all_nodes, rank, network, weights_list,
                                     gradients_accum, n_weights_to_send)
                gradients_accum = torch.zeros(
                    network.feedforward_weights.shape, dtype=torch.float)

                dist.barrier(all_nodes)
                print('Hi')

            ### Remainder of the steps
            for s in range(deltas, S):
                if rank != 0:
                    if s % S_prime == 0:  # Reset internal state for each example
                        network.reset_internal_state()

                    # lr decay
                    if (s % S / 5 == 0) & (learning_rate > 0.005):
                        learning_rate /= 2

                    # Feedforward sampling
                    log_proba, gradients_accum, ls_temp, et_temp \
                        = feedforward_sampling_accum_gradients(network, local_training_sequence, et_temp, ls_temp, gradients_accum, s, S_prime, alpha, r)

                    # Local feedback and global update
                    learning_signal, ls_temp, eligibility_trace, et_temp \
                        = local_feedback_and_update(network, eligibility_trace, learning_signal, et_temp, ls_temp, learning_rate, kappa, s, deltas)

                if (s + 1) % (tau * deltas) == 0:
                    dist.barrier(all_nodes)
                    # Global update
                    global_update_subset(all_nodes, rank, network,
                                         weights_list, gradients_accum,
                                         n_weights_to_send)
                    gradients_accum = torch.zeros(
                        network.feedforward_weights.shape, dtype=torch.float)

                    dist.barrier(all_nodes)

            if rank == 0:
                global_acc, test_loss = get_acc_and_loss(
                    network, dataset, test_indices)
                test_accuracies[i].append(global_acc)
                print('Tau: %d, final accuracy: %f' % (tau, global_acc))

    if rank == 0:
        if save_path is None:
            save_path = os.getcwd()
        np.save(save_path + r'/test_accuracies_r_%f.npy' % rate,
                arr=np.array(test_accuracies))
        print('Training finished and accuracies saved to ' + save_path +
              r'/test_accuracies_r_%f.npy' % rate)
Ejemplo n.º 3
0
def train(rank, num_nodes, args):
    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2],
                               timeout=datetime.timedelta(0, 360000))

    # Setup training parameters
    args.dataset = tables.open_file(args.dataset)

    args.n_classes = args.dataset.root.stats.test_label[1]
    train_data = args.dataset.root.train
    test_data = args.dataset.root.test

    args.S_prime = int(args.sample_length * 1000 / args.dt)
    S = args.num_samples_train * args.S_prime

    args, test_indices, save_dict_loss, save_dict_acc = init_test(args)

    for i in range(args.num_ite):
        # Initialize main parameters for training
        network, indices_local, weights_list, eligibility_trace, et_temp, learning_signal, ls_temp = init_training(
            rank, num_nodes, all_nodes, args)

        dist.barrier(all_nodes)

        # Test loss at beginning + selection of training indices
        # if rank != 0:
        #     print(rank, args.local_labels)
        #     acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1],
        #                                              args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity)
        #     save_dict_acc[0].append(acc)
        #     save_dict_loss[0].append(loss)
        #     network.train()
        # else:
        # test_acc, test_loss, spikes = get_acc_loss_and_spikes(network, test_data, test_indices, args.S_prime, args.n_classes, [1],
        #                                          args.input_shape,  args.dt, args.dataset.root.stats.train_data[1], args.polarity)
        # save_dict_acc[0].append(test_acc)
        # save_dict_loss[0].append(test_loss)
        # np.save(args.save_path + r'/spikes_test_s_%d.npy' % 0, spikes.numpy())
        # network.train()

        dist.barrier(all_nodes)

        for s in range(S):
            if rank == 0:
                if (s + 1) % args.test_interval == 0:
                    test_acc, test_loss, spikes = get_acc_loss_and_spikes(
                        network, test_data,
                        find_indices_for_labels(test_data,
                                                args.labels), args.S_prime,
                        args.n_classes, [1], args.input_shape, args.dt,
                        args.dataset.root.stats.train_data[1], args.polarity)
                    save_dict_acc[s + 1].append(test_acc)
                    save_dict_loss[s + 1].append(test_loss)

                    network.train()
                    save_results(save_dict_acc,
                                 args.save_path + r'/test_acc.pkl')
                    save_results(save_dict_loss,
                                 args.save_path + r'/test_loss.pkl')

                    print('Acc at step %d : %f' % (s, test_acc))

            dist.barrier(all_nodes)

            if rank != 0:
                # if (s + 1) % args.test_interval == 0:
                #     acc, loss = get_acc_and_loss(network, test_data, find_indices_for_labels(test_data, args.labels), args.S_prime, args.n_classes, [1],
                #                                              args.input_shape, args.dt, args.dataset.root.stats.train_data[1], args.polarity)
                #     save_dict_acc[s + 1].append(acc)
                #     save_dict_loss[s + 1].append(loss)
                #
                #     save_results(save_dict_acc, args.save_path + r'/test_acc.pkl')
                #     save_results(save_dict_loss, args.save_path + r'/test_loss.pkl')
                #
                #     network.train()

                if s % args.S_prime == 0:  # at each example
                    refractory_period(network)
                    inputs, label = get_example(
                        train_data, indices_local[s // args.S_prime],
                        args.S_prime, args.n_classes, [1], args.input_shape,
                        args.dt, args.dataset.root.stats.train_data[1],
                        args.polarity)
                    inputs = inputs.to(network.device)
                    label = label.to(network.device)

                # lr decay
                # if s % S / 4 == 0:
                #     args.lr /= 2

                # Feedforward sampling
                log_proba, ls_temp, et_temp, _ = feedforward_sampling(
                    network, inputs[:, s % args.S_prime],
                    label[:, s % args.S_prime], ls_temp, et_temp, args)

                # Local feedback and update
                eligibility_trace, et_temp, learning_signal, ls_temp = local_feedback_and_update(
                    network, eligibility_trace, et_temp, learning_signal,
                    ls_temp, s, args)

            # Global update
            if (s + 1) % (args.tau * args.deltas) == 0:
                dist.barrier(all_nodes)
                global_update(all_nodes, rank, network, weights_list)
                dist.barrier(all_nodes)

        # Final global update
        dist.barrier(all_nodes)
        global_update(all_nodes, rank, network, weights_list)
        dist.barrier(all_nodes)

        if rank == 0:
            test_acc, test_loss, spikes = get_acc_loss_and_spikes(
                network, test_data,
                find_indices_for_labels(test_data, args.labels), args.S_prime,
                args.n_classes, [1], args.input_shape, args.dt,
                args.dataset.root.stats.train_data[1], args.polarity)
            save_dict_acc[S].append(test_acc)
            save_dict_loss[S].append(test_loss)

        network.train()

        save_results(save_dict_acc, args.save_path + r'/acc.pkl')
        save_results(save_dict_loss, args.save_path + r'/loss.pkl')
def train(rank, num_nodes, net_params, train_params):
    # Setup training parameters
    dataset = train_params['dataset']
    epochs = train_params['epochs']
    epochs_test = train_params['epochs_test']
    deltas = train_params['deltas']
    num_ite = train_params['num_ite']
    save_path = net_params['save_path']
    tau = train_params['tau']

    learning_rate = train_params['learning_rate']
    alpha = train_params['alpha']
    eta = train_params['eta']
    kappa = train_params['kappa']
    r = train_params['r']

    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2, 3],
                               timeout=datetime.timedelta(0, 360000))

    test_accuracies = []  # used to store test accuracies
    test_loss = [[] for _ in range(num_ite)]
    test_indices = np.hstack((np.arange(900, 1000)[:epochs_test]))
    if (rank == 1):
        print(test_indices)
    print('training at node', rank)
    for i in range(num_ite):
        # Initialize main parameters for training
        network, local_training_sequence, weights_list, S_prime, S, eligibility_trace, et_temp, learning_signal, ls_temp \
            = init_training(rank, num_nodes, all_nodes, dataset, eta, epochs, net_params)

        dist.barrier(all_nodes)

        if rank != 0:
            _, loss = get_acc_and_loss(network, dataset, test_indices)
            test_loss[i].append((0, loss))
            network.set_mode('train')

        dist.barrier(all_nodes)

        ### First local step
        for s in range(deltas):
            if rank != 0:
                print('local trainig sequence', local_training_sequence)
                # Feedforward sampling step
                log_proba, learning_signal, eligibility_trace \
                    = feedforward_sampling(network, local_training_sequence, eligibility_trace, learning_signal, s, S_prime, alpha, r)

        if rank != 0:
            # First local update
            for parameter in eligibility_trace:
                eligibility_trace[parameter][
                    network.hidden_neurons -
                    network.n_non_learnable_neurons] *= learning_signal
                network.get_parameters(
                )[parameter] += eligibility_trace[parameter] * learning_rate

        # First global update
        if (s + 1) % (tau * deltas) == 0:
            dist.barrier(all_nodes)
            global_update(all_nodes, rank, network, weights_list)
            dist.barrier(all_nodes)

        ### Remainder of the steps
        for s in range(deltas, S):
            if rank != 0:
                if s % S_prime == 0:  # Reset internal state for each example
                    network.reset_internal_state()

                # lr decay
                if (s % S / 5 == 0) & (learning_rate > 0.005):
                    learning_rate /= 2

                # Feedforward sampling
                log_proba, ls_temp, et_temp \
                    = feedforward_sampling(network, local_training_sequence, et_temp, ls_temp, s, S_prime, alpha, r)

                # Local feedback and global update
                learning_signal, ls_temp, eligibility_trace, et_temp \
                    = local_feedback_and_update(network, eligibility_trace, learning_signal, et_temp, ls_temp, learning_rate, kappa, s, deltas)

                ## Every few timesteps, record test losses
                if (s + 1) % 40 == 0:
                    _, loss = get_acc_and_loss(network, dataset, test_indices)
                    test_loss[i].append((s, loss))
                    network.set_mode('train')

            # Global update
            if (s + 1) % (tau * deltas) == 0:
                dist.barrier(all_nodes)
                global_update(all_nodes, rank, network, weights_list)
                dist.barrier(all_nodes)

        if rank == 0:
            global_acc, _ = get_acc_and_loss(network, dataset, test_indices)
            test_accuracies.append(global_acc)
            print('Iteration: %d, final accuracy: %f' % (i, global_acc))

    if rank == 0:
        if save_path is None:
            save_path = os.getcwd()
        np.save(save_path + r'/test_accuracies.npy',
                arr=np.array(test_accuracies))
        print('Training finished and accuracies saved to ' + save_path +
              r'/test_accuracies.npy')

    else:
        if save_path is None:
            save_path = os.getcwd()

        np.save(save_path + r'/test_loss_w%d.npy' % rank,
                arr=np.array(test_loss))
        print('Training finished and accuracies saved to ' + save_path +
              r'//test_loss_w%d.npy' % rank)