Пример #1
0
def demo_temporal_mnist(n_samples = None, smoothing_steps = 200):
    _, _, original_data, original_labels = get_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples).xyxy
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy
    for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels):
        with hold_dbplots():
            dbplot(ox, 'sample', title = str(oy))
            dbplot(tx, 'smooth', title = str(ty))
Пример #2
0
def demo_mnist_logreg(minibatch_size=20, learning_rate=0.01, max_training_samples = None, n_epochs=10, test_epoch_period=0.2):
    """
    Train a Logistic Regressor on the MNIST dataset, report training/test scores throughout training, and return the
    final scores.
    """

    x_train, y_train, x_test, y_test = get_mnist_dataset(flat=True, n_training_samples=max_training_samples).xyxy

    predictor = OnlineLogisticRegressor(n_in=784, n_out=10, learning_rate=learning_rate)

    # Train and periodically record scores.
    epoch_scores = []
    for ix, iteration_info in minibatch_index_info_generator(n_samples = len(x_train), minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_epoch_period)):
        if iteration_info.test_now:
            training_error = 100*(np.argmax(predictor.predict(x_train), axis=1)==y_train).mean()
            test_error = 100*(np.argmax(predictor.predict(x_test), axis=1)==y_test).mean()
            print('Epoch {epoch}: Test Error: {test}%, Training Error: {train}%'.format(epoch=iteration_info.epoch, test=test_error, train=training_error))
            epoch_scores.append((iteration_info.epoch, training_error, test_error))
        predictor.train(x_train[ix], y_train[ix])

    # Plot
    plt.figure()
    epochs, training_costs, test_costs = zip(*epoch_scores)
    plt.plot(epochs, np.array([training_costs, test_costs]).T)
    plt.xlabel('Epoch')
    plt.ylabel('% Error')
    plt.legend(['Training Error', 'Test Error'])
    plt.title("Learning Curve")
    plt.ylim(80, 100)
    plt.ion()  # Don't hang on plot
    plt.show()

    return {'train': epoch_scores[-1][1], 'test': epoch_scores[-1][2]}  # Return final scores
Пример #3
0
def demo_mnist_logreg(minibatch_size=20, learning_rate=0.01, max_training_samples = None, n_epochs=10, test_epoch_period=0.2):
    """
    Train a Logistic Regressor on the MNIST dataset, report training/test scores throughout training, and return the
    final scores.
    """

    x_train, y_train, x_test, y_test = get_mnist_dataset(flat=True, n_training_samples=max_training_samples).xyxy

    predictor = OnlineLogisticRegressor(n_in=784, n_out=10, learning_rate=learning_rate)

    # Train and periodically record scores.
    epoch_scores = []
    for ix, iteration_info in minibatch_index_info_generator(n_samples = len(x_train), minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_epoch_period)):
        if iteration_info.test_now:
            training_error = 100*(np.argmax(predictor.predict(x_train), axis=1)==y_train).mean()
            test_error = 100*(np.argmax(predictor.predict(x_test), axis=1)==y_test).mean()
            print('Epoch {epoch}: Test Error: {test}%, Training Error: {train}%'.format(epoch=iteration_info.epoch, test=test_error, train=training_error))
            epoch_scores.append((iteration_info.epoch, training_error, test_error))
        predictor.train(x_train[ix], y_train[ix])

    # Plot
    plt.figure()
    epochs, training_costs, test_costs = zip(*epoch_scores)
    plt.plot(epochs, np.array([training_costs, test_costs]).T)
    plt.xlabel('Epoch')
    plt.ylabel('% Error')
    plt.legend(['Training Error', 'Test Error'])
    plt.title("Learning Curve")
    plt.ylim(80, 100)
    plt.ion()  # Don't hang on plot
    plt.show()

    return {'train': epoch_scores[-1][1], 'test': epoch_scores[-1][2]}  # Return final scores
Пример #4
0
def get_temporal_mnist_dataset(smoothing_steps=1000, **mnist_kwargs):

    tr_x, tr_y, ts_x, ts_y = get_mnist_dataset(**mnist_kwargs).xyxy
    tr_ixs = temporalize(tr_x, smoothing_steps=smoothing_steps)
    ts_ixs = temporalize(ts_x, smoothing_steps=smoothing_steps)
    return DataSet.from_xyxy(tr_x[tr_ixs], tr_y[tr_ixs], ts_x[ts_ixs],
                             ts_y[ts_ixs])
Пример #5
0
def demo_temporal_mnist(n_samples=None, smoothing_steps=200):
    _, _, original_data, original_labels = get_mnist_dataset(
        n_training_samples=n_samples, n_test_samples=n_samples).xyxy
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(
        n_training_samples=n_samples,
        n_test_samples=n_samples,
        smoothing_steps=smoothing_steps).xyxy
    for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data,
                              temporal_labels):
        with hold_dbplots():
            dbplot(ox, 'sample', title=str(oy))
            dbplot(tx, 'smooth', title=str(ty))
Пример #6
0
def train_conventional_mlp_on_mnist(hidden_sizes,
                                    n_epochs=50,
                                    w_init='xavier-both',
                                    minibatch_size=20,
                                    rng=1234,
                                    optimizer='sgd',
                                    hidden_activations='relu',
                                    output_activation='softmax',
                                    learning_rate=0.01,
                                    cost_function='nll',
                                    use_bias=True,
                                    l1_loss=0,
                                    l2_loss=0,
                                    test_on='training+test'):

    dataset = get_mnist_dataset(flat=True)\

    if output_activation != 'softmax':
        dataset = dataset.to_onehot()

    all_layer_sizes = [dataset.input_size
                       ] + hidden_sizes + [dataset.n_categories]
    weights = initialize_network_params(layer_sizes=all_layer_sizes,
                                        mag=w_init,
                                        base_dist='normal',
                                        include_biases=False,
                                        rng=rng)
    net = MultiLayerPerceptron(weights=weights,
                               hidden_activation=hidden_activations,
                               output_activation=output_activation,
                               use_bias=use_bias)
    predictor = GradientBasedPredictor(
        function=net,
        cost_function=get_named_cost_function(cost_function),
        optimizer=get_named_optimizer(optimizer, learning_rate=learning_rate),
        regularization_cost=lambda params: sum(l1_loss * abs(p_).sum(
        ) + l2_loss * (p_**2).sum() if p_.ndim == 2 else 0
                                               for p_ in params)).compile()
    assess_online_predictor(predictor=predictor,
                            dataset=dataset,
                            evaluation_function='percent_argmax_correct',
                            test_epochs=range(0, n_epochs, 1),
                            test_on=test_on,
                            minibatch_size=minibatch_size)
    ws = [p.get_value() for p in net.parameters]
    return ws
def experiment_mnist_eqprop(
    layer_constructor,
    n_epochs=10,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=.5,
    random_flip_beta=True,
    learning_rate=.05,
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    online_checkpoints_period=None,
    epoch_checkpoint_period=.25,
    skip_zero_epoch_test=False,
    n_test_samples=None,
    prop_direction: Union[str, Tuple] = 'neutral',
    bidirectional=True,
    renew_activations=True,
    do_fast_forward_pass=False,
    rebuild_coders=True,
    l2_loss=None,
    splitstream=True,
    seed=1234,
):
    """
    Replicate the results of Scellier & Bengio:
        Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation
        https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full

    Specifically, the train_model demo here:
        https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop

    Differences between our code and theirs:
    - We do not keep persistent layer activations tied to data points over epochs.  So our results should only really match for the first epoch.
    - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None)
    """

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    rng = get_rng(seed)
    n_in = 784
    n_out = 10
    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    rng = get_rng(rng)

    y_train = y_train.astype(np.float32)

    ra = RunningAverage()
    sp = Speedometer(mode='last')
    is_online_checkpoint = Checkpoints(
        online_checkpoints_period, skip_first=skip_zero_epoch_test
    ) if online_checkpoints_period is not None else lambda: False
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    results = Duck()

    training_states = initialize_states(
        layer_constructor=layer_constructor,
        n_samples=minibatch_size,
        params=initialize_params(layer_sizes=layer_sizes,
                                 initial_weight_scale=initial_weight_scale,
                                 rng=rng))

    if isinstance(prop_direction, str):
        fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction
    else:
        fwd_prop_direction, backward_prop_direction = prop_direction

    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            y_pred_test, y_pred_train = [
                run_inference(
                    x_data=x[:n_test_samples],
                    states=initialize_states(
                        layer_constructor=layer_constructor,
                        params=[s.params for s in training_states],
                        n_samples=min(len(x), n_test_samples)
                        if n_test_samples is not None else len(x)),
                    n_steps=n_negative_steps,
                    prop_direction=fwd_prop_direction,
                ) for x in (x_test, x_train)
            ]
            # y_pred_train = run_inference(x_data=x_train[:n_test_samples], states=initialize_states(params=[s.params for s in training_states], n_samples=min(len(x_train), n_test_samples) if n_test_samples is not None else len(x_train)))
            test_error = percent_argmax_incorrect(y_pred_test,
                                                  y_test[:n_test_samples])
            train_error = percent_argmax_incorrect(y_pred_train,
                                                   y_train[:n_test_samples])
            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Mean Rate: {sp(i):.3g}iter/s'
            )
            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    train_error=train_error,
                                    test_error=test_error)
            yield results
            if epoch > 2 and train_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]
        training_states = run_eqprop_training_update(
            x_data=x_data_sample,
            y_data=y_data_sample,
            layer_states=training_states,
            beta=beta,
            random_flip_beta=random_flip_beta,
            learning_rate=learning_rate,
            layer_constructor=layer_constructor,
            bidirectional=bidirectional,
            l2_loss=l2_loss,
            renew_activations=renew_activations,
            n_negative_steps=n_negative_steps,
            n_positive_steps=n_positive_steps,
            prop_direction=prop_direction,
            splitstream=splitstream,
            rng=rng)
        this_train_score = ra(
            percent_argmax_correct(output_from_state(training_states),
                                   y_train[ixs]))
        if is_online_checkpoint():
            print(
                f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}'
            )
Пример #8
0
def demo_herding_network(kp=.1,
                         kd=1.,
                         kp_back=None,
                         kd_back=None,
                         hidden_sizes=[
                             200,
                         ],
                         n_epochs=50,
                         onehot=False,
                         parallel=False,
                         learning_rate=0.01,
                         dataset='mnist',
                         hidden_activation='relu',
                         adaptive=True,
                         adaptation_rate=0.001,
                         output_activation='softmax',
                         loss='nll',
                         fwd_quantizer='herd',
                         back_quantizer='same',
                         minibatch_size=1,
                         swap_mlp=False,
                         plot=False,
                         test_period=.5,
                         grad_calc='true',
                         rng=1234):

    dataset = get_mnist_dataset(
        flat=True, join_train_and_val=True
    ) if dataset == 'mnist' else get_temporal_mnist_dataset(
        flat=True, join_train_and_val=True)
    if onehot:
        dataset = dataset.to_onehot()
    ws = initialize_network_params(layer_sizes=[28 * 28] + hidden_sizes + [10],
                                   mag='xavier-both',
                                   include_biases=False,
                                   rng=rng)

    if is_test_mode():
        dataset = dataset.shorten(500)
        n_epochs = 0.1
        test_period = 0.03

    if kp_back is None:
        kp_back = kp
    if kd_back is None:
        kd_back = kd
    if back_quantizer == 'same':
        back_quantizer = fwd_quantizer

    if adaptive:
        encdec = lambda: PDAdaptiveEncoderDecoder(kp=kp,
                                                  kd=kd,
                                                  adaptation_rate=
                                                  adaptation_rate,
                                                  quantization=fwd_quantizer)
        encdec_back = lambda: PDAdaptiveEncoderDecoder(
            kp=kp_back,
            kd=kd_back,
            adaptation_rate=adaptation_rate,
            quantization=back_quantizer)
    else:
        encdec = PDEncoderDecoder(kp=kp, kd=kd, quantization=fwd_quantizer)
        encdec_back = PDEncoderDecoder(kp=kp_back,
                                       kd=kd_back,
                                       quantization=back_quantizer)

    if swap_mlp:
        if not parallel:
            assert minibatch_size == 1, "Unfair comparison otherwise, sorry buddy, can't let you do that."
        net = GradientBasedPredictor(
            function=MultiLayerPerceptron.from_weights(
                weights=ws,
                hidden_activations=hidden_activation,
                output_activation=output_activation,
            ),
            cost_function=loss,
            optimizer=GradientDescent(learning_rate),
        )
        prediction_funcs = net.predict.compile()
    else:
        net = PDHerdingNetwork(
            ws=ws,
            encdec=encdec,
            encdec_back=encdec_back,
            hidden_activation=hidden_activation,
            output_activation=output_activation,
            optimizer=GradientDescent(learning_rate),
            minibatch_size=minibatch_size if parallel else 1,
            grad_calc=grad_calc,
            loss=loss)
        noise_free_forward_pass = MultiLayerPerceptron.from_weights(
            weights=[layer.w for layer in net.layers],
            biases=[layer.b for layer in net.layers],
            hidden_activations=hidden_activation,
            output_activation=output_activation).compile()
        prediction_funcs = [('noise_free', noise_free_forward_pass),
                            ('herded', net.predict.compile())]

    op_count_info = []

    def test_callback(info, score):
        if plot:
            dbplot(net.layers[0].w.get_value().T.reshape(-1, 28, 28),
                   'w0',
                   cornertext='Epoch {}'.format(info.epoch))
        if swap_mlp:
            all_layer_sizes = [dataset.input_size
                               ] + hidden_sizes + [dataset.target_size]
            fwd_ops = [
                info.sample * d1 * d2
                for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:])
            ]
            back_ops = [
                info.sample * d1 * d2
                for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:])
            ]
            update_ops = [
                info.sample * d1 * d2
                for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:])
            ]
        else:
            fwd_ops = [
                layer_.fwd_op_count.get_value() for layer_ in net.layers
            ]
            back_ops = [
                layer_.back_op_count.get_value() for layer_ in net.layers
            ]
            update_ops = [
                layer_.update_op_count.get_value() for layer_ in net.layers
            ]
        if info.epoch != 0:
            with IndentPrint('Mean Ops by epoch {}'.format(info.epoch)):
                print 'Fwd: {}'.format([
                    si_format(ops / info.epoch,
                              format_str='{value} {prefix}Ops')
                    for ops in fwd_ops
                ])
                print 'Back: {}'.format([
                    si_format(ops / info.epoch,
                              format_str='{value} {prefix}Ops')
                    for ops in back_ops
                ])
                print 'Update: {}'.format([
                    si_format(ops / info.epoch,
                              format_str='{value} {prefix}Ops')
                    for ops in update_ops
                ])
        if info.epoch > max(
                0.5, 2 * test_period) and not swap_mlp and score.get_score(
                    'train', 'noise_free') < 20:
            raise Exception("This horse ain't goin' nowhere.")

        op_count_info.append((info, (fwd_ops, back_ops, update_ops)))

    info_score_pairs = train_and_test_online_predictor(
        dataset=dataset,
        train_fcn=net.train.compile(),
        predict_fcn=prediction_funcs,
        minibatch_size=minibatch_size,
        n_epochs=n_epochs,
        test_epochs=('every', test_period),
        score_measure='percent_argmax_correct',
        test_on='training+test',
        test_callback=test_callback)
    return info_score_pairs, op_count_info
Пример #9
0
def get_mnist_results_with_parameters(weights,
                                      biases,
                                      scales=None,
                                      hidden_activations='relu',
                                      output_activation='softmax',
                                      n_samples=None,
                                      smoothing_steps=1000):
    """
    Return a data structure showing the error and computation for required by the orignal, rounding, and sigma-delta
    implementation of a network with the given parameters.

    :param weights:
    :param biases:
    :param scales:
    :param hidden_activations:
    :param output_activation:
    :param n_samples:
    :param smoothing_steps:
    :return: results: An OrderedDict
        Where the key is a 3-tuple are:
            (dataset_name, subset, net_version), Where:
                dataset_name: is 'mnist' or 'temp_mnist'
                subset: is 'train' or 'test'
                net_version: is 'td' or 'round' or 'truth'
        And values are another OrderedDict, with keys:
            'MFlops', 'l1_errorm', 'class_error'  ... for discrete nets and
            'Dense MFlops', 'Sparse MFlops', 'class_error' for "true" nets.
    """
    mnist = get_mnist_dataset(flat=True,
                              n_training_samples=n_samples,
                              n_test_samples=n_samples)
    temp_mnist = get_temporal_mnist_dataset(flat=True,
                                            smoothing_steps=smoothing_steps,
                                            n_training_samples=n_samples,
                                            n_test_samples=n_samples)
    results = OrderedDict()
    p = ProgressIndicator(2 * 3 * 2)
    for dataset_name, (tr_x, tr_y, ts_x, ts_y) in [('mnist', mnist.xyxy),
                                                   ('temp_mnist',
                                                    temp_mnist.xyxy)]:
        for subset, x, y in [('train', tr_x, tr_y), ('test', ts_x, ts_y)]:
            traditional_net_output, dense_flops, sparse_flops = forward_pass_and_cost(
                input_data=x,
                weights=weights,
                biases=biases,
                hidden_activations=hidden_activations,
                output_activations=output_activation)
            assert round(dense_flops) == dense_flops and round(
                sparse_flops) == sparse_flops, 'Flop counts must be int!'

            class_error = percent_argmax_incorrect(traditional_net_output, y)
            results[dataset_name, subset, 'truth'] = OrderedDict([
                ('Dense MFlops', dense_flops / (1e6 * len(x))),
                ('Sparse MFlops', sparse_flops / (1e6 * len(x))),
                ('class_error', class_error)
            ])
            for net_version in 'td', 'round':
                (comp_cost_adds, comp_cost_multiplyadds
                 ), output = tdnet_forward_pass_cost_and_output(
                     inputs=x,
                     weights=weights,
                     biases=biases,
                     scales=scales,
                     version=net_version,
                     hidden_activations=hidden_activations,
                     output_activations=output_activation,
                     quantization_method='herd',
                     computation_calc=('adds', 'multiplyadds'))
                l1_error = np.abs(output - traditional_net_output).sum(
                    axis=1).mean(axis=0)
                class_error = percent_argmax_incorrect(output, y)
                results[dataset_name, subset, net_version] = OrderedDict([
                    ('MFlops', comp_cost_adds / (1e6 * len(x))),
                    ('MFlops-multadd',
                     comp_cost_multiplyadds / (1e6 * len(x))),
                    ('l1_error', l1_error), ('class_error', class_error)
                ])
                p.print_update()
    return results
def demo_energy_based_initialize_eq_prop_fwd_energy(
    n_epochs=25,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=0.5,
    epsilon=0.5,
    learning_rate=(0.1, .05),
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    forward_deviation_cost=0.,
    zero_deviation_cost=1,
    epoch_checkpoint_period={
        0: .25,
        1: .5,
        5: 1,
        10: 2,
        50: 4
    },
    n_test_samples=10000,
    skip_zero_epoch_test=False,
    train_with_forward='contrast',
    forward_nonlinearity='rho(x)',
    local_loss=True,
    random_flip_beta=True,
    seed=1234,
):

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    assert train_with_forward in ('contrast', 'contrast+', 'energy', False)

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    eq_params = initialize_params(layer_sizes=layer_sizes,
                                  initial_weight_scale=initial_weight_scale,
                                  rng=rng)
    forward_params = initialize_params(
        layer_sizes=layer_sizes,
        initial_weight_scale=initial_weight_scale,
        rng=rng)

    y_train = y_train.astype(np.float32)

    sp = Speedometer(mode='last')
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    f_negative_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_inference_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_positive_eq_step = equilibriating_step.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_parameter_update = update_eq_params.partial(
        forward_deviation_cost=forward_deviation_cost,
        zero_deviation_cost=zero_deviation_cost).compile()
    f_forward_pass = forward_pass.partial(
        nonlinearity=forward_nonlinearity).compile()
    # f_forward_parameter_update = update_forward_params_with_energy.compile()
    f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial(
        nonlinearity=forward_nonlinearity).compile()

    def do_inference(forward_params_, eq_params_, x, n_steps):
        states_ = forward_states_ = f_forward_pass(
            x=x, params=forward_params_
        ) if train_with_forward else initialize_states(
            n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:])
        for _ in range(n_steps):
            states_ = f_inference_eq_step(params=eq_params_,
                                          states=states_,
                                          fwd_states=forward_states_,
                                          x=x,
                                          epsilon=epsilon)
        return forward_states_[-1], states_[-1]

    results = Duck()
    # last_time, last_epoch = time(), -1
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        # print(f'Training Rate: {(time()-last_time)/(epoch-last_epoch):3g}s/ep')
        # last_time, last_epoch = time(), epoch

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            (test_init_error,
             test_neg_error), (train_init_error, train_neg_error) = [[
                 percent_argmax_incorrect(prediction, y[:n_test_samples])
                 for prediction in do_inference(forward_params_=forward_params,
                                                eq_params_=eq_params,
                                                x=x[:n_test_samples],
                                                n_steps=n_negative_steps)
             ] for x, y in [(x_test, y_test), (x_train, y_train)]]
            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s'
            )
            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    test_init_error=test_init_error,
                                    test_neg_error=test_neg_error,
                                    train_init_error=train_init_error,
                                    train_neg_error=train_neg_error)
            yield results
            if epoch > 2 and train_neg_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        states = forward_states = f_forward_pass(
            x=x_data_sample, params=forward_params
        ) if train_with_forward else initialize_states(
            n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:])
        for t in range(n_negative_steps):
            states = f_negative_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        epsilon=epsilon,
                                        fwd_states=forward_states)
        negative_states = states
        this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta
        for t in range(n_positive_steps):
            states = f_positive_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        y=y_data_sample,
                                        y_pressure=this_beta,
                                        epsilon=epsilon,
                                        fwd_states=forward_states)
        positive_states = states
        eq_params = f_parameter_update(x=x_data_sample,
                                       params=eq_params,
                                       negative_states=negative_states,
                                       positive_states=positive_states,
                                       fwd_states=forward_states,
                                       learning_rates=learning_rate,
                                       beta=this_beta)

        if train_with_forward == 'contrast':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=negative_states,
                learning_rates=learning_rate)
            # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate])
        elif train_with_forward == 'contrast+':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=positive_states,
                learning_rates=learning_rate)
        # elif train_with_forward == 'energy':
        #     forward_params = f_forward_parameter_update(x=x_data_sample, forward_params = forward_params, eq_params=eq_params, learning_rates=learning_rate)
        else:
            assert train_with_forward is False
Пример #11
0
def demo_energy_based_initialize_eq_prop_alignment(
    n_epochs=25,
    hidden_sizes=(500, ),
    minibatch_size=20,
    beta=0.5,
    epsilon=0.5,
    learning_rate=(0.1, .05),
    n_negative_steps=20,
    n_positive_steps=4,
    initial_weight_scale=1.,
    epoch_checkpoint_period={
        0: .25,
        1: .5,
        5: 1,
        10: 2,
        50: 4
    },
    n_test_samples=10000,
    skip_zero_epoch_test=False,
    train_with_forward='contrast',
    forward_nonlinearity='rho(x)',
    local_loss=True,
    random_flip_beta=True,
    seed=1234,
):

    print('Params:\n' +
          '\n'.join(list(f'  {k} = {v}' for k, v in locals().items())))

    assert train_with_forward in ('contrast', 'contrast+', 'energy', False)

    rng = get_rng(seed)
    n_in = 784
    n_out = 10

    dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot()
    x_train, y_train = dataset.training_set.xy
    x_test, y_test = dataset.test_set.xy  # Their 'validation set' is our 'test set'

    if is_test_mode():
        x_train, y_train, x_test, y_test = x_train[:
                                                   100], y_train[:
                                                                 100], x_test[:
                                                                              100], y_test[:
                                                                                           100]
        n_epochs = 1

    layer_sizes = [n_in] + list(hidden_sizes) + [n_out]

    eq_params = initialize_params(layer_sizes=layer_sizes,
                                  initial_weight_scale=initial_weight_scale,
                                  rng=rng)
    forward_params = initialize_params(
        layer_sizes=layer_sizes,
        initial_weight_scale=initial_weight_scale,
        rng=rng)

    y_train = y_train.astype(np.float32)

    sp = Speedometer(mode='last')
    is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period,
                                      skip_first=skip_zero_epoch_test)

    f_negative_eq_step = equilibriating_step.compile()
    f_inference_eq_step = equilibriating_step.compile()
    f_positive_eq_step = equilibriating_step.compile()
    f_parameter_update = update_eq_params.compile()
    f_forward_pass = forward_pass.partial(
        nonlinearity=forward_nonlinearity).compile()
    f_forward_parameter_update = update_forward_params_with_energy.partial(
        disconnect_grads=local_loss,
        nonlinearity=forward_nonlinearity).compile()
    f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial(
        disconnect_grads=local_loss,
        nonlinearity=forward_nonlinearity).compile()
    f_energy = energy.compile()
    f_grad_align = compute_gradient_alignment.partial(
        nonlinearity=forward_nonlinearity).compile()

    def do_inference(forward_params_, eq_params_, x, n_steps):
        states_ = forward_states_ = f_forward_pass(
            x=x, params=forward_params_
        ) if train_with_forward else initialize_states(
            n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:])
        for _ in range(n_steps):
            states_ = f_inference_eq_step(params=eq_params_,
                                          states=states_,
                                          x=x,
                                          epsilon=epsilon)
        return forward_states_[-1], states_[-1]

    results = Duck()
    for i, (ixs, info) in enumerate(
            minibatch_index_info_generator(n_samples=x_train.shape[0],
                                           minibatch_size=minibatch_size,
                                           n_epochs=n_epochs)):
        epoch = i * minibatch_size / x_train.shape[0]

        if is_epoch_checkpoint(epoch):
            n_samples = n_test_samples if n_test_samples is not None else len(
                x_test)
            (test_init_error,
             test_neg_error), (train_init_error, train_neg_error) = [[
                 percent_argmax_incorrect(prediction, y[:n_test_samples])
                 for prediction in do_inference(forward_params_=forward_params,
                                                eq_params_=eq_params,
                                                x=x[:n_test_samples],
                                                n_steps=n_negative_steps)
             ] for x, y in [(x_test, y_test), (x_train, y_train)]]

            print(
                f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s'
            )

            # ===== Compute alignment
            alignment_neg_states = forward_states_ = f_forward_pass(
                x=x_train, params=forward_params
            ) if train_with_forward else initialize_states(
                n_samples=x_train.shape[0],
                noninput_layer_sizes=layer_sizes[1:])
            for _ in range(n_negative_steps):
                alignment_neg_states = f_inference_eq_step(
                    params=eq_params,
                    states=alignment_neg_states,
                    x=x_train,
                    epsilon=epsilon)

            grad_alignments = f_grad_align(forward_params=forward_params,
                                           eq_states=alignment_neg_states,
                                           x=x_train)
            dbplot(grad_alignments,
                   'alignments',
                   plot_type=DBPlotTypes.LINE_HISTORY)
            # --------------------

            # fwd_states, neg_states = do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps)

            results[next, :] = dict(iter=i,
                                    epoch=epoch,
                                    test_init_error=test_init_error,
                                    test_neg_error=test_neg_error,
                                    train_init_error=train_init_error,
                                    train_neg_error=train_neg_error,
                                    alignments=grad_alignments)
            yield results
            if epoch > 2 and train_neg_error > 50:
                return

        # The Original training loop, just taken out here:
        x_data_sample, y_data_sample = x_train[ixs], y_train[ixs]

        states = forward_states = f_forward_pass(
            x=x_data_sample, params=forward_params
        ) if train_with_forward else initialize_states(
            n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:])
        for t in range(n_negative_steps):
            # if i % 200 == 0:
            #     with hold_dbplots():
            #         dbplot_collection(states, 'states', cornertext='NEG')
            #         dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)
            states = f_negative_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        epsilon=epsilon)
        negative_states = states
        this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta
        for t in range(n_positive_steps):
            # if i % 200 == 0:
            #     with hold_dbplots():
            #         dbplot_collection(states, 'states', cornertext='')
            #         dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)
            states = f_positive_eq_step(params=eq_params,
                                        states=states,
                                        x=x_data_sample,
                                        y=y_data_sample,
                                        y_pressure=this_beta,
                                        epsilon=epsilon)
        positive_states = states
        eq_params = f_parameter_update(x=x_data_sample,
                                       params=eq_params,
                                       negative_states=negative_states,
                                       positive_states=positive_states,
                                       learning_rates=learning_rate,
                                       beta=this_beta)

        # with hold_dbplots(draw_every=50):
        #     dbplot_collection([forward_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in forward_params[1:]], '$\phi$')
        #     dbplot_collection([eq_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in eq_params[1:]], '$\\theta$')
        #     dbplot_collection(forward_states, 'forward states')
        #     dbplot_collection(negative_states, 'negative_states')
        #     dbplot(np.array([f_energy(params = eq_params, states=forward_states, x=x_data_sample).mean(), f_energy(params = eq_params, states=negative_states, x=x_data_sample).mean()]), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED)

        if train_with_forward == 'contrast':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=negative_states,
                learning_rates=learning_rate)
            # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate])
        elif train_with_forward == 'contrast+':
            forward_params = f_forward_parameter_contrast_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_states=positive_states,
                learning_rates=learning_rate)
        elif train_with_forward == 'energy':
            forward_params = f_forward_parameter_update(
                x=x_data_sample,
                forward_params=forward_params,
                eq_params=eq_params,
                learning_rates=learning_rate)
        else:
            assert train_with_forward is False
Пример #12
0
def get_temporal_mnist_dataset(smoothing_steps=1000, **mnist_kwargs):

    tr_x, tr_y, ts_x, ts_y = get_mnist_dataset(**mnist_kwargs).xyxy
    tr_ixs = temporalize(tr_x, smoothing_steps=smoothing_steps)
    ts_ixs = temporalize(ts_x, smoothing_steps=smoothing_steps)
    return DataSet.from_xyxy(tr_x[tr_ixs], tr_y[tr_ixs], ts_x[ts_ixs], ts_y[ts_ixs])
def demo_optimize_mnist_net(hidden_sizes=[200, 200],
                            learning_rate=0.01,
                            n_epochs=100,
                            minibatch_size=10,
                            parametrization='log',
                            computation_weights=np.logspace(-6, -3, 8),
                            layerwise_scales=True,
                            show_scales=True,
                            hidden_activations='relu',
                            test_every=0.5,
                            output_activation='softmax',
                            error_loss='L1',
                            comp_evaluation_calc='multiplyadds',
                            smoothing_steps=1000,
                            seed=1234):

    train_data, train_targets, test_data, test_targets = get_mnist_dataset(
        flat=True).to_onehot().xyxy

    params = train_conventional_mlp_on_mnist(
        hidden_sizes=hidden_sizes,
        hidden_activations=hidden_activations,
        output_activation=output_activation,
        rng=seed)
    weights, biases = params[::2], params[1::2]

    rng = get_rng(seed + 1)

    true_out = forward_pass(input_data=test_data,
                            weights=weights,
                            biases=biases,
                            hidden_activations=hidden_activations,
                            output_activation=output_activation)
    optimized_results = OrderedDict([])
    optimized_results['unoptimized'] = get_mnist_results_with_parameters(
        weights=weights,
        biases=biases,
        scales=None,
        hidden_activations=hidden_activations,
        output_activation=output_activation,
        smoothing_steps=smoothing_steps)

    set_dbplot_figure_size(15, 10)
    for comp_weight in computation_weights:
        net = CompErrorScaleOptimizer(ws=weights,
                                      bs=biases,
                                      optimizer=GradientDescent(learning_rate),
                                      comp_weight=comp_weight,
                                      layerwise_scales=layerwise_scales,
                                      hidden_activations=hidden_activations,
                                      output_activation=output_activation,
                                      parametrization=parametrization,
                                      rng=rng)
        f_train = net.train_scales.partial(error_loss=error_loss).compile()
        f_get_scales = net.get_scales.compile()
        for training_minibatch, iter_info in minibatch_iterate_info(
                train_data,
                minibatch_size=minibatch_size,
                n_epochs=n_epochs,
                test_epochs=np.arange(0, n_epochs, test_every)):
            if iter_info.test_now:  # Test the computation and all that
                ks = f_get_scales()
                print 'Epoch %.3g' % (iter_info.epoch, )
                with hold_dbplots():
                    if show_scales:
                        if layerwise_scales:
                            dbplot(ks,
                                   '%s solution_scales' % (comp_weight, ),
                                   plot_type=lambda: LinePlot(
                                       plot_kwargs=dict(linewidth=3),
                                       make_legend=False,
                                       axes_update_mode='expand',
                                       y_bounds=(0, None)),
                                   axis='solution_scales',
                                   xlabel='layer',
                                   ylabel='scale')
                        else:
                            for i, k in enumerate(ks):
                                dbplot(k,
                                       '%s solution_scales' % (i, ),
                                       plot_type=lambda: LinePlot(
                                           plot_kwargs=dict(linewidth=3),
                                           make_legend=False,
                                           axes_update_mode='expand',
                                           y_bounds=(0, None)),
                                       axis='solution_scales',
                                       xlabel='layer',
                                       ylabel='scale')
                    current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output(
                        test_data,
                        weights=weights,
                        scales=ks,
                        quantization_method='round',
                        hidden_activations=hidden_activations,
                        output_activation=output_activation,
                        computation_calc=comp_evaluation_calc,
                        seed=1234)
                    current_error = np.abs(current_outputs - true_out).mean(
                    ) / np.abs(true_out).mean()
                    current_class_error = percent_argmax_incorrect(
                        current_outputs, test_targets)
                    if np.isnan(current_error):
                        print 'ERROR IS NAN!!!'
                    dbplot((current_flop_counts / 1e6, current_error),
                           '%s error-curve' % (comp_weight, ),
                           axis='error-curve',
                           plot_type='trajectory+',
                           xlabel='MFlops',
                           ylabel='error')
                    dbplot((current_flop_counts / 1e6, current_class_error),
                           '%s class-curve' % (comp_weight, ),
                           axis='class-curve',
                           plot_type='trajectory+',
                           xlabel='MFlops',
                           ylabel='class-error')
            f_train(training_minibatch)
        optimized_results['lambda=%.3g' %
                          (comp_weight, )] = get_mnist_results_with_parameters(
                              weights=weights,
                              biases=biases,
                              scales=ks,
                              hidden_activations=hidden_activations,
                              output_activation=output_activation,
                              smoothing_steps=smoothing_steps)
    return optimized_results