def setup_class(cls): project_dir = os.path.abspath(os.path.join(__file__, '../../..')) pix2pix_config = mmcv.Config.fromfile( os.path.join( project_dir, 'configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py')) cls.pix2pix = init_model(pix2pix_config, checkpoint=None, device='cpu') cyclegan_config = mmcv.Config.fromfile( os.path.join( project_dir, 'configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py') ) cls.cyclegan = init_model( cyclegan_config, checkpoint=None, device='cpu') cls.img_path = os.path.join( os.path.dirname(__file__), '..', 'data/unpaired/testA/5.jpg')
def setup_class(cls): project_dir = os.path.abspath(os.path.join(__file__, '../../..')) config = mmcv.Config.fromfile( os.path.join( project_dir, 'configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py')) cls.model = init_model(config, checkpoint=None, device='cpu')
def setup_class(cls): project_dir = os.path.abspath(os.path.join(__file__, '../../..')) ddpm_config = mmcv.Config.fromfile( os.path.join( project_dir, 'configs/improved_ddpm/' 'ddpm_cosine_hybird_timestep-4k_drop0.3_' 'cifar10_32x32_b8x16_500k.py')) # change timesteps to speed up test process ddpm_config.model['num_timesteps'] = 10 cls.model = init_model(ddpm_config, checkpoint=None, device='cpu')
def main(): args = parse_args() model = init_model(args.config, checkpoint=args.checkpoint, device=args.device) if args.sample_cfg is None: args.sample_cfg = dict() if args.label is None and not args.sample_all_classes: label = None num_samples, nrow = args.samples_per_classes, args.nrow mmcv.print_log( '`label` is not passed, code would randomly sample ' f'`samples-per-classes` (={num_samples}) images.', 'mmgen') else: if args.sample_all_classes: mmcv.print_log( '`sample_all_classes` is set as True, `num-samples`, `label`, ' 'and `nrows` would be ignored.', 'mmgen') # get num_classes if hasattr(model, 'num_classes') and model.num_classes is not None: num_classes = model.num_classes else: raise AttributeError( 'Cannot get attribute `num_classes` from ' f'{type(model)}. Please check your config.', 'mmgen') # build label list meta_labels = [idx for idx in range(num_classes)] else: # get unique label meta_labels = list(set(args.label)) meta_labels.sort() # generate label to sample label = [] for idx in meta_labels: label += [idx] * args.samples_per_classes num_samples = len(label) nrow = args.samples_per_classes mmcv.print_log( 'Set `nrows` as number of samples for each class ' f'(={args.samples_per_classes}).', 'mmgen') results = sample_conditional_model(model, num_samples, args.num_batches, args.sample_model, label, **args.sample_cfg) results = (results[:, [2, 1, 0]] + 1.) / 2. # save images mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) utils.save_image(results, args.save_path, nrow=nrow, padding=args.padding)
def initialize(self, context): properties = context.system_properties self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(self.map_location + ':' + str(properties.get('gpu_id')) if torch.cuda. is_available() else self.map_location) self.manifest = context.manifest model_dir = properties.get('model_dir') serialized_file = self.manifest['model']['serializedFile'] checkpoint = os.path.join(model_dir, serialized_file) self.config_file = os.path.join(model_dir, 'config.py') self.model = init_model(self.config_file, checkpoint, self.device) self.initialized = True
def main(): args = parse_args() model = init_model(args.config, checkpoint=args.checkpoint, device=args.device) if args.sample_cfg is None: args.sample_cfg = dict() results = sample_img2img_model(model, args.image_path, args.target_domain, **args.sample_cfg) results = (results[:, [2, 1, 0]] + 1.) / 2. # save images mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) utils.save_image(results, args.save_path)
def main(): args = parse_args() model = init_model(args.config, checkpoint=args.checkpoint, device=args.device) if args.sample_cfg is None: args.sample_cfg = dict() suffix = osp.splitext(args.save_path)[-1] if suffix == '.gif': args.sample_cfg['save_intermedia'] = True results = sample_ddpm_model(model, args.num_samples, args.num_batches, args.sample_model, args.same_noise, **args.sample_cfg) # save images mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) if suffix == '.gif': # concentrate all output of each timestep results_timestep_list = [] for t in results.keys(): # make grid results_timestep = utils.make_grid(results[t], nrow=args.nrow, padding=args.padding) # unsqueeze at 0, because make grid output is size like [H', W', 3] results_timestep_list.append(results_timestep[None, ...]) # Concatenates to [n_timesteps, H', W', 3] results_timestep = torch.cat(results_timestep_list, dim=0) if not args.is_rgb: results_timestep = results_timestep[:, [2, 1, 0]] results_timestep = (results_timestep + 1.) / 2. create_gif(results_timestep, args.save_path, n_skip=args.n_skip) else: if not args.is_rgb: results = results[:, [2, 1, 0]] results = (results + 1.) / 2. utils.save_image(results, args.save_path, nrow=args.nrow, padding=args.padding)
def main(): args = parse_args() model = init_model(args.config, checkpoint=args.checkpoint, device=args.device) if args.sample_cfg is None: args.sample_cfg = dict() results = sample_uncoditional_model(model, args.num_samples, args.num_batches, args.sample_model, **args.sample_cfg) results = (results[:, [2, 1, 0]] + 1.) / 2. # save images mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) utils.save_image(results, args.save_path, nrow=args.nrow, padding=args.padding)
def main(): args = parse_args() # set cudnn_benchmark cfg = Config.fromfile(args.config) if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # set random seeds if args.seed is not None: print('set random seed to', args.seed) set_random_seed(args.seed, deterministic=args.deterministic) os.makedirs(args.results_dir, exist_ok=True) text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() model = init_model(args.config, args.checkpoint, device='cpu') g_ema = model.generator_ema g_ema.eval() if not args.use_cpu: g_ema = g_ema.cuda() mean_latent = g_ema.get_mean_latent() # if given proj_latent if args.proj_latent is not None: mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen') proj_file = torch.load(args.proj_latent) proj_n = len(proj_file) assert proj_n == 1 noise_batch = [] for img_path in proj_file: noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0)) latent_code_init = torch.cat(noise_batch, dim=0).cuda() elif args.mode == 'edit': latent_code_init_not_trunc = torch.randn(1, 512).cuda() with torch.no_grad(): results = g_ema([latent_code_init_not_trunc], return_latents=True, truncation=args.truncation, truncation_latent=mean_latent) latent_code_init = results['latent'] else: latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) with torch.no_grad(): img_orig = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False) latent = latent_code_init.detach().clone() latent.requires_grad = True clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size)) id_loss = FaceIdLoss( facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda')) optimizer = optim.Adam([latent], lr=args.lr) pbar = tqdm(range(args.step)) mmcv.print_log(f'Description: {args.description}') for i in pbar: t = i / args.step lr = get_lr(t, args.lr) optimizer.param_groups[0]['lr'] = lr img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False) img_gen = img_gen[:, [2, 1, 0], ...] # clip loss c_loss = clip_loss(image=img_gen, text=text_inputs) if args.id_lambda > 0: i_loss = id_loss(pred=img_gen, gt=img_orig)[0] else: i_loss = 0 if args.mode == 'edit': l2_loss = ((latent_code_init - latent)**2).sum() loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss else: loss = c_loss optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description((f'loss: {loss.item():.4f};')) if args.save_interval > 0 and (i % args.save_interval == 0): with torch.no_grad(): img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False) img_gen = img_gen[:, [2, 1, 0], ...] torchvision.utils.save_image(img_gen, os.path.join( args.results_dir, f'{str(i).zfill(5)}.png'), normalize=True, range=(-1, 1)) if args.mode == 'edit': img_orig = img_orig[:, [2, 1, 0], ...] final_result = torch.cat([img_orig, img_gen]) else: final_result = img_gen torchvision.utils.save_image(final_result.detach().cpu(), os.path.join(args.results_dir, 'final_result.png'), normalize=True, scale_each=True, range=(-1, 1))