def demo_settling_dynamics(symmetric=False,
                           n_hidden=50,
                           n_out=3,
                           input_influence=0.01,
                           learning_rate=0.0001,
                           cut_time=None,
                           minibatch_size=1,
                           decay=0.05,
                           scale=.4,
                           hidden_act='tanh',
                           output_act='lin',
                           draw_every=10,
                           n_steps=10000,
                           seed=124):
    """
    Here we use Predictive Coding and compare_learning_curves the convergence of a predictive-coded network to one without.
    """

    rng = get_rng(seed)
    net_d = Network.from_init(symmetric=symmetric,
                              n_hidden=n_hidden,
                              n_out=n_out,
                              scale=scale,
                              fh=hidden_act,
                              fx=output_act,
                              decay=decay,
                              rng=rng)
    state_d = net_d.init_state(minibatch_size=minibatch_size)

    net_l = Network.from_init(symmetric=symmetric,
                              n_hidden=n_hidden,
                              n_out=n_out,
                              scale=scale,
                              fh=hidden_act,
                              fx=output_act,
                              decay=decay,
                              rng=rng,
                              input_influence=input_influence,
                              learning_rate=learning_rate)
    state_l = net_l.init_state(minibatch_size=minibatch_size)

    sp = Speedometer()
    for t in range(n_steps):

        error = (state_d.x[0] - state_l.x[0]).mean()
        with hold_dbplots(draw_every=draw_every):
            dbplot(state_d.h[0], 'hd')
            dbplot(state_d.x[0], 'xd')
            dbplot(state_l.h[0], 'hl')
            dbplot(state_l.x[0], 'xl')
            dbplot(np.array([abs(net_l.w_hx).mean()]), 'wmag')
            dbplot(error, 'error')

        state_d = net_d.update(state_d)
        state_l = net_l.update(
            state_l,
            inp=state_d.x if cut_time is None or t < cut_time else None)

        if t % 100 == 0:
            print(f'Rate: {sp(t+1)} iter/s')
def demo_settling_dynamics(symmetric=False,
                           n_hidden=50,
                           n_out=3,
                           minibatch_size=1,
                           decay=0.05,
                           scale=.4,
                           hidden_act='tanh',
                           output_act='lin',
                           draw_every=10,
                           n_steps=10000,
                           seed=124):
    """
    Here we use Predictive Coding and compare_learning_curves the convergence of a predictive-coded network to one without.
    """

    net = Network.from_init(symmetric=symmetric,
                            n_hidden=n_hidden,
                            n_out=n_out,
                            scale=scale,
                            fh=hidden_act,
                            fx=output_act,
                            decay=decay,
                            rng=seed)

    state = net.init_state(minibatch_size=minibatch_size)

    sp = Speedometer()
    for t in range(n_steps):
        state = net.update(state)

        if t % 100 == 0:
            print(f'Rate: {sp(t+1)} iter/s')

        with hold_dbplots(draw_every=draw_every):

            dbplot(state.h[0],
                   'Hidden Units',
                   title='Hidden Units (b={})'.format(net.b_h))
            dbplot(state.x[0],
                   'Y Units',
                   title='Y Units (b={})'.format(net.b_h))
def experiment_mnist_eqprop(
    layer_constructor,
    n_epochs=10,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=.5,
    random_flip_beta=True,
    learning_rate=.05,
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    online_checkpoints_period=None,
    epoch_checkpoint_period=.25,
    skip_zero_epoch_test=False,
    n_test_samples=None,
    prop_direction: Union[str, Tuple] = 'neutral',
    bidirectional=True,
    renew_activations=True,
    do_fast_forward_pass=False,
    rebuild_coders=True,
    l2_loss=None,
    splitstream=True,
    seed=1234,
):
    """
    Replicate the results of Scellier & Bengio:
        Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation
        https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full

    Specifically, the train_model demo here:
        https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop

    Differences between our code and theirs:
    - We do not keep persistent layer activations tied to data points over epochs.  So our results should only really match for the first epoch.
    - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None)
    """

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    rng = get_rng(seed)
    n_in = 784
    n_out = 10
    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    rng = get_rng(rng)

    y_train = y_train.astype(np.float32)

    ra = RunningAverage()
    sp = Speedometer(mode='last')
    is_online_checkpoint = Checkpoints(
        online_checkpoints_period, skip_first=skip_zero_epoch_test
    ) if online_checkpoints_period is not None else lambda: False
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    results = Duck()

    training_states = initialize_states(
        layer_constructor=layer_constructor,
        n_samples=minibatch_size,
        params=initialize_params(layer_sizes=layer_sizes,
                                 initial_weight_scale=initial_weight_scale,
                                 rng=rng))

    if isinstance(prop_direction, str):
        fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction
    else:
        fwd_prop_direction, backward_prop_direction = prop_direction

    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            y_pred_test, y_pred_train = [
                run_inference(
                    x_data=x[:n_test_samples],
                    states=initialize_states(
                        layer_constructor=layer_constructor,
                        params=[s.params for s in training_states],
                        n_samples=min(len(x), n_test_samples)
                        if n_test_samples is not None else len(x)),
                    n_steps=n_negative_steps,
                    prop_direction=fwd_prop_direction,
                ) for x in (x_test, x_train)
            ]
            # y_pred_train = run_inference(x_data=x_train[:n_test_samples], states=initialize_states(params=[s.params for s in training_states], n_samples=min(len(x_train), n_test_samples) if n_test_samples is not None else len(x_train)))
            test_error = percent_argmax_incorrect(y_pred_test,
                                                  y_test[:n_test_samples])
            train_error = percent_argmax_incorrect(y_pred_train,
                                                   y_train[:n_test_samples])
            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Mean Rate: {sp(i):.3g}iter/s'
            )
            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    train_error=train_error,
                                    test_error=test_error)
            yield results
            if epoch > 2 and train_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]
        training_states = run_eqprop_training_update(
            x_data=x_data_sample,
            y_data=y_data_sample,
            layer_states=training_states,
            beta=beta,
            random_flip_beta=random_flip_beta,
            learning_rate=learning_rate,
            layer_constructor=layer_constructor,
            bidirectional=bidirectional,
            l2_loss=l2_loss,
            renew_activations=renew_activations,
            n_negative_steps=n_negative_steps,
            n_positive_steps=n_positive_steps,
            prop_direction=prop_direction,
            splitstream=splitstream,
            rng=rng)
        this_train_score = ra(
            percent_argmax_correct(output_from_state(training_states),
                                   y_train[ixs]))
        if is_online_checkpoint():
            print(
                f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}'
            )
Exemplo n.º 4
0
def experiment_mnist_eqprop_torch(
    layer_constructor: Callable[[int, LayerParams], IDynamicLayer],
    n_epochs=10,
    hidden_sizes=(500, ),
    minibatch_size=10,  # update mini-batch size
    batch_size=500,  # total batch size
    beta=.5,
    random_flip_beta=True,
    learning_rate=.05,
    n_negative_steps=120,
    n_positive_steps=80,
    initial_weight_scale=1.,
    online_checkpoints_period=None,
    epoch_checkpoint_period=1.0,  #'100s', #{0: .25, 1: .5, 5: 1, 10: 2, 50: 4},
    skip_zero_epoch_test=False,
    n_test_samples=10000,
    prop_direction: Union[str, Tuple] = 'neutral',
    bidirectional=True,
    renew_activations=True,
    do_fast_forward_pass=False,
    rebuild_coders=True,
    l2_loss=None,
    splitstream=False,
    seed=1234,
    prediction_inp_size=17,  ## prediction input size
    delay=18,  ## delay size for the clamped phase
    pred=True,  ## if you want to use the prediction
    check_flg=False,
):
    """
    Replicate the results of Scellier & Bengio:
        Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation
        https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full

    Specifically, the train_model demo here:
        https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop

    Differences between our code and theirs:
    - We do not keep persistent layer activations tied to data points over epochs.  So our results should only really match for the first epoch.
    - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None)
    """
    torch.manual_seed(seed)
    device = 'cuda' if torch.cuda.is_available(
    ) and USE_CUDA_WHEN_AVAILABLE else 'cpu'
    if device == 'cuda':
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print(f'Using Device: {device}')

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = input_data.read_data_sets('MNIST_data', one_hot=True)

    x_train, y_train = torch.tensor(
        dataset.train.images, dtype=torch.float32
    ).to(device), torch.tensor(dataset.train.labels, dtype=torch.float32).to(
        device
    )  #(torch.tensor(a.astype(np.float32)).to(device) for a in dataset.mnist.train.images.xy)
    x_test, y_test = torch.tensor(
        dataset.test.images, dtype=torch.float32).to(device), torch.tensor(
            dataset.test.labels, dtype=torch.float32).to(
                device)  # Their 'validation set' is our 'test set'
    x_val, y_val = torch.tensor(
        dataset.validation.images,
        dtype=torch.float32).to(device), torch.tensor(
            dataset.validation.labels, dtype=torch.float32).to(
                device)  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test, x_val, y_val = x_train[:
                                                                 100], y_train[:
                                                                               100], x_test[:
                                                                                            100], y_test[:
                                                                                                         100], x_val[:
                                                                                                                     100], y_val[:
                                                                                                                                 100]
        n_epochs = 1
        n_negative_steps = 3
        n_positive_steps = 3

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    ra = RunningAverage()
    sp = Speedometer(mode='last')
    is_online_checkpoint = Checkpoints(
        online_checkpoints_period, skip_first=skip_zero_epoch_test
    ) if online_checkpoints_period is not None else lambda: False
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    training_states = initialize_states(
        layer_constructor=layer_constructor,
        #n_samples=minibatch_size,
        n_samples=batch_size,
        params=initialize_params(layer_sizes=layer_sizes,
                                 initial_weight_scale=initial_weight_scale,
                                 rng=rng))

    # dbplot(training_states[0].params.w_fore[:10, :10], str(rng.randint(265)))

    if isinstance(prop_direction, str):
        fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction
    else:
        fwd_prop_direction, backward_prop_direction = prop_direction

    def do_test():
        # n_samples = n_test_samples if n_test_samples is not None else len(x_test)
        test_error, train_error, val_error = [
            percent_argmax_incorrect(
                run_inference(
                    x_data=x[:n_test_samples],
                    states=initialize_states(
                        layer_constructor=layer_constructor,
                        params=[s.params for s in training_states],
                        n_samples=n_samples),
                    n_steps=n_negative_steps,
                    prop_direction=fwd_prop_direction,
                ), y[:n_samples]).item()
            for x, y in [(x_test, y_test), (x_train, y_train), (x_val, y_val)]
            for n_samples in [
                min(len(x), n_test_samples
                    ) if n_test_samples is not None else len(x)
            ]
        ]  # Not an actal loop... just hack for assignment in comprehensions
        print(
            f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Validation Error: {val_error:.3g}, Mean Rate: {sp(i):.3g}iter/s'
        )

        return dict(iter=i,
                    epoch=epoch,
                    train_error=train_error,
                    test_error=test_error,
                    val_error=val_error), train_error, test_error, val_error

    results = Duck()
    pi = ProgressIndicator(expected_iterations=n_epochs *
                           dataset.train.num_examples / minibatch_size,
                           update_every='10s')

    dy_squared = []
    dy_squared.append(None)
    dy_squared.append(None)
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.size()[0],
                                           minibatch_size=batch_size,
                                           n_epochs=n_epochs)):
        epoch = i * batch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            check_flg = False
            x_train, y_train = shuffle_data(x_train, y_train)
            with pi.pause_measurement():
                results[next, :], train_err, test_err, val_err = do_test()

                ## prepare for saving the parameters
                ws, bs = zip(*((s.params.w_aft, s.params.b)
                               for s in training_states[1:]))

                f = None
                if os.path.isfile(directory + '/log.txt'):
                    f = open(directory + '/log.txt', 'a')
                else:
                    os.mkdir(directory)
                    f = open(directory + '/log.txt', 'w')

                f.write("Epoch: " + str(epoch) + '\n')
                f.write("accuracy for training: " + str(train_err) + '\n')
                f.write("accuracy for testing: " + str(test_err) + '\n')
                f.write("accuracy for validation: " + str(val_err) + '\n')

                f.close()

                np.save(directory + '/w_epoch_' + str(epoch) + '.npy', ws)
                np.save(directory + '/b_epoch_' + str(epoch) + '.npy', bs)
                np.save(directory + '/dy_squared_epoch_' + str(epoch) + '.npy',
                        dy_squared)

                yield results
                if epoch > 100 and results[-1, 'train_error'] > 50:
                    return

        # The Original training loop, just taken out here:
        ixs = ixs.astype(np.int32)  # this is for python version 3.7

        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        training_states, dy_squared = run_eqprop_training_update(
            x_data=x_data_sample,
            y_data=y_data_sample,
            layer_states=training_states,
            beta=beta,
            random_flip_beta=random_flip_beta,
            learning_rate=learning_rate,
            layer_constructor=layer_constructor,
            bidirectional=bidirectional,
            l2_loss=l2_loss,
            renew_activations=renew_activations,
            n_negative_steps=n_negative_steps,
            n_positive_steps=n_positive_steps,
            prop_direction=prop_direction,
            splitstream=splitstream,
            rng=rng,
            prediction_inp_size=prediction_inp_size,
            delay=delay,
            device=device,
            epoch_check=check_flg,
            epoch=epoch,
            pred=pred,
            batch_size=batch_size,
            minibatch_size=minibatch_size,
            dy_squared=dy_squared)
        check_flg = False

        this_train_score = ra(
            percent_argmax_incorrect(output_from_state(training_states),
                                     y_train[ixs]))
        if is_online_checkpoint():
            print(
                f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}'
            )

        pi.print_update(info=f'Epoch: {epoch}')

    results[next, :], train_err, test_err, val_err = do_test()
    yield results
def demo_energy_based_initialize_eq_prop_fwd_energy(
    n_epochs=25,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=0.5,
    epsilon=0.5,
    learning_rate=(0.1, .05),
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    forward_deviation_cost=0.,
    zero_deviation_cost=1,
    epoch_checkpoint_period={
        0: .25,
        1: .5,
        5: 1,
        10: 2,
        50: 4
    },
    n_test_samples=10000,
    skip_zero_epoch_test=False,
    train_with_forward='contrast',
    forward_nonlinearity='rho(x)',
    local_loss=True,
    random_flip_beta=True,
    seed=1234,
):

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    assert train_with_forward in ('contrast', 'contrast+', 'energy', False)

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    eq_params = initialize_params(layer_sizes=layer_sizes,
                                  initial_weight_scale=initial_weight_scale,
                                  rng=rng)
    forward_params = initialize_params(
        layer_sizes=layer_sizes,
        initial_weight_scale=initial_weight_scale,
        rng=rng)

    y_train = y_train.astype(np.float32)

    sp = Speedometer(mode='last')
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    f_negative_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_inference_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_positive_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_parameter_update = update_eq_params.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_forward_pass = forward_pass.partial(
        nonlinearity=forward_nonlinearity).compile()
    # f_forward_parameter_update = update_forward_params_with_energy.compile()
    f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial(
        nonlinearity=forward_nonlinearity).compile()

    def do_inference(forward_params_, eq_params_, x, n_steps):
        states_ = forward_states_ = f_forward_pass(
            x=x, params=forward_params_
        ) if train_with_forward else initialize_states(
            n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:])
        for _ in range(n_steps):
            states_ = f_inference_eq_step(params=eq_params_,
                                          states=states_,
                                          fwd_states=forward_states_,
                                          x=x,
                                          epsilon=epsilon)
        return forward_states_[-1], states_[-1]

    results = Duck()
    # last_time, last_epoch = time(), -1
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        # print(f'Training Rate: {(time()-last_time)/(epoch-last_epoch):3g}s/ep')
        # last_time, last_epoch = time(), epoch

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            (test_init_error,
             test_neg_error), (train_init_error, train_neg_error) = [[
                 percent_argmax_incorrect(prediction, y[:n_test_samples])
                 for prediction in do_inference(forward_params_=forward_params,
                                                eq_params_=eq_params,
                                                x=x[:n_test_samples],
                                                n_steps=n_negative_steps)
             ] for x, y in [(x_test, y_test), (x_train, y_train)]]
            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s'
            )
            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    test_init_error=test_init_error,
                                    test_neg_error=test_neg_error,
                                    train_init_error=train_init_error,
                                    train_neg_error=train_neg_error)
            yield results
            if epoch > 2 and train_neg_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        states = forward_states = f_forward_pass(
            x=x_data_sample, params=forward_params
        ) if train_with_forward else initialize_states(
            n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:])
        for t in range(n_negative_steps):
            states = f_negative_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        epsilon=epsilon,
                                        fwd_states=forward_states)
        negative_states = states
        this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta
        for t in range(n_positive_steps):
            states = f_positive_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        y=y_data_sample,
                                        y_pressure=this_beta,
                                        epsilon=epsilon,
                                        fwd_states=forward_states)
        positive_states = states
        eq_params = f_parameter_update(x=x_data_sample,
                                       params=eq_params,
                                       negative_states=negative_states,
                                       positive_states=positive_states,
                                       fwd_states=forward_states,
                                       learning_rates=learning_rate,
                                       beta=this_beta)

        if train_with_forward == 'contrast':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=negative_states,
                learning_rates=learning_rate)
            # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate])
        elif train_with_forward == 'contrast+':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=positive_states,
                learning_rates=learning_rate)
        # elif train_with_forward == 'energy':
        #     forward_params = f_forward_parameter_update(x=x_data_sample, forward_params = forward_params, eq_params=eq_params, learning_rates=learning_rate)
        else:
            assert train_with_forward is False
Exemplo n.º 6
0
def demo_energy_based_initialize_eq_prop_alignment(
    n_epochs=25,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=0.5,
    epsilon=0.5,
    learning_rate=(0.1, .05),
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    epoch_checkpoint_period={
        0: .25,
        1: .5,
        5: 1,
        10: 2,
        50: 4
    },
    n_test_samples=10000,
    skip_zero_epoch_test=False,
    train_with_forward='contrast',
    forward_nonlinearity='rho(x)',
    local_loss=True,
    random_flip_beta=True,
    seed=1234,
):

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    assert train_with_forward in ('contrast', 'contrast+', 'energy', False)

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    eq_params = initialize_params(layer_sizes=layer_sizes,
                                  initial_weight_scale=initial_weight_scale,
                                  rng=rng)
    forward_params = initialize_params(
        layer_sizes=layer_sizes,
        initial_weight_scale=initial_weight_scale,
        rng=rng)

    y_train = y_train.astype(np.float32)

    sp = Speedometer(mode='last')
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    f_negative_eq_step = equilibriating_step.compile()
    f_inference_eq_step = equilibriating_step.compile()
    f_positive_eq_step = equilibriating_step.compile()
    f_parameter_update = update_eq_params.compile()
    f_forward_pass = forward_pass.partial(
        nonlinearity=forward_nonlinearity).compile()
    f_forward_parameter_update = update_forward_params_with_energy.partial(
        disconnect_grads=local_loss,
        nonlinearity=forward_nonlinearity).compile()
    f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial(
        disconnect_grads=local_loss,
        nonlinearity=forward_nonlinearity).compile()
    f_energy = energy.compile()
    f_grad_align = compute_gradient_alignment.partial(
        nonlinearity=forward_nonlinearity).compile()

    def do_inference(forward_params_, eq_params_, x, n_steps):
        states_ = forward_states_ = f_forward_pass(
            x=x, params=forward_params_
        ) if train_with_forward else initialize_states(
            n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:])
        for _ in range(n_steps):
            states_ = f_inference_eq_step(params=eq_params_,
                                          states=states_,
                                          x=x,
                                          epsilon=epsilon)
        return forward_states_[-1], states_[-1]

    results = Duck()
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            (test_init_error,
             test_neg_error), (train_init_error, train_neg_error) = [[
                 percent_argmax_incorrect(prediction, y[:n_test_samples])
                 for prediction in do_inference(forward_params_=forward_params,
                                                eq_params_=eq_params,
                                                x=x[:n_test_samples],
                                                n_steps=n_negative_steps)
             ] for x, y in [(x_test, y_test), (x_train, y_train)]]

            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s'
            )

            # ===== Compute alignment
            alignment_neg_states = forward_states_ = f_forward_pass(
                x=x_train, params=forward_params
            ) if train_with_forward else initialize_states(
                n_samples=x_train.shape[0],
                noninput_layer_sizes=layer_sizes[1:])
            for _ in range(n_negative_steps):
                alignment_neg_states = f_inference_eq_step(
                    params=eq_params,
                    states=alignment_neg_states,
                    x=x_train,
                    epsilon=epsilon)

            grad_alignments = f_grad_align(forward_params=forward_params,
                                           eq_states=alignment_neg_states,
                                           x=x_train)
            dbplot(grad_alignments,
                   'alignments',
                   plot_type=DBPlotTypes.LINE_HISTORY)
            # --------------------

            # fwd_states, neg_states = do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps)

            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    test_init_error=test_init_error,
                                    test_neg_error=test_neg_error,
                                    train_init_error=train_init_error,
                                    train_neg_error=train_neg_error,
                                    alignments=grad_alignments)
            yield results
            if epoch > 2 and train_neg_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        states = forward_states = f_forward_pass(
            x=x_data_sample, params=forward_params
        ) if train_with_forward else initialize_states(
            n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:])
        for t in range(n_negative_steps):
            # if i % 200 == 0:
            #     with hold_dbplots():
            #         dbplot_collection(states, 'states', cornertext='NEG')
            #         dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)
            states = f_negative_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        epsilon=epsilon)
        negative_states = states
        this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta
        for t in range(n_positive_steps):
            # if i % 200 == 0:
            #     with hold_dbplots():
            #         dbplot_collection(states, 'states', cornertext='')
            #         dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)
            states = f_positive_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        y=y_data_sample,
                                        y_pressure=this_beta,
                                        epsilon=epsilon)
        positive_states = states
        eq_params = f_parameter_update(x=x_data_sample,
                                       params=eq_params,
                                       negative_states=negative_states,
                                       positive_states=positive_states,
                                       learning_rates=learning_rate,
                                       beta=this_beta)

        # with hold_dbplots(draw_every=50):
        #     dbplot_collection([forward_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in forward_params[1:]], '$\phi$')
        #     dbplot_collection([eq_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in eq_params[1:]], '$\\theta$')
        #     dbplot_collection(forward_states, 'forward states')
        #     dbplot_collection(negative_states, 'negative_states')
        #     dbplot(np.array([f_energy(params = eq_params, states=forward_states, x=x_data_sample).mean(), f_energy(params = eq_params, states=negative_states, x=x_data_sample).mean()]), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)

        if train_with_forward == 'contrast':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=negative_states,
                learning_rates=learning_rate)
            # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate])
        elif train_with_forward == 'contrast+':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=positive_states,
                learning_rates=learning_rate)
        elif train_with_forward == 'energy':
            forward_params = f_forward_parameter_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_params=eq_params,
                learning_rates=learning_rate)
        else:
            assert train_with_forward is False