Пример #1
0
def test_dataset_loss_with_optimizer_and_regularization():
    # Test manual batch processing vs. dataset_loss during regularized training
    # Make a simple model and test that a manual on-the-fly loss calculation
    # approximately matches the one from dataset_loss when given an optimizer
    # and regularization function

    # Set up the network
    num_epochs = 5

    # Empty lists to be compared after training
    epochal_train_losses_manual = []
    epochal_train_losses_dataset = []

    # We require two models and two optimizers to keep things separate
    # The architectures MUST be deep copied or else they are tethered
    # to each other
    model_manual = CGnet(copy.deepcopy(arch), ForceLoss()).float()
    model_dataset = CGnet(copy.deepcopy(arch), ForceLoss()).float()

    optimizer_manual = torch.optim.Adam(model_manual.parameters(), lr=1e-5)
    optimizer_dataset = torch.optim.Adam(model_dataset.parameters(), lr=1e-5)

    # We want a nonrandom loader so we can compare the losses at the end
    nonrandom_loader = DataLoader(dataset, batch_size=batch_size)

    for epoch in range(1, num_epochs + 1):
        train_loss_manual = 0.0
        train_loss_dataset = 0.0

        # This is the manual part
        effective_batch_num = 0

        for batch_num, batch_data in enumerate(nonrandom_loader):
            optimizer_manual.zero_grad()
            coord, force, embedding_property = batch_data

            if batch_num == 0:
                ref_batch_size = coord.numel()

            batch_weight = coord.numel() / ref_batch_size

            energy, pred_force = model_manual.forward(coord,
                                                      embedding_property)

            batch_loss = model_manual.criterion(pred_force, force)
            batch_loss.backward()
            optimizer_manual.step()

            lipschitz_projection(model_manual, strength=lipschitz_strength)

            train_loss_manual += batch_loss.detach().cpu() * batch_weight
            effective_batch_num += batch_weight

        train_loss_manual = train_loss_manual / effective_batch_num
        epochal_train_losses_manual.append(train_loss_manual.numpy())

        # This is the dataset loss part
        train_loss_dataset = dataset_loss(model_dataset, nonrandom_loader,
                                          optimizer_dataset,
                                          _regularization_function)
        epochal_train_losses_dataset.append(train_loss_dataset)

    np.testing.assert_allclose(epochal_train_losses_manual,
                               epochal_train_losses_dataset,
                               rtol=1e-4)