Пример #1
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """
        # get data from data_batch
        real_imgs = data_batch['real_img']
        # If you adopt ddp, this batch size is local batch size for each GPU.
        batch_size = real_imgs.shape[0]

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # check if optimizer from model
        if hasattr(self, 'optimizer'):
            optimizer = self.optimizer

        # update current stage
        self.curr_stage = int(
            min(sum(self.cum_nkimgs <= self.shown_nkimg.item()),
                len(self.scales) - 1))
        self.curr_scale = self.scales[self.curr_stage]
        self._curr_scale_int = self._next_scale_int.clone()
        # add new scale and update training status
        if self.curr_stage != self.prev_stage:
            self.prev_stage = self.curr_stage
            self._actual_nkimgs.append(self.shown_nkimg.item())
            # reset optimizer
            if self.reset_optim_for_new_scale:
                optim_cfg = deepcopy(self.train_cfg['optimizer_cfg'])
                optim_cfg['generator']['lr'] = self.g_lr_schedule.get(
                    str(self.curr_scale[0]), self.g_lr_base)
                optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get(
                    str(self.curr_scale[0]), self.d_lr_base)
                self.optimizer = build_optimizers(self, optim_cfg)
                optimizer = self.optimizer
                mmcv.print_log('Reset optimizer for new scale', logger='mmgen')

        # update training configs, like transition weight for torgb layers.
        # get current transition weight for interpolating two torgb layers
        if self.curr_stage == 0:
            transition_weight = 1.
        else:
            transition_weight = (
                self.shown_nkimg.item() -
                self._actual_nkimgs[-1]) / self.transition_kimgs
            # clip to [0, 1]
            transition_weight = min(max(transition_weight, 0.), 1.)
        self._curr_transition_weight = torch.tensor(transition_weight).to(
            self._curr_transition_weight)

        # resize real image to target scale
        if real_imgs.shape[2:] == self.curr_scale:
            pass
        elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[
                3] >= self.curr_scale[1]:
            real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale)
        else:
            raise RuntimeError(
                f'The scale of real image {real_imgs.shape[2:]} is smaller '
                f'than current scale {self.curr_scale}.')

        # disc training
        set_requires_grad(self.discriminator, True)
        optimizer['discriminator'].zero_grad()
        # TODO: add noise sampler to customize noise sampling
        with torch.no_grad():
            fake_imgs = self.generator(None,
                                       num_batches=batch_size,
                                       curr_scale=self.curr_scale[0],
                                       transition_weight=transition_weight)

        # disc pred for fake imgs and real_imgs
        disc_pred_fake = self.discriminator(
            fake_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)
        disc_pred_real = self.discriminator(
            real_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)
        # get data dict to compute losses for disc
        data_dict_ = dict(
            iteration=curr_iter,
            gen=self.generator,
            disc=self.discriminator,
            disc_pred_fake=disc_pred_fake,
            disc_pred_real=disc_pred_real,
            fake_imgs=fake_imgs,
            real_imgs=real_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight,
            gen_partial=partial(self.generator,
                                curr_scale=self.curr_scale[0],
                                transition_weight=transition_weight),
            disc_partial=partial(self.discriminator,
                                 curr_scale=self.curr_scale[0],
                                 transition_weight=transition_weight))

        loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
        loss_disc.backward()
        optimizer['discriminator'].step()

        # update training log status
        if dist.is_initialized():
            _batch_size = batch_size * dist.get_world_size()
        else:
            if 'batch_size' not in running_status:
                raise RuntimeError(
                    'You should offer "batch_size" in running status for PGGAN'
                )
            _batch_size = running_status['batch_size']
        self.shown_nkimg += (_batch_size / 1000.)
        log_vars_disc.update(
            dict(shown_nkimg=self.shown_nkimg.item(),
                 curr_scale=self.curr_scale[0],
                 transition_weight=transition_weight))

        # skip generator training if only train discriminator for current
        # iteration
        if (curr_iter + 1) % self.disc_steps != 0:
            results = dict(fake_imgs=fake_imgs.cpu(),
                           real_imgs=real_imgs.cpu())
            outputs = dict(log_vars=log_vars_disc,
                           num_samples=batch_size,
                           results=results)
            if hasattr(self, 'iteration'):
                self.iteration += 1
            return outputs

        # generator training
        set_requires_grad(self.discriminator, False)
        optimizer['generator'].zero_grad()

        # TODO: add noise sampler to customize noise sampling
        fake_imgs = self.generator(None,
                                   num_batches=batch_size,
                                   curr_scale=self.curr_scale[0],
                                   transition_weight=transition_weight)
        disc_pred_fake_g = self.discriminator(
            fake_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)

        data_dict_ = dict(iteration=curr_iter,
                          gen=self.generator,
                          disc=self.discriminator,
                          fake_imgs=fake_imgs,
                          disc_pred_fake_g=disc_pred_fake_g)

        loss_gen, log_vars_g = self._get_gen_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

        loss_gen.backward()
        optimizer['generator'].step()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)
        log_vars.update({'batch_size': batch_size})

        results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
        outputs = dict(log_vars=log_vars,
                       num_samples=batch_size,
                       results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1

        # check if a new scale will be added in the next iteration
        _curr_stage = int(
            min(sum(self.cum_nkimgs <= self.shown_nkimg.item()),
                len(self.scales) - 1))
        # in the next iteration, we will switch to a new scale
        if _curr_stage != self.curr_stage:
            # `self._next_scale_int` is updated at the end of `train_step`
            self._next_scale_int = self._next_scale_int * 2
        return outputs
Пример #2
0
def train_model(model,
                dataset,
                cfg,
                distributed=False,
                validate=False,
                timestamp=None,
                meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        use_ddp_wrapper = cfg.get('use_ddp_wrapper', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        if use_ddp_wrapper:
            mmcv.print_log('Use DDP Wrapper.', 'mmgen')
            model = DistributedDataParallelWrapper(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
        else:
            model = MMDistributedDataParallel(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
    else:
        model = MMDataParallel(model.cuda(cfg.gpu_ids[0]),
                               device_ids=cfg.gpu_ids)

    # build runner
    if cfg.optimizer:
        optimizer = build_optimizers(model, cfg.optimizer)
    # In GANs, we allow building optimizer in GAN model.
    else:
        optimizer = None

    # allow users to define the runner
    if cfg.get('runner', None):
        runner = build_runner(
            cfg.runner,
            dict(model=model,
                 optimizer=optimizer,
                 work_dir=cfg.work_dir,
                 logger=logger,
                 meta=meta))
    else:
        runner = IterBasedRunner(model,
                                 optimizer=optimizer,
                                 work_dir=cfg.work_dir,
                                 logger=logger,
                                 meta=meta)
        # set if use dynamic ddp in training
        # is_dynamic_ddp=cfg.get('is_dynamic_ddp', False))
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)

    # In GANs, we can directly optimize parameter in `train_step` function.
    if cfg.get('optimizer_cfg', None) is None:
        optimizer_config = None
    elif fp16_cfg is not None:
        raise NotImplementedError('Fp16 has not been supported.')
        # optimizer_config = Fp16OptimizerHook(
        #     **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    # default to use OptimizerHook
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # update `out_dir` in  ckpt hook
    if cfg.checkpoint_config is not None:
        cfg.checkpoint_config['out_dir'] = os.path.join(
            cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt'))

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))

    # # DistSamplerSeedHook should be used with EpochBasedRunner
    # if distributed:
    #     runner.register_hook(DistSamplerSeedHook())

    # In general, we do NOT adopt standard evaluation hook in GAN training.
    # Thus, if you want a eval hook, you need further define the key of
    # 'evaluation' in the config.
    # register eval hooks
    if validate and cfg.get('evaluation', None) is not None:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        # Support batch_size > 1 in validation
        val_loader_cfg = {
            'samples_per_gpu': 1,
            'shuffle': False,
            'workers_per_gpu': cfg.data.workers_per_gpu,
            **cfg.data.get('val_data_loader', {})
        }
        val_dataloader = build_dataloader(val_dataset,
                                          dist=distributed,
                                          **val_loader_cfg)
        eval_cfg = deepcopy(cfg.get('evaluation'))
        eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
        eval_hook = build_from_cfg(eval_cfg, HOOKS)
        priority = eval_cfg.pop('priority', 'NORMAL')
        runner.register_hook(eval_hook, priority=priority)

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_iters)
Пример #3
0
    def _parse_train_cfg(self):
        """Parsing train config and set some attributes for training."""
        if self.train_cfg is None:
            self.train_cfg = dict()
        # control the work flow in train step
        self.disc_steps = self.train_cfg.get('disc_steps', 1)

        # whether to use exponential moving average for training
        self.use_ema = self.train_cfg.get('use_ema', False)
        if self.use_ema:
            # use deepcopy to guarantee the consistency
            self.generator_ema = deepcopy(self.generator)

        # setup interpolation operation at the beginning of training iter
        interp_real_cfg = deepcopy(self.train_cfg.get('interp_real', None))
        if interp_real_cfg is None:
            interp_real_cfg = dict(mode='bilinear', align_corners=True)

        self.interp_real_to = partial(F.interpolate, **interp_real_cfg)
        # parsing the training schedule: scales : kimg
        assert isinstance(self.train_cfg['nkimgs_per_scale'],
                          dict), ('Please provide "nkimgs_per_'
                                  'scale" to schedule the training procedure.')
        nkimgs_per_scale = deepcopy(self.train_cfg['nkimgs_per_scale'])
        self.scales = []
        self.nkimgs = []
        for k, v in nkimgs_per_scale.items():
            # support for different data types
            if isinstance(k, str):
                k = (int(k), int(k))
            elif isinstance(k, int):
                k = (k, k)
            else:
                assert mmcv.is_tuple_of(k, int)

            # sanity check for the order of scales
            assert len(self.scales) == 0 or k[0] > self.scales[-1][0]
            self.scales.append(k)
            self.nkimgs.append(v)
        self.cum_nkimgs = np.cumsum(self.nkimgs)
        self.curr_stage = 0
        self.prev_stage = 0
        # actually nkimgs shown at the end of per training stage
        self._actual_nkimgs = []
        # In each scale, transit from previous torgb layer to newer torgb layer
        # with `transition_kimgs` imgs
        self.transition_kimgs = self.train_cfg.get('transition_kimgs', 600)

        # setup optimizer
        self.optimizer = build_optimizers(
            self, deepcopy(self.train_cfg['optimizer_cfg']))
        # get lr schedule
        self.g_lr_base = self.train_cfg['g_lr_base']
        self.d_lr_base = self.train_cfg['d_lr_base']
        # example for lr schedule: {'32': 0.001, '64': 0.0001}
        self.g_lr_schedule = self.train_cfg.get('g_lr_schedule', dict())
        self.d_lr_schedule = self.train_cfg.get('d_lr_schedule', dict())
        # reset the states for optimizers, e.g. momentum in Adam
        self.reset_optim_for_new_scale = self.train_cfg.get(
            'reset_optim_for_new_scale', True)
        # dirty walkround for avoiding optimizer bug in resuming
        self.prev_stage = self.train_cfg.get('prev_stage', self.prev_stage)