def training_step(self, batch, idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        c = batch["cloth"]
        im_c = batch["im_cloth"]
        im_g = batch["grid_vis"]
        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward
        grid, theta = self.forward(person_inputs, cloth_inputs)
        self.warped_cloth = F.grid_sample(c, grid, padding_mode="border")
        self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros")
        # loss
        loss = F.l1_loss(self.warped_cloth, im_c)

        # Logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        return result
    def validation_step_result_only_epoch_metrics(self, batch, batch_idx):
        """
        Only track epoch level metrics
        """
        acc = self.step(batch, batch_idx)
        result = EvalResult(checkpoint_on=acc, early_stop_on=acc)

        # step only metrics
        result.log('no_val_no_pbar',
                   torch.tensor(11 + batch_idx).type_as(acc),
                   prog_bar=False,
                   logger=False)
        result.log('val_step_log_acc',
                   torch.tensor(11 + batch_idx).type_as(acc),
                   prog_bar=False,
                   logger=True)
        result.log('val_step_log_pbar_acc',
                   torch.tensor(12 + batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=True)
        result.log('val_step_pbar_acc',
                   torch.tensor(13 + batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=False)

        self.validation_step_called = True
        return result
Exemple #3
0
 def test_step(self, batch, batch_idx) -> EvalResult:
     output = self.forward(**batch)
     result = EvalResult(checkpoint_on=output.loss, early_stop_on=output.loss)
     result.log(
         "test_loss",
         output.loss,
         prog_bar=True,
     )
     return result
Exemple #4
0
    def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
        """
        Full loop flow train step (result obj + dp)
        """
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x.to(self.device))
        loss_val = y_hat.sum()
        result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)

        eval_name = 'validation' if not self.trainer.testing else 'test'
        result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True)

        setattr(self, f'{eval_name}_step_called', True)
        return result
Exemple #5
0
    def generator_step(self, batch, val=False):
        loss_G_adv_multiscale = (  # also calls generator forward
            self.multiscale_adversarial_loss(batch, for_discriminator=False) *
            self.wt_multiscale)
        loss_G_adv_temporal = (
            self.temporal_adversarial_loss(batch, for_discriminator=False) *
            self.wt_temporal)
        ground_truth = batch["image"][:, -1, :, :, :]
        fake_frame = self.all_gen_frames[:, -1, :, :, :]
        loss_G_l1 = self.criterion_l1(fake_frame, ground_truth) * self.wt_l1
        loss_G_vgg = self.criterion_VGG(fake_frame, ground_truth) * self.wt_vgg

        loss_G = loss_G_l1 + loss_G_vgg + loss_G_adv_multiscale + loss_G_adv_temporal

        # Log
        val_ = "val_" if val else ""
        result = (EvalResult(checkpoint_on=loss_G_l1 +
                             loss_G_vgg) if val else TrainResult(loss_G))
        result.log(f"{val_}loss", loss_G)
        result.log(f"{val_}loss/G/adv_multiscale",
                   loss_G_adv_multiscale,
                   prog_bar=True)
        result.log(f"{val_}loss/G/adv_temporal",
                   loss_G_adv_temporal,
                   prog_bar=True)
        result.log(f"{val_}loss/G/l1+vgg", loss_G_l1 + loss_G_vgg)
        result.log(f"{val_}loss/G/l1", loss_G_l1)
        result.log(f"{val_}loss/G/vgg", loss_G_vgg)
        return result
Exemple #6
0
    def test_step_result_obj(self, batch, batch_idx, *args, **kwargs):
        """
        Default, baseline test_step
        :param batch:
        :return:
        """
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        loss_test = self.loss(y, y_hat)

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        test_acc = torch.tensor(test_acc)

        test_acc = test_acc.type_as(x)

        result = EvalResult()
        # alternate possible outputs to test
        if batch_idx % 1 == 0:
            result.log_dict({'test_loss': loss_test, 'test_acc': test_acc})
            return result
        if batch_idx % 2 == 0:
            return test_acc

        if batch_idx % 3 == 0:
            result.log_dict({'test_loss': loss_test, 'test_acc': test_acc})
            result.test_dic = {'test_loss_a': loss_test}
            return result
    def test_step(self, batch, batch_idx):
        data, labels_dict = batch
        outputs = self.forward_tasks(data)

        result = EvalResult()
        filename = self.cfg.get("test.results_path", "./predictions.pt")

        result.write_dict(
            {
                "verb_output": outputs["verb"],
                "noun_output": outputs["noun"],
                "narration_id": labels_dict["narration_id"],
                "video_id": labels_dict["video_id"],
            },
            filename=filename,
        )

        return result
    def validation_step_result_callbacks(self, batch, batch_idx):
        acc = self.step(batch, batch_idx)

        self.assert_backward = False
        losses = [20, 19, 20, 21, 22, 23]
        idx = self.current_epoch
        loss = acc + losses[idx]
        result = EvalResult(early_stop_on=loss, checkpoint_on=loss)

        self.validation_step_called = True
        return result
Exemple #9
0
 def validation_epoch_end(self, outputs):
     loss = torch.stack([x["val_loss"] for x in outputs]).mean()
     acc = torch.stack([x["val_acc"] for x in outputs]).mean()
     result = EvalResult(checkpoint_on=loss)
     result.log("val/loss", loss, on_step=False, on_epoch=True)
     result.log("val/acc", acc, on_step=False, on_epoch=True)
     return result
    def validation_step_for_epoch_end_result(self, batch, batch_idx):
        """
        EvalResult flows to epoch end (without step_end)
        """
        acc = self.step(batch, batch_idx)
        result = EvalResult(checkpoint_on=acc, early_stop_on=acc)

        # step only metrics
        result.log('val_step_metric',
                   torch.tensor(batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=True,
                   on_epoch=True,
                   on_step=False)
        result.log('batch_idx',
                   torch.tensor(batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=True,
                   on_epoch=True,
                   on_step=False)

        self.validation_step_called = True
        return result
 def validation_step(self, batch, batch_idx):
     step_results = self._step(batch)
     result = EvalResult(checkpoint_on=step_results["loss"])
     self.log_metrics(result, step_results, "val")
     return result
Exemple #12
0
    def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        loss_test = self.loss(y, y_hat)

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        test_acc = torch.tensor(test_acc)

        test_acc = test_acc.type_as(x)

        # Do regular EvalResult Logging
        result = EvalResult(checkpoint_on=loss_test)
        result.log('test_loss', loss_test)
        result.log('test_acc', test_acc)

        batch_size = x.size(0)
        lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
        lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
        lst_of_lst = [[x] for x in lst_of_int]
        lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]

        # This is passed in from pytest via parameterization
        option = getattr(self, 'test_option', 0)
        prediction_file = getattr(self, 'prediction_file', 'predictions.pt')

        lazy_ids = torch.arange(batch_idx * self.batch_size,
                                batch_idx * self.batch_size + x.size(0))

        # Base
        if option == 0:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('preds', labels_hat, prediction_file)

        # Check mismatching tensor len
        elif option == 1:
            self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)),
                                  prediction_file)
            self.write_prediction('preds', labels_hat, prediction_file)

        # write multi-dimension
        elif option == 2:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('preds', labels_hat, prediction_file)
            self.write_prediction('x', x, prediction_file)

        # write str list
        elif option == 3:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_str, prediction_file)

        # write int list
        elif option == 4:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_int, prediction_file)

        # write nested list
        elif option == 5:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_lst, prediction_file)

        # write dict list
        elif option == 6:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_dict, prediction_file)

        elif option == 7:
            self.write_prediction_dict({
                'idxs': lazy_ids,
                'preds': labels_hat
            }, prediction_file)

        return result
Exemple #13
0
    def training_step(self, batch, batch_idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        im = batch["image"]
        prev_im = batch["prev_image"]
        cm = batch["cloth_mask"]
        flow = batch["flow"] if self.hparams.flow_warp else None

        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward. save outputs to self for visualization
        (
            self.p_rendereds,
            self.tryon_masks,
            self.p_tryons,
            self.flow_masks,
        ) = self.forward(person_inputs, cloth_inputs, flow, prev_im)
        self.p_tryons = torch.chunk(self.p_tryons,
                                    self.hparams.n_frames_total,
                                    dim=1)
        self.p_rendereds = torch.chunk(self.p_rendereds,
                                       self.hparams.n_frames_total,
                                       dim=1)
        self.tryon_masks = torch.chunk(self.tryon_masks,
                                       self.hparams.n_frames_total,
                                       dim=1)

        self.flow_masks = (torch.chunk(
            self.flow_masks, self.hparams.n_frames_total, dim=1)
                           if self.flow_masks is not None else None)

        im = torch.chunk(im, self.hparams.n_frames_total, dim=1)
        cm = torch.chunk(cm, self.hparams.n_frames_total, dim=1)

        # loss
        loss_image_l1_curr = F.l1_loss(self.p_tryons[-1], im[-1])
        loss_image_l1_prev = F.l1_loss(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_l1_curr)
        loss_image_l1 = 0.5 * (
            loss_image_l1_curr + loss_image_l1_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_l1_curr

        loss_image_vgg_curr = self.criterionVGG(self.p_tryons[-1], im[-1])
        loss_image_vgg_prev = self.criterionVGG(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_vgg_curr)
        loss_image_vgg = 0.5 * (
            loss_image_vgg_curr + loss_image_vgg_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_vgg_curr

        loss_tryon_mask_curr = F.l1_loss(self.tryon_masks[-1], cm[-1])
        loss_tryon_mask_prev = F.l1_loss(
            self.tryon_masks[-2],
            cm[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_tryon_mask_curr)
        loss_tryon_mask_l1 = 0.5 * (
            loss_tryon_mask_curr + loss_tryon_mask_prev
        ) if self.hparams.n_frames_total > 1 else loss_tryon_mask_curr

        loss_flow_mask_l1 = (
            self.flow_masks[-1].sum() if self.flow_masks is not None else torch
            .zeros_like(loss_tryon_mask_curr)) * self.hparams.pen_flow_mask

        loss = loss_image_l1 + loss_image_vgg + loss_tryon_mask_l1 + loss_flow_mask_l1

        # logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        result.log(f"{val_}loss/G/l1", loss_image_l1, prog_bar=True)
        result.log(f"{val_}loss/G/vgg", loss_image_vgg, prog_bar=True)
        result.log(f"{val_}loss/G/tryon_mask_l1",
                   loss_tryon_mask_l1,
                   prog_bar=True)
        result.log(f"{val_}loss/G/flow_mask_l1",
                   loss_flow_mask_l1,
                   prog_bar=True)

        if self.hparams.n_frames_total > 1:
            ## visualize prev frames losses
            result.log(f"{val_}loss/G/l1_prev", loss_image_l1_prev)
            result.log(f"{val_}loss/G/vgg_prev", loss_image_vgg_prev)
            result.log(f"{val_}loss/G/tryon_mask_prev", loss_tryon_mask_prev)

            ## visualize curr frames losses
            result.log(f"{val_}loss/G/l1_curr", loss_image_l1_curr)
            result.log(f"{val_}loss/G/vgg_curr", loss_image_vgg_curr)
            result.log(f"{val_}loss/G/tryon_mask_curr", loss_tryon_mask_curr)
        #from IPython import embed; embed()
        return result