def after_train_iter(self, runner): if not self.every_n_iters(runner, self.interval): return model = runner.model.module if is_module_wrapper( runner.model) else runner.model # update momentum _interp_cfg = deepcopy(self.interp_cfg) if self.momentum_policy != 'fixed': _updated_args = self.momentum_updater(runner, **self.momentum_cfg) _interp_cfg.update(_updated_args) for key in self.module_keys: # get current ema states ema_net = getattr(model, key) states_ema = ema_net.state_dict(keep_vars=False) # get currently original states net = getattr(model, key[:-4]) states_orig = net.state_dict(keep_vars=True) for k, v in states_orig.items(): if runner.iter < self.start_iter: states_ema[k].data.copy_(v.data) else: states_ema[k] = self.interp_func(v, states_ema[k], trainable=v.requires_grad, **_interp_cfg).detach() ema_net.load_state_dict(states_ema, strict=True)
def multipath_batch_processor(model, data, train_mode, **kwargs): """ train_iter """ if is_module_wrapper(model): model_ = model.module else: model_ = model model(**data, mode='target') all_loss = 0 for idx, forward_singleop_online in enumerate(model_.forward_op_online): loss = model(**data, mode='train', idx=idx, forward_singleop_online=forward_singleop_online) all_loss += loss loss /= model_.update_interval if model_.use_fp16: with apex.amp.scale_loss(loss, model_.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() losses = dict(loss=(all_loss / (len(model_.forward_op_online)))) loss, log_vars = parse_losses(losses) outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) return outputs
def before_train_epoch(self, runner): """Close mosaic and mixup augmentation and switches to use L1 loss.""" epoch = runner.epoch train_loader = runner.data_loader model = runner.model if is_module_wrapper(model): model = model.module if (epoch + 1) == runner.max_epochs - self.num_last_epochs: runner.logger.info('No mosaic and mixup aug now!') # The dataset pipeline cannot be updated when persistent_workers # is True, so we need to force the dataloader's multi-process # restart. This is a very hacky approach. train_loader.dataset.update_skip_type_keys(self.skip_type_keys) if hasattr(train_loader, 'persistent_workers' ) and train_loader.persistent_workers is True: train_loader._DataLoader__initialized = False train_loader._iterator = None self._restart_dataloader = True runner.logger.info('Add additional L1 loss now!') if hasattr(model, 'detector'): model.detector.bbox_head.use_l1 = True else: model.bbox_head.use_l1 = True else: # Once the restart is complete, we need to restore # the initialization flag. if self._restart_dataloader: train_loader._DataLoader__initialized = True
def __init__(self, *args, is_dynamic_ddp=False, pass_training_status=False, fp16_loss_scaler=None, use_apex_amp=False, **kwargs): super().__init__(*args, **kwargs) if is_module_wrapper(self.model): _model = self.model.module else: _model = self.model self.is_dynamic_ddp = is_dynamic_ddp self.pass_training_status = pass_training_status # add a flag for checking if `self.optimizer` comes from `_model` self.optimizer_from_model = False # add support for optimizer is None. # sanity check for whether `_model` contains self-defined optimizer if hasattr(_model, 'optimizer'): assert self.optimizer is None, ( 'Runner and model cannot contain optimizer at the same time.') self.optimizer_from_model = True self.optimizer = _model.optimizer # add fp16 grad scaler, using pytorch official GradScaler self.with_fp16_grad_scaler = False if fp16_loss_scaler is not None: self.loss_scaler = GradScaler(**fp16_loss_scaler) self.with_fp16_grad_scaler = True mmcv.print_log('Use FP16 grad scaler in Training', 'mmgen') # flag to use amp in apex (NVIDIA) self.use_apex_amp = use_apex_amp
def after_train_epoch(self, runner): # fast version of eval if not self.every_n_epochs(runner, self.interval): return print('evaluation') dataloader = build_dataloader(dataset=self.dataset, workers_per_gpu=self.cfg.workers_per_gpu, batch_size=1, sampler=torch.utils.data.DistributedSampler(self.dataset), dist=True) if is_module_wrapper(runner.model.module): model = runnner.model.module else: model = runner.model model.eval() results = [] rank = runner.rank world_size = runner.world_size if rank == 0: prog_bar = mmcv.ProgressBar(len(self.dataset)) for i, data in enumerate(dataloader): with torch.no_grad(): result = model.val_step(data, None) results.append(result) if rank == 0: batch_size = 1 # something wrong here for _ in range(batch_size * world_size): prog_bar.update() model.train() # collect results from all ranks results = collect_results(results, len(self.dataset), os.path.join(runner.work_dir, 'temp/cycle_eval')) if runner.rank == 0: self.evaluate(runner, results)
def before_train_iter(self, runner): if self.every_n_iters(runner, self.update_interval): if is_module_wrapper(runner.model): model_ = runner.model.module else: model_ = runner.model model_.forward_op_online = model_.online_backbone.set_forward_cfg(method='fair') model_.forward_op_target = model_.online_backbone.set_forward_cfg(method='fair')
def before_train_epoch(self, runner): if is_module_wrapper(runner.model): model = runner.model.module else: model = runner.model model.set_epoch(runner.epoch) m = model.base_momentum * runner.epoch / runner.max_epochs model.momentum = m
def train(self, data_loader, **kwargs): if is_module_wrapper(self.model): _model = self.model.module else: _model = self.model self.model.train() self.mode = 'train' # check if self.optimizer from model and track it if self.optimizer_from_model: self.optimizer = _model.optimizer self.data_loader = data_loader self._epoch = data_loader.epoch self.call_hook('before_fetch_train_data') data_batch = next(self.data_loader) self.call_hook('before_train_iter') # prepare input args for train_step # running status if self.pass_training_status: running_status = dict(iteration=self.iter, epoch=self.epoch) kwargs['running_status'] = running_status # ddp reducer for tracking dynamic computational graph if self.is_dynamic_ddp: kwargs.update(dict(ddp_reducer=self.model.reducer)) if self.with_fp16_grad_scaler: kwargs.update(dict(loss_scaler=self.loss_scaler)) if self.use_apex_amp: kwargs.update(dict(use_apex_amp=True)) outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) # the loss scaler should be updated after ``train_step`` if self.with_fp16_grad_scaler: self.loss_scaler.update() # further check for the cases where the optimizer is built in # `train_step`. if self.optimizer is None: if hasattr(_model, 'optimizer'): self.optimizer_from_model = True self.optimizer = _model.optimizer # check if self.optimizer from model and track it if self.optimizer_from_model: self.optimizer = _model.optimizer if not isinstance(outputs, dict): raise TypeError('model.train_step() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_train_iter') self._inner_iter += 1 self._iter += 1
def _save_checkpoint(model, filename, optimizer_b=None, optimizer_g=None, optimizer_d=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) mmcv.mkdir_or_exist(osp.dirname(filename)) if is_module_wrapper(model): model = model.module checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(model.state_dict()) } # save optimizer state dict in the checkpoint if isinstance(optimizer_b, Optimizer): checkpoint['optimizer_b'] = optimizer_b.state_dict() elif isinstance(optimizer_b, dict): checkpoint['optimizer_b'] = {} for name, optim in optimizer_b.items(): checkpoint['optimizer_b'][name] = optim.state_dict() if isinstance(optimizer_g, Optimizer): checkpoint['optimizer_g'] = optimizer_g.state_dict() elif isinstance(optimizer_g, dict): checkpoint['optimizer_g'] = {} for name, optim in optimizer_g.items(): checkpoint['optimizer_g'][name] = optim.state_dict() if isinstance(optimizer_d, Optimizer): checkpoint['optimizer_d'] = optimizer_d.state_dict() elif isinstance(optimizer_d, dict): checkpoint['optimizer_d'] = {} for name, optim in optimizer_d.items(): checkpoint['optimizer_d'][name] = optim.state_dict() # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush()
def extract_inception_features(dataloader, inception, num_samples, inception_style='pytorch'): """Extract inception features for FID metric. Args: dataloader (:obj:`DataLoader`): Dataloader for images. inception (nn.Module): Inception network. num_samples (int): The number of samples to be extracted. inception_style (str): The style of Inception network, "pytorch" or "stylegan". Defaults to "pytorch". Returns: torch.Tensor: Inception features. """ batch_size = dataloader.batch_size num_iters = num_samples // batch_size if num_iters * batch_size < num_samples: num_iters += 1 # define mmcv progress bar pbar = mmcv.ProgressBar(num_iters) feature_list = [] curr_iter = 1 for data in dataloader: img = data['real_img'] pbar.update() # the inception network is not wrapped with module wrapper. if not is_module_wrapper(inception): # put the img to the module device img = img.to(get_module_device(inception)) if inception_style == 'stylegan': img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) feature = inception(img, return_features=True) else: feature = inception(img)[0].view(img.shape[0], -1) feature_list.append(feature.to('cpu')) if curr_iter >= num_iters: break curr_iter += 1 # Attention: the number of features may be different as you want. features = torch.cat(feature_list, 0) assert features.shape[0] >= num_samples features = features[:num_samples] # to change the line after pbar sys.stdout.write('\n') return features
def before_train_epoch(self, runner): """Close mosaic and mixup augmentation and switches to use L1 loss.""" epoch = runner.epoch train_loader = runner.data_loader model = runner.model if is_module_wrapper(model): model = model.module if (epoch + 1) == runner.max_epochs - self.num_last_epochs: runner.logger.info('No mosaic and mixup aug now!') train_loader.dataset.update_skip_type_keys(self.skip_type_keys) runner.logger.info('Add additional L1 loss now!') model.bbox_head.use_l1 = True
def before_run(self, runner): model = runner.model.module if is_module_wrapper( runner.model) else runner.model # sanity check for ema model for k in self.module_keys: if not hasattr(model, k) and not hasattr(model, k[:-4]): raise RuntimeError( f'Cannot find both {k[:-4]} and {k} network for EMA hook.') if not hasattr(model, k) and hasattr(model, k[:-4]): setattr(model, k, deepcopy(getattr(model, k[:-4]))) warnings.warn( f'We do not suggest construct and initialize EMA model {k}' ' in hook. You may explicitly define it by yourself.')
def load(module, prefix=''): # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.')
def before_fetch_train_data(self, runner): """The behavior before fetch train data. Args: runner (object): The runner. """ if not self.every_n_iters(runner, self.interval): return _module = runner.model.module if is_module_wrapper( runner.model) else runner.model _next_scale_int = _module._next_scale_int if isinstance(_next_scale_int, torch.Tensor): _next_scale_int = _next_scale_int.item() runner.data_loader.update_dataloader(_next_scale_int)
def test_is_module_wrapper(): class Model(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(2, 2, 1) def forward(self, x): return self.conv(x) # _verify_model_across_ranks is added in torch1.9.0, # _verify_params_across_processes is added in torch1.11.0, # so we should check whether _verify_model_across_ranks # and _verify_params_across_processes are the member of # torch.distributed before mocking if hasattr(torch.distributed, '_verify_model_across_ranks'): torch.distributed._verify_model_across_ranks = mock if hasattr(torch.distributed, '_verify_params_across_processes'): torch.distributed._verify_params_across_processes = mock model = Model() assert not is_module_wrapper(model) dp = DataParallel(model) assert is_module_wrapper(dp) mmdp = MMDataParallel(model) assert is_module_wrapper(mmdp) ddp = DistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(ddp) mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(mmddp) deprecated_mmddp = DeprecatedMMDDP(model) assert is_module_wrapper(deprecated_mmddp) # test module wrapper registry @MODULE_WRAPPERS.register_module() class ModuleWrapper(object): def __init__(self, module): self.module = module def forward(self, *args, **kwargs): return self.module(*args, **kwargs) module_wraper = ModuleWrapper(model) assert is_module_wrapper(module_wraper)
def before_run(self, runner): """To resume model with it's ema parameters more friendly. Register ema parameter as ``named_buffer`` to model """ model = runner.model if is_module_wrapper(model): model = model.module self.param_ema_buffer = {} self.model_parameters = dict(model.named_parameters(recurse=True)) for name, value in self.model_parameters.items(): # "." is not allowed in module's buffer name buffer_name = f"ema_{name.replace('.', '_')}" self.param_ema_buffer[name] = buffer_name model.register_buffer(buffer_name, value.data.clone()) self.model_buffers = dict(model.named_buffers(recurse=True)) if self.checkpoint is not None: runner.resume(self.checkpoint)
def get_state_dict(module, destination=None, prefix='', keep_vars=False): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel module in case that the model has a complicated structure, e.g., nn.Module(nn.Module(DDP)). Args: module (nn.Module): The module to generate state_dict. destination (OrderedDict): Returned dict for the state of the module. prefix (str): Prefix of the key. keep_vars (bool): Whether to keep the variable property of the parameters. Default: False. Returns: dict: A dictionary containing a whole state of the module. """ # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module # below is the same as torch.nn.Module.state_dict() if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict( version=module._version) _save_to_state_dict(module, destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: get_state_dict(child, destination, prefix + name + '.', keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination
def after_train_iter(self, runner): if not self.every_n_iters(runner, self.interval): return model = runner.model.module if is_module_wrapper( runner.model) else runner.model for key in self.module_keys: # get current ema states ema_net = getattr(model, key) states_ema = ema_net.state_dict(keep_vars=False) # get currently original states net = getattr(model, key[:-4]) states_orig = net.state_dict(keep_vars=True) for k, v in states_orig.items(): states_ema[k] = self.interp_func( v, states_ema[k], trainable=v.requires_grad).detach() ema_net.load_state_dict(states_ema, strict=True)
def test_is_module_wrapper(): class Model(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(2, 2, 1) def forward(self, x): return self.conv(x) model = Model() assert not is_module_wrapper(model) dp = DataParallel(model) assert is_module_wrapper(dp) mmdp = MMDataParallel(model) assert is_module_wrapper(mmdp) ddp = DistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(ddp) mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(mmddp) deprecated_mmddp = DeprecatedMMDDP(model) assert is_module_wrapper(deprecated_mmddp) # test module wrapper registry @MODULE_WRAPPERS.register_module() class ModuleWrapper(object): def __init__(self, module): self.module = module def forward(self, *args, **kwargs): return self.module(*args, **kwargs) module_wraper = ModuleWrapper(model) assert is_module_wrapper(module_wraper)
def __init__(self, *args, is_dynamic_ddp=False, pass_training_status=False, **kwargs): super().__init__(*args, **kwargs) if is_module_wrapper(self.model): _model = self.model.module else: _model = self.model self.is_dynamic_ddp = is_dynamic_ddp self.pass_training_status = pass_training_status # add a flag for checking if `self.optimizer` comes from `_model` self.optimizer_from_model = False # add support for optimizer is None. # sanity check for whether `_model` contains self-defined optimizer if hasattr(_model, 'optimizer'): assert self.optimizer is None, ( 'Runner and model cannot contain optimizer at the same time.') self.optimizer_from_model = True self.optimizer = _model.optimizer
def _get_model(self, runner): if is_module_wrapper(runner.model): return runner.model.module else: return runner.model
def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() if filename.startswith('pavi://'): try: from pavi import modelcloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush()
def after_train_iter(self, runner): if self.every_n_iters(runner, self.update_interval): if is_module_wrapper(runner.model): runner.model.module.momentum_update() else: runner.model.momentum_update()
def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ # during initialization, load weights from the ema model if (self.step_counter == self.start_iter and self.generator_ema is not None): if is_module_wrapper(self.generator): self.generator.module.load_state_dict( self.generator_ema.module.state_dict()) else: self.generator.load_state_dict(self.generator_ema.state_dict()) # data lq = data_batch['lq'] gt = data_batch['gt'] gt_pixel, gt_percep, gt_gan = gt.clone(), gt.clone(), gt.clone() if self.is_use_sharpened_gt_in_pixel: gt_pixel = data_batch['gt_unsharp'] if self.is_use_sharpened_gt_in_percep: gt_percep = data_batch['gt_unsharp'] if self.is_use_sharpened_gt_in_gan: gt_gan = data_batch['gt_unsharp'] # generator fake_g_output = self.generator(lq) losses = dict() log_vars = dict() # no updates to discriminator parameters. if self.gan_loss: set_requires_grad(self.discriminator, False) if (self.step_counter % self.disc_steps == 0 and self.step_counter >= self.disc_init_steps): if self.pixel_loss: losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel) if self.perceptual_loss: loss_percep, loss_style = self.perceptual_loss( fake_g_output, gt_percep) if loss_percep is not None: losses['loss_perceptual'] = loss_percep if loss_style is not None: losses['loss_style'] = loss_style # gan loss for generator if self.gan_loss: fake_g_pred = self.discriminator(fake_g_output) losses['loss_gan'] = self.gan_loss(fake_g_pred, target_is_real=True, is_disc=False) # parse loss loss_g, log_vars_g = self.parse_losses(losses) log_vars.update(log_vars_g) # optimize optimizer['generator'].zero_grad() loss_g.backward() optimizer['generator'].step() # discriminator if self.gan_loss: set_requires_grad(self.discriminator, True) # real real_d_pred = self.discriminator(gt_gan) loss_d_real = self.gan_loss(real_d_pred, target_is_real=True, is_disc=True) loss_d, log_vars_d = self.parse_losses( dict(loss_d_real=loss_d_real)) optimizer['discriminator'].zero_grad() loss_d.backward() log_vars.update(log_vars_d) # fake fake_d_pred = self.discriminator(fake_g_output.detach()) loss_d_fake = self.gan_loss(fake_d_pred, target_is_real=False, is_disc=True) loss_d, log_vars_d = self.parse_losses( dict(loss_d_fake=loss_d_fake)) loss_d.backward() log_vars.update(log_vars_d) optimizer['discriminator'].step() self.step_counter += 1 log_vars.pop('loss') # remove the unnecessary 'loss' outputs = dict(log_vars=log_vars, num_samples=len(gt.data), results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) return outputs
def _dist_train(model, dataset, cfg, logger=None, timestamp=None, meta=None): # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader(ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True, shuffle=True, replace=getattr(cfg.data, 'sampling_replace', False), seed=cfg.seed, drop_last=getattr(cfg.data, 'drop_last', False), prefetch=cfg.prefetch, img_norm_cfg=cfg.img_norm_cfg) for ds in dataset ] optimizer = build_optimizer(model, cfg.optimizer) if 'use_fp16' in cfg and cfg.use_fp16: model, optimizer = apex.amp.initialize(model.cuda(), optimizer, opt_level="O1") model.use_fp16 = True print_log('**** Initializing mixed precision done. ****') # put model on gpus model = MMDistributedDataParallel( model if next(model.parameters()).is_cuda else model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True) # build runner runner = MultiStageRunner(model=model, batch_processor=multipath_batch_processor, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta, num_stages=model.module.num_block, max_epochs=cfg.total_epochs) # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp optimizer_config = DistOptimizerHook(**cfg.optimizer_config) # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config) runner.register_hook(DistSamplerSeedHook()) # register custom hooks for hook in cfg.get('custom_hooks', ()): if hook.type == 'DeepClusterHook': common_params = dict(dist_mode=True, data_loaders=data_loaders) else: common_params = dict(dist_mode=True) runner.register_hook(build_hook(hook, common_params)) if cfg.resume_from: runner.resume(cfg.resume_from, resume_optimizer='resume_optimizer' in cfg and cfg.resume_optimizer) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) if is_module_wrapper(model): model.module.optimizer = optimizer else: model.optimizer = optimizer runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def __init__(self, model, batch_processor=None, optimizer=None, work_dir=None, logger=None, meta=None, max_iters=None, max_epochs=None): if batch_processor is not None: if not callable(batch_processor): raise TypeError('batch_processor must be callable, ' f'but got {type(batch_processor)}') warnings.warn('batch_processor is deprecated, please implement ' 'train_step() and val_step() in the model instead.') # raise an error is `batch_processor` is not None and # `model.train_step()` exists. if is_module_wrapper(model): _model = model.module else: _model = model if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): raise RuntimeError( 'batch_processor and model.train_step()/model.val_step() ' 'cannot be both available.') else: assert hasattr(model, 'train_step') # check the type of `optimizer` if isinstance(optimizer, dict): for name, optim in optimizer.items(): if not isinstance(optim, Optimizer): raise TypeError( f'optimizer must be a dict of torch.optim.Optimizers, ' f'but optimizer["{name}"] is a {type(optim)}') elif not isinstance(optimizer, Optimizer) and optimizer is not None: raise TypeError( f'optimizer must be a torch.optim.Optimizer object ' f'or dict or None, but got {type(optimizer)}') # check the type of `logger` if not isinstance(logger, logging.Logger): raise TypeError(f'logger must be a logging.Logger object, ' f'but got {type(logger)}') # check the type of `meta` if meta is not None and not isinstance(meta, dict): raise TypeError( f'meta must be a dict or None, but got {type(meta)}') self.model = model self.batch_processor = batch_processor self.optimizer = optimizer self.logger = logger self.meta = meta # create work_dir if mmcv.is_str(work_dir): self.work_dir = osp.abspath(work_dir) mmcv.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 if max_epochs is not None and max_iters is not None: raise ValueError( 'Only one of `max_epochs` or `max_iters` can be set.') self._max_epochs = max_epochs self._max_iters = max_iters # TODO: Redesign LogBuffer, it is not flexible and elegant enough self.log_buffer = LogBuffer()
def before_train_epoch(self, runner): epoch = runner.epoch model = runner.model if is_module_wrapper(model): model = model.module model.set_epoch(epoch)
def extract_inception_features(dataloader, inception, num_samples, inception_style='pytorch'): """Extract inception features for FID metric. Args: dataloader (:obj:`DataLoader`): Dataloader for images. inception (nn.Module): Inception network. num_samples (int): The number of samples to be extracted. inception_style (str): The style of Inception network, "pytorch" or "stylegan". Defaults to "pytorch". Returns: torch.Tensor: Inception features. """ batch_size = dataloader.batch_size num_iters = num_samples // batch_size if num_iters * batch_size < num_samples: num_iters += 1 # define mmcv progress bar pbar = mmcv.ProgressBar(num_iters) feature_list = [] curr_iter = 1 for data in dataloader: # a dirty walkround to support multiple datasets (mainly for the # unconditional dataset and conditional dataset). In our # implementation, unconditioanl dataset will return real images with # the key "real_img". However, the conditional dataset contains a key # "img" denoting the real images. if 'real_img' in data: # Mainly for the unconditional dataset in our MMGeneration img = data['real_img'] else: # Mainly for conditional dataset in MMClassification img = data['img'] pbar.update() # the inception network is not wrapped with module wrapper. if not is_module_wrapper(inception): # put the img to the module device img = img.to(get_module_device(inception)) if inception_style == 'stylegan': img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) feature = inception(img, return_features=True) else: feature = inception(img)[0].view(img.shape[0], -1) feature_list.append(feature.to('cpu')) if curr_iter >= num_iters: break curr_iter += 1 # Attention: the number of features may be different as you want. features = torch.cat(feature_list, 0) assert features.shape[0] >= num_samples features = features[:num_samples] # to change the line after pbar sys.stdout.write('\n') return features