def training_step(self, batch, batch_idx) -> TrainResult: output = self.forward(**batch) result = TrainResult( minimize=output.loss, ) result.log("train_loss", output.loss.detach(), prog_bar=True, on_epoch=True) return result
def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None): """ Full loop flow train step (result obj + dp) """ x, y = batch x = x.view(x.size(0), -1) y_hat = self(x.to(self.device)) loss_val = y_hat.sum() result = TrainResult(minimize=loss_val) result.log('train_step_metric', loss_val + 1) self.training_step_called = True return result
def training_step(self, batch, batch_idx): x, y = batch y_hat, _ = self(x, batch_idx) loss = F.cross_entropy(y_hat, y) return TrainResult(loss)
def generator_step(self, batch, val=False): loss_G_adv_multiscale = ( # also calls generator forward self.multiscale_adversarial_loss(batch, for_discriminator=False) * self.wt_multiscale) loss_G_adv_temporal = ( self.temporal_adversarial_loss(batch, for_discriminator=False) * self.wt_temporal) ground_truth = batch["image"][:, -1, :, :, :] fake_frame = self.all_gen_frames[:, -1, :, :, :] loss_G_l1 = self.criterion_l1(fake_frame, ground_truth) * self.wt_l1 loss_G_vgg = self.criterion_VGG(fake_frame, ground_truth) * self.wt_vgg loss_G = loss_G_l1 + loss_G_vgg + loss_G_adv_multiscale + loss_G_adv_temporal # Log val_ = "val_" if val else "" result = (EvalResult(checkpoint_on=loss_G_l1 + loss_G_vgg) if val else TrainResult(loss_G)) result.log(f"{val_}loss", loss_G) result.log(f"{val_}loss/G/adv_multiscale", loss_G_adv_multiscale, prog_bar=True) result.log(f"{val_}loss/G/adv_temporal", loss_G_adv_temporal, prog_bar=True) result.log(f"{val_}loss/G/l1+vgg", loss_G_l1 + loss_G_vgg) result.log(f"{val_}loss/G/l1", loss_G_l1) result.log(f"{val_}loss/G/vgg", loss_G_vgg) return result
def training_step(self, batch, idx, val=False): batch = maybe_combine_frames_and_channels(self.hparams, batch) # unpack c = batch["cloth"] im_c = batch["im_cloth"] im_g = batch["grid_vis"] person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs) cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs) # forward grid, theta = self.forward(person_inputs, cloth_inputs) self.warped_cloth = F.grid_sample(c, grid, padding_mode="border") self.warped_grid = F.grid_sample(im_g, grid, padding_mode="zeros") # loss loss = F.l1_loss(self.warped_cloth, im_c) # Logging if not val and self.global_step % self.hparams.display_count == 0: self.visualize(batch) val_ = "val_" if val else "" result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss) result.log(f"{val_}loss/G", loss, prog_bar=True) return result
def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None): """Lightning calls this inside the training loop""" # forward pass x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) # calculate loss loss_val = self.loss(y, y_hat) # call metric val = self.metric(x, y) result = TrainResult(minimize=loss_val) result.log('metric_val', val) return result
def training_step(self, batch, batch_idx, optimizer_idx): imgs, _ = batch # sample noise z = torch.randn(imgs.shape[0], self.latent_dim) z = z.type_as(imgs) # train generator if optimizer_idx == 0: # generate images self.generated_imgs = self(z) # log sampled images sample_imgs = self.generated_imgs[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('generated_images', grid, 0) # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) valid = valid.type_as(imgs) # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) tqdm_dict = {'g_loss': g_loss} result = TrainResult(minimize=g_loss, checkpoint_on=True) result.log_dict(tqdm_dict) return result # train discriminator if optimizer_idx == 1: # Measure discriminator's ability to classify real from generated samples # how well can it label as real? valid = torch.ones(imgs.size(0), 1) valid = valid.type_as(imgs) real_loss = self.adversarial_loss(self.discriminator(imgs), valid) # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) fake = fake.type_as(imgs) fake_loss = self.adversarial_loss( self.discriminator(self(z).detach()), fake) # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 tqdm_dict = {'d_loss': d_loss} result = TrainResult(minimize=d_loss, checkpoint_on=True) result.log_dict(tqdm_dict) return result
def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx): """ Early stop and checkpoint only on these values """ acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) assert 'early_step_on' not in result assert 'checkpoint_on' in result return result
def training_step_no_callbacks_result_obj(self, batch, batch_idx): """ Early stop and checkpoint only on these values """ acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc, checkpoint_on=False) assert 'early_step_on' not in result assert 'checkpoint_on' not in result return result
def training_step_result_log_step_only(self, batch, batch_idx): acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) # step only metrics result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True) result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc)) result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True) self.training_step_called = True return result
def training_step_result_log_epoch_only(self, batch, batch_idx): acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc), on_epoch=True, prog_bar=True, on_step=False) result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc), on_epoch=True, on_step=False) result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc), on_epoch=True, logger=False, prog_bar=True, on_step=False) self.training_step_called = True return result
def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx): """ Early stop and checkpoint only on these values """ acc = self.step(batch, batch_idx) self.assert_backward = False losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20] idx = self.current_epoch loss = acc + losses[idx] result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss) return result
def training_step_result_log_epoch_and_step(self, batch, batch_idx): acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) val_1 = (5 + batch_idx) * (self.current_epoch + 1) val_2 = (6 + batch_idx) * (self.current_epoch + 1) val_3 = (7 + batch_idx) * (self.current_epoch + 1) result.log(f'step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc), on_epoch=True, prog_bar=True) result.log(f'step_epoch_log_acc2', torch.tensor(val_2).type_as(acc), on_epoch=True) result.log(f'step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc), on_epoch=True, logger=False, prog_bar=True) self.training_step_called = True return result
def multiscale_discriminator_step(self, batch): loss_D, loss_D_real, loss_D_fake = self.multiscale_adversarial_loss( batch, for_discriminator=True) result = TrainResult(loss_D) result.log("loss/D/multi", loss_D, prog_bar=True) result.log("loss/D/multi_fake", loss_D_fake) result.log("loss/D/multi_real", loss_D_real) return result
def temporal_discriminator_step(self, batch): loss_D, loss_D_real, loss_D_fake = self.temporal_adversarial_loss( batch, for_discriminator=True) result = TrainResult(loss_D) result.log("loss/D/temporal", loss_D, prog_bar=True) result.log("loss/D/temporal_fake", loss_D_fake) result.log("loss/D/temporal_real", loss_D_real) return result
def training_step(self, batch, batch_idx): images, targets = batch outputs = self.forward(images) loss = self._criterion(outputs, targets) acc = torch.sum( (targets == torch.argmax(outputs, dim=1))).float() / len(targets) result = TrainResult(loss) result.log("train/loss", loss, on_step=False, on_epoch=True) result.log("train/acc", acc, on_step=False, on_epoch=True) return result
def training_step_result_obj(self, batch, batch_idx, optimizer_idx=None): # forward pass x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) # calculate loss loss_val = self.loss(y, y_hat) log_val = loss_val # alternate between tensors and scalars for "log" and "progress_bar" if batch_idx % 2 == 0: log_val = log_val.item() result = TrainResult(loss_val) result.log('some_val', log_val * log_val, prog_bar=True, logger=False) result.log('train_some_val', log_val * log_val) return result
def training_step(self, batch, batch_idx, val=False): batch = maybe_combine_frames_and_channels(self.hparams, batch) # unpack im = batch["image"] prev_im = batch["prev_image"] cm = batch["cloth_mask"] flow = batch["flow"] if self.hparams.flow_warp else None person_inputs = get_and_cat_inputs(batch, self.hparams.person_inputs) cloth_inputs = get_and_cat_inputs(batch, self.hparams.cloth_inputs) # forward. save outputs to self for visualization ( self.p_rendereds, self.tryon_masks, self.p_tryons, self.flow_masks, ) = self.forward(person_inputs, cloth_inputs, flow, prev_im) self.p_tryons = torch.chunk(self.p_tryons, self.hparams.n_frames_total, dim=1) self.p_rendereds = torch.chunk(self.p_rendereds, self.hparams.n_frames_total, dim=1) self.tryon_masks = torch.chunk(self.tryon_masks, self.hparams.n_frames_total, dim=1) self.flow_masks = (torch.chunk( self.flow_masks, self.hparams.n_frames_total, dim=1) if self.flow_masks is not None else None) im = torch.chunk(im, self.hparams.n_frames_total, dim=1) cm = torch.chunk(cm, self.hparams.n_frames_total, dim=1) # loss loss_image_l1_curr = F.l1_loss(self.p_tryons[-1], im[-1]) loss_image_l1_prev = F.l1_loss( self.p_tryons[-2], im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like( loss_image_l1_curr) loss_image_l1 = 0.5 * ( loss_image_l1_curr + loss_image_l1_prev ) if self.hparams.n_frames_total > 1 else loss_image_l1_curr loss_image_vgg_curr = self.criterionVGG(self.p_tryons[-1], im[-1]) loss_image_vgg_prev = self.criterionVGG( self.p_tryons[-2], im[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like( loss_image_vgg_curr) loss_image_vgg = 0.5 * ( loss_image_vgg_curr + loss_image_vgg_prev ) if self.hparams.n_frames_total > 1 else loss_image_vgg_curr loss_tryon_mask_curr = F.l1_loss(self.tryon_masks[-1], cm[-1]) loss_tryon_mask_prev = F.l1_loss( self.tryon_masks[-2], cm[-2]) if self.hparams.n_frames_total > 1 else torch.zeros_like( loss_tryon_mask_curr) loss_tryon_mask_l1 = 0.5 * ( loss_tryon_mask_curr + loss_tryon_mask_prev ) if self.hparams.n_frames_total > 1 else loss_tryon_mask_curr loss_flow_mask_l1 = ( self.flow_masks[-1].sum() if self.flow_masks is not None else torch .zeros_like(loss_tryon_mask_curr)) * self.hparams.pen_flow_mask loss = loss_image_l1 + loss_image_vgg + loss_tryon_mask_l1 + loss_flow_mask_l1 # logging if not val and self.global_step % self.hparams.display_count == 0: self.visualize(batch) val_ = "val_" if val else "" result = EvalResult(checkpoint_on=loss) if val else TrainResult(loss) result.log(f"{val_}loss/G", loss, prog_bar=True) result.log(f"{val_}loss/G/l1", loss_image_l1, prog_bar=True) result.log(f"{val_}loss/G/vgg", loss_image_vgg, prog_bar=True) result.log(f"{val_}loss/G/tryon_mask_l1", loss_tryon_mask_l1, prog_bar=True) result.log(f"{val_}loss/G/flow_mask_l1", loss_flow_mask_l1, prog_bar=True) if self.hparams.n_frames_total > 1: ## visualize prev frames losses result.log(f"{val_}loss/G/l1_prev", loss_image_l1_prev) result.log(f"{val_}loss/G/vgg_prev", loss_image_vgg_prev) result.log(f"{val_}loss/G/tryon_mask_prev", loss_tryon_mask_prev) ## visualize curr frames losses result.log(f"{val_}loss/G/l1_curr", loss_image_l1_curr) result.log(f"{val_}loss/G/vgg_curr", loss_image_vgg_curr) result.log(f"{val_}loss/G/tryon_mask_curr", loss_tryon_mask_curr) #from IPython import embed; embed() return result
def training_step(self, batch, batch_idx): step_results = self._step(batch) result = TrainResult(step_results["loss"]) self.log_metrics(result, step_results, "train") return result