예제 #1
0
def loss_lookahead_diff(model: NeuralTeleportationModel, data: Tensor, target: Tensor,
                        metrics: TrainingMetrics, config: OptimalTeleportationTrainingConfig, **kwargs) -> Number:
    # Save the state of the model, prior to performing the lookahead
    state_dict = model.state_dict()

    # Initialize a new optimizer to perform lookahead
    optimizer = get_optimizer_from_model_and_config(model, config)
    optimizer.zero_grad()

    # Compute loss at the teleported point
    loss = torch.stack([metrics.criterion(model(data_batch), target_batch)
                        for data_batch, target_batch in zip(data, target)]).mean(dim=0)

    # Take a step using the gradient at the teleported point
    loss.backward()

    # Compute loss after the optimizer step
    lookahead_loss = torch.stack([metrics.criterion(model(data_batch), target_batch)
                                  for data_batch, target_batch in zip(data, target)]).mean(dim=0)

    # Restore the state of the model prior to the lookahead
    model.load_state_dict(state_dict)

    # Compute the difference between the lookahead loss and the original loss
    return (loss - lookahead_loss).item()
예제 #2
0
    hidden_layers = (128, 10)

    net1 = MLPCOB(input_shape=(1, 28, 28),
                  num_classes=10,
                  hidden_layers=hidden_layers).to(device)
    if args.same_init:
        net2 = deepcopy(net1)
    else:
        net2 = MLPCOB(input_shape=(1, 28, 28),
                      num_classes=10,
                      hidden_layers=hidden_layers).to(device)

    model1 = NeuralTeleportationModel(network=net1,
                                      input_shape=sample_input_shape)
    if args.weights1 is not None:
        model1.load_state_dict(torch.load(args.weights1))
    config.batch_size = 8  # Change batch size to train to different minima
    train(model1,
          train_dataset=mnist_train,
          metrics=metrics,
          config=config,
          val_dataset=mnist_test)
    torch.save(model1.state_dict(), pjoin(save_path, 'model1.pt'))
    print("Model 1 test results: ", test(model1, mnist_test, metrics, config))

    model2 = NeuralTeleportationModel(network=net2,
                                      input_shape=sample_input_shape)
    if args.weights2 is not None:
        model2.load_state_dict(torch.load(args.weights2))
    config.batch_size = 512  # Change batch size to train to different minima
    train(model2,