def test(args, device, net, test_loader, loss_function): """ Compute the test loss and accuracy on the test dataset Args: args: command line inputs net: network test_loader (DataLoader): dataloader object with the test dataset Returns: Tuple containing: - Test accuracy - Test loss """ loss = 0 if args.classification: accuracy = 0 nb_batches = len(test_loader) with torch.no_grad(): for inputs, targets in test_loader: if args.double_precision: inputs, targets = inputs.double().to(device), targets.to( device) else: inputs, targets = inputs.to(device), targets.to(device) if not args.network_type == 'DDTPConv': inputs = inputs.flatten(1, -1) if args.classification and\ args.output_activation == 'sigmoid': # convert targets to one hot vectors for MSE loss: targets = utils.int_to_one_hot(targets, 10, device, soft_target=args.soft_target) predictions = net.forward(inputs) loss += loss_function(predictions, targets).item() if args.classification: if args.output_activation == 'sigmoid': accuracy += utils.accuracy(predictions, utils.one_hot_to_int(targets)) else: # softmax accuracy += utils.accuracy(predictions, targets) loss /= nb_batches if args.classification: accuracy /= nb_batches else: accuracy = None return accuracy, loss
def train_bp(args, device, train_loader, net, writer, test_loader, summary, val_loader): print('Training network ...') net.train() forward_optimizer = utils.OptimizerList(args, net) nb_batches = len(train_loader) if args.classification: if args.output_activation == 'softmax': loss_function = nn.CrossEntropyLoss() elif args.output_activation == 'sigmoid': loss_function = nn.MSELoss() else: raise ValueError('The mnist dataset can only be combined with a ' 'sigmoid or softmax output activation.') elif args.regression: loss_function = nn.MSELoss() else: raise ValueError('The provided dataset {} is not supported.'.format( args.dataset)) epoch_losses = np.array([]) epoch_reconstruction_losses = np.array([]) epoch_reconstruction_losses_var = np.array([]) test_losses = np.array([]) val_losses = np.array([]) val_loss = None val_accuracy = None if args.classification: epoch_accuracies = np.array([]) test_accuracies = np.array([]) val_accuracies = np.array([]) if args.output_space_plot: forward_optimizer.zero_grad() val_loader_iter = iter(val_loader) (inputs, targets) = val_loader_iter.next() inputs, targets = inputs.to(device), targets.to(device) if args.classification: if args.output_activation == 'sigmoid': targets = utils.int_to_one_hot(targets, 10, device, soft_target=1.) else: raise utils.NetworkError( "output space plot for classification " "tasks is only possible with sigmoid " "output layer.") utils.make_plot_output_space_bp(args, net, args.output_space_plot_layer_idx, loss_function, targets, inputs, steps=20) return summary for e in range(args.epochs): if args.classification: running_accuracy = 0 else: running_accuracy = None running_loss = 0 for i, (inputs, targets) in enumerate(train_loader): if args.double_precision: inputs, targets = inputs.double().to(device), targets.to( device) else: inputs, targets = inputs.to(device), targets.to(device) if not args.network_type == 'BPConv': inputs = inputs.flatten(1, -1) if args.classification and \ args.output_activation == 'sigmoid': # convert targets to one hot vectors for MSE loss: targets = utils.int_to_one_hot(targets, 10, device, soft_target=args.soft_target) forward_optimizer.zero_grad() predictions = net(inputs) loss = loss_function(predictions, targets) loss.backward() forward_optimizer.step() running_loss += loss.item() if args.classification: if args.output_activation == 'sigmoid': running_accuracy += utils.accuracy( predictions, utils.one_hot_to_int(targets)) else: # softmax running_accuracy += utils.accuracy(predictions, targets) test_accuracy, test_loss = test_bp(args, device, net, test_loader, loss_function) if not args.no_val_set: val_accuracy, val_loss = test_bp(args, device, net, val_loader, loss_function) epoch_loss = running_loss / nb_batches if args.classification: epoch_accuracy = running_accuracy / nb_batches else: epoch_accuracy = None print('Epoch {} -- training loss = {}.'.format(e + 1, epoch_loss)) if not args.no_val_set: print('Epoch {} -- val loss = {}.'.format(e + 1, val_loss)) print('Epoch {} -- test loss = {}.'.format(e + 1, test_loss)) if args.classification: print('Epoch {} -- training acc = {}%'.format( e + 1, epoch_accuracy * 100)) if not args.no_val_set: print('Epoch {} -- val acc = {}%'.format( e + 1, val_accuracy * 100)) print('Epoch {} -- test acc = {}%'.format(e + 1, test_accuracy * 100)) if args.save_logs: utils.save_logs(writer, step=e + 1, net=net, loss=epoch_loss, accuracy=epoch_accuracy, test_loss=test_loss, test_accuracy=test_accuracy, val_loss=val_loss, val_accuracy=val_accuracy) epoch_losses = np.append(epoch_losses, epoch_loss) test_losses = np.append(test_losses, test_loss) if not args.no_val_set: val_losses = np.append(val_losses, val_loss) if args.classification: epoch_accuracies = np.append(epoch_accuracies, epoch_accuracy) test_accuracies = np.append(test_accuracies, test_accuracy) if not args.no_val_set: val_accuracies = np.append(val_accuracies, val_accuracy) utils.save_summary_dict(args, summary) if e > 4: # stop unpromising runs if args.dataset in ['mnist', 'fashion_mnist']: if epoch_accuracy < 0.4: # error code to indicate pruned run print('writing error code -1') summary['finished'] = -1 break if args.dataset in ['cifar10']: if epoch_accuracy < 0.25: # error code to indicate pruned run print('writing error code -1') summary['finished'] = -1 break if not args.epochs == 0: # save training summary results in summary dict summary['loss_train_last'] = epoch_loss summary['loss_test_last'] = test_loss summary['loss_train_best'] = epoch_losses.min() summary['loss_test_best'] = test_losses.min() summary['loss_train'] = epoch_losses summary['loss_test'] = test_losses if not args.no_val_set: summary['loss_val_last'] = val_loss summary['loss_val_best'] = val_losses.min() summary['loss_val'] = val_losses # pick the epoch with best validation loss and save the corresponding # test loss best_epoch = val_losses.argmin() summary['epoch_best_loss'] = best_epoch summary['loss_test_val_best'] = \ test_losses[best_epoch] summary['loss_train_val_best'] = \ epoch_losses[best_epoch] if args.classification: summary['acc_train_last'] = epoch_accuracy summary['acc_test_last'] = test_accuracy summary['acc_train_best'] = epoch_accuracies.max() summary['acc_test_best'] = test_accuracies.max() summary['acc_train'] = epoch_accuracies summary['acc_test'] = test_accuracies if not args.no_val_set: summary['acc_val'] = val_accuracies summary['acc_val_last'] = val_accuracy summary['acc_val_best'] = val_accuracies.max() # pick the epoch with best validation acc and save the corresponding # test acc best_epoch = val_accuracies.argmax() summary['epoch_best_acc'] = best_epoch summary['acc_test_val_best'] = \ test_accuracies[best_epoch] summary['acc_train_val_best'] = \ epoch_accuracies[best_epoch] utils.save_summary_dict(args, summary) print('Training network ... Done') return summary
def train_separate(args, train_var, device, train_loader, net, writer): """ Train the given network on the given training dataset with DTP. For each epoch, first the feedback weights are trained on the whole epoch, after which the forward weights are trained on the same epoch (similar to Lee2105) Args: args (Namespace): The command-line arguments. train_var (Namespace): Structure containing training variables device: The PyTorch device to be used train_loader (torch.utils.data.DataLoader): The data handler for training data net (LeeDTPNetwork): The neural network writer (SummaryWriter): TensorboardX summary writer to save logs test_loader (DataLoader): The data handler for the test data """ # Train feedback parameters on whole training batch if not args.freeze_fb_weights: for i, (inputs, targets) in enumerate(train_loader): # print(" train fb: ", i) if args.double_precision: inputs, targets = inputs.double().to(device), targets.to( device) else: inputs, targets = inputs.to(device), targets.to(device) if not args.network_type == 'DDTPConv': inputs = inputs.flatten(1, -1) predictions = net.forward(inputs) train_feedback_parameters(args, net, train_var.feedback_optimizer) train_var.reconstruction_losses = np.append( train_var.reconstruction_losses, net.get_av_reconstruction_loss()) if args.save_logs and i % args.log_interval == 0: utils.save_feedback_batch_logs(args, writer, train_var.batch_idx_fb, net) train_var.batch_idx_fb += 1 # Train forward parameters on whole training batch for i, (inputs, targets) in enumerate(train_loader): if args.double_precision: inputs, targets = inputs.double().to(device), targets.to(device) else: inputs, targets = inputs.to(device), targets.to(device) if not args.network_type == 'DDTPConv': inputs = inputs.flatten(1, -1) if args.classification and \ args.output_activation == 'sigmoid': # convert targets to one hot vectors for MSE loss: targets = utils.int_to_one_hot(targets, 10, device, soft_target=args.soft_target) predictions = net.forward(inputs) # print(predictions.shape, targets.shape) train_var.batch_accuracy, train_var.batch_loss = \ train_forward_parameters(args, net, predictions, targets, train_var.loss_function, train_var.forward_optimizer) if args.classification: train_var.accuracies = np.append(train_var.accuracies, train_var.batch_accuracy) train_var.losses = np.append(train_var.losses, train_var.batch_loss.item()) for l, layer in enumerate(net.layers): loss_rec = layer.reconstruction_loss if loss_rec is not None and args.plots is not None: net.reconstruction_loss.at[train_var.epochs, l] = loss_rec if args.save_logs and i % args.log_interval == 0: utils.save_forward_batch_logs(args, writer, train_var.batch_idx, net, train_var.batch_loss, predictions) train_var.batch_idx += 1 # update the forward parameters if not args.freeze_forward_weights: if args.train_randomized: # Fixme: correct if-else statement to include randomized updates raise NotImplementedError( 'The randomized version of the algorithms' 'is not yet implemented. Select the ' 'correct layer to optimize with ' 'forward_optimizer.step(i).') train_var.forward_optimizer.step()
def train(args, device, train_loader, net, writer, test_loader, summary, val_loader): """ Train the given network on the given training dataset with DTP. Args: args (Namespace): The command-line arguments. device: The PyTorch device to be used train_loader (torch.utils.data.DataLoader): The data handler for training data net (DTPNetwork): The neural network writer (SummaryWriter): TensorboardX summary writer to save logs test_loader (DataLoader): The data handler for the test data summary (dict): summary dictionary with the performance measures of the training and testing val_loader (torch.utils.data.DataLoader): The data handler for the validation data """ print('Training network ...') net.train() if args.save_weights: forward_parameters = net.get_forward_parameter_list() filename = os.path.join(args.out_dir, 'weights.pickle') with open(filename, 'wb') as f: pickle.dump(forward_parameters, f) if args.load_weights: filename = os.path.join(args.out_dir, 'weights.pickle') forward_parameters_loaded = pickle.load(open(filename, 'rb')) for i in range(len(forward_parameters_loaded)): net.layers[i]._weights = forward_parameters_loaded[i] # Simple struct that contains the relevant training variables train_var = Namespace() train_var.summary = summary train_var.forward_optimizer, train_var.feedback_optimizer = \ utils.choose_optimizer(args, net) if args.classification: if args.output_activation == 'softmax': train_var.loss_function = nn.CrossEntropyLoss() elif args.output_activation == 'sigmoid': train_var.loss_function = nn.MSELoss() else: raise ValueError('The mnist dataset can only be combined with a ' 'sigmoid or softmax output activation.') elif args.regression: train_var.loss_function = nn.MSELoss() else: raise ValueError('The provided dataset {} is not supported.'.format( args.dataset)) train_var.batch_idx = 1 train_var.batch_idx_fb = 1 train_var.init_idx = 1 train_var.epoch_losses = np.array([]) train_var.epoch_reconstruction_losses = np.array([]) train_var.epoch_reconstruction_losses_var = np.array([]) train_var.test_losses = np.array([]) train_var.val_losses = np.array([]) train_var.val_loss = None train_var.val_accuracy = None if args.classification: train_var.epoch_accuracies = np.array([]) train_var.test_accuracies = np.array([]) train_var.val_accuracies = np.array([]) if args.epochs_fb == 0 or args.freeze_fb_weights: print("No initial training of feedback weights.") else: print('Training the feedback weights ...') av_reconstruction_loss_init = -1 train_var.summary['rec_loss_first'] = -1 train_var.epochs_init = 0 for e_fb in range(args.epochs_fb): # Train the feedback weights before starting the real training, such # that they are aligned with the pseudoinverse of the forward weights. train_var.epochs_init = e_fb train_var.reconstruction_losses_init = np.array([]) train_only_feedback_parameters(args, train_var, device, train_loader, net, writer) av_reconstruction_loss_init = np.mean( train_var.reconstruction_losses_init) if e_fb == 0: train_var.summary[ 'rec_loss_first'] = av_reconstruction_loss_init print('init epoch {}, reconstruction loss: {}'.format( e_fb + 1, av_reconstruction_loss_init)) # save the final reconstruction loss of the initialization process print( f'Initialization feedback weights done after {args.epochs_fb} epochs.' ) print(f'Reconstruction loss: {av_reconstruction_loss_init}') train_var.summary['rec_loss_init'] = av_reconstruction_loss_init train_var.summary['rec_loss_init_combined'] = \ 0.5 * (train_var.summary['rec_loss_init'] + \ train_var.summary['rec_loss_first']) if args.train_only_feedback_parameters: print('Terminating training') return train_var.summary if args.output_space_plot: train_var.forward_optimizer.zero_grad() val_loader_iter = iter(val_loader) (inputs, targets) = val_loader_iter.next() inputs, targets = inputs.to(device), targets.to(device) if args.classification: if args.output_activation == 'sigmoid': targets = utils.int_to_one_hot(targets, 10, device, soft_target=1.) else: raise utils.NetworkError( "output space plot for classification " "tasks is only possible with sigmoid " "output layer.") utils.make_plot_output_space(args, net, args.output_space_plot_layer_idx, train_var.loss_function, targets, inputs, steps=20) return train_var.summary train_var.epochs = 0 for e in range(args.epochs): train_var.epochs = e if args.classification: train_var.accuracies = np.array([]) train_var.losses = np.array([]) train_var.reconstruction_losses = np.array([]) if not args.train_separate: train_parallel(args, train_var, device, train_loader, net, writer) else: train_separate(args, train_var, device, train_loader, net, writer) if not args.freeze_fb_weights: for extra_e in range(args.extra_fb_epochs): train_only_feedback_parameters(args, train_var, device, train_loader, net, writer, log=False) train_var.test_accuracy, train_var.test_loss = \ test(args, device, net, test_loader, train_var.loss_function) if not args.no_val_set: train_var.val_accuracy, train_var.val_loss = \ test(args, device, net, val_loader, train_var.loss_function) # print intermediate results train_var.epoch_loss = np.mean(train_var.losses) print('Epoch {} -- training loss = {}.'.format(e + 1, train_var.epoch_loss)) if not args.no_val_set: print('Epoch {} -- val loss = {}.'.format(e + 1, train_var.val_loss)) print('Epoch {} -- test loss = {}.'.format(e + 1, train_var.test_loss)) if args.classification: train_var.epoch_accuracy = np.mean(train_var.accuracies) print('Epoch {} -- training acc = {}%'.format( e + 1, train_var.epoch_accuracy * 100)) if not args.no_val_set: print('Epoch {} -- val acc = {}%'.format( e + 1, train_var.val_accuracy * 100)) print('Epoch {} -- test acc = {}%'.format( e + 1, train_var.test_accuracy * 100)) else: train_var.epoch_accuracy = None if args.save_logs: utils.save_logs(writer, step=e + 1, net=net, loss=train_var.epoch_loss, accuracy=train_var.epoch_accuracy, test_loss=train_var.test_loss, val_loss=train_var.val_loss, test_accuracy=train_var.test_accuracy, val_accuracy=train_var.val_accuracy) # save epoch results in summary dict train_var.epoch_losses = np.append(train_var.epoch_losses, train_var.epoch_loss) train_var.test_losses = np.append(train_var.test_losses, train_var.test_loss) if not args.no_val_set: train_var.val_losses = np.append(train_var.val_losses, train_var.val_loss) if not args.freeze_fb_weights: av_epoch_reconstruction_loss = np.mean( train_var.reconstruction_losses) var_epoch_reconstruction_loss = np.var( train_var.reconstruction_losses) train_var.epoch_reconstruction_losses = np.append( train_var.epoch_reconstruction_losses, av_epoch_reconstruction_loss) train_var.epoch_reconstruction_losses_var = np.append( train_var.epoch_reconstruction_losses_var, var_epoch_reconstruction_loss) if args.classification: train_var.epoch_accuracies = np.append(train_var.epoch_accuracies, train_var.epoch_accuracy) train_var.test_accuracies = np.append(train_var.test_accuracies, train_var.test_accuracy) if not args.no_val_set: train_var.val_accuracies = np.append(train_var.val_accuracies, train_var.val_accuracy) utils.save_summary_dict(args, train_var.summary) if e > 4 and (not args.evaluate): # stop unpromising runs if args.dataset in ['mnist', 'fashion_mnist']: if train_var.epoch_accuracy < 0.4: # error code to indicate pruned run print('writing error code -1') train_var.summary['finished'] = -1 break if args.dataset in ['cifar10']: if train_var.epoch_accuracy < 0.25: # error code to indicate pruned run print('writing error code -1') train_var.summary['finished'] = -1 break # do a small gridsearch to find the damping constant for GNT angles if e == 2: if args.gn_damping_hpsearch: print('Doing hpsearch for finding ideal GN damping constant' 'for computing the angle with GNT updates') gn_damping = gn_damping_hpsearch(args, train_var, device, train_loader, net, writer) args.gn_damping = gn_damping print('Damping constants GNT angles: {}'.format(gn_damping)) train_var.summary['gn_damping_values'] = gn_damping return train_var.summary if not args.epochs == 0: # save training summary results in summary dict train_var.summary['loss_train_last'] = train_var.epoch_loss train_var.summary['loss_test_last'] = train_var.test_loss train_var.summary['loss_train_best'] = train_var.epoch_losses.min() train_var.summary['loss_test_best'] = train_var.test_losses.min() train_var.summary['loss_train'] = train_var.epoch_losses train_var.summary['loss_test'] = train_var.test_losses train_var.summary['rec_loss'] = train_var.reconstruction_losses if not args.no_val_set: train_var.summary['loss_val_last'] = train_var.val_loss train_var.summary['loss_val_best'] = train_var.val_losses.min() train_var.summary['loss_val'] = train_var.val_losses # pick the epoch with best validation loss and save the corresponding # test loss if not args.no_val_set: best_epoch = train_var.val_losses.argmin() train_var.summary['epoch_best_loss'] = best_epoch train_var.summary['loss_test_val_best'] = \ train_var.test_losses[best_epoch] train_var.summary['loss_train_val_best'] = \ train_var.epoch_losses[best_epoch] if not args.freeze_fb_weights: train_var.summary['rec_loss_last'] = av_epoch_reconstruction_loss train_var.summary[ 'rec_loss_best'] = train_var.epoch_reconstruction_losses.min() train_var.summary['rec_loss_var_av'] = np.mean( train_var.epoch_reconstruction_losses_var) if args.classification: train_var.summary['acc_train_last'] = train_var.epoch_accuracy train_var.summary['acc_test_last'] = train_var.test_accuracy train_var.summary[ 'acc_train_best'] = train_var.epoch_accuracies.max() train_var.summary['acc_test_best'] = train_var.test_accuracies.max( ) train_var.summary['acc_train_growth'] = train_var.epoch_accuracies[ -1] - train_var.epoch_accuracies[0] train_var.summary['acc_test_growth'] = train_var.test_accuracies[ -1] - train_var.epoch_accuracies[0] train_var.summary['acc_train'] = train_var.epoch_accuracies train_var.summary['acc_test'] = train_var.test_accuracies if not args.no_val_set: train_var.summary['acc_val'] = train_var.val_accuracies train_var.summary['acc_val_last'] = train_var.val_accuracy train_var.summary[ 'acc_val_best'] = train_var.val_accuracies.max() # pick the epoch with best validation acc and save the corresponding # test acc best_epoch = train_var.val_accuracies.argmax() train_var.summary['epoch_best_acc'] = best_epoch train_var.summary['acc_test_val_best'] = \ train_var.test_accuracies[best_epoch] train_var.summary['acc_train_val_best'] = \ train_var.epoch_accuracies[best_epoch] utils.save_summary_dict(args, train_var.summary) print('Training network ... Done') return train_var.summary