Пример #1
0
def figure_local_vs_full():

    from init_eqprop.demo_energy_eq_prop_A_initialized import demo_energy_based_initialize_eq_prop

    ex_local = demo_energy_based_initialize_eq_prop.get_variant(
        'local_loss_baseline').get_variant(local_loss=True)
    ex_global = demo_energy_based_initialize_eq_prop.get_variant(
        'local_loss_baseline').get_variant(local_loss=False)
    plt.figure(figsize=(6, 3))
    set_figure_border_size(border=0.1, bottom=0.15)
    plot_experiment_result(ex_local.get_latest_record(if_none='run'),
                           field='test_init_error',
                           color='C2',
                           label='Local: s^f')
    plot_experiment_result(ex_local.get_latest_record(if_none='run'),
                           field='test_neg_error',
                           color='C1',
                           label='Local: s^-')
    plot_experiment_result(ex_global.get_latest_record(if_none='run'),
                           field='test_init_error',
                           color='C2',
                           linestyle='--',
                           label='Global: s^f')
    plot_experiment_result(ex_global.get_latest_record(if_none='run'),
                           field='test_neg_error',
                           color='C1',
                           linestyle='--',
                           label='Global: s^i')
    plt.xlabel('Epoch')
    plt.ylabel('Classification Test Error')
    plt.legend()
    plt.ylim(0, 10)
    plt.grid()
    plt.show()
Пример #2
0
def figure_compare_lambdas():

    from init_eqprop.demo_energy_eq_prop_AE_initialized_fwd_energy import baseline_final_final
    records = [
        r for r in baseline_final_final.get_variant_records(
            only_last=True,
            only_completed=True,
        ).values() if r is not None
    ]

    duck = make_record_comparison_duck(records)

    plt.figure(figsize=(8, 5))

    set_figure_border_size(hspace=0.4, top=0.1, bottom=0.1, left=0.1)
    for row, netsize in enumerate(('small', 'large')):
        for col, n_negative_steps in enumerate(
                sorted(
                    set(duck[duck[:, 'exp_id'].map(lambda x: netsize in x),
                             'args', 'n_negative_steps']))):
            data = []
            for forward_deviation_cost in sorted(
                    set(duck[:, 'args', 'forward_deviation_cost'])):

                filter_ixs = \
                    (duck[:, 'args', 'n_negative_steps'].each_eq(n_negative_steps)) \
                    & (duck[:, 'args', 'forward_deviation_cost'].each_eq(forward_deviation_cost)) \
                    & (duck[:, 'args', 'train_with_forward'].each_eq('contrast')) \
                    & (duck[:, 'exp_id'].map(lambda x: netsize in x))

                if not any(filter_ixs):
                    continue

                this_result = duck[filter_ixs].only()['result']
                data.append(
                    (forward_deviation_cost, this_result[-1,
                                                         'test_init_error'],
                     this_result[-1, 'test_neg_error']))

            if len(data) == 0:
                continue
            fwd_cost, test_init_error, test_neg_error = zip(*data)

            subplot_at(row, col)
            plt.title(f'{netsize}: $T^-={n_negative_steps}$')
            plt.plot(fwd_cost, test_init_error, label='Init Eq. Prop: $s^f$')
            plt.plot(fwd_cost, test_neg_error, label='Init Eq. Prop: $s^-$')
            plt.ylim(0, 10)
            plt.grid()
            plt.xlabel('$\lambda$')
            plt.ylabel('Test Score')
            plt.legend()

    plt.show()
def create_mnist_figure():

    ex_continuous = experiment_mnist_eqprop.get_variant('vanilla').get_variant('one_hid_swapless')
    ex_binary = experiment_mnist_eqprop.get_variant('quantized').get_variant('one_hid_swapless').get_variant(epsilons='0.843/t**0.0923', lambdas='0.832/t**0.584', quantizer='sigma_delta')

    rec_continuous = ex_continuous.get_latest_record(if_none='run')
    rec_binary = ex_binary.get_latest_record(if_none='run')

    print(f"Continuous: {report_score_from_result(rec_continuous.get_result())}")
    print(f"Binary: {report_score_from_result(rec_binary.get_result())}")

    plt.figure(figsize=(4.5, 3))
    set_figure_border_size(bottom=0.15, left=0.15)
    ex_binary.compare([rec_continuous, rec_binary], show_now=False)

    plt.legend(['Continuous Eq-Prop: Train', 'Continuous Eq-Prop: Test', 'Binary Eq-Prop: Train', 'Binary Eq-Prop: Test'])
    plt.show()
Пример #4
0
def create_gp_convergence_search_figure():

    searches = [
        # ('$\epsilon = 1/\sqrt{t}$', 'demo_quantized_convergence_perturbed.nonadaptive.epsilons=1_SLASH_sqrt(t)'),
        ('$\epsilon = \epsilon_0/t^{\eta_\epsilon}, \lambda = \lambda_0/t^{\eta_\lambda}$', 'demo_quantized_convergence_perturbed.nonadaptive.poly_schedule.parameter_search'),
        ('$\epsilon$ = OSA($\\bar\\nu$)', 'demo_quantized_convergence_perturbed.adaptive.optimal.parameter_search'),
        ('$\epsilon$ = OSA($\\bar\\nu$), $\lambda$', 'demo_quantized_convergence_perturbed.adaptive.optimal.parameter_search_predictive'),
        ('$\epsilon$ = OSA($\\bar\\nu$), $\lambda = \lambda_0/t^{\eta_\lambda}$', 'demo_quantized_convergence_perturbed.adaptive.optimal.OSA-lambda_sched.parameter_search_predictive'),
    ]

    relabels = {'eps_init': '$\epsilon_0$', 'eps_exp':' $\eta_\epsilon$', 'lambda_init': '$\lambda_0$', 'lambda_exp':' $\eta_\lambda$',
                'error_stepsize_target': '$\\bar\\nu$', 'lambdas': '$\lambda$'}
    plt.figure(figsize=(7, 7))
    set_figure_border_size(hspace=0.25, border=0.05, bottom=.1, right=0.1)

    for search_name, search_exp_id in searches:
        add_subplot()
        exp = load_experiment(search_exp_id)
        rec = exp.get_latest_record()
        plot_hyperparameter_search(rec, relabel=relabels, assert_all_relabels_used=False, score_name='Error')
        plt.title(search_name)
    plt.show()
Пример #5
0
def create_gp_mnist_search_figure():

    # TODO: Replace the first one with the search with LONGER when it's done.
    searches = [
        # ('$\epsilon = \epsilon_0/t^{\eta_\epsilon}, \lambda = \lambda_0/t^{\eta_\lambda}$', 'experiment_mnist_eqprop_torch.1_hid.quantized.poly_schedule.epoch_checkpoint_period=None,n_epochs=1,quantizer=sigma_delta.parameter_search'),
        ('$\epsilon = \epsilon_0/t^{\eta_\epsilon}, \lambda = \lambda_0/t^{\eta_\lambda}$', X_polyscheduled_longer_psearch),
        # ('$\epsilon$ = OSA($\\bar\\nu$), $\lambda$', 'experiment_mnist_eqprop_torch.1_hid.adaptive_quantized.optimal.n_negative_steps=100,n_positive_steps=50.epoch_checkpoint_period=None,n_epochs=1.parameter_search')
        ('$\epsilon$ = OSA($\\bar\\nu$), $\lambda$', X_osa_longer_1hid_psearch)
    ]

    relabels = {'eps_init': '$\epsilon_0$', 'eps_exp':' $\eta_\epsilon$', 'lambda_init': '$\lambda_0$', 'lambda_exp':' $\eta_\lambda$',
                'error_stepsize_target': '$\\bar\\nu$', 'lambdas': '$\lambda$'}

    plt.figure(figsize=(7, 4))
    set_figure_border_size(hspace=0.25, border=0.05, bottom=.1, right=0.1)
    for search_name, search_exp_id in searches:
        add_subplot()
        exp = load_experiment(search_exp_id) if isinstance(search_exp_id, str) else search_exp_id  # type: Experiment
        rec = exp.get_latest_record()
        plot_hyperparameter_search(rec, relabel=relabels, assert_all_relabels_used=False, score_name='Val. Error\n after 1 epoch')
        plt.title(search_name)

    plt.show()
def demo_create_signal_figure(
    w=[-.7, .8, .5],
    eps=0.2,
    lambda_schedule='1/t**.75',
    eps_schedule='1/t**.4',
    n_samples=200,
    seed=1247,
    input_convergence_speed=3,
    scale=0.3,
):
    rng = np.random.RandomState(seed)

    varying_sig = lowpass_random(n_samples=n_samples,
                                 cutoff=0.03,
                                 n_dim=len(w),
                                 normalize=(-scale, scale),
                                 rng=rng)
    frac = 1 - np.linspace(1, 0, n_samples)[:, None]**input_convergence_speed
    x = np.clip((1 - frac) * varying_sig + frac * scale * rng.rand(3), 0, 1)
    true_z = [
        s for s in [0] for xt in x
        for s in [np.clip((1 - eps) * s + eps * xt.dot(w), 0, 1)]
    ]

    # Alternative try 2
    eps_stepper = create_step_sizer(eps_schedule)  # type: IStepSizer
    lambda_stepper = create_step_sizer(lambda_schedule)  # type: IStepSizer
    encoder = PredictiveEncoder(lambda_stepper=lambda_stepper,
                                quantizer=SigmaDeltaQuantizer())
    decoder = PredictiveDecoder(lambda_stepper=lambda_stepper)
    q = np.array(
        [qt for enc in [encoder] for xt in x for enc, qt in [enc(xt)]])
    inputs = [qt.dot(w) for qt in q]
    sig, epsilons, lambdaas = unzip([
        (s, eps, dec.lambda_stepper(inp)[1])
        for s, dec, eps_func in [[0, decoder, eps_stepper]] for inp in inputs
        for dec, decoded_input in [dec(inp)]
        for eps_func, eps in [eps_func(decoded_input)]
        for s in [np.clip((1 - eps) * s + eps * decoded_input, 0, 1)]
    ])

    fig = plt.figure(figsize=(3, 4.5))
    set_figure_border_size(0.02, bottom=0.1)

    with vstack_plots(spacing=0.1, xlabel='t', bottom_pad=0.1):

        ax = add_subplot()

        sep = np.max(x) * 1.1
        plot_stacked_signals(x, sep=sep, labels=False)
        plt.gca().set_prop_cycle(None)

        event_raster_plot(
            events=[np.nonzero(q[:, i])[0] for i in range(len(w))],
            sep=sep,
            s=100)
        ax.legend(labels=[f'$s_{i}$' for i in range(1,
                                                    len(w) + 1)],
                  loc='lower left')

        ax = add_subplot()
        stemlight(inputs, ax=ax, label='$u_j$', color='k')
        # plt.plot(inputs, label='$u_j$', color='k')
        ax.axhline(0, color='k')
        ax.tick_params(axis='y', labelleft='off')
        ax.legend(loc='upper left')
        # plt.grid()

        ax = add_subplot()
        # ax.plot([eps_func(t) for t in range(n_samples)], label='$\epsilon$', color='k')
        ax.plot(epsilons, label='$\epsilon$', color='k')
        ax.plot(lambdaas, label='$\lambda$', color='b')
        ax.axhline(0, color='k')
        ax.legend(loc='upper right')
        ax.tick_params(axis='y', labelleft='off')

        ax = add_subplot()
        ax.plot(true_z, label='$s_j$ (real)', color='r')
        ax.plot(sig, label='$s_j$ (binary)', color='k')
        ax.legend(loc='lower right')
        ax.tick_params(axis='y', labelleft='off')
        ax.axhline(0, color='k')

    plt.show()
Пример #7
0
def figure_mnist_eqprop_multi_size():

    forward_deviation_cost = 0.1
    from init_eqprop.demo_energy_eq_prop_AE_initialized_fwd_energy import baseline_final_small, baseline_final_large

    plt.figure(figsize=(7, 4))
    set_figure_border_size(wspace=0.05,
                           hspace=0.2,
                           bottom=0.11,
                           top=0.07,
                           left=0.1)
    for j, size_baseline in enumerate(
        (baseline_final_small, baseline_final_large)):
        # for i, n_negative_steps in enumerate((10, 20) if size_baseline==baseline_final_small else (20, 50)):
        for i, n_negative_steps in enumerate((
                4, 20) if size_baseline == baseline_final_small else (20, 50)):
            ax = subplot_at(i, j)
            h_eqprop, result_eqprop = plot_experiment_result(
                size_baseline.get_variant(
                    train_with_forward=False,
                    n_negative_steps=n_negative_steps,
                    forward_deviation_cost=0).get_latest_record(if_none='run'),
                'test_neg_error',
                label='Eq Prop: $s^-$')
            h_init, result_init = plot_experiment_result(
                size_baseline.get_variant(
                    train_with_forward='contrast',
                    n_negative_steps=n_negative_steps,
                    forward_deviation_cost=forward_deviation_cost).
                get_latest_record(if_none='run'),
                'test_init_error',
                label='Init Eq Prop: $s^f$',
                color='C2')
            h_neg, _ = plot_experiment_result(size_baseline.get_variant(
                train_with_forward='contrast',
                n_negative_steps=n_negative_steps,
                forward_deviation_cost=forward_deviation_cost).
                                              get_latest_record(if_none='run'),
                                              'test_neg_error',
                                              label='Init Eq Prop: $s^-$',
                                              color='C1')

            print(
                f'Final Scores for {size_baseline.name[size_baseline.name.rfind(".")+1:]}, T={n_negative_steps}: Eqprop {result_eqprop[-1, "test_neg_error"]:.3g}, {result_init[-1, "test_init_error"]:.3g}, {result_init[-1, "test_neg_error"]:.3g}'
            )

            plt.ylim(0, 10 if i > 0 else 30)

            if i == 0:
                plt.text(
                    x=result_init[-1, 'epoch'] * 4 / 5,
                    y=ax.get_ylim()[1] * 3.3 / 10,
                    s=f'{result_eqprop[-1, "test_neg_error"]:.3g}',
                    color=h_eqprop.get_color(),
                )
                plt.text(
                    x=result_init[-1, 'epoch'] * 4 / 5,
                    y=ax.get_ylim()[1] * 2.3 / 10,
                    s=f'{result_init[-1, "test_init_error"]:.3g}',
                    color=h_init.get_color(),
                )
                plt.text(
                    x=result_init[-1, 'epoch'] * 4 / 5,
                    y=ax.get_ylim()[1] * 1.3 / 10,
                    s=f'{result_init[-1, "test_neg_error"]:.3g}',
                    color=h_neg.get_color(),
                )
            else:
                plt.text(
                    x=result_init[-1, 'epoch'] * 4 / 5,
                    y=ax.get_ylim()[1] * 5.3 / 10,
                    s=f'{result_eqprop[-1, "test_neg_error"]:.3g}',
                    color=h_eqprop.get_color(),
                )
                plt.text(
                    x=result_init[-1, 'epoch'] * 4 / 5,
                    y=ax.get_ylim()[1] * 4.3 / 10,
                    s=f'{result_init[-1, "test_init_error"]:.3g}',
                    color=h_init.get_color(),
                )
                plt.text(
                    x=result_init[-1, 'epoch'] * 4 / 5,
                    y=ax.get_ylim()[1] * 3.3 / 10,
                    s=f'{result_init[-1, "test_neg_error"]:.3g}',
                    color=h_neg.get_color(),
                )

            plt.grid()
            plt.title(
                f'[784-{"500" if  size_baseline==baseline_final_small else "500-500-500"}-10], {n_negative_steps}-step'
            )
            if i < 1:
                ax.xaxis.set_ticklabels([])
            else:
                plt.xlabel('Epoch')

            if j > 0:
                ax.yaxis.set_ticklabels([])
            else:
                plt.ylabel('Classification Test Error')

            if i == j == 0:
                plt.legend()

    plt.show()