def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] self.key1 = opt_get(opt, ['key1'], 'hq') self.key2 = opt_get(opt, ['key2'], 'lq') for_sr = opt_get( opt, ['for_sr'], False) # When set, color alterations and blurs are disabled. augmentations = [ \ augs.RandomHorizontalFlip(), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] if not for_sr: augmentations.extend([ RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) ]) if opt['normalize']: # The paper calls for normalization. Most datasets/models in this repo don't use this. # Recommend setting true if you want to train exactly like the paper. augmentations.append( augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) self.aug = nn.Sequential(*augmentations)
def __init__(self, opt): DATASET_MAP = { "mnist": datasets.MNIST, "fmnist": datasets.FashionMNIST, "cifar10": CIFAR10, "cifar100": CIFAR100, "imagenet": datasets.ImageNet, "imagefolder": datasets.ImageFolder } normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if opt_get(opt, ['random_crop'], False): transforms = [ T.RandomResizedCrop(opt['image_size']), T.RandomHorizontalFlip(), T.ToTensor(), normalize, ] else: transforms = [ T.Resize(opt['image_size']), T.CenterCrop(opt['image_size']), T.RandomHorizontalFlip(), T.ToTensor(), normalize, ] transforms = T.Compose(transforms) self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs']) self.len = opt_get(opt, ['fixed_len'], len(self.dataset)) self.offset = opt_get(opt, ['offset'], 0)
def forward(self, x, get_steps=False): fea = self.conv_first(x) block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] block_results = {} for idx, m in enumerate(self.RRDB_trunk.children()): fea = m(fea) for b in block_idxs: if b == idx: block_results["block_{}".format(idx)] = fea trunk = self.trunk_conv(fea) last_lr_fea = fea + trunk fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up2) fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up4) fea_up8 = None fea_up16 = None fea_up32 = None if self.scale >= 8: fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up8) if self.scale >= 16: fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up16) if self.scale >= 32: fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up32) out = self.conv_last(self.lrelu(self.HRconv(fea))) results = {'last_lr_fea': last_lr_fea, 'fea_up1': last_lr_fea, 'fea_up2': fea_up2, 'fea_up4': fea_up4, 'fea_up8': fea_up8, 'fea_up16': fea_up16, 'fea_up32': fea_up32, 'out': out} fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False if fea_up0_en: results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False if fea_upn1_en: results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) if get_steps: for k, v in block_results.items(): results[k] = v return results else: return out
def rrdbPreprocessing(self, lr): rrdbResults = self.RRDB(lr, get_steps=True) block_idxs = opt_get( self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] if len(block_idxs) > 0: concat = torch.cat( [rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] if 'fea_up0' in rrdbResults.keys(): keys.append('fea_up0') if 'fea_up-1' in rrdbResults.keys(): keys.append('fea_up-1') if 'fea_up-2' in rrdbResults.keys(): keys.append('fea_up-2') if self.opt['scale'] >= 8: keys.append('fea_up8') if self.opt['scale'] == 16: keys.append('fea_up16') for k in keys: h = rrdbResults[k].shape[2] w = rrdbResults[k].shape[3] rrdbResults[k] = torch.cat( [rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) return rrdbResults
def __init__(self, in_channels, opt): super().__init__() self.need_features = True self.in_channels = in_channels self.in_channels_rrdb = 320 self.kernel_hidden = 1 self.affine_eps = 0.0001 self.n_hidden_layers = 1 hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) self.hidden_channels = 64 if hidden_channels is None else hidden_channels self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) self.channels_for_nn = self.in_channels // 2 self.channels_for_co = self.in_channels - self.channels_for_nn if self.channels_for_nn is None: self.channels_for_nn = self.in_channels // 2 self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb, out_channels=self.channels_for_co * 2, hidden_channels=self.hidden_channels, kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) self.fFeatures = self.F(in_channels=self.in_channels_rrdb, out_channels=self.in_channels * 2, hidden_channels=self.hidden_channels, kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers)
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None): super(SRFlowNet, self).__init__() self.opt = opt self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ None else opt_get(opt, ['datasets', 'train', 'quant']) self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 self.RRDB_training = True # Default is true train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) set_RRDB_to_train = False if set_RRDB_to_train: self.set_rrdb_training(True) self.hr_size = opt_get(opt, ['datasets', 'train', 'GT_size']) self.flowUpsamplerNet = \ FlowUpsamplerNet((self.hr_size, self.hr_size, in_nc), hidden_channels, K, flow_coupling=opt['network_G']['flow']['coupling'], opt=opt) self.i = 0
def get_z(self, heat, seed=None, batch_size=1, lr_shape=None, y_label=None): if y_label is None: pass if seed: torch.manual_seed(seed) if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): C = self.netG.module.flowUpsamplerNet.C H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH) W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW) size = (batch_size, C, H, W) if heat == 0: z = torch.zeros(size) else: z = torch.normal(mean=0, std=heat, size=size) else: L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 fac = 2**(L - 3) z_size = int(self.lr_size // (2**(L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z.to(self.device)
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'): if seed: torch.manual_seed(seed) if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']): C = self.flowUpsamplerNet.C H = int(self.flow_scale * lr_shape[0] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale)) W = int(self.flow_scale * lr_shape[1] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale)) size = (batch_size, C, H, W) if heat == 0: z = torch.zeros(size) else: z = torch.normal(mean=0, std=heat, size=size) else: L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3 fac = 2**(L - 3) z_size = int(self.lr_size // (2**(L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z.to(device)
def forward_pass(model, data, output_dir, opt): alteration_suffix = util.opt_get(opt, ['name'], '') denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1])) with torch.no_grad(): model.feed_data(data, 0, need_GT=need_GT) model.test() visuals = model.get_current_visuals(need_GT)['rlt'].cpu() visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0]) fea_loss = 0 psnr_loss = 0 for i in range(visuals.shape[0]): img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i] img_name = osp.splitext(osp.basename(img_path))[0] sr_img = util.tensor2img(visuals[i]) # uint8 # save images suffix = alteration_suffix if suffix: save_img_path = osp.join(output_dir, img_name + suffix + '.png') else: save_img_path = osp.join(output_dir, img_name + '.png') if need_GT: psnr_sr = util.tensor2img(visuals[i]) psnr_gt = util.tensor2img(data['hq'][i]) psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt) util.save_img(sr_img, save_img_path) return fea_loss, psnr_loss
def __init__(self, opt, path): self.path = path.path self.tiles, _ = util.get_image_paths('img', self.path) self.need_metadata = opt_get(opt, ['strict'], False) or opt_get( opt, ['needs_metadata'], False) self.need_ref = opt_get(opt, ['need_ref'], False) if 'ignore_first' in opt.keys(): self.tiles = self.tiles[opt['ignore_first']:]
def __init__(self, in_channels, opt): super().__init__() self.need_features = True self.in_channels = in_channels self.in_channels_rrdb = 320 self.kernel_hidden = 3 self.affine_eps = 0.0001 self.model_name = opt_get(opt, ['model']) if self.model_name == "SRFlow-DA-D": self.n_hidden_layers = 36 else: self.n_hidden_layers = 4 hidden_channels = opt_get(opt, [ 'network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels' ]) self.hidden_channels = 64 if hidden_channels is None else hidden_channels self.affine_eps = opt_get( opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) self.channels_for_nn = self.in_channels // 2 self.channels_for_co = self.in_channels - self.channels_for_nn if self.channels_for_nn is None: self.channels_for_nn = self.in_channels // 2 if self.model_name == "SRFlow-DA" or self.model_name == "SRFlow-DA-S": self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb, out_channels=self.channels_for_co * 2, hidden_channels=self.hidden_channels, kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) self.fFeatures = self.F(in_channels=self.in_channels_rrdb, out_channels=self.in_channels * 2, hidden_channels=self.hidden_channels, kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) else: self.fAffine1, self.fAffine2, self.fAffine3 = self.F( in_channels=self.channels_for_nn + self.in_channels_rrdb, out_channels=self.channels_for_co * 2, hidden_channels=self.hidden_channels, kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) self.fFeatures1, self.fFeatures2, self.fFeatures3 = self.F( in_channels=self.in_channels_rrdb, out_channels=self.in_channels * 2, hidden_channels=self.hidden_channels, kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers)
def fix_image(img): if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False): img = img.unsqueeze(dim=1) # Normalize so spectrogram is easier to view. img = (img - img.mean()) / img.std() if img.shape[1] > 3: img = img[:, :3, :, :] if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False): img = (img + 1) / 2 if opt_get(self.opt, ['logger', 'reverse_imagenet_norm'], False): img = denormalize(img) return img
def optimize_parameters(self, step): train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) # if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \ # and not self.netG.module.RRDB_training: if train_RRDB_delay is not None and step > int( train_RRDB_delay * self.opt['train']['niter']): if self.netG.module.set_rrdb_training(True): self.add_optimizer_and_scheduler_RRDB(self.opt['train']) # if step % 100 == 0: print("set RRDB trainable") # self.print_rrdb_state() # add GT noise add_gt_noise = opt_get(self.opt, ['train', 'add_gt_noise'], True) self.netG.train() self.log_dict = OrderedDict() self.optimizer_G.zero_grad() losses = {} weight_fl = opt_get(self.opt, ['train', 'weight_fl']) weight_fl = 1 if weight_fl is None else weight_fl if weight_fl > 0: z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False, add_gt_noise=add_gt_noise) nll_loss = torch.mean(nll) losses['nll_loss'] = nll_loss * weight_fl weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 if weight_l1 > 0: z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) l1_loss = (sr - self.real_H).abs().mean() losses['l1_loss'] = l1_loss * weight_l1 total_loss = sum(losses.values()) total_loss.backward() self.optimizer_G.step() mean = total_loss.item() return mean
def optimize_parameters(self, step): train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \ and not self.netG.module.RRDB_training: if self.netG.module.set_rrdb_training(True): self.add_optimizer_and_scheduler_RRDB(self.opt['train']) # self.print_rrdb_state() self.netG.train() self.log_dict = OrderedDict() self.optimizer_G.zero_grad() losses = {} weight_fl = opt_get(self.opt, ['train', 'weight_fl']) weight_fl = 1 if weight_fl is None else weight_fl if weight_fl > 0: #print('self.var_L: ', self.var_L, self.var_L.shape) #print('self.real_H: ', self.real_H, self.real_H.shape) z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) nll_loss = torch.mean(nll) losses['nll_loss'] = nll_loss * weight_fl #print('nll_loss: ', nll_loss) weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 if weight_l1 > 0: z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) l1_loss = (sr - self.real_H).abs().mean() losses['l1_loss'] = l1_loss * weight_l1 #print('l1_loss: ', l1_loss) total_loss = sum(losses.values()) #print('total_loss: ', total_loss) # total_loss: tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) # ERROR: RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 11) total_loss.backward() self.optimizer_G.step() mean = total_loss.item() return mean
def register_stylegan2_discriminator(opt_net, opt): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, do_checkpointing=opt_get( opt_net, ['do_checkpointing'], False), quantize=opt_get(opt_net, ['quantize'], False)) return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
def __init__(self, model, opt_eval, env): super().__init__(model, opt_eval, env, uses_all_ddp=False) self.batches_per_eval = opt_eval['batches_per_eval'] self.batch_sz = opt_eval['batch_size'] self.im_sz = opt_eval['image_size'] self.fid_real_samples = opt_eval['real_fid_path'] self.gen_output_index = opt_eval[ 'gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise') self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input. self.image_norm_range = tuple( opt_get(env['opt'], ['image_normalization_range'], [0, 1]))
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None): if lr_enc is None: lr_enc = self.rrdbPreprocessing(lr) logdet = torch.zeros_like(gt[:, 0, 0, 0]) pixels = thops.pixels(gt) z = gt if add_gt_noise: # Setup noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) if noiseQuant: z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) logdet = logdet + float(-np.log(self.quant) * pixels) # Encode epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, y_onehot=y_onehot) objective = logdet.clone() if isinstance(epses, (list, tuple)): z = epses[-1] else: z = epses objective = objective + flow.GaussianDiag.logp(None, None, z) nll = (-objective) / float(np.log(2.) * pixels) if isinstance(epses, list): return epses, nll, logdet return z, nll, logdet
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None): z = epses.pop() if isinstance(epses, list) else z fl_fea = z # debug.imwrite("fl_fea", fl_fea) bypasses = {} level_conditionals = {} if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: for level in range(self.L + 1): level_conditionals[level] = rrdbResults[self.levelToName[level]] for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): size = shape[2] level = int(np.log(self.hr_size / size) / np.log(2)) # size = fl_fea.shape[2] # level = int(np.log(160 / size) / np.log(2)) if isinstance(layer, Split2d): fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, rrdbResults[self.levelToName[level]], logdet=logdet, y_onehot=y_onehot) elif isinstance(layer, FlowStep): fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level]) else: fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True) # print(rrdbResults['fea_up1'].shape) sr = fl_fea # filters = torch.randn(1,320,3,3, requires_grad=True).cuda() # sr = fl_fea + F.conv2d(rrdbResults['fea_up1'], filters, padding=1) assert sr.shape[1] == 1 return sr, logdet
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None): phase = dataset_opt['phase'] if phase == 'train': if opt_get(opt, ['dist'], False): world_size = torch.distributed.get_world_size() num_workers = dataset_opt['n_workers'] assert dataset_opt['batch_size'] % world_size == 0 batch_size = dataset_opt['batch_size'] // world_size shuffle = False else: num_workers = dataset_opt['n_workers'] batch_size = dataset_opt['batch_size'] shuffle = True return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=True, collate_fn=collate_fn) else: batch_size = dataset_opt['batch_size'] or 1 return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn)
def register_u_resnet50_2(opt_net, opt): model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) if opt_get(opt_net, ['use_pretrained_base'], False): state_dict = load_state_dict_from_url( 'https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) model.load_state_dict(state_dict, strict=False) return model
def arch_split(self, H, W, L, levels, opt, opt_get): correct_splits = opt_get( opt, ['network_G', 'flow', 'split', 'correct_splits'], False) correction = 0 if correct_splits else 1 if opt_get(opt, ['network_G', 'flow', 'split', 'enable' ]) and L < levels - correction: logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 consume_ratio = opt_get( opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 position_name = get_position_name(H, self.opt['scale']) position = position_name if opt_get( opt, ['network_G', 'flow', 'split', 'conditional']) else None cond_channels = opt_get( opt, ['network_G', 'flow', 'split', 'cond_channels']) cond_channels = 0 if cond_channels is None else cond_channels t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') if t == 'Split2d': split = models.modules.Split.Split2d( num_channels=self.C, logs_eps=logs_eps, position=position, cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) self.layers.append(split) self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.C = split.num_channels_pass
def register_styledsr_discriminator(opt_net, opt): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = StyleSrDiscriminator( image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), quantize=opt_get(opt_net, ['quantize'], False), mlp=opt_get(opt_net, ['mlp_head'], True), transfer_mode=opt_get(opt_net, ['transfer_mode'], False)) if 'use_partial_pretrained' in opt_net.keys(): disc.configure_partial_training( opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until']) return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
def load_model(conf_path): opt = option.parse(conf_path, is_train=False) opt['gpu_ids'] = None opt = option.dict_to_nonedict(opt) model = create_model(opt) model_path = opt_get(opt, ['model_path'], None) model.load_network(load_path=model_path, network=model.netG) return model, opt
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): super().__init__() self.num_channels_consume = int(round(num_channels * consume_ratio)) self.num_channels_pass = num_channels - self.num_channels_consume self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels, out_channels=self.num_channels_consume * 2) self.logs_eps = logs_eps self.position = position self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
def __init__(self, opt, env): super().__init__(opt, env) self.real = opt['real'] self.fake = opt['fake'] self.discriminator = opt['discriminator'] self.for_gen = opt['gen_loss'] self.gp_frequency = opt['gradient_penalty_frequency'] self.noise = opt['noise'] if 'noise' in opt.keys() else 0 self.logistic = opt_get( opt, ['logistic'], False ) # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used.
def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) augmentations = [ RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) ] self.aug = nn.Sequential(*augmentations) self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
def __init__(self, opt, env): super().__init__(opt, env) use_ddim = opt_get(opt, ['use_ddim'], False) self.generator = opt['generator'] self.output_batch_size = opt['output_batch_size'] self.output_scale_factor = opt['output_scale_factor'] self.undo_n1_to_1 = opt_get( opt, ['undo_n1_to_1'], False ) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1] opt['diffusion_args']['betas'] = get_named_beta_schedule( **opt['beta_schedule']) if use_ddim: spacing = "ddim" + str(opt['respaced_timestep_spacing']) else: spacing = [ opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps']) ] opt['diffusion_args']['use_timesteps'] = space_timesteps( opt['beta_schedule']['num_diffusion_timesteps'], spacing) self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop self.model_input_keys = opt_get(opt, ['model_input_keys'], []) self.use_ema_model = opt_get(opt, ['use_ema'], False) self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): if seed: torch.manual_seed(seed) if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): C = self.netG.module.flowUpsamplerNet.C # H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH) # W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW) scale = opt_get(self.opt, ['network_G', 'flow', 'L']) H, W = lr_shape[2] // (2**int(scale)), lr_shape[3] // (2** int(scale)) z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros( (batch_size, C, H, W)) else: L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 fac = 2**(L - 3) z_size = int(self.lr_size // (2**(L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z
def __init__(self, opt, step): super(SRFlowModel, self).__init__(opt) self.opt = opt self.heats = opt['val']['heats'] self.n_sample = opt['val']['n_sample'] self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size']) self.hr_size = 160 if self.hr_size is None else self.hr_size self.lr_size = self.hr_size // opt['scale'] if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_Flow(opt, step).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() if opt_get(opt, ['path', 'resume_state'], 1) is not None: self.load() else: print( "WARNING: skipping initial loading, due to resume_state None") if self.is_train: self.netG.train() self.init_optimizer_and_scheduler(train_opt) self.log_dict = OrderedDict()
def __init__(self, opt, env): super().__init__(opt, env) self.generator = opt['generator'] self.output_variational_bounds_key = opt['out_key_vb_loss'] self.output_x_start_key = opt['out_key_x_start'] opt['diffusion_args']['betas'] = get_named_beta_schedule( **opt['beta_schedule']) opt['diffusion_args']['use_timesteps'] = space_timesteps( opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.schedule_sampler = create_named_schedule_sampler( opt['sampler_type'], self.diffusion) self.model_input_keys = opt_get(opt, ['model_input_keys'], [])