Ejemplo n.º 1
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    rank, _ = get_dist_info()

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()
    if not distributed:
        _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
        model = MMDataParallel(model, device_ids=[0])

        # build metrics
        if args.eval:
            if args.eval[0] == 'none':
                # only sample images
                metrics = []
                assert args.num_samples is not None and args.num_samples > 0
            else:
                metrics = [
                    build_metric(cfg.metrics[metric]) for metric in args.eval
                ]
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in cfg.metrics
            ]

        basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                                ckpt=ckpt,
                                sample_model=args.sample_model)

        if len(metrics) == 0:
            basic_table_info['num_samples'] = args.num_samples
            data_loader = None
        else:
            basic_table_info['num_samples'] = -1
            # build the dataloader
            if cfg.data.get('test', None) and cfg.data.test.get(
                    'imgs_root', None):
                dataset = build_dataset(cfg.data.test)
            elif cfg.data.get('val', None) and cfg.data.val.get(
                    'imgs_root', None):
                dataset = build_dataset(cfg.data.val)
            elif cfg.data.get('train', None):
                # we assume that the train part should work well
                dataset = build_dataset(cfg.data.train)
            else:
                raise RuntimeError('There is no valid dataset config to run, '
                                   'please check your dataset configs.')
            data_loader = build_dataloader(dataset,
                                           samples_per_gpu=args.batch_size,
                                           workers_per_gpu=cfg.data.get(
                                               'val_workers_per_gpu',
                                               cfg.data.workers_per_gpu),
                                           dist=distributed,
                                           shuffle=True)

        if args.sample_cfg is None:
            args.sample_cfg = dict()

        # online mode will not save samples
        if args.online and len(metrics) > 0:
            single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                         basic_table_info, args.batch_size,
                                         **args.sample_cfg)
        else:
            single_gpu_evaluation(model, data_loader, metrics, logger,
                                  basic_table_info, args.batch_size,
                                  args.samples_path, **args.sample_cfg)
    else:
        raise NotImplementedError("We hasn't implemented multi gpu eval yet.")
Ejemplo n.º 2
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(
        log_file=log_path, log_level=cfg.log_level, file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    assert isinstance(model, BaseTranslationModel)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    # get source domain and target domain
    target_domain = args.target_domain
    if target_domain is None:
        target_domain = model.module._default_domain
    source_domain = model.module.get_other_domains(target_domain)[0]

    basic_table_info = dict(
        train_cfg=os.path.basename(cfg._filename),
        ckpt=ckpt,
        sample_model=args.sample_model,
        source_domain=source_domain,
        target_domain=target_domain)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(
            dataset,
            samples_per_gpu=args.batch_size,
            workers_per_gpu=cfg.data.get('val_workers_per_gpu',
                                         cfg.data.workers_per_gpu),
            dist=False,
            shuffle=True)

    if args.online:
        single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                     basic_table_info, args.batch_size)
    else:
        single_gpu_evaluation(model, data_loader, metrics, logger,
                              basic_table_info, args.batch_size,
                              args.samples_path)
Ejemplo n.º 3
0
    parser.add_argument('--num-samples', type=int, default=2)
    parser.add_argument('--sample-path', type=str, default=None)
    parser.add_argument('--random-seed', type=int, default=2020)

    args = parser.parse_args()

    set_random_seed(args.random_seed)
    cfg = mmcv.Config.fromfile(args.config)

    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    mmcv.print_log('Building models and loading checkpoints', 'mmgen')
    # build model
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    model.eval()
    load_checkpoint(model, args.ckpt, map_location='cpu')

    # get generator
    if model.use_ema:
        generator = model.generator_ema
    else:
        generator = model.generator

    generator = generator.to(device)
    generator.eval()

    mmcv.print_log('Calculating or loading eigen vectors', 'mmgen')
Ejemplo n.º 4
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    generator.eval()
    device = 'cpu'
    if not args.use_cpu:
        generator = generator.cuda()
        device = 'cuda'

    img_size = min(generator.out_size, 256)
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    # read images
    imgs = []
    for imgfile in args.files:
        img = Image.open(imgfile).convert('RGB')
        img = transform(img)
        img = img[[2, 1, 0], ...]
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    # get mean and standard deviation of style latents
    with torch.no_grad():
        noise_sample = torch.randn(args.n_mean_latent,
                                   generator.style_channels,
                                   device=device)
        latent_out = generator.style_mapping(noise_sample)
        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() /
                      args.n_mean_latent)**0.5
    latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(
        imgs.shape[0], 1)
    if args.w_plus:
        latent_in = latent_in.unsqueeze(1).repeat(1, generator.num_latents, 1)
    latent_in.requires_grad = True

    # define lpips loss
    percept = PerceptualLoss(use_gpu=device.startswith('cuda'))

    # initialize layer noises
    noises_single = generator.make_injected_noise()
    noises = []
    for noise in noises_single:
        noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
    for noise in noises:
        noise.requires_grad = True

    optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
    pbar = tqdm(range(args.total_iters))
    # run optimization
    for i in pbar:
        t = i / args.total_iters
        lr = get_lr(t, args.lr, args.lr_rampdown, args.lr_rampup)
        optimizer.param_groups[0]['lr'] = lr
        noise_strength = latent_std * args.noise * max(
            0, 1 - t / args.noise_ramp)**2
        latent_n = latent_noise(latent_in, noise_strength.item())

        img_gen = generator([latent_n],
                            input_is_latent=True,
                            injected_noise=noises)

        batch, channel, height, width = img_gen.shape

        if height > 256:
            factor = height // 256

            img_gen = img_gen.reshape(batch, channel, height // factor, factor,
                                      width // factor, factor)
            img_gen = img_gen.mean([3, 5])

        p_loss = percept(img_gen, imgs).sum()
        n_loss = noise_regularize(noises)
        mse_loss = F.mse_loss(img_gen, imgs)

        loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        noise_normalize_(noises)

        pbar.set_description(
            f' perceptual: {p_loss.item():.4f}, noise regularize:'
            f'{n_loss.item():.4f}, mse: {mse_loss.item():.4f}, lr: {lr:.4f}')

    results = generator([latent_in.detach().clone()],
                        input_is_latent=True,
                        injected_noise=noises)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)

    mmcv.mkdir_or_exist(args.results_path)
    # save projection results
    result_file = {}
    for i, input_name in enumerate(args.files):
        noise_single = []
        for noise in noises:
            noise_single.append(noise[i:i + 1])
        result_file[input_name] = {
            'img': img_gen[i],
            'latent': latent_in[i],
            'injected_noise': noise_single,
        }
        img_name = os.path.splitext(
            os.path.basename(input_name))[0] + '-project.png'
        save_image(results[i], os.path.join(args.results_path, img_name))

    torch.save(result_file, os.path.join(args.results_path,
                                         'project_result.pt'))
Ejemplo n.º 5
0
def test_pix2pix():

    model_cfg = dict(type='Pix2Pix',
                     generator=dict(type='UnetGenerator',
                                    in_channels=3,
                                    out_channels=3,
                                    num_down=8,
                                    base_channels=64,
                                    norm_cfg=dict(type='BN'),
                                    use_dropout=True,
                                    init_cfg=dict(type='normal', gain=0.02)),
                     discriminator=dict(type='PatchDiscriminator',
                                        in_channels=6,
                                        base_channels=64,
                                        num_conv=3,
                                        norm_cfg=dict(type='BN'),
                                        init_cfg=dict(type='normal',
                                                      gain=0.02)),
                     gan_loss=dict(type='GANLoss',
                                   gan_type='vanilla',
                                   real_label_val=1.0,
                                   fake_label_val=0,
                                   loss_weight=1.0),
                     pixel_loss=dict(type='L1Loss',
                                     loss_weight=100.0,
                                     reduction='mean'))

    train_cfg = None
    test_cfg = None

    # build synthesizer
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)

    # test checking gan loss cannot be None
    with pytest.raises(AssertionError):
        bad_model_cfg = copy.deepcopy(model_cfg)
        bad_model_cfg['gan_loss'] = None
        _ = build_model(bad_model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)

    # test attributes
    assert synthesizer.__class__.__name__ == 'Pix2Pix'
    assert isinstance(synthesizer.generator, UnetGenerator)
    assert isinstance(synthesizer.discriminator, PatchDiscriminator)
    assert isinstance(synthesizer.gan_loss, GANLoss)
    assert isinstance(synthesizer.pixel_loss, L1Loss)
    assert synthesizer.train_cfg is None
    assert synthesizer.test_cfg is None

    # prepare data
    inputs = torch.rand(1, 3, 256, 256)
    targets = torch.rand(1, 3, 256, 256)
    data_batch = {'img_a': inputs, 'img_b': targets}
    img_meta = {}
    img_meta['img_a_path'] = 'img_a_path'
    img_meta['img_b_path'] = 'img_b_path'
    data_batch['meta'] = [img_meta]

    # prepare optimizer
    optim_cfg = dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))
    optimizer = {
        'generator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generator').parameters())),
        'discriminator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminator').parameters()))
    }

    # test forward_dummy
    with torch.no_grad():
        output = synthesizer.forward_dummy(data_batch['img_a'])
    assert torch.is_tensor(output)
    assert output.size() == (1, 3, 256, 256)

    # test forward_test
    with torch.no_grad():
        outputs = synthesizer(inputs, targets, [img_meta], test_mode=True)
    assert torch.equal(outputs['real_a'], data_batch['img_a'])
    assert torch.equal(outputs['real_b'], data_batch['img_b'])
    assert torch.is_tensor(outputs['fake_b'])
    assert outputs['fake_b'].size() == (1, 3, 256, 256)

    # val_step
    with torch.no_grad():
        outputs = synthesizer.val_step(data_batch)
    assert torch.equal(outputs['real_a'], data_batch['img_a'])
    assert torch.equal(outputs['real_b'], data_batch['img_b'])
    assert torch.is_tensor(outputs['fake_b'])
    assert outputs['fake_b'].size() == (1, 3, 256, 256)

    # test forward_train
    outputs = synthesizer(inputs, targets, [img_meta], test_mode=False)
    assert torch.equal(outputs['real_a'], data_batch['img_a'])
    assert torch.equal(outputs['real_b'], data_batch['img_b'])
    assert torch.is_tensor(outputs['fake_b'])
    assert outputs['fake_b'].size() == (1, 3, 256, 256)

    # test train_step
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    for v in [
            'loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_pixel'
    ]:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1
    assert torch.equal(outputs['results']['real_a'], data_batch['img_a'])
    assert torch.equal(outputs['results']['real_b'], data_batch['img_b'])
    assert torch.is_tensor(outputs['results']['fake_b'])
    assert outputs['results']['fake_b'].size() == (1, 3, 256, 256)

    # test train_step and forward_test (gpu)
    if torch.cuda.is_available():
        synthesizer = synthesizer.cuda()
        optimizer = {
            'generator':
            obj_from_dict(
                optim_cfg, torch.optim,
                dict(params=getattr(synthesizer, 'generator').parameters())),
            'discriminator':
            obj_from_dict(
                optim_cfg, torch.optim,
                dict(
                    params=getattr(synthesizer, 'discriminator').parameters()))
        }
        data_batch_cuda = copy.deepcopy(data_batch)
        data_batch_cuda['img_a'] = inputs.cuda()
        data_batch_cuda['img_b'] = targets.cuda()
        data_batch_cuda['meta'] = [DC(img_meta, cpu_only=True).data]

        # forward_test
        with torch.no_grad():
            outputs = synthesizer(data_batch_cuda['img_a'],
                                  data_batch_cuda['img_b'],
                                  data_batch_cuda['meta'],
                                  test_mode=True)
        assert torch.equal(outputs['real_a'], data_batch_cuda['img_a'].cpu())
        assert torch.equal(outputs['real_b'], data_batch_cuda['img_b'].cpu())
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)

        # val_step
        with torch.no_grad():
            outputs = synthesizer.val_step(data_batch_cuda)
        assert torch.equal(outputs['real_a'], data_batch_cuda['img_a'].cpu())
        assert torch.equal(outputs['real_b'], data_batch_cuda['img_b'].cpu())
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)

        # test forward_train
        outputs = synthesizer(data_batch_cuda['img_a'],
                              data_batch_cuda['img_b'],
                              data_batch_cuda['meta'],
                              test_mode=False)
        assert torch.equal(outputs['real_a'], data_batch_cuda['img_a'])
        assert torch.equal(outputs['real_b'], data_batch_cuda['img_b'])
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)

        # train_step
        outputs = synthesizer.train_step(data_batch_cuda, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        for v in [
                'loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g',
                'loss_pixel'
        ]:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_a'],
                           data_batch_cuda['img_a'].cpu())
        assert torch.equal(outputs['results']['real_b'],
                           data_batch_cuda['img_b'].cpu())
        assert torch.is_tensor(outputs['results']['fake_b'])
        assert outputs['results']['fake_b'].size() == (1, 3, 256, 256)

    # test disc_steps and disc_init_steps
    data_batch['img_a'] = inputs.cpu()
    data_batch['img_b'] = targets.cpu()
    train_cfg = dict(disc_steps=2, disc_init_steps=2)
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)
    optimizer = {
        'generator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generator').parameters())),
        'discriminator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminator').parameters()))
    }

    # iter 0, 1
    for i in range(2):
        assert synthesizer.step_counter == i
        outputs = synthesizer.train_step(data_batch, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        assert outputs['log_vars'].get('loss_gan_g') is None
        assert outputs['log_vars'].get('loss_pixel') is None
        for v in ['loss_gan_d_fake', 'loss_gan_d_real']:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_a'], data_batch['img_a'])
        assert torch.equal(outputs['results']['real_b'], data_batch['img_b'])
        assert torch.is_tensor(outputs['results']['fake_b'])
        assert outputs['results']['fake_b'].size() == (1, 3, 256, 256)
        assert synthesizer.step_counter == i + 1

    # iter 2, 3, 4, 5
    for i in range(2, 6):
        assert synthesizer.step_counter == i
        outputs = synthesizer.train_step(data_batch, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        log_check_list = [
            'loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_pixel'
        ]
        if i % 2 == 1:
            assert outputs['log_vars'].get('loss_gan_g') is None
            assert outputs['log_vars'].get('loss_pixel') is None
            log_check_list.remove('loss_gan_g')
            log_check_list.remove('loss_pixel')
        for v in log_check_list:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_a'], data_batch['img_a'])
        assert torch.equal(outputs['results']['real_b'], data_batch['img_b'])
        assert torch.is_tensor(outputs['results']['fake_b'])
        assert outputs['results']['fake_b'].size() == (1, 3, 256, 256)
        assert synthesizer.step_counter == i + 1

    # test without pixel loss
    model_cfg_ = copy.deepcopy(model_cfg)
    model_cfg_.pop('pixel_loss')
    synthesizer = build_model(model_cfg_, train_cfg=None, test_cfg=None)
    optimizer = {
        'generator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generator').parameters())),
        'discriminator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminator').parameters()))
    }
    data_batch['img_a'] = inputs.cpu()
    data_batch['img_b'] = targets.cpu()
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    assert outputs['log_vars'].get('loss_pixel') is None
    for v in ['loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g']:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1
    assert torch.equal(outputs['results']['real_a'], data_batch['img_a'])
    assert torch.equal(outputs['results']['real_b'], data_batch['img_b'])
    assert torch.is_tensor(outputs['results']['fake_b'])
    assert outputs['results']['fake_b'].size() == (1, 3, 256, 256)

    # test b2a translation
    data_batch['img_a'] = inputs.cpu()
    data_batch['img_b'] = targets.cpu()
    train_cfg = dict(direction='b2a')
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)
    optimizer = {
        'generator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generator').parameters())),
        'discriminator':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminator').parameters()))
    }
    assert synthesizer.step_counter == 0
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    for v in [
            'loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_pixel'
    ]:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1
    assert torch.equal(outputs['results']['real_a'], data_batch['img_b'])
    assert torch.equal(outputs['results']['real_b'], data_batch['img_a'])
    assert torch.is_tensor(outputs['results']['fake_b'])
    assert outputs['results']['fake_b'].size() == (1, 3, 256, 256)
    assert synthesizer.step_counter == 1

    # test save image
    # show input
    train_cfg = None
    test_cfg = dict(show_input=True)
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)
    with patch.object(mmcv, 'imwrite', return_value=True):
        # test save path not None Assertion
        with pytest.raises(AssertionError):
            with torch.no_grad():
                _ = synthesizer(inputs,
                                targets, [img_meta],
                                test_mode=True,
                                save_image=True)
        # iteration is None
        with torch.no_grad():
            outputs = synthesizer(inputs,
                                  targets, [img_meta],
                                  test_mode=True,
                                  save_image=True,
                                  save_path='save_path')
        assert torch.equal(outputs['real_a'], data_batch['img_a'])
        assert torch.equal(outputs['real_b'], data_batch['img_b'])
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)
        assert outputs['saved_flag']
        # iteration is not None
        with torch.no_grad():
            outputs = synthesizer(inputs,
                                  targets, [img_meta],
                                  test_mode=True,
                                  save_image=True,
                                  save_path='save_path',
                                  iteration=1000)
        assert torch.equal(outputs['real_a'], data_batch['img_a'])
        assert torch.equal(outputs['real_b'], data_batch['img_b'])
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)
        assert outputs['saved_flag']

    # not show input
    train_cfg = None
    test_cfg = dict(show_input=False)
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)
    with patch.object(mmcv, 'imwrite', return_value=True):
        # test save path not None Assertion
        with pytest.raises(AssertionError):
            with torch.no_grad():
                _ = synthesizer(inputs,
                                targets, [img_meta],
                                test_mode=True,
                                save_image=True)
        # iteration is None
        with torch.no_grad():
            outputs = synthesizer(inputs,
                                  targets, [img_meta],
                                  test_mode=True,
                                  save_image=True,
                                  save_path='save_path')
        assert torch.equal(outputs['real_a'], data_batch['img_a'])
        assert torch.equal(outputs['real_b'], data_batch['img_b'])
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)
        assert outputs['saved_flag']
        # iteration is not None
        with torch.no_grad():
            outputs = synthesizer(inputs,
                                  targets, [img_meta],
                                  test_mode=True,
                                  save_image=True,
                                  save_path='save_path',
                                  iteration=1000)
        assert torch.equal(outputs['real_a'], data_batch['img_a'])
        assert torch.equal(outputs['real_b'], data_batch['img_b'])
        assert torch.is_tensor(outputs['fake_b'])
        assert outputs['fake_b'].size() == (1, 3, 256, 256)
        assert outputs['saved_flag']
Ejemplo n.º 6
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
        rank = 0
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        rank, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)
        assert args.online or world_size == 1, (
            'We only support online mode for distrbuted evaluation.')

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    # check metrics for dist evaluation
    if distributed and metrics:
        for metric in metrics:
            assert metric.name in _distributed_metrics, (
                f'We only support {_distributed_metrics} for multi gpu '
                f'evaluation, but receive {args.eval}.')

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        # build the dataloader
        if cfg.data.get('test', None) and cfg.data.test.get('imgs_root', None):
            dataset = build_dataset(cfg.data.test)
        elif cfg.data.get('val', None) and cfg.data.val.get('imgs_root', None):
            dataset = build_dataset(cfg.data.val)
        elif cfg.data.get('train', None):
            # we assume that the train part should work well
            dataset = build_dataset(cfg.data.train)
        else:
            raise RuntimeError('There is no valid dataset config to run, '
                               'please check your dataset configs.')

        # The default loader config
        loader_cfg = dict(samples_per_gpu=args.batch_size,
                          workers_per_gpu=cfg.data.get(
                              'val_workers_per_gpu', cfg.data.workers_per_gpu),
                          num_gpus=len(cfg.gpu_ids),
                          dist=distributed,
                          shuffle=True)
        # The overall dataloader settings
        loader_cfg.update({
            k: v
            for k, v in cfg.data.items() if k not in [
                'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
                'test_dataloader'
            ]
        })

        # specific config for test loader
        test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})}

        data_loader = build_dataloader(dataset, **test_loader_cfg)
    if args.sample_cfg is None:
        args.sample_cfg = dict()

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
    else:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)

    # online mode will not save samples
    if args.online and len(metrics) > 0:
        online_evaluation(model, data_loader, metrics, logger,
                          basic_table_info, args.batch_size, **args.sample_cfg)
    else:
        offline_evaluation(model, data_loader, metrics, logger,
                           basic_table_info, args.batch_size,
                           args.samples_path, **args.sample_cfg)
Ejemplo n.º 7
0
    def test_default_dcgan_model_cpu(self):
        sngan = build_model(self.default_config)
        assert isinstance(sngan, BasicConditionalGAN)
        assert not sngan.with_disc_auxiliary_loss
        assert sngan.with_disc

        # test forward train
        with pytest.raises(NotImplementedError):
            _ = sngan(None, return_loss=True)
        # test forward test
        imgs = sngan(None, return_loss=False, mode='sampling', num_batches=2)
        assert imgs.shape == (2, 3, 32, 32)

        # test train step
        data = torch.randn((2, 3, 32, 32))
        label = torch.randint(0, 10, (2, ))
        data_input = dict(img=data, gt_label=label)
        optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01)
        optimizer_d = torch.optim.SGD(sngan.discriminator.parameters(),
                                      lr=0.01)
        optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

        model_outputs = sngan.train_step(data_input, optim_dict)
        assert 'results' in model_outputs
        assert 'log_vars' in model_outputs
        assert model_outputs['num_samples'] == 2

        # more tests for different configs with heavy computation
        # test disc_steps
        config_ = deepcopy(self.default_config)
        config_['train_cfg'] = dict(disc_steps=2)
        sngan = build_model(config_)
        model_outputs = sngan.train_step(data_input, optim_dict)
        assert 'loss_disc_fake' in model_outputs['log_vars']
        assert 'loss_disc_fake_g' not in model_outputs['log_vars']
        assert sngan.disc_steps == 2

        model_outputs = sngan.train_step(data_input,
                                         optim_dict,
                                         running_status=dict(iteration=1))
        assert 'loss_disc_fake' in model_outputs['log_vars']
        assert 'loss_disc_fake_g' in model_outputs['log_vars']

        # test customized config
        sagan = BasicConditionalGAN(
            self.generator_cfg,
            self.disc_cfg,
            self.gan_loss,
            self.disc_auxiliary_loss,
        )
        # test train step
        data = torch.randn((2, 3, 32, 32))
        data_input = dict(img=data, gt_label=label)
        optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01)
        optimizer_d = torch.optim.SGD(sngan.discriminator.parameters(),
                                      lr=0.01)
        optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

        model_outputs = sagan.train_step(data_input, optim_dict)
        assert 'results' in model_outputs
        assert 'log_vars' in model_outputs
        assert model_outputs['num_samples'] == 2

        sagan = BasicConditionalGAN(
            self.generator_cfg, self.disc_cfg, self.gan_loss,
            dict(type='DiscShiftLoss',
                 loss_weight=0.5,
                 data_info=dict(pred='disc_pred_fake')),
            dict(type='GeneratorPathRegularizer'))
        assert isinstance(sagan.disc_auxiliary_losses, nn.ModuleList)
        assert isinstance(sagan.gen_auxiliary_losses, nn.ModuleList)

        sagan = BasicConditionalGAN(
            self.generator_cfg, self.disc_cfg, self.gan_loss,
            dict(type='DiscShiftLoss',
                 loss_weight=0.5,
                 data_info=dict(pred='disc_pred_fake')),
            [dict(type='GeneratorPathRegularizer')])
        assert isinstance(sagan.disc_auxiliary_losses, nn.ModuleList)
        assert isinstance(sagan.gen_auxiliary_losses, nn.ModuleList)
Ejemplo n.º 8
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    assert isinstance(model, _supported_model)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(dataset,
                                       samples_per_gpu=args.batch_size,
                                       workers_per_gpu=cfg.data.get(
                                           'val_workers_per_gpu',
                                           cfg.data.workers_per_gpu),
                                       dist=False,
                                       shuffle=True)

    # decide samples path
    samples_path = args.samples_path
    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(samples_path,
                         suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)
    # select key to fetch fake images
    fake_key = 'fake_b'
    if isinstance(model.module, CycleGAN):
        fake_key = 'fake_b' if model.module.test_direction == 'a2b' else \
            'fake_a'
    # if no images, `num_exist` should be zero
    for begin in range(num_exist, num_needed, args.batch_size):
        end = min(begin + args.batch_size, max_num_images)
        # for translation model, we feed them images from dataloader
        data_loader_iter = iter(data_loader)
        data_batch = next(data_loader_iter)
        output_dict = model(test_mode=True, **data_batch)
        fakes = output_dict[fake_key]
        pbar.update(end - begin)
        for i in range(end - begin):
            images = fakes[i:i + 1]
            images = ((images + 1) / 2)
            images = images[:, [2, 1, 0], ...]
            images = images.clamp_(0, 1)
            image_name = str(begin + i) + '.png'
            save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(samples_path, args.batch_size)
    # select key to fetch real images
    if isinstance(model.module, CycleGAN):
        real_key = 'img_b' if model.module.test_direction == 'a2b' else 'img_a'
        if model.module.direction == 'b2a':
            real_key = 'img_a' if real_key == 'img_b' else 'img_b'

    if isinstance(model.module, Pix2Pix):
        real_key = 'img_b' if model.module.direction == 'a2b' else 'img_a'

    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        # feed in real images
        for data in data_loader:
            reals = data[real_key]
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
        metric.summary()
    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
    if delete_samples_path:
        shutil.rmtree(samples_path)
Ejemplo n.º 9
0
def test_pix2pix():
    # model settings
    model_cfg = dict(
        type='Pix2Pix',
        generator=dict(
            type='UnetGenerator',
            in_channels=3,
            out_channels=3,
            num_down=8,
            base_channels=64,
            norm_cfg=dict(type='BN'),
            use_dropout=True,
            init_cfg=dict(type='normal', gain=0.02)),
        discriminator=dict(
            type='PatchDiscriminator',
            in_channels=6,
            base_channels=64,
            num_conv=3,
            norm_cfg=dict(type='BN'),
            init_cfg=dict(type='normal', gain=0.02)),
        gan_loss=dict(
            type='GANLoss',
            gan_type='vanilla',
            real_label_val=1.0,
            fake_label_val=0.0,
            loss_weight=1.0),
        default_domain='photo',
        reachable_domains=['photo'],
        related_domains=['photo', 'mask'],
        gen_auxiliary_loss=dict(
            type='L1Loss',
            loss_weight=100.0,
            data_info=dict(pred='fake_photo', target='real_photo'),
            reduction='mean'))

    train_cfg = None
    test_cfg = None

    # build synthesizer
    synthesizer = build_model(
        model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)
    # test attributes
    assert synthesizer.__class__.__name__ == 'Pix2Pix'
    assert isinstance(synthesizer.generators['photo'], UnetGenerator)
    assert isinstance(synthesizer.discriminators['photo'], PatchDiscriminator)
    assert isinstance(synthesizer.gan_loss, GANLoss)
    assert isinstance(synthesizer.gen_auxiliary_losses[0], L1Loss)
    assert synthesizer.test_cfg is None

    # prepare data
    img_mask = torch.rand(1, 3, 256, 256)
    img_photo = torch.rand(1, 3, 256, 256)
    data_batch = {'img_mask': img_mask, 'img_photo': img_photo}

    # prepare optimizer
    optim_cfg = dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))
    optimizer = {
        'generators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generators').parameters())),
        'discriminators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminators').parameters()))
    }

    # test forward_test
    domain = 'photo'
    with torch.no_grad():
        outputs = synthesizer(img_mask, target_domain=domain, test_mode=True)
    assert torch.equal(outputs['source'], data_batch['img_mask'])
    assert torch.is_tensor(outputs['target'])
    assert outputs['target'].size() == (1, 3, 256, 256)

    # test forward_train
    outputs = synthesizer(img_mask, target_domain=domain, test_mode=False)
    assert torch.equal(outputs['source'], data_batch['img_mask'])
    assert torch.is_tensor(outputs['target'])
    assert outputs['target'].size() == (1, 3, 256, 256)

    # test train_step
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    for v in ['loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_l1']:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1

    assert torch.equal(outputs['results']['real_mask'], data_batch['img_mask'])
    assert torch.equal(outputs['results']['real_photo'],
                       data_batch['img_photo'])
    assert torch.is_tensor(outputs['results']['fake_photo'])
    assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)

    # test cuda
    if torch.cuda.is_available():
        synthesizer = synthesizer.cuda()
        optimizer = {
            'generators':
            obj_from_dict(
                optim_cfg, torch.optim,
                dict(params=getattr(synthesizer, 'generators').parameters())),
            'discriminators':
            obj_from_dict(
                optim_cfg, torch.optim,
                dict(
                    params=getattr(synthesizer,
                                   'discriminators').parameters()))
        }
        data_batch_cuda = copy.deepcopy(data_batch)
        data_batch_cuda['img_mask'] = img_mask.cuda()
        data_batch_cuda['img_photo'] = img_photo.cuda()

        # forward_test
        with torch.no_grad():
            outputs = synthesizer(
                data_batch_cuda['img_mask'],
                target_domain=domain,
                test_mode=True)
        assert torch.equal(outputs['source'],
                           data_batch_cuda['img_mask'].cpu())
        assert torch.is_tensor(outputs['target'])
        assert outputs['target'].size() == (1, 3, 256, 256)

        # test forward_train
        outputs = synthesizer(
            data_batch_cuda['img_mask'], target_domain=domain, test_mode=False)
        assert torch.equal(outputs['source'], data_batch_cuda['img_mask'])
        assert torch.is_tensor(outputs['target'])
        assert outputs['target'].size() == (1, 3, 256, 256)

        # train_step
        outputs = synthesizer.train_step(data_batch_cuda, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        for v in [
                'loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_l1'
        ]:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_mask'],
                           data_batch_cuda['img_mask'].cpu())
        assert torch.equal(outputs['results']['real_photo'],
                           data_batch_cuda['img_photo'].cpu())
        assert torch.is_tensor(outputs['results']['fake_photo'])
        assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)

    # test disc_steps and disc_init_steps
    data_batch['img_mask'] = img_mask.cpu()
    data_batch['img_photo'] = img_photo.cpu()
    train_cfg = dict(disc_steps=2, disc_init_steps=2)
    synthesizer = build_model(
        model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)
    optimizer = {
        'generators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generators').parameters())),
        'discriminators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminators').parameters()))
    }

    # iter 0, 1
    for i in range(2):
        outputs = synthesizer.train_step(data_batch, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        assert outputs['log_vars'].get('loss_gan_g') is None
        assert outputs['log_vars'].get('loss_l1') is None
        for v in ['loss_gan_d_fake', 'loss_gan_d_real']:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_mask'],
                           data_batch['img_mask'])
        assert torch.equal(outputs['results']['real_photo'],
                           data_batch['img_photo'])
        assert torch.is_tensor(outputs['results']['fake_photo'])
        assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)
        assert synthesizer.iteration == i + 1

    # iter 2, 3, 4, 5
    for i in range(2, 6):
        assert synthesizer.iteration == i
        outputs = synthesizer.train_step(data_batch, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        log_check_list = [
            'loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_l1'
        ]
        if i % 2 == 1:
            assert outputs['log_vars'].get('loss_gan_g') is None
            assert outputs['log_vars'].get('loss_pixel') is None
            log_check_list.remove('loss_gan_g')
            log_check_list.remove('loss_l1')
        for v in log_check_list:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_mask'],
                           data_batch['img_mask'])
        assert torch.equal(outputs['results']['real_photo'],
                           data_batch['img_photo'])
        assert torch.is_tensor(outputs['results']['fake_photo'])
        assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)
        assert synthesizer.iteration == i + 1

    # test without pixel loss
    model_cfg_ = copy.deepcopy(model_cfg)
    model_cfg_.pop('gen_auxiliary_loss')
    synthesizer = build_model(model_cfg_, train_cfg=None, test_cfg=None)
    optimizer = {
        'generators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generators').parameters())),
        'discriminators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminators').parameters()))
    }
    data_batch['img_mask'] = img_mask.cpu()
    data_batch['img_photo'] = img_photo.cpu()
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    assert outputs['log_vars'].get('loss_pixel') is None
    for v in ['loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g']:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1
    assert torch.equal(outputs['results']['real_mask'], data_batch['img_mask'])
    assert torch.equal(outputs['results']['real_photo'],
                       data_batch['img_photo'])
    assert torch.is_tensor(outputs['results']['fake_photo'])
    assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)
Ejemplo n.º 10
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text
    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)

    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmgen version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmgen_version=__version__ +
                                          get_git_hash()[:7])

    train_model(model,
                datasets,
                cfg,
                distributed=distributed,
                validate=(not args.no_validate),
                timestamp=timestamp,
                meta=meta)
Ejemplo n.º 11
0
    def test_ada_stylegan2_model_cpu(self):
        synthesis_cfg = {
            'type': 'SynthesisNetwork',
            'channel_base': 1024,
            'channel_max': 16,
            'magnitude_ema_beta': 0.999
        }
        aug_kwargs = {
            'xflip': 1,
            'rotate90': 1,
            'xint': 1,
            'scale': 1,
            'rotate': 1,
            'aniso': 1,
            'xfrac': 1,
            'brightness': 1,
            'contrast': 1,
            'lumaflip': 1,
            'hue': 1,
            'saturation': 1
        }
        default_config = dict(type='StaticUnconditionalGAN',
                              generator=dict(type='StyleGANv3Generator',
                                             out_size=8,
                                             style_channels=8,
                                             img_channels=3,
                                             rgb2bgr=True,
                                             synthesis_cfg=synthesis_cfg),
                              discriminator=dict(
                                  type='ADAStyleGAN2Discriminator',
                                  in_size=8,
                                  input_bgr2rgb=True,
                                  data_aug=dict(type='ADAAug',
                                                update_interval=2,
                                                aug_pipeline=aug_kwargs,
                                                ada_kimg=100)),
                              gan_loss=dict(type='GANLoss',
                                            gan_type='wgan-logistic-ns'))

        s3gan = build_model(default_config)
        assert isinstance(s3gan, StaticUnconditionalGAN)
        assert not s3gan.with_disc_auxiliary_loss
        assert s3gan.with_disc

        # test forward train
        with pytest.raises(NotImplementedError):
            _ = s3gan(None, return_loss=True)
        # test forward test
        imgs = s3gan(None, return_loss=False, mode='sampling', num_batches=2)
        assert imgs.shape == (2, 3, 8, 8)

        # test train step
        data = torch.randn((2, 3, 8, 8))
        data_input = dict(real_img=data)
        optimizer_g = torch.optim.SGD(s3gan.generator.parameters(), lr=0.01)
        optimizer_d = torch.optim.SGD(s3gan.discriminator.parameters(),
                                      lr=0.01)
        optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

        _ = s3gan.train_step(data_input,
                             optim_dict,
                             running_status=dict(iteration=1))
        s3gan.discriminator.ada_aug.aug_pipeline.p.dtype == torch.float32
Ejemplo n.º 12
0
def test_cyclegan():

    model_cfg = dict(type='CycleGAN',
                     default_domain='photo',
                     reachable_domains=['photo', 'mask'],
                     related_domains=['photo', 'mask'],
                     generator=dict(type='ResnetGenerator',
                                    in_channels=3,
                                    out_channels=3,
                                    base_channels=64,
                                    norm_cfg=dict(type='IN'),
                                    use_dropout=False,
                                    num_blocks=9,
                                    padding_mode='reflect',
                                    init_cfg=dict(type='normal', gain=0.02)),
                     discriminator=dict(type='PatchDiscriminator',
                                        in_channels=3,
                                        base_channels=64,
                                        num_conv=3,
                                        norm_cfg=dict(type='IN'),
                                        init_cfg=dict(type='normal',
                                                      gain=0.02)),
                     gan_loss=dict(type='GANLoss',
                                   gan_type='lsgan',
                                   real_label_val=1.0,
                                   fake_label_val=0.0,
                                   loss_weight=1.0),
                     gen_auxiliary_loss=[
                         dict(type='L1Loss',
                              loss_weight=10.0,
                              data_info=dict(pred='cycle_photo',
                                             target='real_photo'),
                              reduction='mean'),
                         dict(type='L1Loss',
                              loss_weight=10.0,
                              data_info=dict(
                                  pred='cycle_mask',
                                  target='real_mask',
                              ),
                              reduction='mean'),
                         dict(type='L1Loss',
                              loss_weight=0.5,
                              data_info=dict(pred='identity_photo',
                                             target='real_photo'),
                              reduction='mean'),
                         dict(type='L1Loss',
                              loss_weight=0.5,
                              data_info=dict(pred='identity_mask',
                                             target='real_mask'),
                              reduction='mean')
                     ])

    train_cfg = None
    test_cfg = None

    # build synthesizer
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)

    # test attributes
    assert synthesizer.__class__.__name__ == 'CycleGAN'
    assert isinstance(synthesizer.generators['photo'], ResnetGenerator)
    assert isinstance(synthesizer.generators['mask'], ResnetGenerator)
    assert isinstance(synthesizer.discriminators['photo'], PatchDiscriminator)
    assert isinstance(synthesizer.discriminators['mask'], PatchDiscriminator)
    assert isinstance(synthesizer.gan_loss, GANLoss)
    for loss_module in synthesizer.gen_auxiliary_losses:
        assert isinstance(loss_module, L1Loss)

    # prepare data
    inputs = torch.rand(1, 3, 64, 64)
    targets = torch.rand(1, 3, 64, 64)
    data_batch = {'img_mask': inputs, 'img_photo': targets}

    # prepare optimizer
    optim_cfg = dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))
    optimizer = {
        'generators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generators').parameters())),
        'discriminators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminators').parameters()))
    }

    # test forward_test
    with torch.no_grad():
        outputs = synthesizer(inputs, target_domain='photo', test_mode=True)
    assert torch.equal(outputs['source'], data_batch['img_mask'])
    assert torch.is_tensor(outputs['target'])
    assert outputs['target'].size() == (1, 3, 64, 64)

    with torch.no_grad():
        outputs = synthesizer(targets, target_domain='mask', test_mode=True)
    assert torch.equal(outputs['source'], data_batch['img_photo'])
    assert torch.is_tensor(outputs['target'])
    assert outputs['target'].size() == (1, 3, 64, 64)

    # test forward_train
    with torch.no_grad():
        outputs = synthesizer(inputs, target_domain='photo', test_mode=True)
    assert torch.equal(outputs['source'], data_batch['img_mask'])
    assert torch.is_tensor(outputs['target'])
    assert outputs['target'].size() == (1, 3, 64, 64)

    with torch.no_grad():
        outputs = synthesizer(targets, target_domain='mask', test_mode=True)
    assert torch.equal(outputs['source'], data_batch['img_photo'])
    assert torch.is_tensor(outputs['target'])
    assert outputs['target'].size() == (1, 3, 64, 64)

    # test train_step
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    for v in [
            'loss_gan_d_mask', 'loss_gan_d_photo', 'loss_gan_g_mask',
            'loss_gan_g_photo', 'loss_l1'
    ]:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1
    assert torch.equal(outputs['results']['real_photo'],
                       data_batch['img_photo'])
    assert torch.equal(outputs['results']['real_mask'], data_batch['img_mask'])
    assert torch.is_tensor(outputs['results']['fake_mask'])
    assert torch.is_tensor(outputs['results']['fake_photo'])
    assert outputs['results']['fake_mask'].size() == (1, 3, 64, 64)
    assert outputs['results']['fake_photo'].size() == (1, 3, 64, 64)

    # test train_step and forward_test (gpu)
    if torch.cuda.is_available():
        synthesizer = synthesizer.cuda()
        optimizer = {
            'generators':
            obj_from_dict(
                optim_cfg, torch.optim,
                dict(params=getattr(synthesizer, 'generators').parameters())),
            'discriminators':
            obj_from_dict(
                optim_cfg, torch.optim,
                dict(params=getattr(synthesizer,
                                    'discriminators').parameters()))
        }
        data_batch_cuda = copy.deepcopy(data_batch)
        data_batch_cuda['img_mask'] = inputs.cuda()
        data_batch_cuda['img_photo'] = targets.cuda()

        # forward_test
        with torch.no_grad():
            outputs = synthesizer(data_batch_cuda['img_mask'],
                                  target_domain='photo',
                                  test_mode=True)
        assert torch.equal(outputs['source'],
                           data_batch_cuda['img_mask'].cpu())
        assert torch.is_tensor(outputs['target'])
        assert outputs['target'].size() == (1, 3, 64, 64)

        with torch.no_grad():
            outputs = synthesizer(data_batch_cuda['img_photo'],
                                  target_domain='mask',
                                  test_mode=True)
        assert torch.equal(outputs['source'],
                           data_batch_cuda['img_photo'].cpu())
        assert torch.is_tensor(outputs['target'])
        assert outputs['target'].size() == (1, 3, 64, 64)

        # test forward_train
        with torch.no_grad():
            outputs = synthesizer(data_batch_cuda['img_mask'],
                                  target_domain='photo',
                                  test_mode=False)
        assert torch.equal(outputs['source'], data_batch_cuda['img_mask'])
        assert torch.is_tensor(outputs['target'])
        assert outputs['target'].size() == (1, 3, 64, 64)

        with torch.no_grad():
            outputs = synthesizer(data_batch_cuda['img_photo'],
                                  target_domain='mask',
                                  test_mode=False)
        assert torch.equal(outputs['source'], data_batch_cuda['img_photo'])
        assert torch.is_tensor(outputs['target'])
        assert outputs['target'].size() == (1, 3, 64, 64)

        # train_step
        outputs = synthesizer.train_step(data_batch_cuda, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        print(outputs['log_vars'].keys())
        assert isinstance(outputs['results'], dict)
        for v in [
                'loss_gan_d_mask', 'loss_gan_d_photo', 'loss_gan_g_mask',
                'loss_gan_g_photo', 'loss_l1'
        ]:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_photo'],
                           data_batch_cuda['img_photo'].cpu())
        assert torch.equal(outputs['results']['real_mask'],
                           data_batch_cuda['img_mask'].cpu())
        assert torch.is_tensor(outputs['results']['fake_mask'])
        assert torch.is_tensor(outputs['results']['fake_photo'])
        assert outputs['results']['fake_mask'].size() == (1, 3, 64, 64)
        assert outputs['results']['fake_photo'].size() == (1, 3, 64, 64)

    # test disc_steps and disc_init_steps
    data_batch['img_mask'] = inputs.cpu()
    data_batch['img_photo'] = targets.cpu()
    train_cfg = dict(disc_steps=2, disc_init_steps=2)
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)
    optimizer = {
        'generators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generators').parameters())),
        'discriminators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminators').parameters()))
    }

    # iter 0, 1
    for i in range(2):
        outputs = synthesizer.train_step(data_batch, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        for v in ['loss_gan_g_mask', 'loss_gan_g_photo', 'loss_l1']:
            assert outputs['log_vars'].get(v) is None
        assert isinstance(outputs['log_vars']['loss_gan_d_mask'], float)
        assert isinstance(outputs['log_vars']['loss_gan_d_photo'], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_photo'],
                           data_batch['img_photo'])
        assert torch.equal(outputs['results']['real_mask'],
                           data_batch['img_mask'])
        assert torch.is_tensor(outputs['results']['fake_mask'])
        assert torch.is_tensor(outputs['results']['fake_photo'])
        assert outputs['results']['fake_mask'].size() == (1, 3, 64, 64)
        assert outputs['results']['fake_photo'].size() == (1, 3, 64, 64)
        assert synthesizer.iteration == i + 1

    # iter 2, 3, 4, 5
    for i in range(2, 6):
        assert synthesizer.iteration == i
        outputs = synthesizer.train_step(data_batch, optimizer)
        assert isinstance(outputs, dict)
        assert isinstance(outputs['log_vars'], dict)
        assert isinstance(outputs['results'], dict)
        log_check_list = [
            'loss_gan_d_mask', 'loss_gan_d_photo', 'loss_gan_g_mask',
            'loss_gan_g_photo', 'loss_l1'
        ]
        if i % 2 == 1:
            log_None_list = ['loss_gan_g_mask', 'loss_gan_g_photo', 'loss_l1']
            for v in log_None_list:
                assert outputs['log_vars'].get(v) is None
                log_check_list.remove(v)
        for v in log_check_list:
            assert isinstance(outputs['log_vars'][v], float)
        assert outputs['num_samples'] == 1
        assert torch.equal(outputs['results']['real_mask'],
                           data_batch['img_mask'])
        assert torch.equal(outputs['results']['real_photo'],
                           data_batch['img_photo'])
        assert torch.is_tensor(outputs['results']['fake_mask'])
        assert torch.is_tensor(outputs['results']['fake_photo'])
        assert outputs['results']['fake_mask'].size() == (1, 3, 64, 64)
        assert outputs['results']['fake_photo'].size() == (1, 3, 64, 64)
        assert synthesizer.iteration == i + 1

    # test GAN image buffer size = 0
    data_batch['img_mask'] = inputs.cpu()
    data_batch['img_photo'] = targets.cpu()
    train_cfg = dict(buffer_size=0)
    synthesizer = build_model(model_cfg,
                              train_cfg=train_cfg,
                              test_cfg=test_cfg)
    optimizer = {
        'generators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'generators').parameters())),
        'discriminators':
        obj_from_dict(
            optim_cfg, torch.optim,
            dict(params=getattr(synthesizer, 'discriminators').parameters()))
    }
    outputs = synthesizer.train_step(data_batch, optimizer)
    assert isinstance(outputs, dict)
    assert isinstance(outputs['log_vars'], dict)
    assert isinstance(outputs['results'], dict)
    for v in [
            'loss_gan_d_mask', 'loss_gan_d_photo', 'loss_gan_g_mask',
            'loss_gan_g_photo', 'loss_l1'
    ]:
        assert isinstance(outputs['log_vars'][v], float)
    assert outputs['num_samples'] == 1
    assert torch.equal(outputs['results']['real_mask'], data_batch['img_mask'])
    assert torch.equal(outputs['results']['real_photo'],
                       data_batch['img_photo'])
    assert torch.is_tensor(outputs['results']['fake_mask'])
    assert torch.is_tensor(outputs['results']['fake_photo'])
    assert outputs['results']['fake_mask'].size() == (1, 3, 64, 64)
    assert outputs['results']['fake_photo'].size() == (1, 3, 64, 64)
    assert synthesizer.iteration == 1
Ejemplo n.º 13
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator
    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
    mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
    mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')

    generator.eval()

    if not args.use_cpu:
        generator = generator.cuda()
    if args.show_mode == 'sequence':
        assert args.endpoint >= 2
    else:
        assert args.endpoint >= 2 and args.endpoint % 2 == 0

    kwargs = dict(max_batch_size=args.batch_size)
    if args.sample_cfg is None:
        args.sample_cfg = dict()
    kwargs.update(args.sample_cfg)

    # get noises corresponding to each endpoint
    noise_batch = batch_inference(
        generator,
        None,
        num_batches=args.endpoint,
        dict_key='noise_batch' if args.space == 'z' else 'latent',
        return_noise=True,
        **kwargs)

    if args.space == 'w':
        kwargs['truncation_latent'] = generator.get_mean_latent()
        kwargs['input_is_latent'] = True

    if args.show_mode == 'sequence':
        results = sample_from_path(generator, noise_batch[:-1, ],
                                   noise_batch[1:, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    else:
        results = sample_from_path(generator, noise_batch[::2, ],
                                   noise_batch[1::2, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    # reorder results
    results = torch.stack(results).permute(1, 0, 2, 3, 4)
    _, _, ch, h, w = results.shape
    results = results.reshape(-1, ch, h, w)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)
    # save image
    mmcv.mkdir_or_exist(args.samples_path)
    if args.show_mode == 'sequence':
        for i in range(results.shape[0]):
            image = results[i:i + 1]
            save_image(
                image,
                os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png'))
    else:
        save_image(
            results,
            os.path.join(args.samples_path, 'group.png'),
            nrow=args.interval)
Ejemplo n.º 14
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator
    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
    mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
    mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')

    generator.eval()

    if not args.use_cpu:
        generator = generator.cuda()

    # if given proj_latent, reset args.endpoint
    if args.proj_latent is not None:
        mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
        proj_file = torch.load(args.proj_latent)
        proj_n = len(proj_file)
        setattr(args, 'endpoint', proj_n)
        assert args.space == 'w', 'Projected latent are w or w-plus latent.'
        noise_batch = []
        for img_path in proj_file:
            noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
        noise_batch = torch.cat(noise_batch, dim=0).cuda()
        if args.use_cpu:
            noise_batch = noise_batch.to('cpu')

    if args.show_mode == 'sequence':
        assert args.endpoint >= 2
    else:
        assert args.endpoint >= 2 and args.endpoint % 2 == 0,\
            '''We need paired images in group mode,
            so keep endpoint an even number'''

    kwargs = dict(max_batch_size=args.batch_size)
    if args.sample_cfg is None:
        args.sample_cfg = dict()
    kwargs.update(args.sample_cfg)
    # remind users to fixed injected noise
    if kwargs.get('randomize_noise', 'True'):
        mmcv.print_log(
            '''Hint: For Style-Based GAN, you can add
            `--sample-cfg randomize_noise=False` to fix injected noises''',
            'mmgen')

    # get noises corresponding to each endpoint
    if not args.proj_latent:
        noise_batch = batch_inference(
            generator,
            None,
            num_batches=args.endpoint,
            dict_key='noise_batch' if args.space == 'z' else 'latent',
            return_noise=True,
            **kwargs)

    if args.space == 'w':
        kwargs['truncation_latent'] = generator.get_mean_latent()
        kwargs['input_is_latent'] = True

    if args.show_mode == 'sequence':
        results = sample_from_path(generator, noise_batch[:-1, ],
                                   noise_batch[1:, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    else:
        results = sample_from_path(generator, noise_batch[::2, ],
                                   noise_batch[1::2, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    # reorder results
    results = torch.stack(results).permute(1, 0, 2, 3, 4)
    _, _, ch, h, w = results.shape
    results = results.reshape(-1, ch, h, w)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)
    # save image
    mmcv.mkdir_or_exist(args.samples_path)
    if args.show_mode == 'sequence':
        if args.export_video:
            # render video.
            video_out = imageio.get_writer(os.path.join(
                args.samples_path, 'lerp.mp4'),
                                           mode='I',
                                           fps=60,
                                           codec='libx264',
                                           bitrate='12M')
            video_out = layout_grid(video_out, results)
            video_out.close()
        else:
            for i in range(results.shape[0]):
                image = results[i:i + 1]
                save_image(
                    image,
                    os.path.join(args.samples_path,
                                 '{:0>5d}'.format(i) + '.png'))
    else:
        if args.export_video:
            # render video.
            video_out = imageio.get_writer(os.path.join(
                args.samples_path, 'lerp.mp4'),
                                           mode='I',
                                           fps=60,
                                           codec='libx264',
                                           bitrate='12M')
            n_pair = args.endpoint // 2
            grid_w, grid_h = crack_integer(n_pair)
            video_out = layout_grid(video_out,
                                    results,
                                    grid_h=grid_h,
                                    grid_w=grid_w)
            video_out.close()
        else:
            save_image(results,
                       os.path.join(args.samples_path, 'group.png'),
                       nrow=args.interval)