def demo_mnist_logreg(minibatch_size=20, learning_rate=0.01, max_training_samples = None, n_epochs=10, test_epoch_period=0.2): """ Train a Logistic Regressor on the MNIST dataset, report training/test scores throughout training, and return the final scores. """ x_train, y_train, x_test, y_test = get_mnist_dataset(flat=True, n_training_samples=max_training_samples).xyxy predictor = OnlineLogisticRegressor(n_in=784, n_out=10, learning_rate=learning_rate) # Train and periodically record scores. epoch_scores = [] for ix, iteration_info in minibatch_index_info_generator(n_samples = len(x_train), minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=('every', test_epoch_period)): if iteration_info.test_now: training_error = 100*(np.argmax(predictor.predict(x_train), axis=1)==y_train).mean() test_error = 100*(np.argmax(predictor.predict(x_test), axis=1)==y_test).mean() print('Epoch {epoch}: Test Error: {test}%, Training Error: {train}%'.format(epoch=iteration_info.epoch, test=test_error, train=training_error)) epoch_scores.append((iteration_info.epoch, training_error, test_error)) predictor.train(x_train[ix], y_train[ix]) # Plot plt.figure() epochs, training_costs, test_costs = zip(*epoch_scores) plt.plot(epochs, np.array([training_costs, test_costs]).T) plt.xlabel('Epoch') plt.ylabel('% Error') plt.legend(['Training Error', 'Test Error']) plt.title("Learning Curve") plt.ylim(80, 100) plt.ion() # Don't hang on plot plt.show() return {'train': epoch_scores[-1][1], 'test': epoch_scores[-1][2]} # Return final scores
def experiment_mnist_eqprop( layer_constructor, n_epochs=10, hidden_sizes=(500, ), minibatch_size=20, beta=.5, random_flip_beta=True, learning_rate=.05, n_negative_steps=20, n_positive_steps=4, initial_weight_scale=1., online_checkpoints_period=None, epoch_checkpoint_period=.25, skip_zero_epoch_test=False, n_test_samples=None, prop_direction: Union[str, Tuple] = 'neutral', bidirectional=True, renew_activations=True, do_fast_forward_pass=False, rebuild_coders=True, l2_loss=None, splitstream=True, seed=1234, ): """ Replicate the results of Scellier & Bengio: Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full Specifically, the train_model demo here: https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop Differences between our code and theirs: - We do not keep persistent layer activations tied to data points over epochs. So our results should only really match for the first epoch. - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None) """ print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot() x_train, y_train = dataset.training_set.xy x_test, y_test = dataset.test_set.xy # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100] n_epochs = 1 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] rng = get_rng(rng) y_train = y_train.astype(np.float32) ra = RunningAverage() sp = Speedometer(mode='last') is_online_checkpoint = Checkpoints( online_checkpoints_period, skip_first=skip_zero_epoch_test ) if online_checkpoints_period is not None else lambda: False is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) results = Duck() training_states = initialize_states( layer_constructor=layer_constructor, n_samples=minibatch_size, params=initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng)) if isinstance(prop_direction, str): fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction else: fwd_prop_direction, backward_prop_direction = prop_direction for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.shape[0], minibatch_size=minibatch_size, n_epochs=n_epochs)): epoch = i * minibatch_size / x_train.shape[0] if is_epoch_checkpoint(epoch): n_samples = n_test_samples if n_test_samples is not None else len( x_test) y_pred_test, y_pred_train = [ run_inference( x_data=x[:n_test_samples], states=initialize_states( layer_constructor=layer_constructor, params=[s.params for s in training_states], n_samples=min(len(x), n_test_samples) if n_test_samples is not None else len(x)), n_steps=n_negative_steps, prop_direction=fwd_prop_direction, ) for x in (x_test, x_train) ] # y_pred_train = run_inference(x_data=x_train[:n_test_samples], states=initialize_states(params=[s.params for s in training_states], n_samples=min(len(x_train), n_test_samples) if n_test_samples is not None else len(x_train))) test_error = percent_argmax_incorrect(y_pred_test, y_test[:n_test_samples]) train_error = percent_argmax_incorrect(y_pred_train, y_train[:n_test_samples]) print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Mean Rate: {sp(i):.3g}iter/s' ) results[next, :] = dict(iter=i, epoch=epoch, train_error=train_error, test_error=test_error) yield results if epoch > 2 and train_error > 50: return # The Original training loop, just taken out here: x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] training_states = run_eqprop_training_update( x_data=x_data_sample, y_data=y_data_sample, layer_states=training_states, beta=beta, random_flip_beta=random_flip_beta, learning_rate=learning_rate, layer_constructor=layer_constructor, bidirectional=bidirectional, l2_loss=l2_loss, renew_activations=renew_activations, n_negative_steps=n_negative_steps, n_positive_steps=n_positive_steps, prop_direction=prop_direction, splitstream=splitstream, rng=rng) this_train_score = ra( percent_argmax_correct(output_from_state(training_states), y_train[ixs])) if is_online_checkpoint(): print( f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}' )
def experiment_mnist_eqprop_torch( layer_constructor: Callable[[int, LayerParams], IDynamicLayer], n_epochs=10, hidden_sizes=(500, ), minibatch_size=10, # update mini-batch size batch_size=500, # total batch size beta=.5, random_flip_beta=True, learning_rate=.05, n_negative_steps=120, n_positive_steps=80, initial_weight_scale=1., online_checkpoints_period=None, epoch_checkpoint_period=1.0, #'100s', #{0: .25, 1: .5, 5: 1, 10: 2, 50: 4}, skip_zero_epoch_test=False, n_test_samples=10000, prop_direction: Union[str, Tuple] = 'neutral', bidirectional=True, renew_activations=True, do_fast_forward_pass=False, rebuild_coders=True, l2_loss=None, splitstream=False, seed=1234, prediction_inp_size=17, ## prediction input size delay=18, ## delay size for the clamped phase pred=True, ## if you want to use the prediction check_flg=False, ): """ Replicate the results of Scellier & Bengio: Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation https://www.frontiersin.org/articles/10.3389/fncom.2017.00024/full Specifically, the train_model demo here: https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop Differences between our code and theirs: - We do not keep persistent layer activations tied to data points over epochs. So our results should only really match for the first epoch. - We evaluate training score periodically, rather than online average (however you can see online score by setting online_checkpoints_period to something that is not None) """ torch.manual_seed(seed) device = 'cuda' if torch.cuda.is_available( ) and USE_CUDA_WHEN_AVAILABLE else 'cpu' if device == 'cuda': torch.set_default_tensor_type(torch.cuda.FloatTensor) print(f'Using Device: {device}') print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = input_data.read_data_sets('MNIST_data', one_hot=True) x_train, y_train = torch.tensor( dataset.train.images, dtype=torch.float32 ).to(device), torch.tensor(dataset.train.labels, dtype=torch.float32).to( device ) #(torch.tensor(a.astype(np.float32)).to(device) for a in dataset.mnist.train.images.xy) x_test, y_test = torch.tensor( dataset.test.images, dtype=torch.float32).to(device), torch.tensor( dataset.test.labels, dtype=torch.float32).to( device) # Their 'validation set' is our 'test set' x_val, y_val = torch.tensor( dataset.validation.images, dtype=torch.float32).to(device), torch.tensor( dataset.validation.labels, dtype=torch.float32).to( device) # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test, x_val, y_val = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100], x_val[: 100], y_val[: 100] n_epochs = 1 n_negative_steps = 3 n_positive_steps = 3 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] ra = RunningAverage() sp = Speedometer(mode='last') is_online_checkpoint = Checkpoints( online_checkpoints_period, skip_first=skip_zero_epoch_test ) if online_checkpoints_period is not None else lambda: False is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) training_states = initialize_states( layer_constructor=layer_constructor, #n_samples=minibatch_size, n_samples=batch_size, params=initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng)) # dbplot(training_states[0].params.w_fore[:10, :10], str(rng.randint(265))) if isinstance(prop_direction, str): fwd_prop_direction, backward_prop_direction = prop_direction, prop_direction else: fwd_prop_direction, backward_prop_direction = prop_direction def do_test(): # n_samples = n_test_samples if n_test_samples is not None else len(x_test) test_error, train_error, val_error = [ percent_argmax_incorrect( run_inference( x_data=x[:n_test_samples], states=initialize_states( layer_constructor=layer_constructor, params=[s.params for s in training_states], n_samples=n_samples), n_steps=n_negative_steps, prop_direction=fwd_prop_direction, ), y[:n_samples]).item() for x, y in [(x_test, y_test), (x_train, y_train), (x_val, y_val)] for n_samples in [ min(len(x), n_test_samples ) if n_test_samples is not None else len(x) ] ] # Not an actal loop... just hack for assignment in comprehensions print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Error: {test_error:.3g}%, Train Error: {train_error:.3g}, Validation Error: {val_error:.3g}, Mean Rate: {sp(i):.3g}iter/s' ) return dict(iter=i, epoch=epoch, train_error=train_error, test_error=test_error, val_error=val_error), train_error, test_error, val_error results = Duck() pi = ProgressIndicator(expected_iterations=n_epochs * dataset.train.num_examples / minibatch_size, update_every='10s') dy_squared = [] dy_squared.append(None) dy_squared.append(None) for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.size()[0], minibatch_size=batch_size, n_epochs=n_epochs)): epoch = i * batch_size / x_train.shape[0] if is_epoch_checkpoint(epoch): check_flg = False x_train, y_train = shuffle_data(x_train, y_train) with pi.pause_measurement(): results[next, :], train_err, test_err, val_err = do_test() ## prepare for saving the parameters ws, bs = zip(*((s.params.w_aft, s.params.b) for s in training_states[1:])) f = None if os.path.isfile(directory + '/log.txt'): f = open(directory + '/log.txt', 'a') else: os.mkdir(directory) f = open(directory + '/log.txt', 'w') f.write("Epoch: " + str(epoch) + '\n') f.write("accuracy for training: " + str(train_err) + '\n') f.write("accuracy for testing: " + str(test_err) + '\n') f.write("accuracy for validation: " + str(val_err) + '\n') f.close() np.save(directory + '/w_epoch_' + str(epoch) + '.npy', ws) np.save(directory + '/b_epoch_' + str(epoch) + '.npy', bs) np.save(directory + '/dy_squared_epoch_' + str(epoch) + '.npy', dy_squared) yield results if epoch > 100 and results[-1, 'train_error'] > 50: return # The Original training loop, just taken out here: ixs = ixs.astype(np.int32) # this is for python version 3.7 x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] training_states, dy_squared = run_eqprop_training_update( x_data=x_data_sample, y_data=y_data_sample, layer_states=training_states, beta=beta, random_flip_beta=random_flip_beta, learning_rate=learning_rate, layer_constructor=layer_constructor, bidirectional=bidirectional, l2_loss=l2_loss, renew_activations=renew_activations, n_negative_steps=n_negative_steps, n_positive_steps=n_positive_steps, prop_direction=prop_direction, splitstream=splitstream, rng=rng, prediction_inp_size=prediction_inp_size, delay=delay, device=device, epoch_check=check_flg, epoch=epoch, pred=pred, batch_size=batch_size, minibatch_size=minibatch_size, dy_squared=dy_squared) check_flg = False this_train_score = ra( percent_argmax_incorrect(output_from_state(training_states), y_train[ixs])) if is_online_checkpoint(): print( f'Epoch {epoch:.3g}: Iter {i}: Score {this_train_score:.3g}%: Mean Rate: {sp(i):.2g}' ) pi.print_update(info=f'Epoch: {epoch}') results[next, :], train_err, test_err, val_err = do_test() yield results
def demo_energy_based_initialize_eq_prop_fwd_energy( n_epochs=25, hidden_sizes=(500, ), minibatch_size=20, beta=0.5, epsilon=0.5, learning_rate=(0.1, .05), n_negative_steps=20, n_positive_steps=4, initial_weight_scale=1., forward_deviation_cost=0., zero_deviation_cost=1, epoch_checkpoint_period={ 0: .25, 1: .5, 5: 1, 10: 2, 50: 4 }, n_test_samples=10000, skip_zero_epoch_test=False, train_with_forward='contrast', forward_nonlinearity='rho(x)', local_loss=True, random_flip_beta=True, seed=1234, ): print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) assert train_with_forward in ('contrast', 'contrast+', 'energy', False) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot() x_train, y_train = dataset.training_set.xy x_test, y_test = dataset.test_set.xy # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100] n_epochs = 1 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] eq_params = initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) forward_params = initialize_params( layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) y_train = y_train.astype(np.float32) sp = Speedometer(mode='last') is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) f_negative_eq_step = equilibriating_step.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_inference_eq_step = equilibriating_step.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_positive_eq_step = equilibriating_step.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_parameter_update = update_eq_params.partial( forward_deviation_cost=forward_deviation_cost, zero_deviation_cost=zero_deviation_cost).compile() f_forward_pass = forward_pass.partial( nonlinearity=forward_nonlinearity).compile() # f_forward_parameter_update = update_forward_params_with_energy.compile() f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial( nonlinearity=forward_nonlinearity).compile() def do_inference(forward_params_, eq_params_, x, n_steps): states_ = forward_states_ = f_forward_pass( x=x, params=forward_params_ ) if train_with_forward else initialize_states( n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:]) for _ in range(n_steps): states_ = f_inference_eq_step(params=eq_params_, states=states_, fwd_states=forward_states_, x=x, epsilon=epsilon) return forward_states_[-1], states_[-1] results = Duck() # last_time, last_epoch = time(), -1 for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.shape[0], minibatch_size=minibatch_size, n_epochs=n_epochs)): epoch = i * minibatch_size / x_train.shape[0] # print(f'Training Rate: {(time()-last_time)/(epoch-last_epoch):3g}s/ep') # last_time, last_epoch = time(), epoch if is_epoch_checkpoint(epoch): n_samples = n_test_samples if n_test_samples is not None else len( x_test) (test_init_error, test_neg_error), (train_init_error, train_neg_error) = [[ percent_argmax_incorrect(prediction, y[:n_test_samples]) for prediction in do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps) ] for x, y in [(x_test, y_test), (x_train, y_train)]] print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s' ) results[next, :] = dict(iter=i, epoch=epoch, test_init_error=test_init_error, test_neg_error=test_neg_error, train_init_error=train_init_error, train_neg_error=train_neg_error) yield results if epoch > 2 and train_neg_error > 50: return # The Original training loop, just taken out here: x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] states = forward_states = f_forward_pass( x=x_data_sample, params=forward_params ) if train_with_forward else initialize_states( n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:]) for t in range(n_negative_steps): states = f_negative_eq_step(params=eq_params, states=states, x=x_data_sample, epsilon=epsilon, fwd_states=forward_states) negative_states = states this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta for t in range(n_positive_steps): states = f_positive_eq_step(params=eq_params, states=states, x=x_data_sample, y=y_data_sample, y_pressure=this_beta, epsilon=epsilon, fwd_states=forward_states) positive_states = states eq_params = f_parameter_update(x=x_data_sample, params=eq_params, negative_states=negative_states, positive_states=positive_states, fwd_states=forward_states, learning_rates=learning_rate, beta=this_beta) if train_with_forward == 'contrast': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=learning_rate) # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate]) elif train_with_forward == 'contrast+': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=positive_states, learning_rates=learning_rate) # elif train_with_forward == 'energy': # forward_params = f_forward_parameter_update(x=x_data_sample, forward_params = forward_params, eq_params=eq_params, learning_rates=learning_rate) else: assert train_with_forward is False
def demo_energy_based_initialize_eq_prop_alignment( n_epochs=25, hidden_sizes=(500, ), minibatch_size=20, beta=0.5, epsilon=0.5, learning_rate=(0.1, .05), n_negative_steps=20, n_positive_steps=4, initial_weight_scale=1., epoch_checkpoint_period={ 0: .25, 1: .5, 5: 1, 10: 2, 50: 4 }, n_test_samples=10000, skip_zero_epoch_test=False, train_with_forward='contrast', forward_nonlinearity='rho(x)', local_loss=True, random_flip_beta=True, seed=1234, ): print('Params:\n' + '\n'.join(list(f' {k} = {v}' for k, v in locals().items()))) assert train_with_forward in ('contrast', 'contrast+', 'energy', False) rng = get_rng(seed) n_in = 784 n_out = 10 dataset = get_mnist_dataset(flat=True, n_test_samples=None).to_onehot() x_train, y_train = dataset.training_set.xy x_test, y_test = dataset.test_set.xy # Their 'validation set' is our 'test set' if is_test_mode(): x_train, y_train, x_test, y_test = x_train[: 100], y_train[: 100], x_test[: 100], y_test[: 100] n_epochs = 1 layer_sizes = [n_in] + list(hidden_sizes) + [n_out] eq_params = initialize_params(layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) forward_params = initialize_params( layer_sizes=layer_sizes, initial_weight_scale=initial_weight_scale, rng=rng) y_train = y_train.astype(np.float32) sp = Speedometer(mode='last') is_epoch_checkpoint = Checkpoints(epoch_checkpoint_period, skip_first=skip_zero_epoch_test) f_negative_eq_step = equilibriating_step.compile() f_inference_eq_step = equilibriating_step.compile() f_positive_eq_step = equilibriating_step.compile() f_parameter_update = update_eq_params.compile() f_forward_pass = forward_pass.partial( nonlinearity=forward_nonlinearity).compile() f_forward_parameter_update = update_forward_params_with_energy.partial( disconnect_grads=local_loss, nonlinearity=forward_nonlinearity).compile() f_forward_parameter_contrast_update = update_forward_params_with_contrast.partial( disconnect_grads=local_loss, nonlinearity=forward_nonlinearity).compile() f_energy = energy.compile() f_grad_align = compute_gradient_alignment.partial( nonlinearity=forward_nonlinearity).compile() def do_inference(forward_params_, eq_params_, x, n_steps): states_ = forward_states_ = f_forward_pass( x=x, params=forward_params_ ) if train_with_forward else initialize_states( n_samples=x.shape[0], noninput_layer_sizes=layer_sizes[1:]) for _ in range(n_steps): states_ = f_inference_eq_step(params=eq_params_, states=states_, x=x, epsilon=epsilon) return forward_states_[-1], states_[-1] results = Duck() for i, (ixs, info) in enumerate( minibatch_index_info_generator(n_samples=x_train.shape[0], minibatch_size=minibatch_size, n_epochs=n_epochs)): epoch = i * minibatch_size / x_train.shape[0] if is_epoch_checkpoint(epoch): n_samples = n_test_samples if n_test_samples is not None else len( x_test) (test_init_error, test_neg_error), (train_init_error, train_neg_error) = [[ percent_argmax_incorrect(prediction, y[:n_test_samples]) for prediction in do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps) ] for x, y in [(x_test, y_test), (x_train, y_train)]] print( f'Epoch: {epoch:.3g}, Iter: {i}, Test Init Error: {test_init_error:.3g}%, Test Neg Error: {test_neg_error:.3g}%, Train Init Error: {train_init_error:.3g}%, Train Neg Error: {train_neg_error:.3g}%, , Mean Rate: {sp(i):.3g}iter/s' ) # ===== Compute alignment alignment_neg_states = forward_states_ = f_forward_pass( x=x_train, params=forward_params ) if train_with_forward else initialize_states( n_samples=x_train.shape[0], noninput_layer_sizes=layer_sizes[1:]) for _ in range(n_negative_steps): alignment_neg_states = f_inference_eq_step( params=eq_params, states=alignment_neg_states, x=x_train, epsilon=epsilon) grad_alignments = f_grad_align(forward_params=forward_params, eq_states=alignment_neg_states, x=x_train) dbplot(grad_alignments, 'alignments', plot_type=DBPlotTypes.LINE_HISTORY) # -------------------- # fwd_states, neg_states = do_inference(forward_params_=forward_params, eq_params_=eq_params, x=x[:n_test_samples], n_steps=n_negative_steps) results[next, :] = dict(iter=i, epoch=epoch, test_init_error=test_init_error, test_neg_error=test_neg_error, train_init_error=train_init_error, train_neg_error=train_neg_error, alignments=grad_alignments) yield results if epoch > 2 and train_neg_error > 50: return # The Original training loop, just taken out here: x_data_sample, y_data_sample = x_train[ixs], y_train[ixs] states = forward_states = f_forward_pass( x=x_data_sample, params=forward_params ) if train_with_forward else initialize_states( n_samples=minibatch_size, noninput_layer_sizes=layer_sizes[1:]) for t in range(n_negative_steps): # if i % 200 == 0: # with hold_dbplots(): # dbplot_collection(states, 'states', cornertext='NEG') # dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED) states = f_negative_eq_step(params=eq_params, states=states, x=x_data_sample, epsilon=epsilon) negative_states = states this_beta = rng.choice([-beta, beta]) if random_flip_beta else beta for t in range(n_positive_steps): # if i % 200 == 0: # with hold_dbplots(): # dbplot_collection(states, 'states', cornertext='') # dbplot(f_energy(params = eq_params, states=states, x=x_data_sample).mean(), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED) states = f_positive_eq_step(params=eq_params, states=states, x=x_data_sample, y=y_data_sample, y_pressure=this_beta, epsilon=epsilon) positive_states = states eq_params = f_parameter_update(x=x_data_sample, params=eq_params, negative_states=negative_states, positive_states=positive_states, learning_rates=learning_rate, beta=this_beta) # with hold_dbplots(draw_every=50): # dbplot_collection([forward_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in forward_params[1:]], '$\phi$') # dbplot_collection([eq_params[0][0][:, :16].T.reshape(-1, 28, 28)] + [w for w, b in eq_params[1:]], '$\\theta$') # dbplot_collection(forward_states, 'forward states') # dbplot_collection(negative_states, 'negative_states') # dbplot(np.array([f_energy(params = eq_params, states=forward_states, x=x_data_sample).mean(), f_energy(params = eq_params, states=negative_states, x=x_data_sample).mean()]), 'energies', plot_type=DBPlotTypes.LINE_HISTORY_RESAMPLED) if train_with_forward == 'contrast': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=learning_rate) # forward_params = f_forward_parameter_contrast_update(x=x_data_sample, forward_params=forward_params, eq_states=negative_states, learning_rates=[lr/10 for lr in learning_rate]) elif train_with_forward == 'contrast+': forward_params = f_forward_parameter_contrast_update( x=x_data_sample, forward_params=forward_params, eq_states=positive_states, learning_rates=learning_rate) elif train_with_forward == 'energy': forward_params = f_forward_parameter_update( x=x_data_sample, forward_params=forward_params, eq_params=eq_params, learning_rates=learning_rate) else: assert train_with_forward is False