Esempio n. 1
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator
    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
    mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
    mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')

    generator.eval()

    if not args.use_cpu:
        generator = generator.cuda()
    if args.show_mode == 'sequence':
        assert args.endpoint >= 2
    else:
        assert args.endpoint >= 2 and args.endpoint % 2 == 0

    kwargs = dict(max_batch_size=args.batch_size)
    if args.sample_cfg is None:
        args.sample_cfg = dict()
    kwargs.update(args.sample_cfg)

    # get noises corresponding to each endpoint
    noise_batch = batch_inference(
        generator,
        None,
        num_batches=args.endpoint,
        dict_key='noise_batch' if args.space == 'z' else 'latent',
        return_noise=True,
        **kwargs)

    if args.space == 'w':
        kwargs['truncation_latent'] = generator.get_mean_latent()
        kwargs['input_is_latent'] = True

    if args.show_mode == 'sequence':
        results = sample_from_path(generator, noise_batch[:-1, ],
                                   noise_batch[1:, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    else:
        results = sample_from_path(generator, noise_batch[::2, ],
                                   noise_batch[1::2, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    # reorder results
    results = torch.stack(results).permute(1, 0, 2, 3, 4)
    _, _, ch, h, w = results.shape
    results = results.reshape(-1, ch, h, w)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)
    # save image
    mmcv.mkdir_or_exist(args.samples_path)
    if args.show_mode == 'sequence':
        for i in range(results.shape[0]):
            image = results[i:i + 1]
            save_image(
                image,
                os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png'))
    else:
        save_image(
            results,
            os.path.join(args.samples_path, 'group.png'),
            nrow=args.interval)
Esempio n. 2
0
                        help='Other customized kwargs for sampling function')

    # system args
    parser.add_argument('--num-samples', type=int, default=2)
    parser.add_argument('--sample-path', type=str, default=None)
    parser.add_argument('--random-seed', type=int, default=2020)

    args = parser.parse_args()

    set_random_seed(args.random_seed)
    cfg = mmcv.Config.fromfile(args.config)

    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    mmcv.print_log('Building models and loading checkpoints', 'mmgen')
    # build model
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    model.eval()
    load_checkpoint(model, args.ckpt, map_location='cpu')

    # get generator
    if model.use_ema:
        generator = model.generator_ema
    else:
        generator = model.generator

    generator = generator.to(device)
Esempio n. 3
0
            dict(type='ImageToTensor', keys=['real_img'])
        ]
        # insert flip aug
        if args.flip:
            pipeline.insert(
                1, dict(type='Flip', keys=['real_img'],
                        direction='horizontal'))

    # build dataloader
    if args.imgsdir is not None:
        dataset = UnconditionalImageDataset(args.imgsdir, pipeline)
    elif args.data_cfg is not None:
        # Please make sure the dataset will sample images in `RGB` order.
        data_config = Config.fromfile(args.data_cfg)
        subset_config = data_config.data.get(args.subset, None)
        print_log(subset_config, 'mmgen')
        dataset = build_dataset(subset_config)
    else:
        raise RuntimeError('Please provide imgsdir or data_cfg')

    data_loader = build_dataloader(dataset,
                                   args.batch_size,
                                   4,
                                   dist=False,
                                   shuffle=(not args.no_shuffle))

    mmcv.mkdir_or_exist(args.pkl_dir)

    # build inception network
    if args.inception_style == 'stylegan':
        inception = torch.jit.load(args.inception_pth).eval().cuda()
Esempio n. 4
0
def main():
    args = parse_args()
    # set cudnn_benchmark
    cfg = Config.fromfile(args.config)
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    os.makedirs(args.results_dir, exist_ok=True)

    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()

    model = init_model(args.config, args.checkpoint, device='cpu')
    g_ema = model.generator_ema
    g_ema.eval()
    if not args.use_cpu:
        g_ema = g_ema.cuda()

    mean_latent = g_ema.get_mean_latent()

    # if given proj_latent
    if args.proj_latent is not None:
        mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
        proj_file = torch.load(args.proj_latent)
        proj_n = len(proj_file)
        assert proj_n == 1
        noise_batch = []
        for img_path in proj_file:
            noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
        latent_code_init = torch.cat(noise_batch, dim=0).cuda()
    elif args.mode == 'edit':
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            results = g_ema([latent_code_init_not_trunc],
                            return_latents=True,
                            truncation=args.truncation,
                            truncation_latent=mean_latent)
            latent_code_init = results['latent']
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    with torch.no_grad():
        img_orig = g_ema([latent_code_init],
                         input_is_latent=True,
                         randomize_noise=False)

    latent = latent_code_init.detach().clone()
    latent.requires_grad = True

    clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size))
    id_loss = FaceIdLoss(
        facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda'))

    optimizer = optim.Adam([latent], lr=args.lr)

    pbar = tqdm(range(args.step))
    mmcv.print_log(f'Description: {args.description}')
    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]['lr'] = lr

        img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False)

        img_gen = img_gen[:, [2, 1, 0], ...]

        # clip loss
        c_loss = clip_loss(image=img_gen, text=text_inputs)

        if args.id_lambda > 0:
            i_loss = id_loss(pred=img_gen, gt=img_orig)[0]
        else:
            i_loss = 0

        if args.mode == 'edit':
            l2_loss = ((latent_code_init - latent)**2).sum()
            loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f'loss: {loss.item():.4f};'))
        if args.save_interval > 0 and (i % args.save_interval == 0):
            with torch.no_grad():
                img_gen = g_ema([latent],
                                input_is_latent=True,
                                randomize_noise=False)

            img_gen = img_gen[:, [2, 1, 0], ...]

            torchvision.utils.save_image(img_gen,
                                         os.path.join(
                                             args.results_dir,
                                             f'{str(i).zfill(5)}.png'),
                                         normalize=True,
                                         range=(-1, 1))

    if args.mode == 'edit':
        img_orig = img_orig[:, [2, 1, 0], ...]
        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    torchvision.utils.save_image(final_result.detach().cpu(),
                                 os.path.join(args.results_dir,
                                              'final_result.png'),
                                 normalize=True,
                                 scale_each=True,
                                 range=(-1, 1))
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """
        # get data from data_batch
        real_imgs = data_batch['real_img']
        # If you adopt ddp, this batch size is local batch size for each GPU.
        batch_size = real_imgs.shape[0]

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # check if optimizer from model
        if hasattr(self, 'optimizer'):
            optimizer = self.optimizer

        # update current stage
        self.curr_stage = int(
            min(sum(self.cum_nkimgs <= self.shown_nkimg.item()),
                len(self.scales) - 1))
        self.curr_scale = self.scales[self.curr_stage]
        self._curr_scale_int = self._next_scale_int.clone()
        # add new scale and update training status
        if self.curr_stage != self.prev_stage:
            self.prev_stage = self.curr_stage
            self._actual_nkimgs.append(self.shown_nkimg.item())
            # reset optimizer
            if self.reset_optim_for_new_scale:
                optim_cfg = deepcopy(self.train_cfg['optimizer_cfg'])
                optim_cfg['generator']['lr'] = self.g_lr_schedule.get(
                    str(self.curr_scale[0]), self.g_lr_base)
                optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get(
                    str(self.curr_scale[0]), self.d_lr_base)
                self.optimizer = build_optimizers(self, optim_cfg)
                optimizer = self.optimizer
                mmcv.print_log('Reset optimizer for new scale', logger='mmgen')

        # update training configs, like transition weight for torgb layers.
        # get current transition weight for interpolating two torgb layers
        if self.curr_stage == 0:
            transition_weight = 1.
        else:
            transition_weight = (
                self.shown_nkimg.item() -
                self._actual_nkimgs[-1]) / self.transition_kimgs
            # clip to [0, 1]
            transition_weight = min(max(transition_weight, 0.), 1.)
        self._curr_transition_weight = torch.tensor(transition_weight).to(
            self._curr_transition_weight)

        # resize real image to target scale
        if real_imgs.shape[2:] == self.curr_scale:
            pass
        elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[
                3] >= self.curr_scale[1]:
            real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale)
        else:
            raise RuntimeError(
                f'The scale of real image {real_imgs.shape[2:]} is smaller '
                f'than current scale {self.curr_scale}.')

        # disc training
        set_requires_grad(self.discriminator, True)
        optimizer['discriminator'].zero_grad()
        # TODO: add noise sampler to customize noise sampling
        with torch.no_grad():
            fake_imgs = self.generator(None,
                                       num_batches=batch_size,
                                       curr_scale=self.curr_scale[0],
                                       transition_weight=transition_weight)

        # disc pred for fake imgs and real_imgs
        disc_pred_fake = self.discriminator(
            fake_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)
        disc_pred_real = self.discriminator(
            real_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)
        # get data dict to compute losses for disc
        data_dict_ = dict(
            iteration=curr_iter,
            gen=self.generator,
            disc=self.discriminator,
            disc_pred_fake=disc_pred_fake,
            disc_pred_real=disc_pred_real,
            fake_imgs=fake_imgs,
            real_imgs=real_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight,
            gen_partial=partial(self.generator,
                                curr_scale=self.curr_scale[0],
                                transition_weight=transition_weight),
            disc_partial=partial(self.discriminator,
                                 curr_scale=self.curr_scale[0],
                                 transition_weight=transition_weight))

        loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
        loss_disc.backward()
        optimizer['discriminator'].step()

        # update training log status
        if dist.is_initialized():
            _batch_size = batch_size * dist.get_world_size()
        else:
            if 'batch_size' not in running_status:
                raise RuntimeError(
                    'You should offer "batch_size" in running status for PGGAN'
                )
            _batch_size = running_status['batch_size']
        self.shown_nkimg += (_batch_size / 1000.)
        log_vars_disc.update(
            dict(shown_nkimg=self.shown_nkimg.item(),
                 curr_scale=self.curr_scale[0],
                 transition_weight=transition_weight))

        # skip generator training if only train discriminator for current
        # iteration
        if (curr_iter + 1) % self.disc_steps != 0:
            results = dict(fake_imgs=fake_imgs.cpu(),
                           real_imgs=real_imgs.cpu())
            outputs = dict(log_vars=log_vars_disc,
                           num_samples=batch_size,
                           results=results)
            if hasattr(self, 'iteration'):
                self.iteration += 1
            return outputs

        # generator training
        set_requires_grad(self.discriminator, False)
        optimizer['generator'].zero_grad()

        # TODO: add noise sampler to customize noise sampling
        fake_imgs = self.generator(None,
                                   num_batches=batch_size,
                                   curr_scale=self.curr_scale[0],
                                   transition_weight=transition_weight)
        disc_pred_fake_g = self.discriminator(
            fake_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)

        data_dict_ = dict(iteration=curr_iter,
                          gen=self.generator,
                          disc=self.discriminator,
                          fake_imgs=fake_imgs,
                          disc_pred_fake_g=disc_pred_fake_g)

        loss_gen, log_vars_g = self._get_gen_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

        loss_gen.backward()
        optimizer['generator'].step()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)
        log_vars.update({'batch_size': batch_size})

        results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
        outputs = dict(log_vars=log_vars,
                       num_samples=batch_size,
                       results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1

        # check if a new scale will be added in the next iteration
        _curr_stage = int(
            min(sum(self.cum_nkimgs <= self.shown_nkimg.item()),
                len(self.scales) - 1))
        # in the next iteration, we will switch to a new scale
        if _curr_stage != self.curr_stage:
            # `self._next_scale_int` is updated at the end of `train_step`
            self._next_scale_int = self._next_scale_int * 2
        return outputs
Esempio n. 6
0
def single_gpu_evaluation(model,
                          data_loader,
                          metrics,
                          logger,
                          basic_table_info,
                          batch_size,
                          samples_path=None,
                          **kwargs):
    """Evaluate model with a single gpu.

    This method evaluate model with a single gpu and displays eval progress
     bar.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): PyTorch data loader.
        metrics (list): List of metric objects.
        logger (Logger): logger used to record results of evaluation.
        batch_size (int): Batch size of images fed into metrics.
        basic_table_info (dict): Dictionary containing the basic information \
            of the metric table include training configuration and ckpt.
        samples_path (str): Used to save generated images. If it's none, we'll
            give it a default directory and delete it after finishing the
            evaluation. Default to None.
        kwargs (dict): Other arguments.
    """
    # eval special metric online only
    special_metric_name = ['PPL']
    for metric in metrics:
        assert metric.name not in special_metric_name, 'Please eval '\
             f'{metric.name} online'

    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(samples_path,
                         suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)

    # if no images, `num_exist` should be zero
    for begin in range(num_exist, num_needed, batch_size):
        end = min(begin + batch_size, max_num_images)
        fakes = model(None,
                      num_batches=end - begin,
                      return_loss=False,
                      sample_model=basic_table_info['sample_model'],
                      **kwargs)
        pbar.update(end - begin)
        for i in range(end - begin):
            images = fakes[i:i + 1]
            images = ((images + 1) / 2)
            images = images[:, [2, 1, 0], ...]
            images = images.clamp_(0, 1)
            image_name = str(begin + i) + '.png'
            save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(samples_path, batch_size)
    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        # feed in real images
        for data in data_loader:
            reals = data['real_img']
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
        metric.summary()
    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
    if delete_samples_path:
        shutil.rmtree(samples_path)
Esempio n. 7
0
def main():

    args = parse_args()

    assert args.out or args.eval or args.format_only or args.show \
        or args.show_dir, \
        ('Please specify at least one operation (save/eval/format/show the '
         'results / save the results) with the argument "--out", "--eval"'
         ', "--format-only", "--show" or "--show-dir"')

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

    cfg = Config.fromfile(args.config)
    if cfg.get('USE_MMDET', False):
        from mmdet.apis import multi_gpu_test, single_gpu_test
        from mmdet.datasets import build_dataloader
        from mmdet.models import build_detector as build_model
    else:
        from mmtrack.apis import multi_gpu_test, single_gpu_test
        from mmtrack.datasets import build_dataloader
        from mmtrack.models import build_model
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # cfg.model.pretrains = None
    if hasattr(cfg.model, 'detector'):
        cfg.model.detector.pretrained = None
    cfg.data.test.test_mode = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(dataset,
                                   samples_per_gpu=1,
                                   workers_per_gpu=cfg.data.workers_per_gpu,
                                   dist=distributed,
                                   shuffle=False)

    logger = get_logger('ParamsSearcher', log_file=args.log)
    # get all cases
    search_params = get_search_params(cfg.model.tracker, logger=logger)
    combinations = [p for p in product(*search_params.values())]
    search_cfgs = []
    for c in combinations:
        search_cfg = dotty(cfg.model.tracker.copy())
        for i, k in enumerate(search_params.keys()):
            search_cfg[k] = c[i]
        search_cfgs.append(dict(search_cfg))
    print_log(f'Totally {len(search_cfgs)} cases.', logger)
    # init with the first one
    cfg.model.tracker = search_cfgs[0].copy()

    # build the model and load checkpoint
    if cfg.get('test_cfg', False):
        model = build_model(cfg.model,
                            train_cfg=cfg.train_cfg,
                            test_cfg=cfg.test_cfg)
    else:
        model = build_model(cfg.model)
    # We need call `init_weights()` to load pretained weights in MOT task.
    model.init_weights()
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)

    if args.checkpoint is not None:
        checkpoint = load_checkpoint(model,
                                     args.checkpoint,
                                     map_location='cpu')
        if 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
    if not hasattr(model, 'CLASSES'):
        model.CLASSES = dataset.CLASSES

    if args.fuse_conv_bn:
        model = fuse_conv_bn(model)

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
    else:
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False)

    print_log(f'Record {cfg.search_metrics}.', logger)
    for i, search_cfg in enumerate(search_cfgs):
        if not distributed:
            model.module.tracker = build_tracker(search_cfg)
            outputs = single_gpu_test(model, data_loader, args.show,
                                      args.show_dir)
        else:
            model.module.tracker = build_tracker(search_cfg)
            outputs = multi_gpu_test(model, data_loader, args.tmpdir,
                                     args.gpu_collect)
        rank, _ = get_dist_info()
        if rank == 0:
            if args.out:
                print(f'\nwriting results to {args.out}')
                mmcv.dump(outputs, args.out)
            kwargs = {} if args.eval_options is None else args.eval_options
            if args.format_only:
                dataset.format_results(outputs, **kwargs)
            if args.eval:
                eval_kwargs = cfg.get('evaluation', {}).copy()
                # hard-code way to remove EvalHook args
                for key in ['interval', 'tmpdir', 'start', 'gpu_collect']:
                    eval_kwargs.pop(key, None)
                eval_kwargs.update(dict(metric=args.eval, **kwargs))
                results = dataset.evaluate(outputs, **eval_kwargs)
                _records = []
                for k in cfg.search_metrics:
                    if isinstance(results[k], float):
                        _records.append(f'{(results[k]):.3f}')
                    else:
                        _records.append(f'{(results[k])}')
                print_log(f'{combinations[i]}: {_records}', logger)
Esempio n. 8
0
def train_model(model,
                dataset,
                cfg,
                distributed=False,
                validate=False,
                timestamp=None,
                meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            persistent_workers=cfg.data.get('persistent_workers', False),
            seed=cfg.seed) for ds in dataset
    ]

    # dirty code for use apex amp
    # apex.amp request that models should be in cuda device before
    # initialization.
    if cfg.get('apex_amp', None):
        assert distributed, (
            'Currently, apex.amp is only supported with DDP training.')
        model = model.cuda()

    # build optimizer
    if cfg.optimizer:
        optimizer = build_optimizers(model, cfg.optimizer)
    # In GANs, we allow building optimizer in GAN model.
    else:
        optimizer = None

    _use_apex_amp = False
    if cfg.get('apex_amp', None):
        model, optimizer = apex_amp_initialize(model, optimizer,
                                               **cfg.apex_amp)
        _use_apex_amp = True

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        use_ddp_wrapper = cfg.get('use_ddp_wrapper', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        if use_ddp_wrapper:
            mmcv.print_log('Use DDP Wrapper.', 'mmgen')
            model = DistributedDataParallelWrapper(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
        else:
            model = MMDistributedDataParallel(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
    else:
        model = MMDataParallel(
            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

    # allow users to define the runner
    if cfg.get('runner', None):
        runner = build_runner(
            cfg.runner,
            dict(
                model=model,
                optimizer=optimizer,
                work_dir=cfg.work_dir,
                logger=logger,
                use_apex_amp=_use_apex_amp,
                meta=meta))
    else:
        runner = IterBasedRunner(
            model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta)
        # set if use dynamic ddp in training
        # is_dynamic_ddp=cfg.get('is_dynamic_ddp', False))
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)

    # In GANs, we can directly optimize parameter in `train_step` function.
    if cfg.get('optimizer_cfg', None) is None:
        optimizer_config = None
    elif fp16_cfg is not None:
        raise NotImplementedError('Fp16 has not been supported.')
        # optimizer_config = Fp16OptimizerHook(
        #     **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    # default to use OptimizerHook
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # update `out_dir` in  ckpt hook
    if cfg.checkpoint_config is not None:
        cfg.checkpoint_config['out_dir'] = os.path.join(
            cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt'))

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))

    # # DistSamplerSeedHook should be used with EpochBasedRunner
    # if distributed:
    #     runner.register_hook(DistSamplerSeedHook())

    # In general, we do NOT adopt standard evaluation hook in GAN training.
    # Thus, if you want a eval hook, you need further define the key of
    # 'evaluation' in the config.
    # register eval hooks
    if validate and cfg.get('evaluation', None) is not None:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        # Support batch_size > 1 in validation
        val_loader_cfg = {
            'samples_per_gpu': 1,
            'shuffle': False,
            'workers_per_gpu': cfg.data.workers_per_gpu,
            **cfg.data.get('val_data_loader', {})
        }
        val_dataloader = build_dataloader(
            val_dataset, dist=distributed, **val_loader_cfg)
        eval_cfg = deepcopy(cfg.get('evaluation'))
        priority = eval_cfg.pop('priority', 'LOW')
        eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
        eval_hook = build_from_cfg(eval_cfg, HOOKS)
        runner.register_hook(eval_hook, priority=priority)

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_iters)
Esempio n. 9
0
def single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                 basic_table_info, batch_size, **kwargs):
    """Evaluate model with a single gpu in online mode.

    This method evaluate model with a single gpu and displays eval progress
    bar. Different form `single_gpu_evaluation`, this function will not save
    the images or read images from disks. Namely, there do not exist any IO
    operations in this function. Thus, in general, `online` mode will achieve a
    faster evaluation. However, this mode will take much more memory cost.
    Therefore This evaluation function is recommended to evaluate your model
    with a single metric.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): PyTorch data loader.
        metrics (list): List of metric objects.
        logger (Logger): logger used to record results of evaluation.
        batch_size (int): Batch size of images fed into metrics.
        basic_table_info (dict): Dictionary containing the basic information \
            of the metric table include training configuration and ckpt.
        kwargs (dict): Other arguments.
    """
    # separate metrics into special metrics and vanilla metrics.
    # For vanilla metrics, images are generated in a random way, and are
    # shared by these metrics. For special metrics like 'PPL', images are
    # generated in a metric-special way and not shared between different
    # metrics.
    special_metrics = []
    vanilla_metrics = []
    special_metric_name = ['PPL']
    for metric in metrics:
        if metric.name in special_metric_name:
            special_metrics.append(metric)
        else:
            vanilla_metrics.append(metric)

    max_num_images = 0 if len(vanilla_metrics) == 0 else max(
        metric.num_images for metric in vanilla_metrics)
    for metric in vanilla_metrics:
        mmcv.print_log(f'Feed reals to {metric.name} metric.', 'mmgen')
        metric.prepare()
        pbar = mmcv.ProgressBar(metric.num_real_need)
        # feed in real images
        for data in data_loader:
            reals = data['real_img']
            num_feed = metric.feed(reals, 'reals')
            if num_feed <= 0:
                break

            pbar.update(num_feed)

        # finish the pbar stdout
        sys.stdout.write('\n')

    mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
                   'mmgen')
    # define mmcv progress bar
    max_num_images = 0 if len(vanilla_metrics) == 0 else max(
        metric.num_fake_need for metric in vanilla_metrics)
    pbar = mmcv.ProgressBar(max_num_images)
    # sampling fake images and directly send them to metrics
    for begin in range(0, max_num_images, batch_size):
        end = min(begin + batch_size, max_num_images)
        fakes = model(None,
                      num_batches=end - begin,
                      return_loss=False,
                      sample_model=basic_table_info['sample_model'],
                      **kwargs)
        pbar.update(end - begin)
        fakes = fakes[:end - begin]

        for metric in vanilla_metrics:
            # feed in fake images
            _ = metric.feed(fakes, 'fakes')

    # finish the pbar stdout
    sys.stdout.write('\n')

    # feed special metric
    for metric in special_metrics:
        metric.prepare()
        fakedata_iterator = iter(
            metric.get_sampler(model.module, batch_size,
                               basic_table_info['sample_model']))
        pbar = mmcv.ProgressBar(metric.num_images)
        for fakes in fakedata_iterator:
            num_left = metric.feed(fakes, 'fakes')
            pbar.update(fakes.shape[0])
            if num_left <= 0:
                break

        # finish the pbar stdout
        sys.stdout.write('\n')

    for metric in metrics:
        metric.summary()

    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
Esempio n. 10
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator
    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
    mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
    mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')

    generator.eval()

    if not args.use_cpu:
        generator = generator.cuda()

    # if given proj_latent, reset args.endpoint
    if args.proj_latent is not None:
        mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
        proj_file = torch.load(args.proj_latent)
        proj_n = len(proj_file)
        setattr(args, 'endpoint', proj_n)
        assert args.space == 'w', 'Projected latent are w or w-plus latent.'
        noise_batch = []
        for img_path in proj_file:
            noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
        noise_batch = torch.cat(noise_batch, dim=0).cuda()
        if args.use_cpu:
            noise_batch = noise_batch.to('cpu')

    if args.show_mode == 'sequence':
        assert args.endpoint >= 2
    else:
        assert args.endpoint >= 2 and args.endpoint % 2 == 0,\
            '''We need paired images in group mode,
            so keep endpoint an even number'''

    kwargs = dict(max_batch_size=args.batch_size)
    if args.sample_cfg is None:
        args.sample_cfg = dict()
    kwargs.update(args.sample_cfg)
    # remind users to fixed injected noise
    if kwargs.get('randomize_noise', 'True'):
        mmcv.print_log(
            '''Hint: For Style-Based GAN, you can add
            `--sample-cfg randomize_noise=False` to fix injected noises''',
            'mmgen')

    # get noises corresponding to each endpoint
    if not args.proj_latent:
        noise_batch = batch_inference(
            generator,
            None,
            num_batches=args.endpoint,
            dict_key='noise_batch' if args.space == 'z' else 'latent',
            return_noise=True,
            **kwargs)

    if args.space == 'w':
        kwargs['truncation_latent'] = generator.get_mean_latent()
        kwargs['input_is_latent'] = True

    if args.show_mode == 'sequence':
        results = sample_from_path(generator, noise_batch[:-1, ],
                                   noise_batch[1:, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    else:
        results = sample_from_path(generator, noise_batch[::2, ],
                                   noise_batch[1::2, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    # reorder results
    results = torch.stack(results).permute(1, 0, 2, 3, 4)
    _, _, ch, h, w = results.shape
    results = results.reshape(-1, ch, h, w)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)
    # save image
    mmcv.mkdir_or_exist(args.samples_path)
    if args.show_mode == 'sequence':
        if args.export_video:
            # render video.
            video_out = imageio.get_writer(os.path.join(
                args.samples_path, 'lerp.mp4'),
                                           mode='I',
                                           fps=60,
                                           codec='libx264',
                                           bitrate='12M')
            video_out = layout_grid(video_out, results)
            video_out.close()
        else:
            for i in range(results.shape[0]):
                image = results[i:i + 1]
                save_image(
                    image,
                    os.path.join(args.samples_path,
                                 '{:0>5d}'.format(i) + '.png'))
    else:
        if args.export_video:
            # render video.
            video_out = imageio.get_writer(os.path.join(
                args.samples_path, 'lerp.mp4'),
                                           mode='I',
                                           fps=60,
                                           codec='libx264',
                                           bitrate='12M')
            n_pair = args.endpoint // 2
            grid_w, grid_h = crack_integer(n_pair)
            video_out = layout_grid(video_out,
                                    results,
                                    grid_h=grid_h,
                                    grid_w=grid_w)
            video_out.close()
        else:
            save_image(results,
                       os.path.join(args.samples_path, 'group.png'),
                       nrow=args.interval)
Esempio n. 11
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    generator.eval()
    device = 'cpu'
    if not args.use_cpu:
        generator = generator.cuda()
        device = 'cuda'

    img_size = min(generator.out_size, 256)
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    # read images
    imgs = []
    for imgfile in args.files:
        img = Image.open(imgfile).convert('RGB')
        img = transform(img)
        img = img[[2, 1, 0], ...]
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    # get mean and standard deviation of style latents
    with torch.no_grad():
        noise_sample = torch.randn(args.n_mean_latent,
                                   generator.style_channels,
                                   device=device)
        latent_out = generator.style_mapping(noise_sample)
        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() /
                      args.n_mean_latent)**0.5
    latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(
        imgs.shape[0], 1)
    if args.w_plus:
        latent_in = latent_in.unsqueeze(1).repeat(1, generator.num_latents, 1)
    latent_in.requires_grad = True

    # define lpips loss
    percept = PerceptualLoss(use_gpu=device.startswith('cuda'))

    # initialize layer noises
    noises_single = generator.make_injected_noise()
    noises = []
    for noise in noises_single:
        noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
    for noise in noises:
        noise.requires_grad = True

    optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
    pbar = tqdm(range(args.total_iters))
    # run optimization
    for i in pbar:
        t = i / args.total_iters
        lr = get_lr(t, args.lr, args.lr_rampdown, args.lr_rampup)
        optimizer.param_groups[0]['lr'] = lr
        noise_strength = latent_std * args.noise * max(
            0, 1 - t / args.noise_ramp)**2
        latent_n = latent_noise(latent_in, noise_strength.item())

        img_gen = generator([latent_n],
                            input_is_latent=True,
                            injected_noise=noises)

        batch, channel, height, width = img_gen.shape

        if height > 256:
            factor = height // 256

            img_gen = img_gen.reshape(batch, channel, height // factor, factor,
                                      width // factor, factor)
            img_gen = img_gen.mean([3, 5])

        p_loss = percept(img_gen, imgs).sum()
        n_loss = noise_regularize(noises)
        mse_loss = F.mse_loss(img_gen, imgs)

        loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        noise_normalize_(noises)

        pbar.set_description(
            f' perceptual: {p_loss.item():.4f}, noise regularize:'
            f'{n_loss.item():.4f}, mse: {mse_loss.item():.4f}, lr: {lr:.4f}')

    results = generator([latent_in.detach().clone()],
                        input_is_latent=True,
                        injected_noise=noises)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)

    mmcv.mkdir_or_exist(args.results_path)
    # save projection results
    result_file = OrderedDict()
    for i, input_name in enumerate(args.files):
        noise_single = []
        for noise in noises:
            noise_single.append(noise[i:i + 1])
        result_file[input_name] = {
            'img': img_gen[i],
            'latent': latent_in[i],
            'injected_noise': noise_single,
        }
        img_name = os.path.splitext(
            os.path.basename(input_name))[0] + '-project.png'
        save_image(results[i], os.path.join(args.results_path, img_name))

    torch.save(result_file, os.path.join(args.results_path,
                                         'project_result.pt'))
Esempio n. 12
0
def test_print_log_exception():
    with pytest.raises(TypeError):
        print_log('welcome', logger=0)
Esempio n. 13
0
def test_print_log_silent(capsys, caplog):
    print_log('welcome', logger='silent')
    out, _ = capsys.readouterr()
    assert out == ''
    assert len(caplog.records) == 0
Esempio n. 14
0
def test_print_log_print(capsys):
    print_log('welcome', logger=None)
    out, _ = capsys.readouterr()
    assert out == 'welcome\n'
Esempio n. 15
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    rank, _ = get_dist_info()

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()
    if not distributed:
        _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
        model = MMDataParallel(model, device_ids=[0])

        # build metrics
        if args.eval:
            if args.eval[0] == 'none':
                # only sample images
                metrics = []
                assert args.num_samples is not None and args.num_samples > 0
            else:
                metrics = [
                    build_metric(cfg.metrics[metric]) for metric in args.eval
                ]
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in cfg.metrics
            ]

        basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                                ckpt=ckpt,
                                sample_model=args.sample_model)

        if len(metrics) == 0:
            basic_table_info['num_samples'] = args.num_samples
            data_loader = None
        else:
            basic_table_info['num_samples'] = -1
            # build the dataloader
            if cfg.data.get('test', None) and cfg.data.test.get(
                    'imgs_root', None):
                dataset = build_dataset(cfg.data.test)
            elif cfg.data.get('val', None) and cfg.data.val.get(
                    'imgs_root', None):
                dataset = build_dataset(cfg.data.val)
            elif cfg.data.get('train', None):
                # we assume that the train part should work well
                dataset = build_dataset(cfg.data.train)
            else:
                raise RuntimeError('There is no valid dataset config to run, '
                                   'please check your dataset configs.')
            data_loader = build_dataloader(dataset,
                                           samples_per_gpu=args.batch_size,
                                           workers_per_gpu=cfg.data.get(
                                               'val_workers_per_gpu',
                                               cfg.data.workers_per_gpu),
                                           dist=distributed,
                                           shuffle=True)

        if args.sample_cfg is None:
            args.sample_cfg = dict()

        # online mode will not save samples
        if args.online and len(metrics) > 0:
            single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                         basic_table_info, args.batch_size,
                                         **args.sample_cfg)
        else:
            single_gpu_evaluation(model, data_loader, metrics, logger,
                                  basic_table_info, args.batch_size,
                                  args.samples_path, **args.sample_cfg)
    else:
        raise NotImplementedError("We hasn't implemented multi gpu eval yet.")
Esempio n. 16
0
    def after_train_iter(self, runner):
        """The behavior after each train iteration.

        Args:
            runner (``mmcv.runner.BaseRunner``): The runner.
        """
        if not self.every_n_iters(runner, self.interval):
            return

        runner.model.eval()

        # sample fake images
        max_num_images = max(metric.num_images for metric in self.metrics)
        for metric in self.metrics:
            if metric.num_real_feeded >= metric.num_real_need:
                continue
            mmcv.print_log(f'Feed reals to {metric.name} metric.', 'mmgen')
            # feed in real images
            for data in self.dataloader:
                reals = data['real_img']
                num_feed = metric.feed(reals, 'reals')
                if num_feed <= 0:
                    break

        mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
                       'mmgen')
        batch_size = self.dataloader.batch_size

        rank, ws = get_dist_info()
        total_batch_size = batch_size * ws

        # define mmcv progress bar
        if rank == 0:
            pbar = mmcv.ProgressBar(max_num_images)

        # sampling fake images and directly send them to metrics
        for _ in range(0, max_num_images, total_batch_size):

            with torch.no_grad():
                fakes = runner.model(
                    None,
                    num_batches=batch_size,
                    return_loss=False,
                    **self.sample_kwargs)

                for metric in self.metrics:
                    # feed in fake images
                    num_left = metric.feed(fakes, 'fakes')
                    if num_left <= 0:
                        break

            if rank == 0:
                pbar.update(total_batch_size)

        runner.log_buffer.clear()
        # a dirty walkround to change the line at the end of pbar
        if rank == 0:
            sys.stdout.write('\n')
            for metric in self.metrics:
                with torch.no_grad():
                    metric.summary()
                for name, val in metric._result_dict.items():
                    runner.log_buffer.output[name] = val

                    # record best metric and save the best ckpt
                    if self.save_best_ckpt and name == self.best_metric:
                        self._save_best_ckpt(runner, val)

            runner.log_buffer.ready = True
        runner.model.train()

        # clear all current states for next evaluation
        for metric in self.metrics:
            metric.clear()