def loss_lookahead_diff(model: NeuralTeleportationModel, data: Tensor, target: Tensor, metrics: TrainingMetrics, config: OptimalTeleportationTrainingConfig, **kwargs) -> Number: # Save the state of the model, prior to performing the lookahead state_dict = model.state_dict() # Initialize a new optimizer to perform lookahead optimizer = get_optimizer_from_model_and_config(model, config) optimizer.zero_grad() # Compute loss at the teleported point loss = torch.stack([metrics.criterion(model(data_batch), target_batch) for data_batch, target_batch in zip(data, target)]).mean(dim=0) # Take a step using the gradient at the teleported point loss.backward() # Compute loss after the optimizer step lookahead_loss = torch.stack([metrics.criterion(model(data_batch), target_batch) for data_batch, target_batch in zip(data, target)]).mean(dim=0) # Restore the state of the model prior to the lookahead model.load_state_dict(state_dict) # Compute the difference between the lookahead loss and the original loss return (loss - lookahead_loss).item()
else: net2 = MLPCOB(input_shape=(1, 28, 28), num_classes=10, hidden_layers=hidden_layers).to(device) model1 = NeuralTeleportationModel(network=net1, input_shape=sample_input_shape) if args.weights1 is not None: model1.load_state_dict(torch.load(args.weights1)) config.batch_size = 8 # Change batch size to train to different minima train(model1, train_dataset=mnist_train, metrics=metrics, config=config, val_dataset=mnist_test) torch.save(model1.state_dict(), pjoin(save_path, 'model1.pt')) print("Model 1 test results: ", test(model1, mnist_test, metrics, config)) model2 = NeuralTeleportationModel(network=net2, input_shape=sample_input_shape) if args.weights2 is not None: model2.load_state_dict(torch.load(args.weights2)) config.batch_size = 512 # Change batch size to train to different minima train(model2, train_dataset=mnist_train, metrics=metrics, config=config, val_dataset=mnist_test) torch.save(model2.state_dict(), pjoin(save_path, 'model2.pt')) print("Model 2 test results: ", test(model2, mnist_test, metrics, config))