def test_dataset_loss_with_optimizer_and_regularization(): # Test manual batch processing vs. dataset_loss during regularized training # Make a simple model and test that a manual on-the-fly loss calculation # approximately matches the one from dataset_loss when given an optimizer # and regularization function # Set up the network num_epochs = 5 # Empty lists to be compared after training epochal_train_losses_manual = [] epochal_train_losses_dataset = [] # We require two models and two optimizers to keep things separate # The architectures MUST be deep copied or else they are tethered # to each other model_manual = CGnet(copy.deepcopy(arch), ForceLoss()).float() model_dataset = CGnet(copy.deepcopy(arch), ForceLoss()).float() optimizer_manual = torch.optim.Adam(model_manual.parameters(), lr=1e-5) optimizer_dataset = torch.optim.Adam(model_dataset.parameters(), lr=1e-5) # We want a nonrandom loader so we can compare the losses at the end nonrandom_loader = DataLoader(dataset, batch_size=batch_size) for epoch in range(1, num_epochs + 1): train_loss_manual = 0.0 train_loss_dataset = 0.0 # This is the manual part effective_batch_num = 0 for batch_num, batch_data in enumerate(nonrandom_loader): optimizer_manual.zero_grad() coord, force, embedding_property = batch_data if batch_num == 0: ref_batch_size = coord.numel() batch_weight = coord.numel() / ref_batch_size energy, pred_force = model_manual.forward(coord, embedding_property) batch_loss = model_manual.criterion(pred_force, force) batch_loss.backward() optimizer_manual.step() lipschitz_projection(model_manual, strength=lipschitz_strength) train_loss_manual += batch_loss.detach().cpu() * batch_weight effective_batch_num += batch_weight train_loss_manual = train_loss_manual / effective_batch_num epochal_train_losses_manual.append(train_loss_manual.numpy()) # This is the dataset loss part train_loss_dataset = dataset_loss(model_dataset, nonrandom_loader, optimizer_dataset, _regularization_function) epochal_train_losses_dataset.append(train_loss_dataset) np.testing.assert_allclose(epochal_train_losses_manual, epochal_train_losses_dataset, rtol=1e-4)