def training_step(self, batch, idx, val=False): batch = maybe_combine_frames_and_channels(self.hparams, batch) # unpack c = batch["cloth"] im_c = batch["im_cloth"] im_g = batch["grid_vis"] person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs) cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs) # forward grid, theta = self.forward(person_inputs, cloth_inputs) self.warped_cloth = F.grid_sample(c, grid, padding_mode="border") self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros") # loss loss = F.l1_loss(self.warped_cloth, im_c) # Logging if not val and self.global_step % self.hparams.display_count == 0: self.visualize(batch) val_ = "val_" if val else "" result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss) result.log(f"{val_}loss/G", loss, prog_bar=True) return result
def validation_step_result_only_epoch_metrics(self, batch, batch_idx): """ Only track epoch level metrics """ acc = self.step(batch, batch_idx) result = EvalResult(checkpoint_on=acc, early_stop_on=acc) # step only metrics result.log('no_val_no_pbar', torch.tensor(11 + batch_idx).type_as(acc), prog_bar=False, logger=False) result.log('val_step_log_acc', torch.tensor(11 + batch_idx).type_as(acc), prog_bar=False, logger=True) result.log('val_step_log_pbar_acc', torch.tensor(12 + batch_idx).type_as(acc), prog_bar=True, logger=True) result.log('val_step_pbar_acc', torch.tensor(13 + batch_idx).type_as(acc), prog_bar=True, logger=False) self.validation_step_called = True return result
def test_step(self, batch, batch_idx) -> EvalResult: output = self.forward(**batch) result = EvalResult(checkpoint_on=output.loss, early_stop_on=output.loss) result.log( "test_loss", output.loss, prog_bar=True, ) return result
def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None): """ Full loop flow train step (result obj + dp) """ x, y = batch x = x.view(x.size(0), -1) y_hat = self(x.to(self.device)) loss_val = y_hat.sum() result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val) eval_name = 'validation' if not self.trainer.testing else 'test' result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True) setattr(self, f'{eval_name}_step_called', True) return result
def generator_step(self, batch, val=False): loss_G_adv_multiscale = ( # also calls generator forward self.multiscale_adversarial_loss(batch, for_discriminator=False) * self.wt_multiscale) loss_G_adv_temporal = ( self.temporal_adversarial_loss(batch, for_discriminator=False) * self.wt_temporal) ground_truth = batch["image"][:, -1, :, :, :] fake_frame = self.all_gen_frames[:, -1, :, :, :] loss_G_l1 = self.criterion_l1(fake_frame, ground_truth) * self.wt_l1 loss_G_vgg = self.criterion_VGG(fake_frame, ground_truth) * self.wt_vgg loss_G = loss_G_l1 + loss_G_vgg + loss_G_adv_multiscale + loss_G_adv_temporal # Log val_ = "val_" if val else "" result = (EvalResult(checkpoint_on=loss_G_l1 + loss_G_vgg) if val else TrainResult(loss_G)) result.log(f"{val_}loss", loss_G) result.log(f"{val_}loss/G/adv_multiscale", loss_G_adv_multiscale, prog_bar=True) result.log(f"{val_}loss/G/adv_temporal", loss_G_adv_temporal, prog_bar=True) result.log(f"{val_}loss/G/l1+vgg", loss_G_l1 + loss_G_vgg) result.log(f"{val_}loss/G/l1", loss_G_l1) result.log(f"{val_}loss/G/vgg", loss_G_vgg) return result
def test_step_result_obj(self, batch, batch_idx, *args, **kwargs): """ Default, baseline test_step :param batch: :return: """ x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) loss_test = self.loss(y, y_hat) # acc labels_hat = torch.argmax(y_hat, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) test_acc = torch.tensor(test_acc) test_acc = test_acc.type_as(x) result = EvalResult() # alternate possible outputs to test if batch_idx % 1 == 0: result.log_dict({'test_loss': loss_test, 'test_acc': test_acc}) return result if batch_idx % 2 == 0: return test_acc if batch_idx % 3 == 0: result.log_dict({'test_loss': loss_test, 'test_acc': test_acc}) result.test_dic = {'test_loss_a': loss_test} return result
def test_step(self, batch, batch_idx): data, labels_dict = batch outputs = self.forward_tasks(data) result = EvalResult() filename = self.cfg.get("test.results_path", "./predictions.pt") result.write_dict( { "verb_output": outputs["verb"], "noun_output": outputs["noun"], "narration_id": labels_dict["narration_id"], "video_id": labels_dict["video_id"], }, filename=filename, ) return result
def validation_step_result_callbacks(self, batch, batch_idx): acc = self.step(batch, batch_idx) self.assert_backward = False losses = [20, 19, 20, 21, 22, 23] idx = self.current_epoch loss = acc + losses[idx] result = EvalResult(early_stop_on=loss, checkpoint_on=loss) self.validation_step_called = True return result
def validation_epoch_end(self, outputs): loss = torch.stack([x["val_loss"] for x in outputs]).mean() acc = torch.stack([x["val_acc"] for x in outputs]).mean() result = EvalResult(checkpoint_on=loss) result.log("val/loss", loss, on_step=False, on_epoch=True) result.log("val/acc", acc, on_step=False, on_epoch=True) return result
def validation_step_for_epoch_end_result(self, batch, batch_idx): """ EvalResult flows to epoch end (without step_end) """ acc = self.step(batch, batch_idx) result = EvalResult(checkpoint_on=acc, early_stop_on=acc) # step only metrics result.log('val_step_metric', torch.tensor(batch_idx).type_as(acc), prog_bar=True, logger=True, on_epoch=True, on_step=False) result.log('batch_idx', torch.tensor(batch_idx).type_as(acc), prog_bar=True, logger=True, on_epoch=True, on_step=False) self.validation_step_called = True return result
def validation_step(self, batch, batch_idx): step_results = self._step(batch) result = EvalResult(checkpoint_on=step_results["loss"]) self.log_metrics(result, step_results, "val") return result
def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None): x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) loss_test = self.loss(y, y_hat) # acc labels_hat = torch.argmax(y_hat, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) test_acc = torch.tensor(test_acc) test_acc = test_acc.type_as(x) # Do regular EvalResult Logging result = EvalResult(checkpoint_on=loss_test) result.log('test_loss', loss_test) result.log('test_acc', test_acc) batch_size = x.size(0) lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)] lst_of_int = [random.randint(500, 1000) for i in range(batch_size)] lst_of_lst = [[x] for x in lst_of_int] lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)] # This is passed in from pytest via parameterization option = getattr(self, 'test_option', 0) prediction_file = getattr(self, 'prediction_file', 'predictions.pt') lazy_ids = torch.arange(batch_idx * self.batch_size, batch_idx * self.batch_size + x.size(0)) # Base if option == 0: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', labels_hat, prediction_file) # Check mismatching tensor len elif option == 1: self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) self.write_prediction('preds', labels_hat, prediction_file) # write multi-dimension elif option == 2: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', labels_hat, prediction_file) self.write_prediction('x', x, prediction_file) # write str list elif option == 3: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_str, prediction_file) # write int list elif option == 4: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_int, prediction_file) # write nested list elif option == 5: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_lst, prediction_file) # write dict list elif option == 6: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_dict, prediction_file) elif option == 7: self.write_prediction_dict({ 'idxs': lazy_ids, 'preds': labels_hat }, prediction_file) return result
def training_step(self, batch, batch_idx, val=False): batch = maybe_combine_frames_and_channels(self.hparams, batch) # unpack im = batch["image"] prev_im = batch["prev_image"] cm = batch["cloth_mask"] flow = batch["flow"] if self.hparams.flow_warp else None person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs) cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs) # forward. save outputs to self for visualization ( self.p_rendereds, self.tryon_masks, self.p_tryons, self.flow_masks, ) = self.forward(person_inputs, cloth_inputs, flow, prev_im) self.p_tryons = torch.chunk(self.p_tryons, self.hparams.n_frames_total, dim=1) self.p_rendereds = torch.chunk(self.p_rendereds, self.hparams.n_frames_total, dim=1) self.tryon_masks = torch.chunk(self.tryon_masks, self.hparams.n_frames_total, dim=1) self.flow_masks = (torch.chunk( self.flow_masks, self.hparams.n_frames_total, dim=1) if self.flow_masks is not None else None) im = torch.chunk(im, self.hparams.n_frames_total, dim=1) cm = torch.chunk(cm, self.hparams.n_frames_total, dim=1) # loss loss_image_l1_curr = F.l1_loss(self.p_tryons[-1], im[-1]) loss_image_l1_prev = F.l1_loss( self.p_tryons[-2], im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like( loss_image_l1_curr) loss_image_l1 = 0.5 * ( loss_image_l1_curr + loss_image_l1_prev ) if self.hparams.n_frames_total > 1 else loss_image_l1_curr loss_image_vgg_curr = self.criterionVGG(self.p_tryons[-1], im[-1]) loss_image_vgg_prev = self.criterionVGG( self.p_tryons[-2], im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like( loss_image_vgg_curr) loss_image_vgg = 0.5 * ( loss_image_vgg_curr + loss_image_vgg_prev ) if self.hparams.n_frames_total > 1 else loss_image_vgg_curr loss_tryon_mask_curr = F.l1_loss(self.tryon_masks[-1], cm[-1]) loss_tryon_mask_prev = F.l1_loss( self.tryon_masks[-2], cm[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like( loss_tryon_mask_curr) loss_tryon_mask_l1 = 0.5 * ( loss_tryon_mask_curr + loss_tryon_mask_prev ) if self.hparams.n_frames_total > 1 else loss_tryon_mask_curr loss_flow_mask_l1 = ( self.flow_masks[-1].sum() if self.flow_masks is not None else torch .zeros_like(loss_tryon_mask_curr)) * self.hparams.pen_flow_mask loss = loss_image_l1 + loss_image_vgg + loss_tryon_mask_l1 + loss_flow_mask_l1 # logging if not val and self.global_step % self.hparams.display_count == 0: self.visualize(batch) val_ = "val_" if val else "" result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss) result.log(f"{val_}loss/G", loss, prog_bar=True) result.log(f"{val_}loss/G/l1", loss_image_l1, prog_bar=True) result.log(f"{val_}loss/G/vgg", loss_image_vgg, prog_bar=True) result.log(f"{val_}loss/G/tryon_mask_l1", loss_tryon_mask_l1, prog_bar=True) result.log(f"{val_}loss/G/flow_mask_l1", loss_flow_mask_l1, prog_bar=True) if self.hparams.n_frames_total > 1: ## visualize prev frames losses result.log(f"{val_}loss/G/l1_prev", loss_image_l1_prev) result.log(f"{val_}loss/G/vgg_prev", loss_image_vgg_prev) result.log(f"{val_}loss/G/tryon_mask_prev", loss_tryon_mask_prev) ## visualize curr frames losses result.log(f"{val_}loss/G/l1_curr", loss_image_l1_curr) result.log(f"{val_}loss/G/vgg_curr", loss_image_vgg_curr) result.log(f"{val_}loss/G/tryon_mask_curr", loss_tryon_mask_curr) #from IPython import embed; embed() return result