コード例 #1
0
def demo_mnist_logreg(minibatch_size=20, learning_rate=0.01, max_training_samples = None, n_epochs=10, test_epoch_period=0.2):
    """
    Train a Logistic Regressor on the MNIST dataset, report training/test scores throughout training, and return the
    final scores.
    """

    x_train, y_train, x_test, y_test = get_mnist_dataset(flat=True, n_training_samples=max_training_samples).xyxy

    predictor = OnlineLogisticRegressor(n_in=784, n_out=10, learning_rate=learning_rate)

    # Train and periodically record scores.
    epoch_scores = []
    for ix, iteration_info in minibatch_index_info_generator(n_samples = len(x_train), minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_epoch_period)):
        if iteration_info.test_now:
            training_error = 100*(np.argmax(predictor.predict(x_train), axis=1)==y_train).mean()
            test_error = 100*(np.argmax(predictor.predict(x_test), axis=1)==y_test).mean()
            print('Epoch {epoch}: Test Error: {test}%, Training Error: {train}%'.format(epoch=iteration_info.epoch, test=test_error, train=training_error))
            epoch_scores.append((iteration_info.epoch, training_error, test_error))
        predictor.train(x_train[ix], y_train[ix])

    # Plot
    plt.figure()
    epochs, training_costs, test_costs = zip(*epoch_scores)
    plt.plot(epochs, np.array([training_costs, test_costs]).T)
    plt.xlabel('Epoch')
    plt.ylabel('% Error')
    plt.legend(['Training Error', 'Test Error'])
    plt.title("Learning Curve")
    plt.ylim(80, 100)
    plt.ion()  # Don't hang on plot
    plt.show()

    return {'train': epoch_scores[-1][1], 'test': epoch_scores[-1][2]}  # Return final scores
コード例 #2
0
ファイル: demo_mnist_logreg.py プロジェクト: QUVA-Lab/artemis
def demo_mnist_logreg(minibatch_size=20, learning_rate=0.01, max_training_samples = None, n_epochs=10, test_epoch_period=0.2):
    """
    Train a Logistic Regressor on the MNIST dataset, report training/test scores throughout training, and return the
    final scores.
    """

    x_train, y_train, x_test, y_test = get_mnist_dataset(flat=True, n_training_samples=max_training_samples).xyxy

    predictor = OnlineLogisticRegressor(n_in=784, n_out=10, learning_rate=learning_rate)

    # Train and periodically record scores.
    epoch_scores = []
    for ix, iteration_info in minibatch_index_info_generator(n_samples = len(x_train), minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_epoch_period)):
        if iteration_info.test_now:
            training_error = 100*(np.argmax(predictor.predict(x_train), axis=1)==y_train).mean()
            test_error = 100*(np.argmax(predictor.predict(x_test), axis=1)==y_test).mean()
            print('Epoch {epoch}: Test Error: {test}%, Training Error: {train}%'.format(epoch=iteration_info.epoch, test=test_error, train=training_error))
            epoch_scores.append((iteration_info.epoch, training_error, test_error))
        predictor.train(x_train[ix], y_train[ix])

    # Plot
    plt.figure()
    epochs, training_costs, test_costs = zip(*epoch_scores)
    plt.plot(epochs, np.array([training_costs, test_costs]).T)
    plt.xlabel('Epoch')
    plt.ylabel('% Error')
    plt.legend(['Training Error', 'Test Error'])
    plt.title("Learning Curve")
    plt.ylim(80, 100)
    plt.ion()  # Don't hang on plot
    plt.show()

    return {'train': epoch_scores[-1][1], 'test': epoch_scores[-1][2]}  # Return final scores
コード例 #3
0
def experiment_mnist_eqprop(
    layer_constructor,
    n_epochs=10,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=.5,
    random_flip_beta=True,
    learning_rate=.05,
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    online_checkpoints_period=None,
    epoch_checkpoint_period=.25,
    skip_zero_epoch_test=False,
    n_test_samples=None,
    prop_direction: Union[str, Tuple] = 'neutral',
    bidirectional=True,
    renew_activations=True,
    do_fast_forward_pass=False,
    rebuild_coders=True,
    l2_loss=None,
    splitstream=True,
    seed=1234,
):
    """
    Replicate the results of Scellier & Bengio:
        Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation
        https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full

    Specifically, the train_model demo here:
        https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop

    Differences between our code and theirs:
    - We do not keep persistent layer activations tied to data points over epochs.  So our results should only really match for the first epoch.
    - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None)
    """

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    rng = get_rng(seed)
    n_in = 784
    n_out = 10
    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    rng = get_rng(rng)

    y_train = y_train.astype(np.float32)

    ra = RunningAverage()
    sp = Speedometer(mode='last')
    is_online_checkpoint = Checkpoints(
        online_checkpoints_period, skip_first=skip_zero_epoch_test
    ) if online_checkpoints_period is not None else lambda: False
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    results = Duck()

    training_states = initialize_states(
        layer_constructor=layer_constructor,
        n_samples=minibatch_size,
        params=initialize_params(layer_sizes=layer_sizes,
                                 initial_weight_scale=initial_weight_scale,
                                 rng=rng))

    if isinstance(prop_direction, str):
        fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction
    else:
        fwd_prop_direction, backward_prop_direction = prop_direction

    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            y_pred_test, y_pred_train = [
                run_inference(
                    x_data=x[:n_test_samples],
                    states=initialize_states(
                        layer_constructor=layer_constructor,
                        params=[s.params for s in training_states],
                        n_samples=min(len(x), n_test_samples)
                        if n_test_samples is not None else len(x)),
                    n_steps=n_negative_steps,
                    prop_direction=fwd_prop_direction,
                ) for x in (x_test, x_train)
            ]
            # y_pred_train = run_inference(x_data=x_train[:n_test_samples], states=initialize_states(params=[s.params for s in training_states], n_samples=min(len(x_train), n_test_samples) if n_test_samples is not None else len(x_train)))
            test_error = percent_argmax_incorrect(y_pred_test,
                                                  y_test[:n_test_samples])
            train_error = percent_argmax_incorrect(y_pred_train,
                                                   y_train[:n_test_samples])
            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Mean Rate: {sp(i):.3g}iter/s'
            )
            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    train_error=train_error,
                                    test_error=test_error)
            yield results
            if epoch > 2 and train_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]
        training_states = run_eqprop_training_update(
            x_data=x_data_sample,
            y_data=y_data_sample,
            layer_states=training_states,
            beta=beta,
            random_flip_beta=random_flip_beta,
            learning_rate=learning_rate,
            layer_constructor=layer_constructor,
            bidirectional=bidirectional,
            l2_loss=l2_loss,
            renew_activations=renew_activations,
            n_negative_steps=n_negative_steps,
            n_positive_steps=n_positive_steps,
            prop_direction=prop_direction,
            splitstream=splitstream,
            rng=rng)
        this_train_score = ra(
            percent_argmax_correct(output_from_state(training_states),
                                   y_train[ixs]))
        if is_online_checkpoint():
            print(
                f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}'
            )
コード例 #4
0
def experiment_mnist_eqprop_torch(
    layer_constructor: Callable[[int, LayerParams], IDynamicLayer],
    n_epochs=10,
    hidden_sizes=(500, ),
    minibatch_size=10,  # update mini-batch size
    batch_size=500,  # total batch size
    beta=.5,
    random_flip_beta=True,
    learning_rate=.05,
    n_negative_steps=120,
    n_positive_steps=80,
    initial_weight_scale=1.,
    online_checkpoints_period=None,
    epoch_checkpoint_period=1.0,  #'100s', #{0: .25, 1: .5, 5: 1, 10: 2, 50: 4},
    skip_zero_epoch_test=False,
    n_test_samples=10000,
    prop_direction: Union[str, Tuple] = 'neutral',
    bidirectional=True,
    renew_activations=True,
    do_fast_forward_pass=False,
    rebuild_coders=True,
    l2_loss=None,
    splitstream=False,
    seed=1234,
    prediction_inp_size=17,  ## prediction input size
    delay=18,  ## delay size for the clamped phase
    pred=True,  ## if you want to use the prediction
    check_flg=False,
):
    """
    Replicate the results of Scellier & Bengio:
        Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation
        https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full

    Specifically, the train_model demo here:
        https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop

    Differences between our code and theirs:
    - We do not keep persistent layer activations tied to data points over epochs.  So our results should only really match for the first epoch.
    - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None)
    """
    torch.manual_seed(seed)
    device = 'cuda' if torch.cuda.is_available(
    ) and USE_CUDA_WHEN_AVAILABLE else 'cpu'
    if device == 'cuda':
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print(f'Using Device: {device}')

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = input_data.read_data_sets('MNIST_data', one_hot=True)

    x_train, y_train = torch.tensor(
        dataset.train.images, dtype=torch.float32
    ).to(device), torch.tensor(dataset.train.labels, dtype=torch.float32).to(
        device
    )  #(torch.tensor(a.astype(np.float32)).to(device) for a in dataset.mnist.train.images.xy)
    x_test, y_test = torch.tensor(
        dataset.test.images, dtype=torch.float32).to(device), torch.tensor(
            dataset.test.labels, dtype=torch.float32).to(
                device)  # Their 'validation set' is our 'test set'
    x_val, y_val = torch.tensor(
        dataset.validation.images,
        dtype=torch.float32).to(device), torch.tensor(
            dataset.validation.labels, dtype=torch.float32).to(
                device)  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test, x_val, y_val = x_train[:
                                                                 100], y_train[:
                                                                               100], x_test[:
                                                                                            100], y_test[:
                                                                                                         100], x_val[:
                                                                                                                     100], y_val[:
                                                                                                                                 100]
        n_epochs = 1
        n_negative_steps = 3
        n_positive_steps = 3

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    ra = RunningAverage()
    sp = Speedometer(mode='last')
    is_online_checkpoint = Checkpoints(
        online_checkpoints_period, skip_first=skip_zero_epoch_test
    ) if online_checkpoints_period is not None else lambda: False
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    training_states = initialize_states(
        layer_constructor=layer_constructor,
        #n_samples=minibatch_size,
        n_samples=batch_size,
        params=initialize_params(layer_sizes=layer_sizes,
                                 initial_weight_scale=initial_weight_scale,
                                 rng=rng))

    # dbplot(training_states[0].params.w_fore[:10, :10], str(rng.randint(265)))

    if isinstance(prop_direction, str):
        fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction
    else:
        fwd_prop_direction, backward_prop_direction = prop_direction

    def do_test():
        # n_samples = n_test_samples if n_test_samples is not None else len(x_test)
        test_error, train_error, val_error = [
            percent_argmax_incorrect(
                run_inference(
                    x_data=x[:n_test_samples],
                    states=initialize_states(
                        layer_constructor=layer_constructor,
                        params=[s.params for s in training_states],
                        n_samples=n_samples),
                    n_steps=n_negative_steps,
                    prop_direction=fwd_prop_direction,
                ), y[:n_samples]).item()
            for x, y in [(x_test, y_test), (x_train, y_train), (x_val, y_val)]
            for n_samples in [
                min(len(x), n_test_samples
                    ) if n_test_samples is not None else len(x)
            ]
        ]  # Not an actal loop... just hack for assignment in comprehensions
        print(
            f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Validation Error: {val_error:.3g}, Mean Rate: {sp(i):.3g}iter/s'
        )

        return dict(iter=i,
                    epoch=epoch,
                    train_error=train_error,
                    test_error=test_error,
                    val_error=val_error), train_error, test_error, val_error

    results = Duck()
    pi = ProgressIndicator(expected_iterations=n_epochs *
                           dataset.train.num_examples / minibatch_size,
                           update_every='10s')

    dy_squared = []
    dy_squared.append(None)
    dy_squared.append(None)
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.size()[0],
                                           minibatch_size=batch_size,
                                           n_epochs=n_epochs)):
        epoch = i * batch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            check_flg = False
            x_train, y_train = shuffle_data(x_train, y_train)
            with pi.pause_measurement():
                results[next, :], train_err, test_err, val_err = do_test()

                ## prepare for saving the parameters
                ws, bs = zip(*((s.params.w_aft, s.params.b)
                               for s in training_states[1:]))

                f = None
                if os.path.isfile(directory + '/log.txt'):
                    f = open(directory + '/log.txt', 'a')
                else:
                    os.mkdir(directory)
                    f = open(directory + '/log.txt', 'w')

                f.write("Epoch: " + str(epoch) + '\n')
                f.write("accuracy for training: " + str(train_err) + '\n')
                f.write("accuracy for testing: " + str(test_err) + '\n')
                f.write("accuracy for validation: " + str(val_err) + '\n')

                f.close()

                np.save(directory + '/w_epoch_' + str(epoch) + '.npy', ws)
                np.save(directory + '/b_epoch_' + str(epoch) + '.npy', bs)
                np.save(directory + '/dy_squared_epoch_' + str(epoch) + '.npy',
                        dy_squared)

                yield results
                if epoch > 100 and results[-1, 'train_error'] > 50:
                    return

        # The Original training loop, just taken out here:
        ixs = ixs.astype(np.int32)  # this is for python version 3.7

        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        training_states, dy_squared = run_eqprop_training_update(
            x_data=x_data_sample,
            y_data=y_data_sample,
            layer_states=training_states,
            beta=beta,
            random_flip_beta=random_flip_beta,
            learning_rate=learning_rate,
            layer_constructor=layer_constructor,
            bidirectional=bidirectional,
            l2_loss=l2_loss,
            renew_activations=renew_activations,
            n_negative_steps=n_negative_steps,
            n_positive_steps=n_positive_steps,
            prop_direction=prop_direction,
            splitstream=splitstream,
            rng=rng,
            prediction_inp_size=prediction_inp_size,
            delay=delay,
            device=device,
            epoch_check=check_flg,
            epoch=epoch,
            pred=pred,
            batch_size=batch_size,
            minibatch_size=minibatch_size,
            dy_squared=dy_squared)
        check_flg = False

        this_train_score = ra(
            percent_argmax_incorrect(output_from_state(training_states),
                                     y_train[ixs]))
        if is_online_checkpoint():
            print(
                f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}'
            )

        pi.print_update(info=f'Epoch: {epoch}')

    results[next, :], train_err, test_err, val_err = do_test()
    yield results
def demo_energy_based_initialize_eq_prop_fwd_energy(
    n_epochs=25,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=0.5,
    epsilon=0.5,
    learning_rate=(0.1, .05),
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    forward_deviation_cost=0.,
    zero_deviation_cost=1,
    epoch_checkpoint_period={
        0: .25,
        1: .5,
        5: 1,
        10: 2,
        50: 4
    },
    n_test_samples=10000,
    skip_zero_epoch_test=False,
    train_with_forward='contrast',
    forward_nonlinearity='rho(x)',
    local_loss=True,
    random_flip_beta=True,
    seed=1234,
):

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    assert train_with_forward in ('contrast', 'contrast+', 'energy', False)

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    eq_params = initialize_params(layer_sizes=layer_sizes,
                                  initial_weight_scale=initial_weight_scale,
                                  rng=rng)
    forward_params = initialize_params(
        layer_sizes=layer_sizes,
        initial_weight_scale=initial_weight_scale,
        rng=rng)

    y_train = y_train.astype(np.float32)

    sp = Speedometer(mode='last')
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    f_negative_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_inference_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_positive_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_parameter_update = update_eq_params.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_forward_pass = forward_pass.partial(
        nonlinearity=forward_nonlinearity).compile()
    # f_forward_parameter_update = update_forward_params_with_energy.compile()
    f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial(
        nonlinearity=forward_nonlinearity).compile()

    def do_inference(forward_params_, eq_params_, x, n_steps):
        states_ = forward_states_ = f_forward_pass(
            x=x, params=forward_params_
        ) if train_with_forward else initialize_states(
            n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:])
        for _ in range(n_steps):
            states_ = f_inference_eq_step(params=eq_params_,
                                          states=states_,
                                          fwd_states=forward_states_,
                                          x=x,
                                          epsilon=epsilon)
        return forward_states_[-1], states_[-1]

    results = Duck()
    # last_time, last_epoch = time(), -1
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        # print(f'Training Rate: {(time()-last_time)/(epoch-last_epoch):3g}s/ep')
        # last_time, last_epoch = time(), epoch

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            (test_init_error,
             test_neg_error), (train_init_error, train_neg_error) = [[
                 percent_argmax_incorrect(prediction, y[:n_test_samples])
                 for prediction in do_inference(forward_params_=forward_params,
                                                eq_params_=eq_params,
                                                x=x[:n_test_samples],
                                                n_steps=n_negative_steps)
             ] for x, y in [(x_test, y_test), (x_train, y_train)]]
            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s'
            )
            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    test_init_error=test_init_error,
                                    test_neg_error=test_neg_error,
                                    train_init_error=train_init_error,
                                    train_neg_error=train_neg_error)
            yield results
            if epoch > 2 and train_neg_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        states = forward_states = f_forward_pass(
            x=x_data_sample, params=forward_params
        ) if train_with_forward else initialize_states(
            n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:])
        for t in range(n_negative_steps):
            states = f_negative_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        epsilon=epsilon,
                                        fwd_states=forward_states)
        negative_states = states
        this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta
        for t in range(n_positive_steps):
            states = f_positive_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        y=y_data_sample,
                                        y_pressure=this_beta,
                                        epsilon=epsilon,
                                        fwd_states=forward_states)
        positive_states = states
        eq_params = f_parameter_update(x=x_data_sample,
                                       params=eq_params,
                                       negative_states=negative_states,
                                       positive_states=positive_states,
                                       fwd_states=forward_states,
                                       learning_rates=learning_rate,
                                       beta=this_beta)

        if train_with_forward == 'contrast':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=negative_states,
                learning_rates=learning_rate)
            # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate])
        elif train_with_forward == 'contrast+':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=positive_states,
                learning_rates=learning_rate)
        # elif train_with_forward == 'energy':
        #     forward_params = f_forward_parameter_update(x=x_data_sample, forward_params = forward_params, eq_params=eq_params, learning_rates=learning_rate)
        else:
            assert train_with_forward is False
コード例 #6
0
def demo_energy_based_initialize_eq_prop_alignment(
    n_epochs=25,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=0.5,
    epsilon=0.5,
    learning_rate=(0.1, .05),
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    epoch_checkpoint_period={
        0: .25,
        1: .5,
        5: 1,
        10: 2,
        50: 4
    },
    n_test_samples=10000,
    skip_zero_epoch_test=False,
    train_with_forward='contrast',
    forward_nonlinearity='rho(x)',
    local_loss=True,
    random_flip_beta=True,
    seed=1234,
):

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    assert train_with_forward in ('contrast', 'contrast+', 'energy', False)

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    eq_params = initialize_params(layer_sizes=layer_sizes,
                                  initial_weight_scale=initial_weight_scale,
                                  rng=rng)
    forward_params = initialize_params(
        layer_sizes=layer_sizes,
        initial_weight_scale=initial_weight_scale,
        rng=rng)

    y_train = y_train.astype(np.float32)

    sp = Speedometer(mode='last')
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    f_negative_eq_step = equilibriating_step.compile()
    f_inference_eq_step = equilibriating_step.compile()
    f_positive_eq_step = equilibriating_step.compile()
    f_parameter_update = update_eq_params.compile()
    f_forward_pass = forward_pass.partial(
        nonlinearity=forward_nonlinearity).compile()
    f_forward_parameter_update = update_forward_params_with_energy.partial(
        disconnect_grads=local_loss,
        nonlinearity=forward_nonlinearity).compile()
    f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial(
        disconnect_grads=local_loss,
        nonlinearity=forward_nonlinearity).compile()
    f_energy = energy.compile()
    f_grad_align = compute_gradient_alignment.partial(
        nonlinearity=forward_nonlinearity).compile()

    def do_inference(forward_params_, eq_params_, x, n_steps):
        states_ = forward_states_ = f_forward_pass(
            x=x, params=forward_params_
        ) if train_with_forward else initialize_states(
            n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:])
        for _ in range(n_steps):
            states_ = f_inference_eq_step(params=eq_params_,
                                          states=states_,
                                          x=x,
                                          epsilon=epsilon)
        return forward_states_[-1], states_[-1]

    results = Duck()
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            (test_init_error,
             test_neg_error), (train_init_error, train_neg_error) = [[
                 percent_argmax_incorrect(prediction, y[:n_test_samples])
                 for prediction in do_inference(forward_params_=forward_params,
                                                eq_params_=eq_params,
                                                x=x[:n_test_samples],
                                                n_steps=n_negative_steps)
             ] for x, y in [(x_test, y_test), (x_train, y_train)]]

            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s'
            )

            # ===== Compute alignment
            alignment_neg_states = forward_states_ = f_forward_pass(
                x=x_train, params=forward_params
            ) if train_with_forward else initialize_states(
                n_samples=x_train.shape[0],
                noninput_layer_sizes=layer_sizes[1:])
            for _ in range(n_negative_steps):
                alignment_neg_states = f_inference_eq_step(
                    params=eq_params,
                    states=alignment_neg_states,
                    x=x_train,
                    epsilon=epsilon)

            grad_alignments = f_grad_align(forward_params=forward_params,
                                           eq_states=alignment_neg_states,
                                           x=x_train)
            dbplot(grad_alignments,
                   'alignments',
                   plot_type=DBPlotTypes.LINE_HISTORY)
            # --------------------

            # fwd_states, neg_states = do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps)

            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    test_init_error=test_init_error,
                                    test_neg_error=test_neg_error,
                                    train_init_error=train_init_error,
                                    train_neg_error=train_neg_error,
                                    alignments=grad_alignments)
            yield results
            if epoch > 2 and train_neg_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        states = forward_states = f_forward_pass(
            x=x_data_sample, params=forward_params
        ) if train_with_forward else initialize_states(
            n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:])
        for t in range(n_negative_steps):
            # if i % 200 == 0:
            #     with hold_dbplots():
            #         dbplot_collection(states, 'states', cornertext='NEG')
            #         dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)
            states = f_negative_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        epsilon=epsilon)
        negative_states = states
        this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta
        for t in range(n_positive_steps):
            # if i % 200 == 0:
            #     with hold_dbplots():
            #         dbplot_collection(states, 'states', cornertext='')
            #         dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)
            states = f_positive_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        y=y_data_sample,
                                        y_pressure=this_beta,
                                        epsilon=epsilon)
        positive_states = states
        eq_params = f_parameter_update(x=x_data_sample,
                                       params=eq_params,
                                       negative_states=negative_states,
                                       positive_states=positive_states,
                                       learning_rates=learning_rate,
                                       beta=this_beta)

        # with hold_dbplots(draw_every=50):
        #     dbplot_collection([forward_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in forward_params[1:]], '$\phi$')
        #     dbplot_collection([eq_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in eq_params[1:]], '$\\theta$')
        #     dbplot_collection(forward_states, 'forward states')
        #     dbplot_collection(negative_states, 'negative_states')
        #     dbplot(np.array([f_energy(params = eq_params, states=forward_states, x=x_data_sample).mean(), f_energy(params = eq_params, states=negative_states, x=x_data_sample).mean()]), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)

        if train_with_forward == 'contrast':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=negative_states,
                learning_rates=learning_rate)
            # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate])
        elif train_with_forward == 'contrast+':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=positive_states,
                learning_rates=learning_rate)
        elif train_with_forward == 'energy':
            forward_params = f_forward_parameter_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_params=eq_params,
                learning_rates=learning_rate)
        else:
            assert train_with_forward is False