Esempio n. 1
0
    def after_train_iter(self, runner):
        if not self.every_n_iters(runner, self.interval):
            return

        model = runner.model.module if is_module_wrapper(
            runner.model) else runner.model

        # update momentum
        _interp_cfg = deepcopy(self.interp_cfg)
        if self.momentum_policy != 'fixed':
            _updated_args = self.momentum_updater(runner, **self.momentum_cfg)
            _interp_cfg.update(_updated_args)

        for key in self.module_keys:
            # get current ema states
            ema_net = getattr(model, key)
            states_ema = ema_net.state_dict(keep_vars=False)
            # get currently original states
            net = getattr(model, key[:-4])
            states_orig = net.state_dict(keep_vars=True)

            for k, v in states_orig.items():
                if runner.iter < self.start_iter:
                    states_ema[k].data.copy_(v.data)
                else:
                    states_ema[k] = self.interp_func(v,
                                                     states_ema[k],
                                                     trainable=v.requires_grad,
                                                     **_interp_cfg).detach()
            ema_net.load_state_dict(states_ema, strict=True)
Esempio n. 2
0
def multipath_batch_processor(model, data, train_mode, **kwargs):
    """ train_iter """
    if is_module_wrapper(model):
        model_ = model.module
    else:
        model_ = model

    model(**data, mode='target')

    all_loss = 0
    for idx, forward_singleop_online in enumerate(model_.forward_op_online):

        loss = model(**data,
                     mode='train',
                     idx=idx,
                     forward_singleop_online=forward_singleop_online)

        all_loss += loss
        loss /= model_.update_interval
        if model_.use_fp16:
            with apex.amp.scale_loss(loss, model_.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
    losses = dict(loss=(all_loss / (len(model_.forward_op_online))))
    loss, log_vars = parse_losses(losses)

    outputs = dict(loss=loss,
                   log_vars=log_vars,
                   num_samples=len(data['img'].data))

    return outputs
Esempio n. 3
0
 def before_train_epoch(self, runner):
     """Close mosaic and mixup augmentation and switches to use L1 loss."""
     epoch = runner.epoch
     train_loader = runner.data_loader
     model = runner.model
     if is_module_wrapper(model):
         model = model.module
     if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
         runner.logger.info('No mosaic and mixup aug now!')
         # The dataset pipeline cannot be updated when persistent_workers
         # is True, so we need to force the dataloader's multi-process
         # restart. This is a very hacky approach.
         train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
         if hasattr(train_loader, 'persistent_workers'
                    ) and train_loader.persistent_workers is True:
             train_loader._DataLoader__initialized = False
             train_loader._iterator = None
             self._restart_dataloader = True
         runner.logger.info('Add additional L1 loss now!')
         if hasattr(model, 'detector'):
             model.detector.bbox_head.use_l1 = True
         else:
             model.bbox_head.use_l1 = True
     else:
         # Once the restart is complete, we need to restore
         # the initialization flag.
         if self._restart_dataloader:
             train_loader._DataLoader__initialized = True
Esempio n. 4
0
    def __init__(self,
                 *args,
                 is_dynamic_ddp=False,
                 pass_training_status=False,
                 fp16_loss_scaler=None,
                 use_apex_amp=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        if is_module_wrapper(self.model):
            _model = self.model.module
        else:
            _model = self.model

        self.is_dynamic_ddp = is_dynamic_ddp
        self.pass_training_status = pass_training_status

        # add a flag for checking if `self.optimizer` comes from `_model`
        self.optimizer_from_model = False
        # add support for optimizer is None.
        # sanity check for whether `_model` contains self-defined optimizer
        if hasattr(_model, 'optimizer'):
            assert self.optimizer is None, (
                'Runner and model cannot contain optimizer at the same time.')
            self.optimizer_from_model = True
            self.optimizer = _model.optimizer

        # add fp16 grad scaler, using pytorch official GradScaler
        self.with_fp16_grad_scaler = False
        if fp16_loss_scaler is not None:
            self.loss_scaler = GradScaler(**fp16_loss_scaler)
            self.with_fp16_grad_scaler = True
            mmcv.print_log('Use FP16 grad scaler in Training', 'mmgen')

        # flag to use amp in apex (NVIDIA)
        self.use_apex_amp = use_apex_amp
Esempio n. 5
0
    def after_train_epoch(self, runner):  # fast version of eval
        if not self.every_n_epochs(runner, self.interval):
            return
        print('evaluation')
        dataloader = build_dataloader(dataset=self.dataset,
                                      workers_per_gpu=self.cfg.workers_per_gpu,
                                      batch_size=1,
                                      sampler=torch.utils.data.DistributedSampler(self.dataset),
                                      dist=True)
        if is_module_wrapper(runner.model.module):
            model = runnner.model.module
        else:
            model = runner.model
        model.eval()
        results = []
        rank = runner.rank
        world_size = runner.world_size
        if rank == 0:
            prog_bar = mmcv.ProgressBar(len(self.dataset))
        for i, data in enumerate(dataloader):
            with torch.no_grad():
                result = model.val_step(data, None)
            results.append(result)

            if rank == 0:
                batch_size = 1  # something wrong here
                for _ in range(batch_size * world_size):
                    prog_bar.update()
        model.train()

        # collect results from all ranks
        results = collect_results(results, len(self.dataset),
                                  os.path.join(runner.work_dir, 'temp/cycle_eval'))
        if runner.rank == 0:
            self.evaluate(runner, results)
Esempio n. 6
0
 def before_train_iter(self, runner):
     if self.every_n_iters(runner, self.update_interval):
         if is_module_wrapper(runner.model):
             model_ = runner.model.module
         else:
             model_ = runner.model
         model_.forward_op_online = model_.online_backbone.set_forward_cfg(method='fair')
         model_.forward_op_target = model_.online_backbone.set_forward_cfg(method='fair')
Esempio n. 7
0
    def before_train_epoch(self, runner):
        if is_module_wrapper(runner.model):
            model = runner.model.module
        else:
            model = runner.model

        model.set_epoch(runner.epoch)
        m = model.base_momentum * runner.epoch / runner.max_epochs
        model.momentum = m
Esempio n. 8
0
    def train(self, data_loader, **kwargs):
        if is_module_wrapper(self.model):
            _model = self.model.module
        else:
            _model = self.model
        self.model.train()
        self.mode = 'train'
        # check if self.optimizer from model and track it
        if self.optimizer_from_model:
            self.optimizer = _model.optimizer

        self.data_loader = data_loader
        self._epoch = data_loader.epoch
        self.call_hook('before_fetch_train_data')
        data_batch = next(self.data_loader)
        self.call_hook('before_train_iter')

        # prepare input args for train_step
        # running status
        if self.pass_training_status:
            running_status = dict(iteration=self.iter, epoch=self.epoch)
            kwargs['running_status'] = running_status
        # ddp reducer for tracking dynamic computational graph
        if self.is_dynamic_ddp:
            kwargs.update(dict(ddp_reducer=self.model.reducer))

        if self.with_fp16_grad_scaler:
            kwargs.update(dict(loss_scaler=self.loss_scaler))

        if self.use_apex_amp:
            kwargs.update(dict(use_apex_amp=True))

        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)

        # the loss scaler should be updated after ``train_step``
        if self.with_fp16_grad_scaler:
            self.loss_scaler.update()

        # further check for the cases where the optimizer is built in
        # `train_step`.
        if self.optimizer is None:
            if hasattr(_model, 'optimizer'):
                self.optimizer_from_model = True
                self.optimizer = _model.optimizer

        # check if self.optimizer from model and track it
        if self.optimizer_from_model:
            self.optimizer = _model.optimizer
        if not isinstance(outputs, dict):
            raise TypeError('model.train_step() must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._inner_iter += 1
        self._iter += 1
Esempio n. 9
0
def _save_checkpoint(model,
                     filename,
                     optimizer_b=None,
                     optimizer_g=None,
                     optimizer_d=None,
                     meta=None):
    """Save checkpoint to file.

    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
    ``optimizer``. By default ``meta`` will contain version and time info.

    Args:
        model (Module): Module whose params are to be saved.
        filename (str): Checkpoint filename.
        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
        meta (dict, optional): Metadata to be saved in checkpoint.
    """
    if meta is None:
        meta = {}
    elif not isinstance(meta, dict):
        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())

    mmcv.mkdir_or_exist(osp.dirname(filename))
    if is_module_wrapper(model):
        model = model.module

    checkpoint = {
        'meta': meta,
        'state_dict': weights_to_cpu(model.state_dict())
    }
    # save optimizer state dict in the checkpoint
    if isinstance(optimizer_b, Optimizer):
        checkpoint['optimizer_b'] = optimizer_b.state_dict()
    elif isinstance(optimizer_b, dict):
        checkpoint['optimizer_b'] = {}
        for name, optim in optimizer_b.items():
            checkpoint['optimizer_b'][name] = optim.state_dict()
    if isinstance(optimizer_g, Optimizer):
        checkpoint['optimizer_g'] = optimizer_g.state_dict()
    elif isinstance(optimizer_g, dict):
        checkpoint['optimizer_g'] = {}
        for name, optim in optimizer_g.items():
            checkpoint['optimizer_g'][name] = optim.state_dict()
    if isinstance(optimizer_d, Optimizer):
        checkpoint['optimizer_d'] = optimizer_d.state_dict()
    elif isinstance(optimizer_d, dict):
        checkpoint['optimizer_d'] = {}
        for name, optim in optimizer_d.items():
            checkpoint['optimizer_d'][name] = optim.state_dict()
    # immediately flush buffer
    with open(filename, 'wb') as f:
        torch.save(checkpoint, f)
        f.flush()
Esempio n. 10
0
def extract_inception_features(dataloader,
                               inception,
                               num_samples,
                               inception_style='pytorch'):
    """Extract inception features for FID metric.

    Args:
        dataloader (:obj:`DataLoader`): Dataloader for images.
        inception (nn.Module): Inception network.
        num_samples (int): The number of samples to be extracted.
        inception_style (str): The style of Inception network, "pytorch" or
            "stylegan". Defaults to "pytorch".

    Returns:
        torch.Tensor: Inception features.
    """
    batch_size = dataloader.batch_size
    num_iters = num_samples // batch_size
    if num_iters * batch_size < num_samples:
        num_iters += 1
    # define mmcv progress bar
    pbar = mmcv.ProgressBar(num_iters)

    feature_list = []
    curr_iter = 1
    for data in dataloader:
        img = data['real_img']
        pbar.update()

        # the inception network is not wrapped with module wrapper.
        if not is_module_wrapper(inception):
            # put the img to the module device
            img = img.to(get_module_device(inception))

        if inception_style == 'stylegan':
            img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            feature = inception(img, return_features=True)
        else:
            feature = inception(img)[0].view(img.shape[0], -1)
        feature_list.append(feature.to('cpu'))

        if curr_iter >= num_iters:
            break
        curr_iter += 1

    # Attention: the number of features may be different as you want.
    features = torch.cat(feature_list, 0)

    assert features.shape[0] >= num_samples
    features = features[:num_samples]

    # to change the line after pbar
    sys.stdout.write('\n')
    return features
 def before_train_epoch(self, runner):
     """Close mosaic and mixup augmentation and switches to use L1 loss."""
     epoch = runner.epoch
     train_loader = runner.data_loader
     model = runner.model
     if is_module_wrapper(model):
         model = model.module
     if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
         runner.logger.info('No mosaic and mixup aug now!')
         train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
         runner.logger.info('Add additional L1 loss now!')
         model.bbox_head.use_l1 = True
Esempio n. 12
0
 def before_run(self, runner):
     model = runner.model.module if is_module_wrapper(
         runner.model) else runner.model
     # sanity check for ema model
     for k in self.module_keys:
         if not hasattr(model, k) and not hasattr(model, k[:-4]):
             raise RuntimeError(
                 f'Cannot find both {k[:-4]} and {k} network for EMA hook.')
         if not hasattr(model, k) and hasattr(model, k[:-4]):
             setattr(model, k, deepcopy(getattr(model, k[:-4])))
             warnings.warn(
                 f'We do not suggest construct and initialize EMA model {k}'
                 ' in hook. You may explicitly define it by yourself.')
Esempio n. 13
0
 def load(module, prefix=''):
     # recursively check parallel module in case that the model has a
     # complicated structure, e.g., nn.Module(nn.Module(DDP))
     if is_module_wrapper(module):
         module = module.module
     local_metadata = {} if metadata is None else metadata.get(
         prefix[:-1], {})
     module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                  all_missing_keys, unexpected_keys,
                                  err_msg)
     for name, child in module._modules.items():
         if child is not None:
             load(child, prefix + name + '.')
    def before_fetch_train_data(self, runner):
        """The behavior before fetch train data.

        Args:
            runner (object): The runner.
        """
        if not self.every_n_iters(runner, self.interval):
            return
        _module = runner.model.module if is_module_wrapper(
            runner.model) else runner.model

        _next_scale_int = _module._next_scale_int
        if isinstance(_next_scale_int, torch.Tensor):
            _next_scale_int = _next_scale_int.item()
        runner.data_loader.update_dataloader(_next_scale_int)
Esempio n. 15
0
def test_is_module_wrapper():

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(2, 2, 1)

        def forward(self, x):
            return self.conv(x)

    # _verify_model_across_ranks is added in torch1.9.0,
    # _verify_params_across_processes is added in torch1.11.0,
    # so we should check whether _verify_model_across_ranks
    # and _verify_params_across_processes are the member of
    # torch.distributed before mocking
    if hasattr(torch.distributed, '_verify_model_across_ranks'):
        torch.distributed._verify_model_across_ranks = mock
    if hasattr(torch.distributed, '_verify_params_across_processes'):
        torch.distributed._verify_params_across_processes = mock

    model = Model()
    assert not is_module_wrapper(model)

    dp = DataParallel(model)
    assert is_module_wrapper(dp)

    mmdp = MMDataParallel(model)
    assert is_module_wrapper(mmdp)

    ddp = DistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(ddp)

    mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(mmddp)

    deprecated_mmddp = DeprecatedMMDDP(model)
    assert is_module_wrapper(deprecated_mmddp)

    # test module wrapper registry
    @MODULE_WRAPPERS.register_module()
    class ModuleWrapper(object):

        def __init__(self, module):
            self.module = module

        def forward(self, *args, **kwargs):
            return self.module(*args, **kwargs)

    module_wraper = ModuleWrapper(model)
    assert is_module_wrapper(module_wraper)
Esempio n. 16
0
    def before_run(self, runner):
        """To resume model with it's ema parameters more friendly.

        Register ema parameter as ``named_buffer`` to model
        """
        model = runner.model
        if is_module_wrapper(model):
            model = model.module
        self.param_ema_buffer = {}
        self.model_parameters = dict(model.named_parameters(recurse=True))
        for name, value in self.model_parameters.items():
            # "." is not allowed in module's buffer name
            buffer_name = f"ema_{name.replace('.', '_')}"
            self.param_ema_buffer[name] = buffer_name
            model.register_buffer(buffer_name, value.data.clone())
        self.model_buffers = dict(model.named_buffers(recurse=True))
        if self.checkpoint is not None:
            runner.resume(self.checkpoint)
Esempio n. 17
0
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
    """Returns a dictionary containing a whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are
    included. Keys are corresponding parameter and buffer names.

    This method is modified from :meth:`torch.nn.Module.state_dict` to
    recursively check parallel module in case that the model has a complicated
    structure, e.g., nn.Module(nn.Module(DDP)).

    Args:
        module (nn.Module): The module to generate state_dict.
        destination (OrderedDict): Returned dict for the state of the
            module.
        prefix (str): Prefix of the key.
        keep_vars (bool): Whether to keep the variable property of the
            parameters. Default: False.

    Returns:
        dict: A dictionary containing a whole state of the module.
    """
    # recursively check parallel module in case that the model has a
    # complicated structure, e.g., nn.Module(nn.Module(DDP))
    if is_module_wrapper(module):
        module = module.module

    # below is the same as torch.nn.Module.state_dict()
    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()
    destination._metadata[prefix[:-1]] = local_metadata = dict(
        version=module._version)
    _save_to_state_dict(module, destination, prefix, keep_vars)
    for name, child in module._modules.items():
        if child is not None:
            get_state_dict(child,
                           destination,
                           prefix + name + '.',
                           keep_vars=keep_vars)
    for hook in module._state_dict_hooks.values():
        hook_result = hook(module, destination, prefix, local_metadata)
        if hook_result is not None:
            destination = hook_result
    return destination
Esempio n. 18
0
    def after_train_iter(self, runner):
        if not self.every_n_iters(runner, self.interval):
            return

        model = runner.model.module if is_module_wrapper(
            runner.model) else runner.model

        for key in self.module_keys:
            # get current ema states
            ema_net = getattr(model, key)
            states_ema = ema_net.state_dict(keep_vars=False)
            # get currently original states
            net = getattr(model, key[:-4])
            states_orig = net.state_dict(keep_vars=True)

            for k, v in states_orig.items():
                states_ema[k] = self.interp_func(
                    v, states_ema[k], trainable=v.requires_grad).detach()
            ema_net.load_state_dict(states_ema, strict=True)
def test_is_module_wrapper():

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(2, 2, 1)

        def forward(self, x):
            return self.conv(x)

    model = Model()
    assert not is_module_wrapper(model)

    dp = DataParallel(model)
    assert is_module_wrapper(dp)

    mmdp = MMDataParallel(model)
    assert is_module_wrapper(mmdp)

    ddp = DistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(ddp)

    mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(mmddp)

    deprecated_mmddp = DeprecatedMMDDP(model)
    assert is_module_wrapper(deprecated_mmddp)

    # test module wrapper registry
    @MODULE_WRAPPERS.register_module()
    class ModuleWrapper(object):

        def __init__(self, module):
            self.module = module

        def forward(self, *args, **kwargs):
            return self.module(*args, **kwargs)

    module_wraper = ModuleWrapper(model)
    assert is_module_wrapper(module_wraper)
Esempio n. 20
0
    def __init__(self,
                 *args,
                 is_dynamic_ddp=False,
                 pass_training_status=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        if is_module_wrapper(self.model):
            _model = self.model.module
        else:
            _model = self.model

        self.is_dynamic_ddp = is_dynamic_ddp
        self.pass_training_status = pass_training_status

        # add a flag for checking if `self.optimizer` comes from `_model`
        self.optimizer_from_model = False
        # add support for optimizer is None.
        # sanity check for whether `_model` contains self-defined optimizer
        if hasattr(_model, 'optimizer'):
            assert self.optimizer is None, (
                'Runner and model cannot contain optimizer at the same time.')
            self.optimizer_from_model = True
            self.optimizer = _model.optimizer
Esempio n. 21
0
 def _get_model(self, runner):
     if is_module_wrapper(runner.model):
         return runner.model.module
     else:
         return runner.model
Esempio n. 22
0
def save_checkpoint(model, filename, optimizer=None, meta=None):
    """Save checkpoint to file.

    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
    ``optimizer``. By default ``meta`` will contain version and time info.

    Args:
        model (Module): Module whose params are to be saved.
        filename (str): Checkpoint filename.
        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
        meta (dict, optional): Metadata to be saved in checkpoint.
    """
    if meta is None:
        meta = {}
    elif not isinstance(meta, dict):
        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())

    if is_module_wrapper(model):
        model = model.module

    if hasattr(model, 'CLASSES') and model.CLASSES is not None:
        # save class name to the meta
        meta.update(CLASSES=model.CLASSES)

    checkpoint = {
        'meta': meta,
        'state_dict': weights_to_cpu(get_state_dict(model))
    }
    # save optimizer state dict in the checkpoint
    if isinstance(optimizer, Optimizer):
        checkpoint['optimizer'] = optimizer.state_dict()
    elif isinstance(optimizer, dict):
        checkpoint['optimizer'] = {}
        for name, optim in optimizer.items():
            checkpoint['optimizer'][name] = optim.state_dict()

    if filename.startswith('pavi://'):
        try:
            from pavi import modelcloud
            from pavi.exception import NodeNotFoundError
        except ImportError:
            raise ImportError(
                'Please install pavi to load checkpoint from modelcloud.')
        model_path = filename[7:]
        root = modelcloud.Folder()
        model_dir, model_name = osp.split(model_path)
        try:
            model = modelcloud.get(model_dir)
        except NodeNotFoundError:
            model = root.create_training_model(model_dir)
        with TemporaryDirectory() as tmp_dir:
            checkpoint_file = osp.join(tmp_dir, model_name)
            with open(checkpoint_file, 'wb') as f:
                torch.save(checkpoint, f)
                f.flush()
            model.create_file(checkpoint_file, name=model_name)
    else:
        mmcv.mkdir_or_exist(osp.dirname(filename))
        # immediately flush buffer
        with open(filename, 'wb') as f:
            torch.save(checkpoint, f)
            f.flush()
Esempio n. 23
0
 def after_train_iter(self, runner):
     if self.every_n_iters(runner, self.update_interval):
         if is_module_wrapper(runner.model):
             runner.model.module.momentum_update()
         else:
             runner.model.momentum_update()
Esempio n. 24
0
    def train_step(self, data_batch, optimizer):
        """Train step.

        Args:
            data_batch (dict): A batch of data.
            optimizer (obj): Optimizer.

        Returns:
            dict: Returned output.
        """
        # during initialization, load weights from the ema model
        if (self.step_counter == self.start_iter
                and self.generator_ema is not None):
            if is_module_wrapper(self.generator):
                self.generator.module.load_state_dict(
                    self.generator_ema.module.state_dict())
            else:
                self.generator.load_state_dict(self.generator_ema.state_dict())

        # data
        lq = data_batch['lq']
        gt = data_batch['gt']

        gt_pixel, gt_percep, gt_gan = gt.clone(), gt.clone(), gt.clone()
        if self.is_use_sharpened_gt_in_pixel:
            gt_pixel = data_batch['gt_unsharp']
        if self.is_use_sharpened_gt_in_percep:
            gt_percep = data_batch['gt_unsharp']
        if self.is_use_sharpened_gt_in_gan:
            gt_gan = data_batch['gt_unsharp']

        # generator
        fake_g_output = self.generator(lq)

        losses = dict()
        log_vars = dict()

        # no updates to discriminator parameters.
        if self.gan_loss:
            set_requires_grad(self.discriminator, False)

        if (self.step_counter % self.disc_steps == 0
                and self.step_counter >= self.disc_init_steps):
            if self.pixel_loss:
                losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel)
            if self.perceptual_loss:
                loss_percep, loss_style = self.perceptual_loss(
                    fake_g_output, gt_percep)
                if loss_percep is not None:
                    losses['loss_perceptual'] = loss_percep
                if loss_style is not None:
                    losses['loss_style'] = loss_style

            # gan loss for generator
            if self.gan_loss:
                fake_g_pred = self.discriminator(fake_g_output)
                losses['loss_gan'] = self.gan_loss(fake_g_pred,
                                                   target_is_real=True,
                                                   is_disc=False)

            # parse loss
            loss_g, log_vars_g = self.parse_losses(losses)
            log_vars.update(log_vars_g)

            # optimize
            optimizer['generator'].zero_grad()
            loss_g.backward()
            optimizer['generator'].step()

        # discriminator
        if self.gan_loss:
            set_requires_grad(self.discriminator, True)
            # real
            real_d_pred = self.discriminator(gt_gan)
            loss_d_real = self.gan_loss(real_d_pred,
                                        target_is_real=True,
                                        is_disc=True)
            loss_d, log_vars_d = self.parse_losses(
                dict(loss_d_real=loss_d_real))
            optimizer['discriminator'].zero_grad()
            loss_d.backward()
            log_vars.update(log_vars_d)
            # fake
            fake_d_pred = self.discriminator(fake_g_output.detach())
            loss_d_fake = self.gan_loss(fake_d_pred,
                                        target_is_real=False,
                                        is_disc=True)
            loss_d, log_vars_d = self.parse_losses(
                dict(loss_d_fake=loss_d_fake))
            loss_d.backward()
            log_vars.update(log_vars_d)

            optimizer['discriminator'].step()

        self.step_counter += 1

        log_vars.pop('loss')  # remove the unnecessary 'loss'
        outputs = dict(log_vars=log_vars,
                       num_samples=len(gt.data),
                       results=dict(lq=lq.cpu(),
                                    gt=gt.cpu(),
                                    output=fake_g_output.cpu()))

        return outputs
Esempio n. 25
0
def _dist_train(model, dataset, cfg, logger=None, timestamp=None, meta=None):
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(ds,
                         cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu,
                         dist=True,
                         shuffle=True,
                         replace=getattr(cfg.data, 'sampling_replace', False),
                         seed=cfg.seed,
                         drop_last=getattr(cfg.data, 'drop_last', False),
                         prefetch=cfg.prefetch,
                         img_norm_cfg=cfg.img_norm_cfg) for ds in dataset
    ]
    optimizer = build_optimizer(model, cfg.optimizer)
    if 'use_fp16' in cfg and cfg.use_fp16:
        model, optimizer = apex.amp.initialize(model.cuda(),
                                               optimizer,
                                               opt_level="O1")
        model.use_fp16 = True
        print_log('**** Initializing mixed precision done. ****')

    # put model on gpus
    model = MMDistributedDataParallel(
        model if next(model.parameters()).is_cuda else model.cuda(),
        device_ids=[torch.cuda.current_device()],
        broadcast_buffers=False,
        find_unused_parameters=True)

    # build runner
    runner = MultiStageRunner(model=model,
                              batch_processor=multipath_batch_processor,
                              optimizer=optimizer,
                              work_dir=cfg.work_dir,
                              logger=logger,
                              meta=meta,
                              num_stages=model.module.num_block,
                              max_epochs=cfg.total_epochs)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    optimizer_config = DistOptimizerHook(**cfg.optimizer_config)

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
    runner.register_hook(DistSamplerSeedHook())
    # register custom hooks
    for hook in cfg.get('custom_hooks', ()):
        if hook.type == 'DeepClusterHook':
            common_params = dict(dist_mode=True, data_loaders=data_loaders)
        else:
            common_params = dict(dist_mode=True)
        runner.register_hook(build_hook(hook, common_params))

    if cfg.resume_from:
        runner.resume(cfg.resume_from,
                      resume_optimizer='resume_optimizer' in cfg
                      and cfg.resume_optimizer)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    if is_module_wrapper(model):
        model.module.optimizer = optimizer
    else:
        model.optimizer = optimizer
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Esempio n. 26
0
    def __init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            warnings.warn('batch_processor is deprecated, please implement '
                          'train_step() and val_step() in the model instead.')
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
            if is_module_wrapper(model):
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
        else:
            assert hasattr(model, 'train_step')

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

        self.model = model
        self.batch_processor = batch_processor
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
Esempio n. 27
0
 def before_train_epoch(self, runner):
     epoch = runner.epoch
     model = runner.model
     if is_module_wrapper(model):
         model = model.module
     model.set_epoch(epoch)
Esempio n. 28
0
def extract_inception_features(dataloader,
                               inception,
                               num_samples,
                               inception_style='pytorch'):
    """Extract inception features for FID metric.

    Args:
        dataloader (:obj:`DataLoader`): Dataloader for images.
        inception (nn.Module): Inception network.
        num_samples (int): The number of samples to be extracted.
        inception_style (str): The style of Inception network, "pytorch" or
            "stylegan". Defaults to "pytorch".

    Returns:
        torch.Tensor: Inception features.
    """
    batch_size = dataloader.batch_size
    num_iters = num_samples // batch_size
    if num_iters * batch_size < num_samples:
        num_iters += 1
    # define mmcv progress bar
    pbar = mmcv.ProgressBar(num_iters)

    feature_list = []
    curr_iter = 1
    for data in dataloader:
        # a dirty walkround to support multiple datasets (mainly for the
        # unconditional dataset and conditional dataset). In our
        # implementation, unconditioanl dataset will return real images with
        # the key "real_img". However, the conditional dataset contains a key
        # "img" denoting the real images.
        if 'real_img' in data:
            # Mainly for the unconditional dataset in our MMGeneration
            img = data['real_img']
        else:
            # Mainly for conditional dataset in MMClassification
            img = data['img']
        pbar.update()

        # the inception network is not wrapped with module wrapper.
        if not is_module_wrapper(inception):
            # put the img to the module device
            img = img.to(get_module_device(inception))

        if inception_style == 'stylegan':
            img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            feature = inception(img, return_features=True)
        else:
            feature = inception(img)[0].view(img.shape[0], -1)
        feature_list.append(feature.to('cpu'))

        if curr_iter >= num_iters:
            break
        curr_iter += 1

    # Attention: the number of features may be different as you want.
    features = torch.cat(feature_list, 0)

    assert features.shape[0] >= num_samples
    features = features[:num_samples]

    # to change the line after pbar
    sys.stdout.write('\n')
    return features