Пример #1
0
    # Training Loop
    model.train()
    for iteration, batch in enumerate(tqdm(train_dataloader)):
        # Reset gradients back to zero for this iteration
        optimizer.zero_grad()
        # Move batch to device
        _, batch = batch  # Returns key, value for each Pokemon
        batch = batch.to(device)
        # Run our model & get outputs
        reconstructed = model(batch)
        # Calculate reconstruction loss
        batch_loss, loss_dict = loss.mse_ssim_loss(
            batch,
            reconstructed,
            use_sum=False,
            ssim_module=ssim_module,
            mse_weight=mse_weight,
            ssim_weight=ssim_weight,
        )
        # Backprop
        batch_loss.backward()
        # Update our optimizer parameters
        optimizer.step()
        # Add the batch's loss to the total loss for the epoch
        train_loss += loss_dict["MSE"] + loss_dict["SSIM"]

    # Validation Loop
    model.eval()
    with torch.no_grad():
        for iteration, batch in enumerate(tqdm(val_dataloader)):
            # Move batch to device
Пример #2
0
                # Backprop
                batch_loss.backward()
                # Add the batch's loss to the total loss for the epoch
                train_fusion_loss += batch_loss.item()
                train_fusion_recon_loss += loss_dict["MSE"] + loss_dict["SSIM"]
                train_fusion_kl_d += loss_dict["KL Divergence"]

            if fusion_mode == "decoder" or fusion_mode == "both":
                # Run our model & get outputs
                fusion_output = model.decoder(midpoint_embedding)
                # Calculate reconstruction loss:
                # Midpoint Embedding Output vs Original Fusion
                batch_loss, loss_dict = loss.mse_ssim_loss(
                    fusion_output,
                    fusion,
                    use_sum=fusion_use_sum,
                    ssim_module=ssim_module,
                    mse_weight=mse_weight,
                    ssim_weight=ssim_weight,
                )
                # For multiple MSE
                # For every MSE, we halve the image size
                # And take the MSE between the resulting images
                for i in range(num_fusion_mse):
                    new_size = image_size // pow(2, i + 1)
                    with torch.no_grad():
                        resized_batch = nn.functional.interpolate(
                            fusion, size=new_size, mode="bilinear")
                    resized_output = nn.functional.interpolate(fusion_output,
                                                               size=new_size,
                                                               mode="bilinear")
                    mse = loss.mse_loss(resized_output, resized_batch, use_sum)
Пример #3
0
                                                          mode="bilinear")
            resized_output = nn.functional.interpolate(reconstructed,
                                                       size=new_size,
                                                       mode="bilinear")
            mse = loss.mse_loss(resized_output, resized_batch, use_sum)
            batch_loss += mse
            loss_dict["MSE"] += mse.item()

        with torch.no_grad():
            _, teacher_logits = teacher_model(batch, return_logits=True)

        # Calculate teacher loss
        teacher_loss, teacher_loss_dict = loss.mse_ssim_loss(
            student_logits,
            teacher_logits,
            use_sum=teacher_use_sum,
            ssim_module=None,
            mse_weight=mse_weight,
            ssim_weight=ssim_weight,
        )

        # Compute Overall Loss
        batch_loss = batch_loss + (teacher_loss * temperature)
        loss_dict["MSE"] += teacher_loss_dict["MSE"]
        loss_dict["SSIM"] += teacher_loss_dict["SSIM"]
        # Backprop
        batch_loss.backward()

        # Update our optimizer parameters
        optimizer.step()

        # Add the batch's loss to the total loss for the epoch
Пример #4
0
    # Training Loop - Standard
    model.train()
    for iteration, batch in enumerate(tqdm(train_dataloader)):
        # Reset gradients back to zero for this iteration
        optimizer.zero_grad()
        # Move batch to device
        _, batch = batch  # Returns key, value for each Pokemon
        batch = batch.to(device)
        # Run our model & get outputs
        reconstructed = model(batch)
        # Calculate reconstruction loss
        batch_loss, _ = loss.mse_ssim_loss(
            reconstructed,
            batch,
            use_sum=False,
            ssim_module=ssim_module,
            mse_weight=mse_weight,
            ssim_weight=ssim_weight,
        )
        # Backprop
        batch_loss.backward()
        # Update our optimizer parameters
        optimizer.step()
        # Add the batch's loss to the total loss for the epoch
        train_loss += batch_loss.item()

    if freeze_conv_for_fusions:
        models.toggle_layer_freezing(freezable_layers, trainable=False)

    if learning_rate != fusion_learning_rate:
        optimizer = models.set_learning_rate(optimizer, fusion_learning_rate)
Пример #5
0
        optimizer.zero_grad()

        # Move batch to device
        _, features, labels = batch  # (names), (images)
        features = features.to(device)
        labels = labels.to(device)
        current_batch_size = features.shape[0]

        # Run Model & Get Output
        predictions = model(features)

        # Calculate Loss
        batch_loss, loss_dict = loss.mse_ssim_loss(
            predictions,
            labels,
            use_sum=False,
            ssim_module=ssim_module,
            mse_weight=mse_weight,
            ssim_weight=ssim_weight,
        )

        # Backprop
        batch_loss.backward()

        # Update our optimizer parameters
        optimizer.step()

        # Add the batch's loss to the total loss for the epoch
        train_loss += loss_dict["MSE"] + loss_dict["SSIM"]

    # Validation Loop
    model.eval()
 model.train()
 for iteration, batch in enumerate(tqdm(train_dataloader)):
     # Reset gradients back to zero for this iteration
     optimizer.zero_grad()
     # Move batch to device
     _, (base, fusee, fusion) = batch  # (names), (images)
     base = base.to(device)
     fusee = fusee.to(device)
     fusion = fusion.to(device)
     # Run our model & get outputs
     reconstructed = model(base, fusee)
     # Calculate reconstruction loss
     batch_loss, loss_dict = loss.mse_ssim_loss(
         reconstructed,
         fusion,
         use_sum=use_sum,
         ssim_module=ssim_module,
         mse_weight=mse_weight,
         ssim_weight=ssim_weight,
     )
     # For multiple MSE
     # For every MSE, we halve the image size
     # And take the MSE between the resulting images
     for i in range(num_mse):
         new_size = image_size // pow(2, i + 1)
         with torch.no_grad():
             resized_batch = nn.functional.interpolate(
                 fusion, size=new_size, mode="bilinear"
             )
         resized_output = nn.functional.interpolate(
             reconstructed, size=new_size, mode="bilinear"
         )