Ejemplo n.º 1
0
def single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                 basic_table_info, batch_size, **kwargs):
    """Evaluate model with a single gpu in online mode.

    This method evaluate model with a single gpu and displays eval progress
    bar. Different form `single_gpu_evaluation`, this function will not save
    the images or read images from disks. Namely, there do not exist any IO
    operations in this function. Thus, in general, `online` mode will achieve a
    faster evaluation. However, this mode will take much more memory cost.
    Therefore this evaluation function is recommended to evaluate your model
    with a single metric.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): PyTorch data loader.
        metrics (list): List of metric objects.
        logger (Logger): logger used to record results of evaluation.
        basic_table_info (dict): Dictionary containing the basic information \
            of the metric table include training configuration and ckpt.
        batch_size (int): Batch size of images fed into metrics.
        kwargs (dict): Other arguments.
    """
    # sample images
    max_num_images = 0 if len(metrics) == 0 else max(metric.num_fake_need
                                                     for metric in metrics)
    pbar = mmcv.ProgressBar(max_num_images)

    # select key to fetch images
    target_domain = basic_table_info['target_domain']
    source_domain = basic_table_info['source_domain']

    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()

    # feed reals and fakes
    data_loader_iter = iter(data_loader)
    for begin in range(0, max_num_images, batch_size):
        end = min(begin + batch_size, max_num_images)
        # for translation model, we feed them images from dataloader
        data_batch = next(data_loader_iter)
        output_dict = model(
            data_batch[f'img_{source_domain}'],
            test_mode=True,
            target_domain=target_domain)
        fakes = output_dict['target']
        reals = data_batch[f'img_{target_domain}']
        pbar.update(end - begin)
        for metric in metrics:
            metric.feed(reals, 'reals')
            metric.feed(fakes, 'fakes')

    for metric in metrics:
        metric.summary()

    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
Ejemplo n.º 2
0
def single_gpu_evaluation(model,
                          data_loader,
                          metrics,
                          logger,
                          basic_table_info,
                          batch_size,
                          samples_path=None,
                          **kwargs):
    """Evaluate model with a single gpu.

    This method evaluate model with a single gpu and displays eval progress
        bar.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): PyTorch data loader.
        metrics (list): List of metric objects.
        logger (Logger): logger used to record results of evaluation.
        basic_table_info (dict): Dictionary containing the basic information \
            of the metric table include training configuration and ckpt.
        batch_size (int): Batch size of images fed into metrics.
        samples_path (str): Used to save generated images. If it's none, we'll
            give it a default directory and delete it after finishing the
            evaluation. Default to None.
        kwargs (dict): Other arguments.
    """
    # decide samples path
    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(
                samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)
    # select key to fetch fake images
    target_domain = basic_table_info['target_domain']
    source_domain = basic_table_info['source_domain']
    # if no images, `num_exist` should be zero
    data_loader_iter = iter(data_loader)
    for begin in range(num_exist, num_needed, batch_size):
        end = min(begin + batch_size, max_num_images)
        # for translation model, we feed them images from dataloader
        data_batch = next(data_loader_iter)
        output_dict = model(
            data_batch[f'img_{source_domain}'],
            test_mode=True,
            target_domain=target_domain)
        fakes = output_dict['target']
        pbar.update(end - begin)
        for i in range(end - begin):
            images = fakes[i:i + 1]
            images = ((images + 1) / 2)
            images = images[:, [2, 1, 0], ...]
            images = images.clamp_(0, 1)
            image_name = str(begin + i) + '.png'
            save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(samples_path, batch_size)

    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        # feed in real images
        for data in data_loader:
            reals = data[f'img_{target_domain}']
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
        metric.summary()
    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
    if delete_samples_path:
        shutil.rmtree(samples_path)
Ejemplo n.º 3
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    assert isinstance(model, _supported_model)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(dataset,
                                       samples_per_gpu=args.batch_size,
                                       workers_per_gpu=cfg.data.get(
                                           'val_workers_per_gpu',
                                           cfg.data.workers_per_gpu),
                                       dist=False,
                                       shuffle=True)

    # decide samples path
    samples_path = args.samples_path
    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(samples_path,
                         suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)
    # select key to fetch fake images
    fake_key = 'fake_b'
    if isinstance(model.module, CycleGAN):
        fake_key = 'fake_b' if model.module.test_direction == 'a2b' else \
            'fake_a'
    # if no images, `num_exist` should be zero
    for begin in range(num_exist, num_needed, args.batch_size):
        end = min(begin + args.batch_size, max_num_images)
        # for translation model, we feed them images from dataloader
        data_loader_iter = iter(data_loader)
        data_batch = next(data_loader_iter)
        output_dict = model(test_mode=True, **data_batch)
        fakes = output_dict[fake_key]
        pbar.update(end - begin)
        for i in range(end - begin):
            images = fakes[i:i + 1]
            images = ((images + 1) / 2)
            images = images[:, [2, 1, 0], ...]
            images = images.clamp_(0, 1)
            image_name = str(begin + i) + '.png'
            save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(samples_path, args.batch_size)
    # select key to fetch real images
    if isinstance(model.module, CycleGAN):
        real_key = 'img_b' if model.module.test_direction == 'a2b' else 'img_a'
        if model.module.direction == 'b2a':
            real_key = 'img_a' if real_key == 'img_b' else 'img_b'

    if isinstance(model.module, Pix2Pix):
        real_key = 'img_b' if model.module.direction == 'a2b' else 'img_a'

    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        # feed in real images
        for data in data_loader:
            reals = data[real_key]
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
        metric.summary()
    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
    if delete_samples_path:
        shutil.rmtree(samples_path)