def demo_temporal_mnist(n_samples = None, smoothing_steps = 200): _, _, original_data, original_labels = get_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples).xyxy _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels): with hold_dbplots(): dbplot(ox, 'sample', title = str(oy)) dbplot(tx, 'smooth', title = str(ty))
def demo_mnist_logreg(minibatch_size=20, learning_rate=0.01, max_training_samples = None, n_epochs=10, test_epoch_period=0.2): """ Train a Logistic Regressor on the MNIST dataset, report training/test scores throughout training, and return the final scores. """ x_train, y_train, x_test, y_test = get_mnist_dataset(flat=True, n_training_samples=max_training_samples).xyxy predictor = OnlineLogisticRegressor(n_in=784, n_out=10, learning_rate=learning_rate) # Train and periodically record scores. epoch_scores = [] for ix, iteration_info in minibatch_index_info_generator(n_samples = len(x_train), minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_epoch_period)): if iteration_info.test_now: training_error = 100*(np.argmax(predictor.predict(x_train), axis=1)==y_train).mean() test_error = 100*(np.argmax(predictor.predict(x_test), axis=1)==y_test).mean() print('Epoch {epoch}: Test Error: {test}%, Training Error: {train}%'.format(epoch=iteration_info.epoch, test=test_error, train=training_error)) epoch_scores.append((iteration_info.epoch, training_error, test_error)) predictor.train(x_train[ix], y_train[ix]) # Plot plt.figure() epochs, training_costs, test_costs = zip(*epoch_scores) plt.plot(epochs, np.array([training_costs, test_costs]).T) plt.xlabel('Epoch') plt.ylabel('% Error') plt.legend(['Training Error', 'Test Error']) plt.title("Learning Curve") plt.ylim(80, 100) plt.ion() # Don't hang on plot plt.show() return {'train': epoch_scores[-1][1], 'test': epoch_scores[-1][2]} # Return final scores
def get_temporal_mnist_dataset(smoothing_steps=1000, **mnist_kwargs): tr_x, tr_y, ts_x, ts_y = get_mnist_dataset(**mnist_kwargs).xyxy tr_ixs = temporalize(tr_x, smoothing_steps=smoothing_steps) ts_ixs = temporalize(ts_x, smoothing_steps=smoothing_steps) return DataSet.from_xyxy(tr_x[tr_ixs], tr_y[tr_ixs], ts_x[ts_ixs], ts_y[ts_ixs])
def demo_temporal_mnist(n_samples=None, smoothing_steps=200): _, _, original_data, original_labels = get_mnist_dataset( n_training_samples=n_samples, n_test_samples=n_samples).xyxy _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset( n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels): with hold_dbplots(): dbplot(ox, 'sample', title=str(oy)) dbplot(tx, 'smooth', title=str(ty))
def train_conventional_mlp_on_mnist(hidden_sizes, n_epochs=50, w_init='xavier-both', minibatch_size=20, rng=1234, optimizer='sgd', hidden_activations='relu', output_activation='softmax', learning_rate=0.01, cost_function='nll', use_bias=True, l1_loss=0, l2_loss=0, test_on='training+test'): dataset = get_mnist_dataset(flat=True)\ if output_activation != 'softmax': dataset = dataset.to_onehot() all_layer_sizes = [dataset.input_size ] + hidden_sizes + [dataset.n_categories] weights = initialize_network_params(layer_sizes=all_layer_sizes, mag=w_init, base_dist='normal', include_biases=False, rng=rng) net = MultiLayerPerceptron(weights=weights, hidden_activation=hidden_activations, output_activation=output_activation, use_bias=use_bias) predictor = GradientBasedPredictor( function=net, cost_function=get_named_cost_function(cost_function), optimizer=get_named_optimizer(optimizer, learning_rate=learning_rate), regularization_cost=lambda params: sum(l1_loss * abs(p_).sum( ) + l2_loss * (p_**2).sum() if p_.ndim == 2 else 0 for p_ in params)).compile() assess_online_predictor(predictor=predictor, dataset=dataset, evaluation_function='percent_argmax_correct', test_epochs=range(0, n_epochs, 1), test_on=test_on, minibatch_size=minibatch_size) ws = [p.get_value() for p in net.parameters] return ws
def experiment_mnist_eqprop( layer_constructor, n_epochs=10, hidden_sizes=(500, ), minibatch_size=20, beta=.5, random_flip_beta=True, learning_rate=.05, n_negative_steps=20, n_positive_steps=4, initial_weight_scale=1., online_checkpoints_period=None, epoch_checkpoint_period=.25, skip_zero_epoch_test=False, n_test_samples=None, prop_direction: Union[str, Tuple] = 'neutral', bidirectional=True, renew_activations=True, do_fast_forward_pass=False, rebuild_coders=True, l2_loss=None, splitstream=True, seed=1234, ): """ Replicate the results of Scellier & Bengio: Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full Specifically, the train_model demo here: https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop Differences between our code and theirs: - We do not keep persistent layer activations tied to data points over epochs. So our results should only really match for the first epoch. - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None) """ print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot() x_train, y_train = dataset.training_set.xy x_test, y_test = dataset.test_set.xy # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100] n_epochs = 1 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] rng = get_rng(rng) y_train = y_train.astype(np.float32) ra = RunningAverage() sp = Speedometer(mode='last') is_online_checkpoint = Checkpoints( online_checkpoints_period, skip_first=skip_zero_epoch_test ) if online_checkpoints_period is not None else lambda: False is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) results = Duck() training_states = initialize_states( layer_constructor=layer_constructor, n_samples=minibatch_size, params=initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng)) if isinstance(prop_direction, str): fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction else: fwd_prop_direction, backward_prop_direction = prop_direction for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.shape[0], minibatch_size=minibatch_size, n_epochs=n_epochs)): epoch = i * minibatch_size / x_train.shape[0] if is_epoch_checkpoint(epoch): n_samples = n_test_samples if n_test_samples is not None else len( x_test) y_pred_test, y_pred_train = [ run_inference( x_data=x[:n_test_samples], states=initialize_states( layer_constructor=layer_constructor, params=[s.params for s in training_states], n_samples=min(len(x), n_test_samples) if n_test_samples is not None else len(x)), n_steps=n_negative_steps, prop_direction=fwd_prop_direction, ) for x in (x_test, x_train) ] # y_pred_train = run_inference(x_data=x_train[:n_test_samples], states=initialize_states(params=[s.params for s in training_states], n_samples=min(len(x_train), n_test_samples) if n_test_samples is not None else len(x_train))) test_error = percent_argmax_incorrect(y_pred_test, y_test[:n_test_samples]) train_error = percent_argmax_incorrect(y_pred_train, y_train[:n_test_samples]) print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Mean Rate: {sp(i):.3g}iter/s' ) results[next, :] = dict(iter=i, epoch=epoch, train_error=train_error, test_error=test_error) yield results if epoch > 2 and train_error > 50: return # The Original training loop, just taken out here: x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] training_states = run_eqprop_training_update( x_data=x_data_sample, y_data=y_data_sample, layer_states=training_states, beta=beta, random_flip_beta=random_flip_beta, learning_rate=learning_rate, layer_constructor=layer_constructor, bidirectional=bidirectional, l2_loss=l2_loss, renew_activations=renew_activations, n_negative_steps=n_negative_steps, n_positive_steps=n_positive_steps, prop_direction=prop_direction, splitstream=splitstream, rng=rng) this_train_score = ra( percent_argmax_correct(output_from_state(training_states), y_train[ixs])) if is_online_checkpoint(): print( f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}' )
def demo_herding_network(kp=.1, kd=1., kp_back=None, kd_back=None, hidden_sizes=[ 200, ], n_epochs=50, onehot=False, parallel=False, learning_rate=0.01, dataset='mnist', hidden_activation='relu', adaptive=True, adaptation_rate=0.001, output_activation='softmax', loss='nll', fwd_quantizer='herd', back_quantizer='same', minibatch_size=1, swap_mlp=False, plot=False, test_period=.5, grad_calc='true', rng=1234): dataset = get_mnist_dataset( flat=True, join_train_and_val=True ) if dataset == 'mnist' else get_temporal_mnist_dataset( flat=True, join_train_and_val=True) if onehot: dataset = dataset.to_onehot() ws = initialize_network_params(layer_sizes=[28 * 28] + hidden_sizes + [10], mag='xavier-both', include_biases=False, rng=rng) if is_test_mode(): dataset = dataset.shorten(500) n_epochs = 0.1 test_period = 0.03 if kp_back is None: kp_back = kp if kd_back is None: kd_back = kd if back_quantizer == 'same': back_quantizer = fwd_quantizer if adaptive: encdec = lambda: PDAdaptiveEncoderDecoder(kp=kp, kd=kd, adaptation_rate= adaptation_rate, quantization=fwd_quantizer) encdec_back = lambda: PDAdaptiveEncoderDecoder( kp=kp_back, kd=kd_back, adaptation_rate=adaptation_rate, quantization=back_quantizer) else: encdec = PDEncoderDecoder(kp=kp, kd=kd, quantization=fwd_quantizer) encdec_back = PDEncoderDecoder(kp=kp_back, kd=kd_back, quantization=back_quantizer) if swap_mlp: if not parallel: assert minibatch_size == 1, "Unfair comparison otherwise, sorry buddy, can't let you do that." net = GradientBasedPredictor( function=MultiLayerPerceptron.from_weights( weights=ws, hidden_activations=hidden_activation, output_activation=output_activation, ), cost_function=loss, optimizer=GradientDescent(learning_rate), ) prediction_funcs = net.predict.compile() else: net = PDHerdingNetwork( ws=ws, encdec=encdec, encdec_back=encdec_back, hidden_activation=hidden_activation, output_activation=output_activation, optimizer=GradientDescent(learning_rate), minibatch_size=minibatch_size if parallel else 1, grad_calc=grad_calc, loss=loss) noise_free_forward_pass = MultiLayerPerceptron.from_weights( weights=[layer.w for layer in net.layers], biases=[layer.b for layer in net.layers], hidden_activations=hidden_activation, output_activation=output_activation).compile() prediction_funcs = [('noise_free', noise_free_forward_pass), ('herded', net.predict.compile())] op_count_info = [] def test_callback(info, score): if plot: dbplot(net.layers[0].w.get_value().T.reshape(-1, 28, 28), 'w0', cornertext='Epoch {}'.format(info.epoch)) if swap_mlp: all_layer_sizes = [dataset.input_size ] + hidden_sizes + [dataset.target_size] fwd_ops = [ info.sample * d1 * d2 for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:]) ] back_ops = [ info.sample * d1 * d2 for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:]) ] update_ops = [ info.sample * d1 * d2 for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:]) ] else: fwd_ops = [ layer_.fwd_op_count.get_value() for layer_ in net.layers ] back_ops = [ layer_.back_op_count.get_value() for layer_ in net.layers ] update_ops = [ layer_.update_op_count.get_value() for layer_ in net.layers ] if info.epoch != 0: with IndentPrint('Mean Ops by epoch {}'.format(info.epoch)): print 'Fwd: {}'.format([ si_format(ops / info.epoch, format_str='{value} {prefix}Ops') for ops in fwd_ops ]) print 'Back: {}'.format([ si_format(ops / info.epoch, format_str='{value} {prefix}Ops') for ops in back_ops ]) print 'Update: {}'.format([ si_format(ops / info.epoch, format_str='{value} {prefix}Ops') for ops in update_ops ]) if info.epoch > max( 0.5, 2 * test_period) and not swap_mlp and score.get_score( 'train', 'noise_free') < 20: raise Exception("This horse ain't goin' nowhere.") op_count_info.append((info, (fwd_ops, back_ops, update_ops))) info_score_pairs = train_and_test_online_predictor( dataset=dataset, train_fcn=net.train.compile(), predict_fcn=prediction_funcs, minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_period), score_measure='percent_argmax_correct', test_on='training+test', test_callback=test_callback) return info_score_pairs, op_count_info
def get_mnist_results_with_parameters(weights, biases, scales=None, hidden_activations='relu', output_activation='softmax', n_samples=None, smoothing_steps=1000): """ Return a data structure showing the error and computation for required by the orignal, rounding, and sigma-delta implementation of a network with the given parameters. :param weights: :param biases: :param scales: :param hidden_activations: :param output_activation: :param n_samples: :param smoothing_steps: :return: results: An OrderedDict Where the key is a 3-tuple are: (dataset_name, subset, net_version), Where: dataset_name: is 'mnist' or 'temp_mnist' subset: is 'train' or 'test' net_version: is 'td' or 'round' or 'truth' And values are another OrderedDict, with keys: 'MFlops', 'l1_errorm', 'class_error' ... for discrete nets and 'Dense MFlops', 'Sparse MFlops', 'class_error' for "true" nets. """ mnist = get_mnist_dataset(flat=True, n_training_samples=n_samples, n_test_samples=n_samples) temp_mnist = get_temporal_mnist_dataset(flat=True, smoothing_steps=smoothing_steps, n_training_samples=n_samples, n_test_samples=n_samples) results = OrderedDict() p = ProgressIndicator(2 * 3 * 2) for dataset_name, (tr_x, tr_y, ts_x, ts_y) in [('mnist', mnist.xyxy), ('temp_mnist', temp_mnist.xyxy)]: for subset, x, y in [('train', tr_x, tr_y), ('test', ts_x, ts_y)]: traditional_net_output, dense_flops, sparse_flops = forward_pass_and_cost( input_data=x, weights=weights, biases=biases, hidden_activations=hidden_activations, output_activations=output_activation) assert round(dense_flops) == dense_flops and round( sparse_flops) == sparse_flops, 'Flop counts must be int!' class_error = percent_argmax_incorrect(traditional_net_output, y) results[dataset_name, subset, 'truth'] = OrderedDict([ ('Dense MFlops', dense_flops / (1e6 * len(x))), ('Sparse MFlops', sparse_flops / (1e6 * len(x))), ('class_error', class_error) ]) for net_version in 'td', 'round': (comp_cost_adds, comp_cost_multiplyadds ), output = tdnet_forward_pass_cost_and_output( inputs=x, weights=weights, biases=biases, scales=scales, version=net_version, hidden_activations=hidden_activations, output_activations=output_activation, quantization_method='herd', computation_calc=('adds', 'multiplyadds')) l1_error = np.abs(output - traditional_net_output).sum( axis=1).mean(axis=0) class_error = percent_argmax_incorrect(output, y) results[dataset_name, subset, net_version] = OrderedDict([ ('MFlops', comp_cost_adds / (1e6 * len(x))), ('MFlops-multadd', comp_cost_multiplyadds / (1e6 * len(x))), ('l1_error', l1_error), ('class_error', class_error) ]) p.print_update() return results
def demo_energy_based_initialize_eq_prop_fwd_energy( n_epochs=25, hidden_sizes=(500, ), minibatch_size=20, beta=0.5, epsilon=0.5, learning_rate=(0.1, .05), n_negative_steps=20, n_positive_steps=4, initial_weight_scale=1., forward_deviation_cost=0., zero_deviation_cost=1, epoch_checkpoint_period={ 0: .25, 1: .5, 5: 1, 10: 2, 50: 4 }, n_test_samples=10000, skip_zero_epoch_test=False, train_with_forward='contrast', forward_nonlinearity='rho(x)', local_loss=True, random_flip_beta=True, seed=1234, ): print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) assert train_with_forward in ('contrast', 'contrast+', 'energy', False) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot() x_train, y_train = dataset.training_set.xy x_test, y_test = dataset.test_set.xy # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100] n_epochs = 1 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] eq_params = initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) forward_params = initialize_params( layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) y_train = y_train.astype(np.float32) sp = Speedometer(mode='last') is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) f_negative_eq_step = equilibriating_step.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_inference_eq_step = equilibriating_step.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_positive_eq_step = equilibriating_step.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_parameter_update = update_eq_params.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_forward_pass = forward_pass.partial( nonlinearity=forward_nonlinearity).compile() # f_forward_parameter_update = update_forward_params_with_energy.compile() f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial( nonlinearity=forward_nonlinearity).compile() def do_inference(forward_params_, eq_params_, x, n_steps): states_ = forward_states_ = f_forward_pass( x=x, params=forward_params_ ) if train_with_forward else initialize_states( n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:]) for _ in range(n_steps): states_ = f_inference_eq_step(params=eq_params_, states=states_, fwd_states=forward_states_, x=x, epsilon=epsilon) return forward_states_[-1], states_[-1] results = Duck() # last_time, last_epoch = time(), -1 for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.shape[0], minibatch_size=minibatch_size, n_epochs=n_epochs)): epoch = i * minibatch_size / x_train.shape[0] # print(f'Training Rate: {(time()-last_time)/(epoch-last_epoch):3g}s/ep') # last_time, last_epoch = time(), epoch if is_epoch_checkpoint(epoch): n_samples = n_test_samples if n_test_samples is not None else len( x_test) (test_init_error, test_neg_error), (train_init_error, train_neg_error) = [[ percent_argmax_incorrect(prediction, y[:n_test_samples]) for prediction in do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps) ] for x, y in [(x_test, y_test), (x_train, y_train)]] print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s' ) results[next, :] = dict(iter=i, epoch=epoch, test_init_error=test_init_error, test_neg_error=test_neg_error, train_init_error=train_init_error, train_neg_error=train_neg_error) yield results if epoch > 2 and train_neg_error > 50: return # The Original training loop, just taken out here: x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] states = forward_states = f_forward_pass( x=x_data_sample, params=forward_params ) if train_with_forward else initialize_states( n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:]) for t in range(n_negative_steps): states = f_negative_eq_step(params=eq_params, states=states, x=x_data_sample, epsilon=epsilon, fwd_states=forward_states) negative_states = states this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta for t in range(n_positive_steps): states = f_positive_eq_step(params=eq_params, states=states, x=x_data_sample, y=y_data_sample, y_pressure=this_beta, epsilon=epsilon, fwd_states=forward_states) positive_states = states eq_params = f_parameter_update(x=x_data_sample, params=eq_params, negative_states=negative_states, positive_states=positive_states, fwd_states=forward_states, learning_rates=learning_rate, beta=this_beta) if train_with_forward == 'contrast': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=learning_rate) # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate]) elif train_with_forward == 'contrast+': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=positive_states, learning_rates=learning_rate) # elif train_with_forward == 'energy': # forward_params = f_forward_parameter_update(x=x_data_sample, forward_params = forward_params, eq_params=eq_params, learning_rates=learning_rate) else: assert train_with_forward is False
def demo_energy_based_initialize_eq_prop_alignment( n_epochs=25, hidden_sizes=(500, ), minibatch_size=20, beta=0.5, epsilon=0.5, learning_rate=(0.1, .05), n_negative_steps=20, n_positive_steps=4, initial_weight_scale=1., epoch_checkpoint_period={ 0: .25, 1: .5, 5: 1, 10: 2, 50: 4 }, n_test_samples=10000, skip_zero_epoch_test=False, train_with_forward='contrast', forward_nonlinearity='rho(x)', local_loss=True, random_flip_beta=True, seed=1234, ): print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) assert train_with_forward in ('contrast', 'contrast+', 'energy', False) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot() x_train, y_train = dataset.training_set.xy x_test, y_test = dataset.test_set.xy # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100] n_epochs = 1 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] eq_params = initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) forward_params = initialize_params( layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) y_train = y_train.astype(np.float32) sp = Speedometer(mode='last') is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) f_negative_eq_step = equilibriating_step.compile() f_inference_eq_step = equilibriating_step.compile() f_positive_eq_step = equilibriating_step.compile() f_parameter_update = update_eq_params.compile() f_forward_pass = forward_pass.partial( nonlinearity=forward_nonlinearity).compile() f_forward_parameter_update = update_forward_params_with_energy.partial( disconnect_grads=local_loss, nonlinearity=forward_nonlinearity).compile() f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial( disconnect_grads=local_loss, nonlinearity=forward_nonlinearity).compile() f_energy = energy.compile() f_grad_align = compute_gradient_alignment.partial( nonlinearity=forward_nonlinearity).compile() def do_inference(forward_params_, eq_params_, x, n_steps): states_ = forward_states_ = f_forward_pass( x=x, params=forward_params_ ) if train_with_forward else initialize_states( n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:]) for _ in range(n_steps): states_ = f_inference_eq_step(params=eq_params_, states=states_, x=x, epsilon=epsilon) return forward_states_[-1], states_[-1] results = Duck() for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.shape[0], minibatch_size=minibatch_size, n_epochs=n_epochs)): epoch = i * minibatch_size / x_train.shape[0] if is_epoch_checkpoint(epoch): n_samples = n_test_samples if n_test_samples is not None else len( x_test) (test_init_error, test_neg_error), (train_init_error, train_neg_error) = [[ percent_argmax_incorrect(prediction, y[:n_test_samples]) for prediction in do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps) ] for x, y in [(x_test, y_test), (x_train, y_train)]] print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s' ) # ===== Compute alignment alignment_neg_states = forward_states_ = f_forward_pass( x=x_train, params=forward_params ) if train_with_forward else initialize_states( n_samples=x_train.shape[0], noninput_layer_sizes=layer_sizes[1:]) for _ in range(n_negative_steps): alignment_neg_states = f_inference_eq_step( params=eq_params, states=alignment_neg_states, x=x_train, epsilon=epsilon) grad_alignments = f_grad_align(forward_params=forward_params, eq_states=alignment_neg_states, x=x_train) dbplot(grad_alignments, 'alignments', plot_type=DBPlotTypes.LINE_HISTORY) # -------------------- # fwd_states, neg_states = do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps) results[next, :] = dict(iter=i, epoch=epoch, test_init_error=test_init_error, test_neg_error=test_neg_error, train_init_error=train_init_error, train_neg_error=train_neg_error, alignments=grad_alignments) yield results if epoch > 2 and train_neg_error > 50: return # The Original training loop, just taken out here: x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] states = forward_states = f_forward_pass( x=x_data_sample, params=forward_params ) if train_with_forward else initialize_states( n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:]) for t in range(n_negative_steps): # if i % 200 == 0: # with hold_dbplots(): # dbplot_collection(states, 'states', cornertext='NEG') # dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED) states = f_negative_eq_step(params=eq_params, states=states, x=x_data_sample, epsilon=epsilon) negative_states = states this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta for t in range(n_positive_steps): # if i % 200 == 0: # with hold_dbplots(): # dbplot_collection(states, 'states', cornertext='') # dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED) states = f_positive_eq_step(params=eq_params, states=states, x=x_data_sample, y=y_data_sample, y_pressure=this_beta, epsilon=epsilon) positive_states = states eq_params = f_parameter_update(x=x_data_sample, params=eq_params, negative_states=negative_states, positive_states=positive_states, learning_rates=learning_rate, beta=this_beta) # with hold_dbplots(draw_every=50): # dbplot_collection([forward_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in forward_params[1:]], '$\phi$') # dbplot_collection([eq_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in eq_params[1:]], '$\\theta$') # dbplot_collection(forward_states, 'forward states') # dbplot_collection(negative_states, 'negative_states') # dbplot(np.array([f_energy(params = eq_params, states=forward_states, x=x_data_sample).mean(), f_energy(params = eq_params, states=negative_states, x=x_data_sample).mean()]), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED) if train_with_forward == 'contrast': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=learning_rate) # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate]) elif train_with_forward == 'contrast+': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=positive_states, learning_rates=learning_rate) elif train_with_forward == 'energy': forward_params = f_forward_parameter_update( x=x_data_sample, forward_params=forward_params, eq_params=eq_params, learning_rates=learning_rate) else: assert train_with_forward is False
def demo_optimize_mnist_net(hidden_sizes=[200, 200], learning_rate=0.01, n_epochs=100, minibatch_size=10, parametrization='log', computation_weights=np.logspace(-6, -3, 8), layerwise_scales=True, show_scales=True, hidden_activations='relu', test_every=0.5, output_activation='softmax', error_loss='L1', comp_evaluation_calc='multiplyadds', smoothing_steps=1000, seed=1234): train_data, train_targets, test_data, test_targets = get_mnist_dataset( flat=True).to_onehot().xyxy params = train_conventional_mlp_on_mnist( hidden_sizes=hidden_sizes, hidden_activations=hidden_activations, output_activation=output_activation, rng=seed) weights, biases = params[::2], params[1::2] rng = get_rng(seed + 1) true_out = forward_pass(input_data=test_data, weights=weights, biases=biases, hidden_activations=hidden_activations, output_activation=output_activation) optimized_results = OrderedDict([]) optimized_results['unoptimized'] = get_mnist_results_with_parameters( weights=weights, biases=biases, scales=None, hidden_activations=hidden_activations, output_activation=output_activation, smoothing_steps=smoothing_steps) set_dbplot_figure_size(15, 10) for comp_weight in computation_weights: net = CompErrorScaleOptimizer(ws=weights, bs=biases, optimizer=GradientDescent(learning_rate), comp_weight=comp_weight, layerwise_scales=layerwise_scales, hidden_activations=hidden_activations, output_activation=output_activation, parametrization=parametrization, rng=rng) f_train = net.train_scales.partial(error_loss=error_loss).compile() f_get_scales = net.get_scales.compile() for training_minibatch, iter_info in minibatch_iterate_info( train_data, minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=np.arange(0, n_epochs, test_every)): if iter_info.test_now: # Test the computation and all that ks = f_get_scales() print 'Epoch %.3g' % (iter_info.epoch, ) with hold_dbplots(): if show_scales: if layerwise_scales: dbplot(ks, '%s solution_scales' % (comp_weight, ), plot_type=lambda: LinePlot( plot_kwargs=dict(linewidth=3), make_legend=False, axes_update_mode='expand', y_bounds=(0, None)), axis='solution_scales', xlabel='layer', ylabel='scale') else: for i, k in enumerate(ks): dbplot(k, '%s solution_scales' % (i, ), plot_type=lambda: LinePlot( plot_kwargs=dict(linewidth=3), make_legend=False, axes_update_mode='expand', y_bounds=(0, None)), axis='solution_scales', xlabel='layer', ylabel='scale') current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output( test_data, weights=weights, scales=ks, quantization_method='round', hidden_activations=hidden_activations, output_activation=output_activation, computation_calc=comp_evaluation_calc, seed=1234) current_error = np.abs(current_outputs - true_out).mean( ) / np.abs(true_out).mean() current_class_error = percent_argmax_incorrect( current_outputs, test_targets) if np.isnan(current_error): print 'ERROR IS NAN!!!' dbplot((current_flop_counts / 1e6, current_error), '%s error-curve' % (comp_weight, ), axis='error-curve', plot_type='trajectory+', xlabel='MFlops', ylabel='error') dbplot((current_flop_counts / 1e6, current_class_error), '%s class-curve' % (comp_weight, ), axis='class-curve', plot_type='trajectory+', xlabel='MFlops', ylabel='class-error') f_train(training_minibatch) optimized_results['lambda=%.3g' % (comp_weight, )] = get_mnist_results_with_parameters( weights=weights, biases=biases, scales=ks, hidden_activations=hidden_activations, output_activation=output_activation, smoothing_steps=smoothing_steps) return optimized_results