Beispiel #1
0
 def validation_epoch_end(self, outputs):
     loss = torch.stack([x["val_loss"] for x in outputs]).mean()
     acc = torch.stack([x["val_acc"] for x in outputs]).mean()
     result = EvalResult(checkpoint_on=loss)
     result.log("val/loss", loss, on_step=False, on_epoch=True)
     result.log("val/acc", acc, on_step=False, on_epoch=True)
     return result
    def training_step(self, batch, idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        c = batch["cloth"]
        im_c = batch["im_cloth"]
        im_g = batch["grid_vis"]
        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward
        grid, theta = self.forward(person_inputs, cloth_inputs)
        self.warped_cloth = F.grid_sample(c, grid, padding_mode="border")
        self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros")
        # loss
        loss = F.l1_loss(self.warped_cloth, im_c)

        # Logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        return result
Beispiel #3
0
 def test_step(self, batch, batch_idx) -> EvalResult:
     output = self.forward(**batch)
     result = EvalResult(checkpoint_on=output.loss, early_stop_on=output.loss)
     result.log(
         "test_loss",
         output.loss,
         prog_bar=True,
     )
     return result
    def validation_step_result_only_epoch_metrics(self, batch, batch_idx):
        """
        Only track epoch level metrics
        """
        acc = self.step(batch, batch_idx)
        result = EvalResult(checkpoint_on=acc, early_stop_on=acc)

        # step only metrics
        result.log('no_val_no_pbar',
                   torch.tensor(11 + batch_idx).type_as(acc),
                   prog_bar=False,
                   logger=False)
        result.log('val_step_log_acc',
                   torch.tensor(11 + batch_idx).type_as(acc),
                   prog_bar=False,
                   logger=True)
        result.log('val_step_log_pbar_acc',
                   torch.tensor(12 + batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=True)
        result.log('val_step_pbar_acc',
                   torch.tensor(13 + batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=False)

        self.validation_step_called = True
        return result
Beispiel #5
0
    def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
        """
        Full loop flow train step (result obj + dp)
        """
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x.to(self.device))
        loss_val = y_hat.sum()
        result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)

        eval_name = 'validation' if not self.trainer.testing else 'test'
        result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True)

        setattr(self, f'{eval_name}_step_called', True)
        return result
    def validation_step_for_epoch_end_result(self, batch, batch_idx):
        """
        EvalResult flows to epoch end (without step_end)
        """
        acc = self.step(batch, batch_idx)
        result = EvalResult(checkpoint_on=acc, early_stop_on=acc)

        # step only metrics
        result.log('val_step_metric',
                   torch.tensor(batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=True,
                   on_epoch=True,
                   on_step=False)
        result.log('batch_idx',
                   torch.tensor(batch_idx).type_as(acc),
                   prog_bar=True,
                   logger=True,
                   on_epoch=True,
                   on_step=False)

        self.validation_step_called = True
        return result
Beispiel #7
0
    def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        loss_test = self.loss(y, y_hat)

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        test_acc = torch.tensor(test_acc)

        test_acc = test_acc.type_as(x)

        # Do regular EvalResult Logging
        result = EvalResult(checkpoint_on=loss_test)
        result.log('test_loss', loss_test)
        result.log('test_acc', test_acc)

        batch_size = x.size(0)
        lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
        lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
        lst_of_lst = [[x] for x in lst_of_int]
        lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]

        # This is passed in from pytest via parameterization
        option = getattr(self, 'test_option', 0)
        prediction_file = getattr(self, 'prediction_file', 'predictions.pt')

        lazy_ids = torch.arange(batch_idx * self.batch_size,
                                batch_idx * self.batch_size + x.size(0))

        # Base
        if option == 0:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('preds', labels_hat, prediction_file)

        # Check mismatching tensor len
        elif option == 1:
            self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)),
                                  prediction_file)
            self.write_prediction('preds', labels_hat, prediction_file)

        # write multi-dimension
        elif option == 2:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('preds', labels_hat, prediction_file)
            self.write_prediction('x', x, prediction_file)

        # write str list
        elif option == 3:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_str, prediction_file)

        # write int list
        elif option == 4:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_int, prediction_file)

        # write nested list
        elif option == 5:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_lst, prediction_file)

        # write dict list
        elif option == 6:
            self.write_prediction('idxs', lazy_ids, prediction_file)
            self.write_prediction('vals', lst_of_dict, prediction_file)

        elif option == 7:
            self.write_prediction_dict({
                'idxs': lazy_ids,
                'preds': labels_hat
            }, prediction_file)

        return result
Beispiel #8
0
    def training_step(self, batch, batch_idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        im = batch["image"]
        prev_im = batch["prev_image"]
        cm = batch["cloth_mask"]
        flow = batch["flow"] if self.hparams.flow_warp else None

        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward. save outputs to self for visualization
        (
            self.p_rendereds,
            self.tryon_masks,
            self.p_tryons,
            self.flow_masks,
        ) = self.forward(person_inputs, cloth_inputs, flow, prev_im)
        self.p_tryons = torch.chunk(self.p_tryons,
                                    self.hparams.n_frames_total,
                                    dim=1)
        self.p_rendereds = torch.chunk(self.p_rendereds,
                                       self.hparams.n_frames_total,
                                       dim=1)
        self.tryon_masks = torch.chunk(self.tryon_masks,
                                       self.hparams.n_frames_total,
                                       dim=1)

        self.flow_masks = (torch.chunk(
            self.flow_masks, self.hparams.n_frames_total, dim=1)
                           if self.flow_masks is not None else None)

        im = torch.chunk(im, self.hparams.n_frames_total, dim=1)
        cm = torch.chunk(cm, self.hparams.n_frames_total, dim=1)

        # loss
        loss_image_l1_curr = F.l1_loss(self.p_tryons[-1], im[-1])
        loss_image_l1_prev = F.l1_loss(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_l1_curr)
        loss_image_l1 = 0.5 * (
            loss_image_l1_curr + loss_image_l1_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_l1_curr

        loss_image_vgg_curr = self.criterionVGG(self.p_tryons[-1], im[-1])
        loss_image_vgg_prev = self.criterionVGG(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_vgg_curr)
        loss_image_vgg = 0.5 * (
            loss_image_vgg_curr + loss_image_vgg_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_vgg_curr

        loss_tryon_mask_curr = F.l1_loss(self.tryon_masks[-1], cm[-1])
        loss_tryon_mask_prev = F.l1_loss(
            self.tryon_masks[-2],
            cm[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_tryon_mask_curr)
        loss_tryon_mask_l1 = 0.5 * (
            loss_tryon_mask_curr + loss_tryon_mask_prev
        ) if self.hparams.n_frames_total > 1 else loss_tryon_mask_curr

        loss_flow_mask_l1 = (
            self.flow_masks[-1].sum() if self.flow_masks is not None else torch
            .zeros_like(loss_tryon_mask_curr)) * self.hparams.pen_flow_mask

        loss = loss_image_l1 + loss_image_vgg + loss_tryon_mask_l1 + loss_flow_mask_l1

        # logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        result.log(f"{val_}loss/G/l1", loss_image_l1, prog_bar=True)
        result.log(f"{val_}loss/G/vgg", loss_image_vgg, prog_bar=True)
        result.log(f"{val_}loss/G/tryon_mask_l1",
                   loss_tryon_mask_l1,
                   prog_bar=True)
        result.log(f"{val_}loss/G/flow_mask_l1",
                   loss_flow_mask_l1,
                   prog_bar=True)

        if self.hparams.n_frames_total > 1:
            ## visualize prev frames losses
            result.log(f"{val_}loss/G/l1_prev", loss_image_l1_prev)
            result.log(f"{val_}loss/G/vgg_prev", loss_image_vgg_prev)
            result.log(f"{val_}loss/G/tryon_mask_prev", loss_tryon_mask_prev)

            ## visualize curr frames losses
            result.log(f"{val_}loss/G/l1_curr", loss_image_l1_curr)
            result.log(f"{val_}loss/G/vgg_curr", loss_image_vgg_curr)
            result.log(f"{val_}loss/G/tryon_mask_curr", loss_tryon_mask_curr)
        #from IPython import embed; embed()
        return result