def val_step(self, *inputs, **kwargs): """val_step() API for module wrapped by DistributedDataParallel. This method is basically the same as ``DistributedDataParallel.forward()``, while replacing ``self.module.forward()`` with ``self.module.val_step()``. It is compatible with PyTorch 1.1 - 1.5. """ if getattr(self, 'require_forward_param_sync', True): self._sync_params() if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: output = self.module.val_step(*inputs[0], **kwargs[0]) else: outputs = self.parallel_apply( self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: output = self.module.val_step(*inputs, **kwargs) if torch.is_grad_enabled() and getattr( self, 'require_backward_grad_sync', True): if self.find_unused_parameters: self.reducer.prepare_for_backward(list(_find_tensors(output))) else: self.reducer.prepare_for_backward([]) else: if TORCH_VERSION > '1.2': self.require_forward_param_sync = False return output
def forward(self, *inputs, **kwargs): """Modified to support not scattering inputs when it's only on one device""" if self.require_forward_param_sync: self._sync_params() if self.device_ids: if len(self.device_ids) == 1: output = self.module(*inputs, **kwargs) else: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) outputs = self.parallel_apply( self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: output = self.module(*inputs, **kwargs) if torch.is_grad_enabled() and self.require_backward_grad_sync: self.require_forward_param_sync = True # We'll return the output object verbatim since it is a freeform # object. We need to find any tensors in this object, though, # because we need to figure out which parameters were used during # this forward pass, to ensure we short circuit reduction for any # unused parameters. Only if `find_unused_parameters` is set. if self.find_unused_parameters: self.reducer.prepare_for_backward(list(_find_tensors(output))) else: self.reducer.prepare_for_backward([]) else: self.require_forward_param_sync = False return output
def train_step(self, *inputs, **kwargs): """train_step() API for module wrapped by DistributedDataParallel. This method is basically the same as ``DistributedDataParallel.forward()``, while replacing ``self.module.forward()`` with ``self.module.train_step()``. It is compatible with PyTorch 1.1 - 1.5. """ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # end of backward to the beginning of forward. if ('parrots' not in TORCH_VERSION and digit_version(TORCH_VERSION) >= digit_version('1.7') and self.reducer._rebuild_buckets()): print_log( 'Reducer buckets have been rebuilt in this iteration.', logger='mmcv') if ('parrots' not in TORCH_VERSION and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): if self._check_sync_bufs_pre_fwd(): self._sync_buffers() else: if (getattr(self, 'require_forward_param_sync', False) and self.require_forward_param_sync): self._sync_params() if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: output = self.module.train_step(*inputs[0], **kwargs[0]) else: outputs = self.parallel_apply( self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: output = self.module.train_step(*inputs, **kwargs) if ('parrots' not in TORCH_VERSION and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): if self._check_sync_bufs_post_fwd(): self._sync_buffers() if (torch.is_grad_enabled() and getattr(self, 'require_backward_grad_sync', False) and self.require_backward_grad_sync): if self.find_unused_parameters: self.reducer.prepare_for_backward(list(_find_tensors(output))) else: self.reducer.prepare_for_backward([]) else: if ('parrots' not in TORCH_VERSION and digit_version(TORCH_VERSION) > digit_version('1.2')): self.require_forward_param_sync = False return output
def train_step(self, data_batch, optimizer, ddp_reducer=None, running_status=None): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ # get data from data_batch real_imgs = data_batch['real_img'] # If you adopt ddp, this batch size is local batch size for each GPU. # If you adopt dp, this batch size is the global batch size as usual. batch_size = real_imgs.shape[0] # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration if dist.is_initialized(): # randomly sample a scale for current training iteration chosen_scale = np.random.choice(self.multi_input_scales, 1, self.multi_scale_probability)[0] chosen_scale = torch.tensor(chosen_scale, dtype=torch.int).cuda() dist.broadcast(chosen_scale, 0) chosen_scale = int(chosen_scale.item()) else: mmcv.print_log( 'Distributed training has not been initialized. Degrade to ' 'the standard stylegan2', logger='mmgen', level=logging.WARN) chosen_scale = 0 curr_size = (4 + chosen_scale) * (2**self.num_upblocks) # adjust the shape of images if real_imgs.shape[-2:] != (curr_size, curr_size): real_imgs = F.interpolate(real_imgs, size=(curr_size, curr_size), mode='bilinear', align_corners=True) # disc training set_requires_grad(self.discriminator, True) optimizer['discriminator'].zero_grad() # TODO: add noise sampler to customize noise sampling with torch.no_grad(): fake_imgs = self.generator(None, num_batches=batch_size, chosen_scale=chosen_scale) # disc pred for fake imgs and real_imgs disc_pred_fake = self.discriminator(fake_imgs) disc_pred_real = self.discriminator(real_imgs) # get data dict to compute losses for disc data_dict_ = dict(gen=self.generator, disc=self.discriminator, disc_pred_fake=disc_pred_fake, disc_pred_real=disc_pred_real, fake_imgs=fake_imgs, real_imgs=real_imgs, iteration=curr_iter, batch_size=batch_size, gen_partial=partial(self.generator, chosen_scale=chosen_scale)) loss_disc, log_vars_disc = self._get_disc_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_disc)) loss_disc.backward() optimizer['discriminator'].step() # skip generator training if only train discriminator for current # iteration if (curr_iter + 1) % self.disc_steps != 0: results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) log_vars_disc['curr_size'] = curr_size outputs = dict(log_vars=log_vars_disc, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs # generator training set_requires_grad(self.discriminator, False) optimizer['generator'].zero_grad() # TODO: add noise sampler to customize noise sampling fake_imgs = self.generator(None, num_batches=batch_size, chosen_scale=chosen_scale) disc_pred_fake_g = self.discriminator(fake_imgs) data_dict_ = dict(gen=self.generator, disc=self.discriminator, fake_imgs=fake_imgs, disc_pred_fake_g=disc_pred_fake_g, iteration=curr_iter, batch_size=batch_size, gen_partial=partial(self.generator, chosen_scale=chosen_scale)) loss_gen, log_vars_g = self._get_gen_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_gen)) loss_gen.backward() optimizer['generator'].step() log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) log_vars['curr_size'] = curr_size results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs
def train_step(self, data_batch, optimizer, ddp_reducer=None, loss_scaler=None, use_apex_amp=False, running_status=None): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional): The loss/gradient scaler used for auto mixed-precision training. Defaults to ``None``. use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to ``False``. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ # get data from data_batch real_imgs = data_batch[self.real_img_key] # If you adopt ddp, this batch size is local batch size for each GPU. # If you adopt dp, this batch size is the global batch size as usual. batch_size = real_imgs.shape[0] # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # disc training set_requires_grad(self.discriminator, True) optimizer['discriminator'].zero_grad() # TODO: add noise sampler to customize noise sampling with torch.no_grad(): fake_imgs = self.generator(None, num_batches=batch_size) # disc pred for fake imgs and real_imgs disc_pred_fake = self.discriminator(fake_imgs) disc_pred_real = self.discriminator(real_imgs) # get data dict to compute losses for disc data_dict_ = dict(gen=self.generator, disc=self.discriminator, disc_pred_fake=disc_pred_fake, disc_pred_real=disc_pred_real, fake_imgs=fake_imgs, real_imgs=real_imgs, iteration=curr_iter, batch_size=batch_size, loss_scaler=loss_scaler) loss_disc, log_vars_disc = self._get_disc_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_disc)) if loss_scaler: # add support for fp16 loss_scaler.scale(loss_disc).backward() elif use_apex_amp: from apex import amp with amp.scale_loss(loss_disc, optimizer['discriminator'], loss_id=0) as scaled_loss_disc: scaled_loss_disc.backward() else: loss_disc.backward() if loss_scaler: loss_scaler.unscale_(optimizer['discriminator']) # note that we do not contain clip_grad procedure loss_scaler.step(optimizer['discriminator']) # loss_scaler.update will be called in runner.train() else: optimizer['discriminator'].step() # skip generator training if only train discriminator for current # iteration if (curr_iter + 1) % self.disc_steps != 0: results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars_disc, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs # generator training set_requires_grad(self.discriminator, False) optimizer['generator'].zero_grad() # TODO: add noise sampler to customize noise sampling fake_imgs = self.generator(None, num_batches=batch_size) disc_pred_fake_g = self.discriminator(fake_imgs) data_dict_ = dict(gen=self.generator, disc=self.discriminator, fake_imgs=fake_imgs, disc_pred_fake_g=disc_pred_fake_g, iteration=curr_iter, batch_size=batch_size, loss_scaler=loss_scaler) loss_gen, log_vars_g = self._get_gen_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_gen)) if loss_scaler: loss_scaler.scale(loss_gen).backward() elif use_apex_amp: from apex import amp with amp.scale_loss(loss_gen, optimizer['generator'], loss_id=1) as scaled_loss_disc: scaled_loss_disc.backward() else: loss_gen.backward() if loss_scaler: loss_scaler.unscale_(optimizer['generator']) # note that we do not contain clip_grad procedure loss_scaler.step(optimizer['generator']) # loss_scaler.update will be called in runner.train() else: optimizer['generator'].step() log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs
def train_step(self, data_batch, optimizer, ddp_reducer=None, running_status=None): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # init each scale if curr_iter % self.train_cfg['iters_per_scale'] == 0: self.curr_stage += 1 # load weights from prev scale self.get_module(self.generator, 'check_and_load_prev_weight')(self.curr_stage) self.get_module(self.discriminator, 'check_and_load_prev_weight')(self.curr_stage) # build optimizer for each scale g_module = self.get_module(self.generator, 'blocks') param_list = g_module[self.curr_stage].parameters() self.g_optim = torch.optim.Adam(param_list, lr=self.train_cfg['lr_g'], betas=(0.5, 0.999)) d_module = self.get_module(self.discriminator, 'blocks') self.d_optim = torch.optim.Adam( d_module[self.curr_stage].parameters(), lr=self.train_cfg['lr_d'], betas=(0.5, 0.999)) self.optimizer = dict(generator=self.g_optim, discriminator=self.d_optim) self.g_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=self.g_optim, **self.train_cfg['lr_scheduler_args']) self.d_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=self.d_optim, **self.train_cfg['lr_scheduler_args']) optimizer = self.optimizer # setup fixed noises and reals pyramid if curr_iter == 0 or len(self.reals) == 0: keys = [k for k in data_batch.keys() if 'real_scale' in k] scales = len(keys) self.reals = [data_batch[f'real_scale{s}'] for s in range(scales)] # here we do not padding fixed noises self.construct_fixed_noises() # disc training set_requires_grad(self.discriminator, True) for _ in range(self.train_cfg['disc_steps']): optimizer['discriminator'].zero_grad() # TODO: add noise sampler to customize noise sampling with torch.no_grad(): fake_imgs = self.generator(data_batch['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='rand', curr_scale=self.curr_stage) # disc pred for fake imgs and real_imgs disc_pred_fake = self.discriminator(fake_imgs.detach(), self.curr_stage) disc_pred_real = self.discriminator(self.reals[self.curr_stage], self.curr_stage) # get data dict to compute losses for disc data_dict_ = dict(iteration=curr_iter, gen=self.generator, disc=self.discriminator, disc_pred_fake=disc_pred_fake, disc_pred_real=disc_pred_real, fake_imgs=fake_imgs, real_imgs=self.reals[self.curr_stage], disc_partial=partial(self.discriminator, curr_scale=self.curr_stage)) loss_disc, log_vars_disc = self._get_disc_loss(data_dict_) # prepare for backward in ddp. If you do not call this function # before back propagation, the ddp will not dynamically find the # used params in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_disc)) loss_disc.backward() optimizer['discriminator'].step() log_vars_disc.update(dict(curr_stage=self.curr_stage)) # generator training set_requires_grad(self.discriminator, False) for _ in range(self.train_cfg['generator_steps']): optimizer['generator'].zero_grad() # TODO: add noise sampler to customize noise sampling fake_imgs = self.generator(data_batch['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='rand', curr_scale=self.curr_stage) disc_pred_fake_g = self.discriminator(fake_imgs, curr_scale=self.curr_stage) recon_imgs = self.generator(data_batch['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='recon', curr_scale=self.curr_stage) data_dict_ = dict(iteration=curr_iter, gen=self.generator, disc=self.discriminator, fake_imgs=fake_imgs, recon_imgs=recon_imgs, real_imgs=self.reals[self.curr_stage], disc_pred_fake_g=disc_pred_fake_g) loss_gen, log_vars_g = self._get_gen_loss(data_dict_) # prepare for backward in ddp. If you do not call this function # before back propagation, the ddp will not dynamically find the # used params in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_gen)) loss_gen.backward() optimizer['generator'].step() # end of each scale # calculate noise weight for next scale if (curr_iter % self.train_cfg['iters_per_scale'] == 0) and (self.curr_stage < len(self.reals) - 1): with torch.no_grad(): g_recon = self.generator(data_batch['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='recon', curr_scale=self.curr_stage) if isinstance(g_recon, dict): g_recon = g_recon['fake_img'] g_recon = F.interpolate( g_recon, self.reals[self.curr_stage + 1].shape[-2:]) mse = F.mse_loss(g_recon.detach(), self.reals[self.curr_stage + 1]) rmse = torch.sqrt(mse) self.noise_weights.append( self.train_cfg.get('noise_weight_init', 0.1) * rmse.item()) # try to release GPU memory. torch.cuda.empty_cache() log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=self.reals[self.curr_stage].cpu(), recon_imgs=recon_imgs.cpu(), curr_stage=self.curr_stage, fixed_noises=self.fixed_noises, noise_weights=self.noise_weights) outputs = dict(log_vars=log_vars, num_samples=1, results=results) # update lr scheduler self.d_scheduler.step() self.g_scheduler.step() if hasattr(self, 'iteration'): self.iteration += 1 return outputs
def train_step(self, data, optimizer, ddp_reducer=None, loss_scaler=None, use_apex_amp=False, running_status=None): """The iteration step during training. This method defines an iteration step during training. Different from other repo in **MM** series, we allow the back propagation and optimizer updating to directly follow the iterative training schedule of DDPMs. Of course, we will show that you can also move the back propagation outside of this method, and then optimize the parameters in the optimizer hook. But this will cause extra GPU memory cost as a result of retaining computational graph. Otherwise, the training schedule should be modified in the detailed implementation. Args: optimizer (dict): Dict contains optimizer for denoising network. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. """ # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration real_imgs = data[self.real_img_key] # denoising training optimizer['denoising'].zero_grad() denoising_dict_ = self.reconstruction_step(data, timesteps=self.sampler, sample_model='orig', return_noise=True) denoising_dict_['iteration'] = curr_iter denoising_dict_['real_imgs'] = real_imgs denoising_dict_['loss_scaler'] = loss_scaler loss, log_vars = self._get_loss(denoising_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss)) if loss_scaler: # add support for fp16 loss_scaler.scale(loss).backward() elif use_apex_amp: from apex import amp with amp.scale_loss(loss, optimizer['denoising'], loss_id=0) as scaled_loss_disc: scaled_loss_disc.backward() else: loss.backward() if loss_scaler: loss_scaler.unscale_(optimizer['denoising']) # note that we do not contain clip_grad procedure loss_scaler.step(optimizer['denoising']) # loss_scaler.update will be called in runner.train() else: optimizer['denoising'].step() # image used for vislization results = dict(real_imgs=real_imgs, x_0_pred=denoising_dict_['x_0_pred'], x_t=denoising_dict_['diffusion_batches'], x_t_1=denoising_dict_['fake_img']) outputs = dict(log_vars=log_vars, num_samples=real_imgs.shape[0], results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs
def train_step(self, data_batch, optimizer, ddp_reducer=None, running_status=None): """Training step function. Args: data_batch (dict): Dict of the input data batch. optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for the generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Dict of loss, information for logger, the number of samples\ and results for visualization. """ # data target_domain = self._default_domain source_domain = self.get_other_domains(self._default_domain)[0] source_image = data_batch[f'img_{source_domain}'] target_image = data_batch[f'img_{target_domain}'] # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # forward generator outputs = dict() results = self(source_image, target_domain=self._default_domain, test_mode=False) outputs[f'real_{source_domain}'] = results['source'] outputs[f'fake_{target_domain}'] = results['target'] outputs[f'real_{target_domain}'] = target_image log_vars = dict() # discriminator set_requires_grad(self.discriminators, True) # optimize optimizer['discriminators'].zero_grad() loss_d, log_vars_d = self._get_disc_loss(outputs) log_vars.update(log_vars_d) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_d)) loss_d.backward() optimizer['discriminators'].step() # generator, no updates to discriminator parameters. if (curr_iter % self.disc_steps == 0 and curr_iter >= self.disc_init_steps): set_requires_grad(self.discriminators, False) # optimize optimizer['generators'].zero_grad() loss_g, log_vars_g = self._get_gen_loss(outputs) log_vars.update(log_vars_g) # prepare for backward in ddp. If you do not call this function # before back propagation, the ddp will not dynamically find the # used params in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_g)) loss_g.backward() optimizer['generators'].step() if hasattr(self, 'iteration'): self.iteration += 1 image_results = dict() image_results[f'real_{source_domain}'] = outputs[ f'real_{source_domain}'].cpu() image_results[f'fake_{target_domain}'] = outputs[ f'fake_{target_domain}'].cpu() image_results[f'real_{target_domain}'] = outputs[ f'real_{target_domain}'].cpu() results = dict(log_vars=log_vars, num_samples=len(outputs[f'real_{source_domain}']), results=image_results) return results
def train_step(self, data_batch, optimizer, ddp_reducer=None, running_status=None): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ # get data from data_batch real_imgs = data_batch['real_img'] # If you adopt ddp, this batch size is local batch size for each GPU. batch_size = real_imgs.shape[0] # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # check if optimizer from model if hasattr(self, 'optimizer'): optimizer = self.optimizer # update current stage self.curr_stage = int( min(sum(self.cum_nkimgs <= self.shown_nkimg.item()), len(self.scales) - 1)) self.curr_scale = self.scales[self.curr_stage] self._curr_scale_int = self._next_scale_int.clone() # add new scale and update training status if self.curr_stage != self.prev_stage: self.prev_stage = self.curr_stage self._actual_nkimgs.append(self.shown_nkimg.item()) # reset optimizer if self.reset_optim_for_new_scale: optim_cfg = deepcopy(self.train_cfg['optimizer_cfg']) optim_cfg['generator']['lr'] = self.g_lr_schedule.get( str(self.curr_scale[0]), self.g_lr_base) optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get( str(self.curr_scale[0]), self.d_lr_base) self.optimizer = build_optimizers(self, optim_cfg) optimizer = self.optimizer mmcv.print_log('Reset optimizer for new scale', logger='mmgen') # update training configs, like transition weight for torgb layers. # get current transition weight for interpolating two torgb layers if self.curr_stage == 0: transition_weight = 1. else: transition_weight = ( self.shown_nkimg.item() - self._actual_nkimgs[-1]) / self.transition_kimgs # clip to [0, 1] transition_weight = min(max(transition_weight, 0.), 1.) self._curr_transition_weight = torch.tensor(transition_weight).to( self._curr_transition_weight) # resize real image to target scale if real_imgs.shape[2:] == self.curr_scale: pass elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[ 3] >= self.curr_scale[1]: real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale) else: raise RuntimeError( f'The scale of real image {real_imgs.shape[2:]} is smaller ' f'than current scale {self.curr_scale}.') # disc training set_requires_grad(self.discriminator, True) optimizer['discriminator'].zero_grad() # TODO: add noise sampler to customize noise sampling with torch.no_grad(): fake_imgs = self.generator(None, num_batches=batch_size, curr_scale=self.curr_scale[0], transition_weight=transition_weight) # disc pred for fake imgs and real_imgs disc_pred_fake = self.discriminator( fake_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight) disc_pred_real = self.discriminator( real_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight) # get data dict to compute losses for disc data_dict_ = dict( iteration=curr_iter, gen=self.generator, disc=self.discriminator, disc_pred_fake=disc_pred_fake, disc_pred_real=disc_pred_real, fake_imgs=fake_imgs, real_imgs=real_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight, gen_partial=partial(self.generator, curr_scale=self.curr_scale[0], transition_weight=transition_weight), disc_partial=partial(self.discriminator, curr_scale=self.curr_scale[0], transition_weight=transition_weight)) loss_disc, log_vars_disc = self._get_disc_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_disc)) loss_disc.backward() optimizer['discriminator'].step() # update training log status if dist.is_initialized(): _batch_size = batch_size * dist.get_world_size() else: if 'batch_size' not in running_status: raise RuntimeError( 'You should offer "batch_size" in running status for PGGAN' ) _batch_size = running_status['batch_size'] self.shown_nkimg += (_batch_size / 1000.) log_vars_disc.update( dict(shown_nkimg=self.shown_nkimg.item(), curr_scale=self.curr_scale[0], transition_weight=transition_weight)) # skip generator training if only train discriminator for current # iteration if (curr_iter + 1) % self.disc_steps != 0: results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars_disc, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs # generator training set_requires_grad(self.discriminator, False) optimizer['generator'].zero_grad() # TODO: add noise sampler to customize noise sampling fake_imgs = self.generator(None, num_batches=batch_size, curr_scale=self.curr_scale[0], transition_weight=transition_weight) disc_pred_fake_g = self.discriminator( fake_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight) data_dict_ = dict(iteration=curr_iter, gen=self.generator, disc=self.discriminator, fake_imgs=fake_imgs, disc_pred_fake_g=disc_pred_fake_g) loss_gen, log_vars_g = self._get_gen_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_gen)) loss_gen.backward() optimizer['generator'].step() log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) log_vars.update({'batch_size': batch_size}) results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 # check if a new scale will be added in the next iteration _curr_stage = int( min(sum(self.cum_nkimgs <= self.shown_nkimg.item()), len(self.scales) - 1)) # in the next iteration, we will switch to a new scale if _curr_stage != self.curr_stage: # `self._next_scale_int` is updated at the end of `train_step` self._next_scale_int = self._next_scale_int * 2 return outputs
def train_step(self, data_batch, optimizer, ddp_reducer=None, running_status=None): """Training step function. Args: data_batch (dict): Dict of the input data batch. optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for the generators and discriminators. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Dict of loss, information for logger, the number of samples\ and results for visualization. """ # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # forward generators outputs = dict() for target_domain in self._reachable_domains: # fetch data by domain source_domain = self.get_other_domains(target_domain)[0] img = data_batch[f'img_{source_domain}'] # translation process results = self(img, test_mode=False, target_domain=target_domain) outputs[f'real_{source_domain}'] = results['source'] outputs[f'fake_{target_domain}'] = results['target'] # cycle process results = self(results['target'], test_mode=False, target_domain=source_domain) outputs[f'cycle_{source_domain}'] = results['target'] log_vars = dict() # discriminators set_requires_grad(self.discriminators, True) # optimize optimizer['discriminators'].zero_grad() loss_d, log_vars_d = self._get_disc_loss(outputs) log_vars.update(log_vars_d) if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_d)) loss_d.backward() optimizer['discriminators'].step() # generators, no updates to discriminator parameters. if (curr_iter % self.disc_steps == 0 and curr_iter >= self.disc_init_steps): set_requires_grad(self.discriminators, False) # optimize optimizer['generators'].zero_grad() loss_g, log_vars_g = self._get_gen_loss(outputs) log_vars.update(log_vars_g) if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_g)) loss_g.backward() optimizer['generators'].step() if hasattr(self, 'iteration'): self.iteration += 1 image_results = dict() for domain in self._reachable_domains: image_results[f'real_{domain}'] = outputs[f'real_{domain}'].cpu() image_results[f'fake_{domain}'] = outputs[f'fake_{domain}'].cpu() results = dict(log_vars=log_vars, num_samples=len(outputs[f'real_{domain}']), results=image_results) return results