Example #1
0
 def setup_class(cls):
     project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
     pix2pix_config = mmcv.Config.fromfile(
         os.path.join(
             project_dir,
             'configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py'))
     cls.pix2pix = init_model(pix2pix_config, checkpoint=None, device='cpu')
     cyclegan_config = mmcv.Config.fromfile(
         os.path.join(
             project_dir,
             'configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py')
     )
     cls.cyclegan = init_model(
         cyclegan_config, checkpoint=None, device='cpu')
     cls.img_path = os.path.join(
         os.path.dirname(__file__), '..', 'data/unpaired/testA/5.jpg')
Example #2
0
 def setup_class(cls):
     project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
     config = mmcv.Config.fromfile(
         os.path.join(
             project_dir,
             'configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py'))
     cls.model = init_model(config, checkpoint=None, device='cpu')
Example #3
0
 def setup_class(cls):
     project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
     ddpm_config = mmcv.Config.fromfile(
         os.path.join(
             project_dir, 'configs/improved_ddpm/'
             'ddpm_cosine_hybird_timestep-4k_drop0.3_'
             'cifar10_32x32_b8x16_500k.py'))
     # change timesteps to speed up test process
     ddpm_config.model['num_timesteps'] = 10
     cls.model = init_model(ddpm_config, checkpoint=None, device='cpu')
def main():
    args = parse_args()
    model = init_model(args.config,
                       checkpoint=args.checkpoint,
                       device=args.device)

    if args.sample_cfg is None:
        args.sample_cfg = dict()

    if args.label is None and not args.sample_all_classes:
        label = None
        num_samples, nrow = args.samples_per_classes, args.nrow
        mmcv.print_log(
            '`label` is not passed, code would randomly sample '
            f'`samples-per-classes` (={num_samples}) images.', 'mmgen')
    else:
        if args.sample_all_classes:
            mmcv.print_log(
                '`sample_all_classes` is set as True, `num-samples`, `label`, '
                'and `nrows` would be ignored.', 'mmgen')

            # get num_classes
            if hasattr(model, 'num_classes') and model.num_classes is not None:
                num_classes = model.num_classes
            else:
                raise AttributeError(
                    'Cannot get attribute `num_classes` from '
                    f'{type(model)}. Please check your config.', 'mmgen')
            # build label list
            meta_labels = [idx for idx in range(num_classes)]
        else:
            # get unique label
            meta_labels = list(set(args.label))
            meta_labels.sort()

        # generate label to sample
        label = []
        for idx in meta_labels:
            label += [idx] * args.samples_per_classes
        num_samples = len(label)
        nrow = args.samples_per_classes

        mmcv.print_log(
            'Set `nrows` as number of samples for each class '
            f'(={args.samples_per_classes}).', 'mmgen')

    results = sample_conditional_model(model, num_samples, args.num_batches,
                                       args.sample_model, label,
                                       **args.sample_cfg)
    results = (results[:, [2, 1, 0]] + 1.) / 2.

    # save images
    mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
    utils.save_image(results, args.save_path, nrow=nrow, padding=args.padding)
Example #5
0
    def initialize(self, context):
        properties = context.system_properties
        self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(self.map_location + ':' +
                                   str(properties.get('gpu_id')) if torch.cuda.
                                   is_available() else self.map_location)
        self.manifest = context.manifest

        model_dir = properties.get('model_dir')
        serialized_file = self.manifest['model']['serializedFile']
        checkpoint = os.path.join(model_dir, serialized_file)
        self.config_file = os.path.join(model_dir, 'config.py')

        self.model = init_model(self.config_file, checkpoint, self.device)
        self.initialized = True
Example #6
0
def main():
    args = parse_args()
    model = init_model(args.config,
                       checkpoint=args.checkpoint,
                       device=args.device)

    if args.sample_cfg is None:
        args.sample_cfg = dict()

    results = sample_img2img_model(model, args.image_path, args.target_domain,
                                   **args.sample_cfg)
    results = (results[:, [2, 1, 0]] + 1.) / 2.

    # save images
    mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
    utils.save_image(results, args.save_path)
Example #7
0
def main():
    args = parse_args()
    model = init_model(args.config,
                       checkpoint=args.checkpoint,
                       device=args.device)

    if args.sample_cfg is None:
        args.sample_cfg = dict()

    suffix = osp.splitext(args.save_path)[-1]
    if suffix == '.gif':
        args.sample_cfg['save_intermedia'] = True

    results = sample_ddpm_model(model, args.num_samples, args.num_batches,
                                args.sample_model, args.same_noise,
                                **args.sample_cfg)

    # save images
    mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
    if suffix == '.gif':
        # concentrate all output of each timestep
        results_timestep_list = []
        for t in results.keys():
            # make grid
            results_timestep = utils.make_grid(results[t],
                                               nrow=args.nrow,
                                               padding=args.padding)
            # unsqueeze at 0, because make grid output is size like [H', W', 3]
            results_timestep_list.append(results_timestep[None, ...])

        # Concatenates to [n_timesteps, H', W', 3]
        results_timestep = torch.cat(results_timestep_list, dim=0)
        if not args.is_rgb:
            results_timestep = results_timestep[:, [2, 1, 0]]
        results_timestep = (results_timestep + 1.) / 2.
        create_gif(results_timestep, args.save_path, n_skip=args.n_skip)
    else:
        if not args.is_rgb:
            results = results[:, [2, 1, 0]]

        results = (results + 1.) / 2.
        utils.save_image(results,
                         args.save_path,
                         nrow=args.nrow,
                         padding=args.padding)
Example #8
0
def main():
    args = parse_args()
    model = init_model(args.config,
                       checkpoint=args.checkpoint,
                       device=args.device)

    if args.sample_cfg is None:
        args.sample_cfg = dict()

    results = sample_uncoditional_model(model, args.num_samples,
                                        args.num_batches, args.sample_model,
                                        **args.sample_cfg)
    results = (results[:, [2, 1, 0]] + 1.) / 2.

    # save images
    mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
    utils.save_image(results,
                     args.save_path,
                     nrow=args.nrow,
                     padding=args.padding)
Example #9
0
def main():
    args = parse_args()
    # set cudnn_benchmark
    cfg = Config.fromfile(args.config)
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        print('set random seed to', args.seed)
        set_random_seed(args.seed, deterministic=args.deterministic)

    os.makedirs(args.results_dir, exist_ok=True)

    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()

    model = init_model(args.config, args.checkpoint, device='cpu')
    g_ema = model.generator_ema
    g_ema.eval()
    if not args.use_cpu:
        g_ema = g_ema.cuda()

    mean_latent = g_ema.get_mean_latent()

    # if given proj_latent
    if args.proj_latent is not None:
        mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
        proj_file = torch.load(args.proj_latent)
        proj_n = len(proj_file)
        assert proj_n == 1
        noise_batch = []
        for img_path in proj_file:
            noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
        latent_code_init = torch.cat(noise_batch, dim=0).cuda()
    elif args.mode == 'edit':
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            results = g_ema([latent_code_init_not_trunc],
                            return_latents=True,
                            truncation=args.truncation,
                            truncation_latent=mean_latent)
            latent_code_init = results['latent']
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    with torch.no_grad():
        img_orig = g_ema([latent_code_init],
                         input_is_latent=True,
                         randomize_noise=False)

    latent = latent_code_init.detach().clone()
    latent.requires_grad = True

    clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size))
    id_loss = FaceIdLoss(
        facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda'))

    optimizer = optim.Adam([latent], lr=args.lr)

    pbar = tqdm(range(args.step))
    mmcv.print_log(f'Description: {args.description}')
    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]['lr'] = lr

        img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False)

        img_gen = img_gen[:, [2, 1, 0], ...]

        # clip loss
        c_loss = clip_loss(image=img_gen, text=text_inputs)

        if args.id_lambda > 0:
            i_loss = id_loss(pred=img_gen, gt=img_orig)[0]
        else:
            i_loss = 0

        if args.mode == 'edit':
            l2_loss = ((latent_code_init - latent)**2).sum()
            loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f'loss: {loss.item():.4f};'))
        if args.save_interval > 0 and (i % args.save_interval == 0):
            with torch.no_grad():
                img_gen = g_ema([latent],
                                input_is_latent=True,
                                randomize_noise=False)

            img_gen = img_gen[:, [2, 1, 0], ...]

            torchvision.utils.save_image(img_gen,
                                         os.path.join(
                                             args.results_dir,
                                             f'{str(i).zfill(5)}.png'),
                                         normalize=True,
                                         range=(-1, 1))

    if args.mode == 'edit':
        img_orig = img_orig[:, [2, 1, 0], ...]
        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    torchvision.utils.save_image(final_result.detach().cpu(),
                                 os.path.join(args.results_dir,
                                              'final_result.png'),
                                 normalize=True,
                                 scale_each=True,
                                 range=(-1, 1))