def training_step(self, batch, idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        c = batch["cloth"]
        im_c = batch["im_cloth"]
        im_g = batch["grid_vis"]
        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward
        grid, theta = self.forward(person_inputs, cloth_inputs)
        self.warped_cloth = F.grid_sample(c, grid, padding_mode="border")
        self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros")
        # loss
        loss = F.l1_loss(self.warped_cloth, im_c)

        # Logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        return result
Beispiel #2
0
    def test_step(self, batch, batch_idx):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        dataset_names = batch["dataset_name"]
        im_names = batch["image_name"]
        if self.hparams.n_frames_total > 1:
            dataset_names = get_last_item_per_batch(dataset_names)
            im_names = get_last_item_per_batch(im_names)

        task = "tryon" if self.hparams.tryon_list else "reconstruction"
        try_on_dirs = [
            osp.join(self.test_results_dir, dname, task)
            for dname in dataset_names
        ]

        # if we already did a forward-pass on this batch, skip it
        save_paths = get_save_paths(try_on_dirs, im_names)
        if all(osp.exists(s) for s in save_paths):
            progress_bar = {"file": f"Skipping {im_names[0]}"}
        else:
            progress_bar = {"file": f"{im_names[0]}"}

            person_inputs = get_and_cat_inputs(batch,
                                               self.hparams.person_inputs)
            cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

            _, _, self.p_tryon, _ = self.forward(person_inputs, cloth_inputs)

            # TODO CLEANUP: we get the last frame here by picking the last RGB channels;
            #  this is different from how it's done in training_step, which uses
            #  chunking and -1 indexing. We should choose one method for consistency.
            save_images(self.p_tryon[:, -TryonDataset.RGB_CHANNELS:, :, ],
                        im_names, try_on_dirs)
        result = {"progress_bar": progress_bar}
        return result
Beispiel #3
0
 def visualize(self, b, tag="train"):
     if tag == "validation":
         b = maybe_combine_frames_and_channels(self.hparams, b)
     person_visuals = self.fetch_person_visuals(b)
     visuals = [
         person_visuals,
         [
             # extract only the latest frame (for --n_frames_total)
             b["cloth"][:, -TryonDataset.CLOTH_CHANNELS:, :, :],
             b["cloth_mask"][:, -TryonDataset.CLOTH_MASK_CHANNELS:, :, :] *
             2 - 1,
             self.tryon_masks[-TryonDataset.MASK_CHANNELS] * 2 - 1,
         ],
         [
             self.p_rendereds[-1],
             self.p_tryons[-1],
             b["image"][:, -TryonDataset.RGB_CHANNELS:, :, :],
             b["prev_image"][:, -TryonDataset.RGB_CHANNELS:, :, :],
         ],
     ]
     for list_i in range(len(visuals)):
         for list_j in range(len(visuals[list_i])):
             tensor = visuals[list_i][list_j]
             if tensor.dim() == 5:
                 tensor = torch.squeeze(tensor, 1)
                 visuals[list_i][list_j] = tensor
     tensor = tensor_list_for_board(visuals)
     # add to experiment
     for i, img in enumerate(tensor):
         self.logger.experiment.add_image(f"{tag}/{i:03d}", img,
                                          self.global_step)
    def visualize(self, b, tag="train"):
        if tag == "validation":
            b = maybe_combine_frames_and_channels(self.hparams, b)
        person_visuals = self.fetch_person_visuals(b)

        visuals = [
            person_visuals,
            [b["cloth"], self.warped_cloth, b["im_cloth"]],
            [
                self.warped_grid, (self.warped_cloth + b["image"]) * 0.5,
                b["image"]
            ],
        ]
        tensor = tensor_list_for_board(visuals)
        # add to experiment
        for i, img in enumerate(tensor):
            self.logger.experiment.add_image(f"{tag}/{i:03d}", img,
                                             self.global_step)
    def __getitem__(self, index):
        if index < len(self.viton_dataset):
            item = self.viton_dataset[index]
            return item
        index -= len(self.viton_dataset)

        if index < len(self.vvt_dataset):
            item = self.vvt_dataset[index]
            if self.opt.model == "warp":
                assert self.opt.n_frames_total == 1, (
                    f"{self.opt.n_frames_total=}; "
                    f"warp model shouldn't be using n_frames_total > 1")
                item = maybe_combine_frames_and_channels(self.opt,
                                                         item,
                                                         has_batch_dim=False)
            return item
        index -= len(self.vvt_dataset)

        item = self.mpv_dataset[index]
        return item
    def test_step(self, batch, batch_idx):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        dataset_names = batch["dataset_name"]
        # produce subfolders for each subdataset
        warp_cloth_dirs = [
            osp.join(self.test_results_dir, dname, "warp-cloth")
            for dname in dataset_names
        ]
        warp_mask_dirs = [
            osp.join(self.test_results_dir, dname, "warp-mask")
            for dname in dataset_names
        ]
        c_names = batch["cloth_name"]
        # if we already did a forward-pass on this batch, skip it
        save_paths = get_save_paths(warp_cloth_dirs, c_names)
        if all(osp.exists(s) for s in save_paths):
            progress_bar = {"file": f"Skipping {c_names[0]}"}
        else:
            progress_bar = {"file": c_names[0]}
            # unpack the the data
            c = batch["cloth"]
            cm = batch["cloth_mask"]
            im_g = batch["grid_vis"]
            person_inputs = get_and_cat_inputs(batch,
                                               self.hparams.person_inputs)
            cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

            # forward pass
            grid, theta = self.forward(person_inputs, cloth_inputs)
            self.warped_cloth = F.grid_sample(c, grid, padding_mode="border")
            warped_mask = F.grid_sample(cm, grid, padding_mode="zeros")
            self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros")

            # save images
            save_images(self.warped_cloth, c_names, warp_cloth_dirs)
            save_images(warped_mask * 2 - 1, c_names, warp_mask_dirs)

        result = {"progress_bar": progress_bar}
        return result
Beispiel #7
0
    def training_step(self, batch, batch_idx, val=False):
        batch = maybe_combine_frames_and_channels(self.hparams, batch)
        # unpack
        im = batch["image"]
        prev_im = batch["prev_image"]
        cm = batch["cloth_mask"]
        flow = batch["flow"] if self.hparams.flow_warp else None

        person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs)
        cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs)

        # forward. save outputs to self for visualization
        (
            self.p_rendereds,
            self.tryon_masks,
            self.p_tryons,
            self.flow_masks,
        ) = self.forward(person_inputs, cloth_inputs, flow, prev_im)
        self.p_tryons = torch.chunk(self.p_tryons,
                                    self.hparams.n_frames_total,
                                    dim=1)
        self.p_rendereds = torch.chunk(self.p_rendereds,
                                       self.hparams.n_frames_total,
                                       dim=1)
        self.tryon_masks = torch.chunk(self.tryon_masks,
                                       self.hparams.n_frames_total,
                                       dim=1)

        self.flow_masks = (torch.chunk(
            self.flow_masks, self.hparams.n_frames_total, dim=1)
                           if self.flow_masks is not None else None)

        im = torch.chunk(im, self.hparams.n_frames_total, dim=1)
        cm = torch.chunk(cm, self.hparams.n_frames_total, dim=1)

        # loss
        loss_image_l1_curr = F.l1_loss(self.p_tryons[-1], im[-1])
        loss_image_l1_prev = F.l1_loss(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_l1_curr)
        loss_image_l1 = 0.5 * (
            loss_image_l1_curr + loss_image_l1_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_l1_curr

        loss_image_vgg_curr = self.criterionVGG(self.p_tryons[-1], im[-1])
        loss_image_vgg_prev = self.criterionVGG(
            self.p_tryons[-2],
            im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_image_vgg_curr)
        loss_image_vgg = 0.5 * (
            loss_image_vgg_curr + loss_image_vgg_prev
        ) if self.hparams.n_frames_total > 1 else loss_image_vgg_curr

        loss_tryon_mask_curr = F.l1_loss(self.tryon_masks[-1], cm[-1])
        loss_tryon_mask_prev = F.l1_loss(
            self.tryon_masks[-2],
            cm[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like(
                loss_tryon_mask_curr)
        loss_tryon_mask_l1 = 0.5 * (
            loss_tryon_mask_curr + loss_tryon_mask_prev
        ) if self.hparams.n_frames_total > 1 else loss_tryon_mask_curr

        loss_flow_mask_l1 = (
            self.flow_masks[-1].sum() if self.flow_masks is not None else torch
            .zeros_like(loss_tryon_mask_curr)) * self.hparams.pen_flow_mask

        loss = loss_image_l1 + loss_image_vgg + loss_tryon_mask_l1 + loss_flow_mask_l1

        # logging
        if not val and self.global_step % self.hparams.display_count == 0:
            self.visualize(batch)

        val_ = "val_" if val else ""
        result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss)
        result.log(f"{val_}loss/G", loss, prog_bar=True)

        result.log(f"{val_}loss/G/l1", loss_image_l1, prog_bar=True)
        result.log(f"{val_}loss/G/vgg", loss_image_vgg, prog_bar=True)
        result.log(f"{val_}loss/G/tryon_mask_l1",
                   loss_tryon_mask_l1,
                   prog_bar=True)
        result.log(f"{val_}loss/G/flow_mask_l1",
                   loss_flow_mask_l1,
                   prog_bar=True)

        if self.hparams.n_frames_total > 1:
            ## visualize prev frames losses
            result.log(f"{val_}loss/G/l1_prev", loss_image_l1_prev)
            result.log(f"{val_}loss/G/vgg_prev", loss_image_vgg_prev)
            result.log(f"{val_}loss/G/tryon_mask_prev", loss_tryon_mask_prev)

            ## visualize curr frames losses
            result.log(f"{val_}loss/G/l1_curr", loss_image_l1_curr)
            result.log(f"{val_}loss/G/vgg_curr", loss_image_vgg_curr)
            result.log(f"{val_}loss/G/tryon_mask_curr", loss_tryon_mask_curr)
        #from IPython import embed; embed()
        return result