Пример #1
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    model.eval()

    # load ckpt
    mmcv.print_log(f'Loading ckpt from {args.checkpoint}', 'mmgen')
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')

    # add dp wrapper
    model = MMDataParallel(model, device_ids=[0])

    pbar = mmcv.ProgressBar(args.num_samples)
    for sample_iter in range(args.num_samples):
        outputs = model(None, num_batches=1, get_prev_res=args.save_prev_res)

        # store results from previous stages
        if args.save_prev_res:
            fake_img = outputs['fake_img']
            prev_res_list = outputs['prev_res_list']
            prev_res_list.append(fake_img)
            for i, img in enumerate(prev_res_list):
                img = _tensor2img(img)
                mmcv.imwrite(
                    img,
                    os.path.join(args.samples_path, f'stage{i}',
                                 f'rand_sample_{sample_iter}.png'))
        # just store the final result
        else:
            img = _tensor2img(outputs)
            mmcv.imwrite(
                img,
                os.path.join(args.samples_path,
                             f'rand_sample_{sample_iter}.png'))

        pbar.update()

    # change the line after pbar
    sys.stdout.write('\n')
Пример #2
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    rank, _ = get_dist_info()

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(
        log_file=log_path, log_level=cfg.log_level, file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()
    if not distributed:
        _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
        model = MMDataParallel(model, device_ids=[0])

        # build metrics
        if args.eval:
            if args.eval[0] == 'none':
                # only sample images
                metrics = []
                assert args.num_samples is not None and args.num_samples > 0
            else:
                metrics = [
                    build_metric(cfg.metrics[metric]) for metric in args.eval
                ]
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in cfg.metrics
            ]

        basic_table_info = dict(
            train_cfg=os.path.basename(cfg._filename),
            ckpt=ckpt,
            sample_model=args.sample_model)

        if len(metrics) == 0:
            basic_table_info['num_samples'] = args.num_samples
            data_loader = None
        else:
            basic_table_info['num_samples'] = -1
            # build the dataloader
            if cfg.data.get('test', None):
                dataset = build_dataset(cfg.data.test)
            elif cfg.data.get('val', None):
                dataset = build_dataset(cfg.data.val)
            else:
                dataset = build_dataset(cfg.data.train)
            data_loader = build_dataloader(
                dataset,
                samples_per_gpu=args.batch_size,
                workers_per_gpu=cfg.data.get('val_workers_per_gpu',
                                             cfg.data.workers_per_gpu),
                dist=distributed,
                shuffle=True)

        if args.sample_cfg is None:
            args.sample_cfg = dict()

        # online mode will not save samples
        if args.online and len(metrics) > 0:
            single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                         basic_table_info, args.batch_size,
                                         **args.sample_cfg)
        else:
            single_gpu_evaluation(model, data_loader, metrics, logger,
                                  basic_table_info, args.batch_size,
                                  args.samples_path, **args.sample_cfg)
    else:
        raise NotImplementedError("We hasn't implemented multi gpu eval yet.")
Пример #3
0
    parser.add_argument('--truncation-mean', type=int, default=4096)
    parser.add_argument('--noise-channels', type=int, default=512)
    parser.add_argument('--input-scale', type=int, default=4)
    parser.add_argument('--sample-cfg',
                        nargs='+',
                        action=DictAction,
                        help='Other customized kwargs for sampling function')

    # system args
    parser.add_argument('--num-samples', type=int, default=2)
    parser.add_argument('--sample-path', type=str, default=None)
    parser.add_argument('--random-seed', type=int, default=2020)

    args = parser.parse_args()

    set_random_seed(args.random_seed)
    cfg = mmcv.Config.fromfile(args.config)

    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    mmcv.print_log('Building models and loading checkpoints', 'mmgen')
    # build model
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    model.eval()
    load_checkpoint(model, args.ckpt, map_location='cpu')

    # get generator
Пример #4
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(
        log_file=log_path, log_level=cfg.log_level, file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    assert isinstance(model, BaseTranslationModel)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    # get source domain and target domain
    target_domain = args.target_domain
    if target_domain is None:
        target_domain = model.module._default_domain
    source_domain = model.module.get_other_domains(target_domain)[0]

    basic_table_info = dict(
        train_cfg=os.path.basename(cfg._filename),
        ckpt=ckpt,
        sample_model=args.sample_model,
        source_domain=source_domain,
        target_domain=target_domain)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(
            dataset,
            samples_per_gpu=args.batch_size,
            workers_per_gpu=cfg.data.get('val_workers_per_gpu',
                                         cfg.data.workers_per_gpu),
            dist=False,
            shuffle=True)

    if args.online:
        single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                     basic_table_info, args.batch_size)
    else:
        single_gpu_evaluation(model, data_loader, metrics, logger,
                              basic_table_info, args.batch_size,
                              args.samples_path)
Пример #5
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    generator.eval()
    device = 'cpu'
    if not args.use_cpu:
        generator = generator.cuda()
        device = 'cuda'

    img_size = min(generator.out_size, 256)
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    # read images
    imgs = []
    for imgfile in args.files:
        img = Image.open(imgfile).convert('RGB')
        img = transform(img)
        img = img[[2, 1, 0], ...]
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    # get mean and standard deviation of style latents
    with torch.no_grad():
        noise_sample = torch.randn(args.n_mean_latent,
                                   generator.style_channels,
                                   device=device)
        latent_out = generator.style_mapping(noise_sample)
        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() /
                      args.n_mean_latent)**0.5
    latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(
        imgs.shape[0], 1)
    if args.w_plus:
        latent_in = latent_in.unsqueeze(1).repeat(1, generator.num_latents, 1)
    latent_in.requires_grad = True

    # define lpips loss
    percept = PerceptualLoss(use_gpu=device.startswith('cuda'))

    # initialize layer noises
    noises_single = generator.make_injected_noise()
    noises = []
    for noise in noises_single:
        noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
    for noise in noises:
        noise.requires_grad = True

    optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
    pbar = tqdm(range(args.total_iters))
    # run optimization
    for i in pbar:
        t = i / args.total_iters
        lr = get_lr(t, args.lr, args.lr_rampdown, args.lr_rampup)
        optimizer.param_groups[0]['lr'] = lr
        noise_strength = latent_std * args.noise * max(
            0, 1 - t / args.noise_ramp)**2
        latent_n = latent_noise(latent_in, noise_strength.item())

        img_gen = generator([latent_n],
                            input_is_latent=True,
                            injected_noise=noises)

        batch, channel, height, width = img_gen.shape

        if height > 256:
            factor = height // 256

            img_gen = img_gen.reshape(batch, channel, height // factor, factor,
                                      width // factor, factor)
            img_gen = img_gen.mean([3, 5])

        p_loss = percept(img_gen, imgs).sum()
        n_loss = noise_regularize(noises)
        mse_loss = F.mse_loss(img_gen, imgs)

        loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        noise_normalize_(noises)

        pbar.set_description(
            f' perceptual: {p_loss.item():.4f}, noise regularize:'
            f'{n_loss.item():.4f}, mse: {mse_loss.item():.4f}, lr: {lr:.4f}')

    results = generator([latent_in.detach().clone()],
                        input_is_latent=True,
                        injected_noise=noises)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)

    mmcv.mkdir_or_exist(args.results_path)
    # save projection results
    result_file = {}
    for i, input_name in enumerate(args.files):
        noise_single = []
        for noise in noises:
            noise_single.append(noise[i:i + 1])
        result_file[input_name] = {
            'img': img_gen[i],
            'latent': latent_in[i],
            'injected_noise': noise_single,
        }
        img_name = os.path.splitext(
            os.path.basename(input_name))[0] + '-project.png'
        save_image(results[i], os.path.join(args.results_path, img_name))

    torch.save(result_file, os.path.join(args.results_path,
                                         'project_result.pt'))
Пример #6
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
        rank = 0
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        rank, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)
        assert args.online or world_size == 1, (
            'We only support online mode for distrbuted evaluation.')

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    # check metrics for dist evaluation
    if distributed and metrics:
        for metric in metrics:
            assert metric.name in _distributed_metrics, (
                f'We only support {_distributed_metrics} for multi gpu '
                f'evaluation, but receive {args.eval}.')

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        # build the dataloader
        if cfg.data.get('test', None) and cfg.data.test.get('imgs_root', None):
            dataset = build_dataset(cfg.data.test)
        elif cfg.data.get('val', None) and cfg.data.val.get('imgs_root', None):
            dataset = build_dataset(cfg.data.val)
        elif cfg.data.get('train', None):
            # we assume that the train part should work well
            dataset = build_dataset(cfg.data.train)
        else:
            raise RuntimeError('There is no valid dataset config to run, '
                               'please check your dataset configs.')

        # The default loader config
        loader_cfg = dict(samples_per_gpu=args.batch_size,
                          workers_per_gpu=cfg.data.get(
                              'val_workers_per_gpu', cfg.data.workers_per_gpu),
                          num_gpus=len(cfg.gpu_ids),
                          dist=distributed,
                          shuffle=True)
        # The overall dataloader settings
        loader_cfg.update({
            k: v
            for k, v in cfg.data.items() if k not in [
                'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
                'test_dataloader'
            ]
        })

        # specific config for test loader
        test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})}

        data_loader = build_dataloader(dataset, **test_loader_cfg)
    if args.sample_cfg is None:
        args.sample_cfg = dict()

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
    else:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)

    # online mode will not save samples
    if args.online and len(metrics) > 0:
        online_evaluation(model, data_loader, metrics, logger,
                          basic_table_info, args.batch_size, **args.sample_cfg)
    else:
        offline_evaluation(model, data_loader, metrics, logger,
                           basic_table_info, args.batch_size,
                           args.samples_path, **args.sample_cfg)
Пример #7
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    assert isinstance(model, _supported_model)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(dataset,
                                       samples_per_gpu=args.batch_size,
                                       workers_per_gpu=cfg.data.get(
                                           'val_workers_per_gpu',
                                           cfg.data.workers_per_gpu),
                                       dist=False,
                                       shuffle=True)

    # decide samples path
    samples_path = args.samples_path
    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(samples_path,
                         suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)
    # select key to fetch fake images
    fake_key = 'fake_b'
    if isinstance(model.module, CycleGAN):
        fake_key = 'fake_b' if model.module.test_direction == 'a2b' else \
            'fake_a'
    # if no images, `num_exist` should be zero
    for begin in range(num_exist, num_needed, args.batch_size):
        end = min(begin + args.batch_size, max_num_images)
        # for translation model, we feed them images from dataloader
        data_loader_iter = iter(data_loader)
        data_batch = next(data_loader_iter)
        output_dict = model(test_mode=True, **data_batch)
        fakes = output_dict[fake_key]
        pbar.update(end - begin)
        for i in range(end - begin):
            images = fakes[i:i + 1]
            images = ((images + 1) / 2)
            images = images[:, [2, 1, 0], ...]
            images = images.clamp_(0, 1)
            image_name = str(begin + i) + '.png'
            save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(samples_path, args.batch_size)
    # select key to fetch real images
    if isinstance(model.module, CycleGAN):
        real_key = 'img_b' if model.module.test_direction == 'a2b' else 'img_a'
        if model.module.direction == 'b2a':
            real_key = 'img_a' if real_key == 'img_b' else 'img_b'

    if isinstance(model.module, Pix2Pix):
        real_key = 'img_b' if model.module.direction == 'a2b' else 'img_a'

    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        # feed in real images
        for data in data_loader:
            reals = data[real_key]
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
        metric.summary()
    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
    if delete_samples_path:
        shutil.rmtree(samples_path)
Пример #8
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text
    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)

    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmgen version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmgen_version=__version__ +
                                          get_git_hash()[:7])

    train_model(model,
                datasets,
                cfg,
                distributed=distributed,
                validate=(not args.no_validate),
                timestamp=timestamp,
                meta=meta)
Пример #9
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator
    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
    mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
    mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')

    generator.eval()

    if not args.use_cpu:
        generator = generator.cuda()
    if args.show_mode == 'sequence':
        assert args.endpoint >= 2
    else:
        assert args.endpoint >= 2 and args.endpoint % 2 == 0

    kwargs = dict(max_batch_size=args.batch_size)
    if args.sample_cfg is None:
        args.sample_cfg = dict()
    kwargs.update(args.sample_cfg)

    # get noises corresponding to each endpoint
    noise_batch = batch_inference(
        generator,
        None,
        num_batches=args.endpoint,
        dict_key='noise_batch' if args.space == 'z' else 'latent',
        return_noise=True,
        **kwargs)

    if args.space == 'w':
        kwargs['truncation_latent'] = generator.get_mean_latent()
        kwargs['input_is_latent'] = True

    if args.show_mode == 'sequence':
        results = sample_from_path(generator, noise_batch[:-1, ],
                                   noise_batch[1:, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    else:
        results = sample_from_path(generator, noise_batch[::2, ],
                                   noise_batch[1::2, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    # reorder results
    results = torch.stack(results).permute(1, 0, 2, 3, 4)
    _, _, ch, h, w = results.shape
    results = results.reshape(-1, ch, h, w)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)
    # save image
    mmcv.mkdir_or_exist(args.samples_path)
    if args.show_mode == 'sequence':
        for i in range(results.shape[0]):
            image = results[i:i + 1]
            save_image(
                image,
                os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png'))
    else:
        save_image(
            results,
            os.path.join(args.samples_path, 'group.png'),
            nrow=args.interval)
Пример #10
0
def main():
    args = parse_args()
    # set cudnn_benchmark
    cfg = Config.fromfile(args.config)
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    os.makedirs(args.results_dir, exist_ok=True)

    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()

    model = init_model(args.config, args.checkpoint, device='cpu')
    g_ema = model.generator_ema
    g_ema.eval()
    if not args.use_cpu:
        g_ema = g_ema.cuda()

    mean_latent = g_ema.get_mean_latent()

    # if given proj_latent
    if args.proj_latent is not None:
        mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
        proj_file = torch.load(args.proj_latent)
        proj_n = len(proj_file)
        assert proj_n == 1
        noise_batch = []
        for img_path in proj_file:
            noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
        latent_code_init = torch.cat(noise_batch, dim=0).cuda()
    elif args.mode == 'edit':
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            results = g_ema([latent_code_init_not_trunc],
                            return_latents=True,
                            truncation=args.truncation,
                            truncation_latent=mean_latent)
            latent_code_init = results['latent']
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    with torch.no_grad():
        img_orig = g_ema([latent_code_init],
                         input_is_latent=True,
                         randomize_noise=False)

    latent = latent_code_init.detach().clone()
    latent.requires_grad = True

    clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size))
    id_loss = FaceIdLoss(
        facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda'))

    optimizer = optim.Adam([latent], lr=args.lr)

    pbar = tqdm(range(args.step))
    mmcv.print_log(f'Description: {args.description}')
    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]['lr'] = lr

        img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False)

        img_gen = img_gen[:, [2, 1, 0], ...]

        # clip loss
        c_loss = clip_loss(image=img_gen, text=text_inputs)

        if args.id_lambda > 0:
            i_loss = id_loss(pred=img_gen, gt=img_orig)[0]
        else:
            i_loss = 0

        if args.mode == 'edit':
            l2_loss = ((latent_code_init - latent)**2).sum()
            loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f'loss: {loss.item():.4f};'))
        if args.save_interval > 0 and (i % args.save_interval == 0):
            with torch.no_grad():
                img_gen = g_ema([latent],
                                input_is_latent=True,
                                randomize_noise=False)

            img_gen = img_gen[:, [2, 1, 0], ...]

            torchvision.utils.save_image(img_gen,
                                         os.path.join(
                                             args.results_dir,
                                             f'{str(i).zfill(5)}.png'),
                                         normalize=True,
                                         range=(-1, 1))

    if args.mode == 'edit':
        img_orig = img_orig[:, [2, 1, 0], ...]
        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    torchvision.utils.save_image(final_result.detach().cpu(),
                                 os.path.join(args.results_dir,
                                              'final_result.png'),
                                 normalize=True,
                                 scale_each=True,
                                 range=(-1, 1))
Пример #11
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'
    if args.sample_model == 'ema':
        generator = model.generator_ema
    else:
        generator = model.generator
    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
    mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
    mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')

    generator.eval()

    if not args.use_cpu:
        generator = generator.cuda()

    # if given proj_latent, reset args.endpoint
    if args.proj_latent is not None:
        mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
        proj_file = torch.load(args.proj_latent)
        proj_n = len(proj_file)
        setattr(args, 'endpoint', proj_n)
        assert args.space == 'w', 'Projected latent are w or w-plus latent.'
        noise_batch = []
        for img_path in proj_file:
            noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
        noise_batch = torch.cat(noise_batch, dim=0).cuda()
        if args.use_cpu:
            noise_batch = noise_batch.to('cpu')

    if args.show_mode == 'sequence':
        assert args.endpoint >= 2
    else:
        assert args.endpoint >= 2 and args.endpoint % 2 == 0,\
            '''We need paired images in group mode,
            so keep endpoint an even number'''

    kwargs = dict(max_batch_size=args.batch_size)
    if args.sample_cfg is None:
        args.sample_cfg = dict()
    kwargs.update(args.sample_cfg)
    # remind users to fixed injected noise
    if kwargs.get('randomize_noise', 'True'):
        mmcv.print_log(
            '''Hint: For Style-Based GAN, you can add
            `--sample-cfg randomize_noise=False` to fix injected noises''',
            'mmgen')

    # get noises corresponding to each endpoint
    if not args.proj_latent:
        noise_batch = batch_inference(
            generator,
            None,
            num_batches=args.endpoint,
            dict_key='noise_batch' if args.space == 'z' else 'latent',
            return_noise=True,
            **kwargs)

    if args.space == 'w':
        kwargs['truncation_latent'] = generator.get_mean_latent()
        kwargs['input_is_latent'] = True

    if args.show_mode == 'sequence':
        results = sample_from_path(generator, noise_batch[:-1, ],
                                   noise_batch[1:, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    else:
        results = sample_from_path(generator, noise_batch[::2, ],
                                   noise_batch[1::2, ], args.interval,
                                   args.interp_mode, args.space, **kwargs)
    # reorder results
    results = torch.stack(results).permute(1, 0, 2, 3, 4)
    _, _, ch, h, w = results.shape
    results = results.reshape(-1, ch, h, w)
    # rescale value range to [0, 1]
    results = ((results + 1) / 2)
    results = results[:, [2, 1, 0], ...]
    results = results.clamp_(0, 1)
    # save image
    mmcv.mkdir_or_exist(args.samples_path)
    if args.show_mode == 'sequence':
        if args.export_video:
            # render video.
            video_out = imageio.get_writer(os.path.join(
                args.samples_path, 'lerp.mp4'),
                                           mode='I',
                                           fps=60,
                                           codec='libx264',
                                           bitrate='12M')
            video_out = layout_grid(video_out, results)
            video_out.close()
        else:
            for i in range(results.shape[0]):
                image = results[i:i + 1]
                save_image(
                    image,
                    os.path.join(args.samples_path,
                                 '{:0>5d}'.format(i) + '.png'))
    else:
        if args.export_video:
            # render video.
            video_out = imageio.get_writer(os.path.join(
                args.samples_path, 'lerp.mp4'),
                                           mode='I',
                                           fps=60,
                                           codec='libx264',
                                           bitrate='12M')
            n_pair = args.endpoint // 2
            grid_w, grid_h = crack_integer(n_pair)
            video_out = layout_grid(video_out,
                                    results,
                                    grid_h=grid_h,
                                    grid_w=grid_w)
            video_out.close()
        else:
            save_image(results,
                       os.path.join(args.samples_path, 'group.png'),
                       nrow=args.interval)