Exemple #1
0
 def training_step(self, batch, batch_idx) -> TrainResult:
     output = self.forward(**batch)
     result = TrainResult(
         minimize=output.loss,
     )
     result.log("train_loss", output.loss.detach(), prog_bar=True, on_epoch=True)
     return result
Exemple #2
0
 def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
     """
     Full loop flow train step (result obj + dp)
     """
     x, y = batch
     x = x.view(x.size(0), -1)
     y_hat = self(x.to(self.device))
     loss_val = y_hat.sum()
     result = TrainResult(minimize=loss_val)
     result.log('train_step_metric', loss_val + 1)
     self.training_step_called = True
     return result
    def training_step(self, batch, batch_idx):
        x, y = batch

        y_hat, _ = self(x, batch_idx)
        loss = F.cross_entropy(y_hat, y)

        return TrainResult(loss)
Exemple #4
0
    def generator_step(self, batch, val=False):
        loss_G_adv_multiscale = (  # also calls generator forward
            self.multiscale_adversarial_loss(batch, for_discriminator=False) *
            self.wt_multiscale)
        loss_G_adv_temporal = (
            self.temporal_adversarial_loss(batch, for_discriminator=False) *
            self.wt_temporal)
        ground_truth = batch["image"][:, -1, :, :, :]
        fake_frame = self.all_gen_frames[:, -1, :, :, :]
        loss_G_l1 = self.criterion_l1(fake_frame, ground_truth) * self.wt_l1
        loss_G_vgg = self.criterion_VGG(fake_frame, ground_truth) * self.wt_vgg

        loss_G = loss_G_l1 + loss_G_vgg + loss_G_adv_multiscale + loss_G_adv_temporal

        # Log
        val_ = "val_" if val else ""
        result = (EvalResult(checkpoint_on=loss_G_l1 +
                             loss_G_vgg) if val else TrainResult(loss_G))
        result.log(f"{val_}loss", loss_G)
        result.log(f"{val_}loss/G/adv_multiscale",
                   loss_G_adv_multiscale,
                   prog_bar=True)
        result.log(f"{val_}loss/G/adv_temporal",
                   loss_G_adv_temporal,
                   prog_bar=True)
        result.log(f"{val_}loss/G/l1+vgg", loss_G_l1 + loss_G_vgg)
        result.log(f"{val_}loss/G/l1", loss_G_l1)
        result.log(f"{val_}loss/G/vgg", loss_G_vgg)
        return result
    def training_step(self, batch, idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        c = batch["cloth"]
        im_c = batch["im_cloth"]
        im_g = batch["grid_vis"]
        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward
        grid, theta = self.forward(person_inputs, cloth_inputs)
        self.warped_cloth = F.grid_sample(c, grid, padding_mode="border")
        self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros")
        # loss
        loss = F.l1_loss(self.warped_cloth, im_c)

        # Logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        return result
Exemple #6
0
    def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None):
        """Lightning calls this inside the training loop"""
        # forward pass
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        # calculate loss
        loss_val = self.loss(y, y_hat)

        # call metric
        val = self.metric(x, y)

        result = TrainResult(minimize=loss_val)
        result.log('metric_val', val)
        return result
    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.latent_dim)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}
            result = TrainResult(minimize=g_loss, checkpoint_on=True)
            result.log_dict(tqdm_dict)

            return result

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            result = TrainResult(minimize=d_loss, checkpoint_on=True)
            result.log_dict(tqdm_dict)

            return result
Exemple #8
0
 def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx):
     """
     Early stop and checkpoint only on these values
     """
     acc = self.step(batch, batch_idx)
     result = TrainResult(minimize=acc)
     assert 'early_step_on' not in result
     assert 'checkpoint_on' in result
     return result
 def training_step_no_callbacks_result_obj(self, batch, batch_idx):
     """
     Early stop and checkpoint only on these values
     """
     acc = self.step(batch, batch_idx)
     result = TrainResult(minimize=acc, checkpoint_on=False)
     assert 'early_step_on' not in result
     assert 'checkpoint_on' not in result
     return result
Exemple #10
0
    def training_step_result_log_step_only(self, batch, batch_idx):
        acc = self.step(batch, batch_idx)
        result = TrainResult(minimize=acc)

        # step only metrics
        result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True)
        result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc))
        result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True)

        self.training_step_called = True
        return result
Exemple #11
0
    def training_step_result_log_epoch_only(self, batch, batch_idx):
        acc = self.step(batch, batch_idx)
        result = TrainResult(minimize=acc)

        result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc),
                   on_epoch=True, prog_bar=True, on_step=False)
        result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc),
                   on_epoch=True, on_step=False)
        result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc),
                   on_epoch=True, logger=False, prog_bar=True, on_step=False)

        self.training_step_called = True
        return result
Exemple #12
0
    def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx):
        """
        Early stop and checkpoint only on these values
        """
        acc = self.step(batch, batch_idx)

        self.assert_backward = False
        losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20]
        idx = self.current_epoch
        loss = acc + losses[idx]
        result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss)
        return result
Exemple #13
0
    def training_step_result_log_epoch_and_step(self, batch, batch_idx):
        acc = self.step(batch, batch_idx)
        result = TrainResult(minimize=acc)

        val_1 = (5 + batch_idx) * (self.current_epoch + 1)
        val_2 = (6 + batch_idx) * (self.current_epoch + 1)
        val_3 = (7 + batch_idx) * (self.current_epoch + 1)
        result.log(f'step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc),
                   on_epoch=True, prog_bar=True)
        result.log(f'step_epoch_log_acc2', torch.tensor(val_2).type_as(acc),
                   on_epoch=True)
        result.log(f'step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc),
                   on_epoch=True, logger=False, prog_bar=True)

        self.training_step_called = True
        return result
Exemple #14
0
    def multiscale_discriminator_step(self, batch):
        loss_D, loss_D_real, loss_D_fake = self.multiscale_adversarial_loss(
            batch, for_discriminator=True)

        result = TrainResult(loss_D)
        result.log("loss/D/multi", loss_D, prog_bar=True)
        result.log("loss/D/multi_fake", loss_D_fake)
        result.log("loss/D/multi_real", loss_D_real)
        return result
Exemple #15
0
    def temporal_discriminator_step(self, batch):
        loss_D, loss_D_real, loss_D_fake = self.temporal_adversarial_loss(
            batch, for_discriminator=True)

        result = TrainResult(loss_D)
        result.log("loss/D/temporal", loss_D, prog_bar=True)
        result.log("loss/D/temporal_fake", loss_D_fake)
        result.log("loss/D/temporal_real", loss_D_real)
        return result
Exemple #16
0
 def training_step(self, batch, batch_idx):
     images, targets = batch
     outputs = self.forward(images)
     loss = self._criterion(outputs, targets)
     acc = torch.sum(
         (targets == torch.argmax(outputs, dim=1))).float() / len(targets)
     result = TrainResult(loss)
     result.log("train/loss", loss, on_step=False, on_epoch=True)
     result.log("train/acc", acc, on_step=False, on_epoch=True)
     return result
Exemple #17
0
    def training_step_result_obj(self, batch, batch_idx, optimizer_idx=None):
        # forward pass
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        # calculate loss
        loss_val = self.loss(y, y_hat)
        log_val = loss_val

        # alternate between tensors and scalars for "log" and "progress_bar"
        if batch_idx % 2 == 0:
            log_val = log_val.item()

        result = TrainResult(loss_val)
        result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
        result.log('train_some_val', log_val * log_val)
        return result
Exemple #18
0
    def training_step(self, batch, batch_idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        im = batch["image"]
        prev_im = batch["prev_image"]
        cm = batch["cloth_mask"]
        flow = batch["flow"] if self.hparams.flow_warp else None

        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward. save outputs to self for visualization
        (
            self.p_rendereds,
            self.tryon_masks,
            self.p_tryons,
            self.flow_masks,
        ) = self.forward(person_inputs, cloth_inputs, flow, prev_im)
        self.p_tryons = torch.chunk(self.p_tryons,
                                    self.hparams.n_frames_total,
                                    dim=1)
        self.p_rendereds = torch.chunk(self.p_rendereds,
                                       self.hparams.n_frames_total,
                                       dim=1)
        self.tryon_masks = torch.chunk(self.tryon_masks,
                                       self.hparams.n_frames_total,
                                       dim=1)

        self.flow_masks = (torch.chunk(
            self.flow_masks, self.hparams.n_frames_total, dim=1)
                           if self.flow_masks is not None else None)

        im = torch.chunk(im, self.hparams.n_frames_total, dim=1)
        cm = torch.chunk(cm, self.hparams.n_frames_total, dim=1)

        # loss
        loss_image_l1_curr = F.l1_loss(self.p_tryons[-1], im[-1])
        loss_image_l1_prev = F.l1_loss(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_l1_curr)
        loss_image_l1 = 0.5 * (
            loss_image_l1_curr + loss_image_l1_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_l1_curr

        loss_image_vgg_curr = self.criterionVGG(self.p_tryons[-1], im[-1])
        loss_image_vgg_prev = self.criterionVGG(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_vgg_curr)
        loss_image_vgg = 0.5 * (
            loss_image_vgg_curr + loss_image_vgg_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_vgg_curr

        loss_tryon_mask_curr = F.l1_loss(self.tryon_masks[-1], cm[-1])
        loss_tryon_mask_prev = F.l1_loss(
            self.tryon_masks[-2],
            cm[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_tryon_mask_curr)
        loss_tryon_mask_l1 = 0.5 * (
            loss_tryon_mask_curr + loss_tryon_mask_prev
        ) if self.hparams.n_frames_total > 1 else loss_tryon_mask_curr

        loss_flow_mask_l1 = (
            self.flow_masks[-1].sum() if self.flow_masks is not None else torch
            .zeros_like(loss_tryon_mask_curr)) * self.hparams.pen_flow_mask

        loss = loss_image_l1 + loss_image_vgg + loss_tryon_mask_l1 + loss_flow_mask_l1

        # logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        result.log(f"{val_}loss/G/l1", loss_image_l1, prog_bar=True)
        result.log(f"{val_}loss/G/vgg", loss_image_vgg, prog_bar=True)
        result.log(f"{val_}loss/G/tryon_mask_l1",
                   loss_tryon_mask_l1,
                   prog_bar=True)
        result.log(f"{val_}loss/G/flow_mask_l1",
                   loss_flow_mask_l1,
                   prog_bar=True)

        if self.hparams.n_frames_total > 1:
            ## visualize prev frames losses
            result.log(f"{val_}loss/G/l1_prev", loss_image_l1_prev)
            result.log(f"{val_}loss/G/vgg_prev", loss_image_vgg_prev)
            result.log(f"{val_}loss/G/tryon_mask_prev", loss_tryon_mask_prev)

            ## visualize curr frames losses
            result.log(f"{val_}loss/G/l1_curr", loss_image_l1_curr)
            result.log(f"{val_}loss/G/vgg_curr", loss_image_vgg_curr)
            result.log(f"{val_}loss/G/tryon_mask_curr", loss_tryon_mask_curr)
        #from IPython import embed; embed()
        return result
    def training_step(self, batch, batch_idx):
        step_results = self._step(batch)

        result = TrainResult(step_results["loss"])
        self.log_metrics(result, step_results, "train")
        return result