def train_step(self, data_batch, optimizer, ddp_reducer=None, running_status=None): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ # get data from data_batch real_imgs = data_batch['real_img'] # If you adopt ddp, this batch size is local batch size for each GPU. batch_size = real_imgs.shape[0] # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # check if optimizer from model if hasattr(self, 'optimizer'): optimizer = self.optimizer # update current stage self.curr_stage = int( min(sum(self.cum_nkimgs <= self.shown_nkimg.item()), len(self.scales) - 1)) self.curr_scale = self.scales[self.curr_stage] self._curr_scale_int = self._next_scale_int.clone() # add new scale and update training status if self.curr_stage != self.prev_stage: self.prev_stage = self.curr_stage self._actual_nkimgs.append(self.shown_nkimg.item()) # reset optimizer if self.reset_optim_for_new_scale: optim_cfg = deepcopy(self.train_cfg['optimizer_cfg']) optim_cfg['generator']['lr'] = self.g_lr_schedule.get( str(self.curr_scale[0]), self.g_lr_base) optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get( str(self.curr_scale[0]), self.d_lr_base) self.optimizer = build_optimizers(self, optim_cfg) optimizer = self.optimizer mmcv.print_log('Reset optimizer for new scale', logger='mmgen') # update training configs, like transition weight for torgb layers. # get current transition weight for interpolating two torgb layers if self.curr_stage == 0: transition_weight = 1. else: transition_weight = ( self.shown_nkimg.item() - self._actual_nkimgs[-1]) / self.transition_kimgs # clip to [0, 1] transition_weight = min(max(transition_weight, 0.), 1.) self._curr_transition_weight = torch.tensor(transition_weight).to( self._curr_transition_weight) # resize real image to target scale if real_imgs.shape[2:] == self.curr_scale: pass elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[ 3] >= self.curr_scale[1]: real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale) else: raise RuntimeError( f'The scale of real image {real_imgs.shape[2:]} is smaller ' f'than current scale {self.curr_scale}.') # disc training set_requires_grad(self.discriminator, True) optimizer['discriminator'].zero_grad() # TODO: add noise sampler to customize noise sampling with torch.no_grad(): fake_imgs = self.generator(None, num_batches=batch_size, curr_scale=self.curr_scale[0], transition_weight=transition_weight) # disc pred for fake imgs and real_imgs disc_pred_fake = self.discriminator( fake_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight) disc_pred_real = self.discriminator( real_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight) # get data dict to compute losses for disc data_dict_ = dict( iteration=curr_iter, gen=self.generator, disc=self.discriminator, disc_pred_fake=disc_pred_fake, disc_pred_real=disc_pred_real, fake_imgs=fake_imgs, real_imgs=real_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight, gen_partial=partial(self.generator, curr_scale=self.curr_scale[0], transition_weight=transition_weight), disc_partial=partial(self.discriminator, curr_scale=self.curr_scale[0], transition_weight=transition_weight)) loss_disc, log_vars_disc = self._get_disc_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_disc)) loss_disc.backward() optimizer['discriminator'].step() # update training log status if dist.is_initialized(): _batch_size = batch_size * dist.get_world_size() else: if 'batch_size' not in running_status: raise RuntimeError( 'You should offer "batch_size" in running status for PGGAN' ) _batch_size = running_status['batch_size'] self.shown_nkimg += (_batch_size / 1000.) log_vars_disc.update( dict(shown_nkimg=self.shown_nkimg.item(), curr_scale=self.curr_scale[0], transition_weight=transition_weight)) # skip generator training if only train discriminator for current # iteration if (curr_iter + 1) % self.disc_steps != 0: results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars_disc, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs # generator training set_requires_grad(self.discriminator, False) optimizer['generator'].zero_grad() # TODO: add noise sampler to customize noise sampling fake_imgs = self.generator(None, num_batches=batch_size, curr_scale=self.curr_scale[0], transition_weight=transition_weight) disc_pred_fake_g = self.discriminator( fake_imgs, curr_scale=self.curr_scale[0], transition_weight=transition_weight) data_dict_ = dict(iteration=curr_iter, gen=self.generator, disc=self.discriminator, fake_imgs=fake_imgs, disc_pred_fake_g=disc_pred_fake_g) loss_gen, log_vars_g = self._get_gen_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_gen)) loss_gen.backward() optimizer['generator'].step() log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) log_vars.update({'batch_size': batch_size}) results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict(log_vars=log_vars, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 # check if a new scale will be added in the next iteration _curr_stage = int( min(sum(self.cum_nkimgs <= self.shown_nkimg.item()), len(self.scales) - 1)) # in the next iteration, we will switch to a new scale if _curr_stage != self.curr_stage: # `self._next_scale_int` is updated at the end of `train_step` self._next_scale_int = self._next_scale_int * 2 return outputs
def train_model(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): logger = get_root_logger(cfg.log_level) # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, # cfg.gpus will be ignored if distributed len(cfg.gpu_ids), dist=distributed, seed=cfg.seed) for ds in dataset ] # put model on gpus if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) use_ddp_wrapper = cfg.get('use_ddp_wrapper', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel if use_ddp_wrapper: mmcv.print_log('Use DDP Wrapper.', 'mmgen') model = DistributedDataParallelWrapper( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) # build runner if cfg.optimizer: optimizer = build_optimizers(model, cfg.optimizer) # In GANs, we allow building optimizer in GAN model. else: optimizer = None # allow users to define the runner if cfg.get('runner', None): runner = build_runner( cfg.runner, dict(model=model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta)) else: runner = IterBasedRunner(model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta) # set if use dynamic ddp in training # is_dynamic_ddp=cfg.get('is_dynamic_ddp', False)) # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp # fp16 setting fp16_cfg = cfg.get('fp16', None) # In GANs, we can directly optimize parameter in `train_step` function. if cfg.get('optimizer_cfg', None) is None: optimizer_config = None elif fp16_cfg is not None: raise NotImplementedError('Fp16 has not been supported.') # optimizer_config = Fp16OptimizerHook( # **cfg.optimizer_config, **fp16_cfg, distributed=distributed) # default to use OptimizerHook elif distributed and 'type' not in cfg.optimizer_config: optimizer_config = OptimizerHook(**cfg.optimizer_config) else: optimizer_config = cfg.optimizer_config # update `out_dir` in ckpt hook if cfg.checkpoint_config is not None: cfg.checkpoint_config['out_dir'] = os.path.join( cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt')) # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) # # DistSamplerSeedHook should be used with EpochBasedRunner # if distributed: # runner.register_hook(DistSamplerSeedHook()) # In general, we do NOT adopt standard evaluation hook in GAN training. # Thus, if you want a eval hook, you need further define the key of # 'evaluation' in the config. # register eval hooks if validate and cfg.get('evaluation', None) is not None: val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) # Support batch_size > 1 in validation val_loader_cfg = { 'samples_per_gpu': 1, 'shuffle': False, 'workers_per_gpu': cfg.data.workers_per_gpu, **cfg.data.get('val_data_loader', {}) } val_dataloader = build_dataloader(val_dataset, dist=distributed, **val_loader_cfg) eval_cfg = deepcopy(cfg.get('evaluation')) eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader)) eval_hook = build_from_cfg(eval_cfg, HOOKS) priority = eval_cfg.pop('priority', 'NORMAL') runner.register_hook(eval_hook, priority=priority) # user-defined hooks if cfg.get('custom_hooks', None): custom_hooks = cfg.custom_hooks assert isinstance(custom_hooks, list), \ f'custom_hooks expect list type, but got {type(custom_hooks)}' for hook_cfg in cfg.custom_hooks: assert isinstance(hook_cfg, dict), \ 'Each item in custom_hooks expects dict type, but got ' \ f'{type(hook_cfg)}' hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_iters)
def _parse_train_cfg(self): """Parsing train config and set some attributes for training.""" if self.train_cfg is None: self.train_cfg = dict() # control the work flow in train step self.disc_steps = self.train_cfg.get('disc_steps', 1) # whether to use exponential moving average for training self.use_ema = self.train_cfg.get('use_ema', False) if self.use_ema: # use deepcopy to guarantee the consistency self.generator_ema = deepcopy(self.generator) # setup interpolation operation at the beginning of training iter interp_real_cfg = deepcopy(self.train_cfg.get('interp_real', None)) if interp_real_cfg is None: interp_real_cfg = dict(mode='bilinear', align_corners=True) self.interp_real_to = partial(F.interpolate, **interp_real_cfg) # parsing the training schedule: scales : kimg assert isinstance(self.train_cfg['nkimgs_per_scale'], dict), ('Please provide "nkimgs_per_' 'scale" to schedule the training procedure.') nkimgs_per_scale = deepcopy(self.train_cfg['nkimgs_per_scale']) self.scales = [] self.nkimgs = [] for k, v in nkimgs_per_scale.items(): # support for different data types if isinstance(k, str): k = (int(k), int(k)) elif isinstance(k, int): k = (k, k) else: assert mmcv.is_tuple_of(k, int) # sanity check for the order of scales assert len(self.scales) == 0 or k[0] > self.scales[-1][0] self.scales.append(k) self.nkimgs.append(v) self.cum_nkimgs = np.cumsum(self.nkimgs) self.curr_stage = 0 self.prev_stage = 0 # actually nkimgs shown at the end of per training stage self._actual_nkimgs = [] # In each scale, transit from previous torgb layer to newer torgb layer # with `transition_kimgs` imgs self.transition_kimgs = self.train_cfg.get('transition_kimgs', 600) # setup optimizer self.optimizer = build_optimizers( self, deepcopy(self.train_cfg['optimizer_cfg'])) # get lr schedule self.g_lr_base = self.train_cfg['g_lr_base'] self.d_lr_base = self.train_cfg['d_lr_base'] # example for lr schedule: {'32': 0.001, '64': 0.0001} self.g_lr_schedule = self.train_cfg.get('g_lr_schedule', dict()) self.d_lr_schedule = self.train_cfg.get('d_lr_schedule', dict()) # reset the states for optimizers, e.g. momentum in Adam self.reset_optim_for_new_scale = self.train_cfg.get( 'reset_optim_for_new_scale', True) # dirty walkround for avoiding optimizer bug in resuming self.prev_stage = self.train_cfg.get('prev_stage', self.prev_stage)