Пример #1
0
    def __init__(self, opt):
        super().__init__()
        self.wrapped_dataset = create_dataset(opt['dataset'])
        self.cropped_img_size = opt['crop_size']
        self.key1 = opt_get(opt, ['key1'], 'hq')
        self.key2 = opt_get(opt, ['key2'], 'lq')
        for_sr = opt_get(
            opt, ['for_sr'],
            False)  # When set, color alterations and blurs are disabled.

        augmentations = [ \
            augs.RandomHorizontalFlip(),
            augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
        if not for_sr:
            augmentations.extend([
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
                augs.RandomGrayscale(p=0.2),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
            ])
        if opt['normalize']:
            # The paper calls for normalization. Most datasets/models in this repo don't use this.
            # Recommend setting true if you want to train exactly like the paper.
            augmentations.append(
                augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                               std=torch.tensor([0.229, 0.224, 0.225])))
        self.aug = nn.Sequential(*augmentations)
Пример #2
0
 def __init__(self, opt):
     DATASET_MAP = {
         "mnist": datasets.MNIST,
         "fmnist": datasets.FashionMNIST,
         "cifar10": CIFAR10,
         "cifar100": CIFAR100,
         "imagenet": datasets.ImageNet,
         "imagefolder": datasets.ImageFolder
     }
     normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
     if opt_get(opt, ['random_crop'], False):
         transforms = [
             T.RandomResizedCrop(opt['image_size']),
             T.RandomHorizontalFlip(),
             T.ToTensor(),
             normalize,
         ]
     else:
         transforms = [
             T.Resize(opt['image_size']),
             T.CenterCrop(opt['image_size']),
             T.RandomHorizontalFlip(),
             T.ToTensor(),
             normalize,
         ]
     transforms = T.Compose(transforms)
     self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs'])
     self.len = opt_get(opt, ['fixed_len'], len(self.dataset))
     self.offset = opt_get(opt, ['offset'], 0)
Пример #3
0
    def forward(self, x, get_steps=False):
        fea = self.conv_first(x)

        block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
        block_results = {}

        for idx, m in enumerate(self.RRDB_trunk.children()):
            fea = m(fea)
            for b in block_idxs:
                if b == idx:
                    block_results["block_{}".format(idx)] = fea

        trunk = self.trunk_conv(fea)

        last_lr_fea = fea + trunk

        fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
        fea = self.lrelu(fea_up2)

        fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
        fea = self.lrelu(fea_up4)

        fea_up8 = None
        fea_up16 = None
        fea_up32 = None

        if self.scale >= 8:
            fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
            fea = self.lrelu(fea_up8)
        if self.scale >= 16:
            fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
            fea = self.lrelu(fea_up16)
        if self.scale >= 32:
            fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
            fea = self.lrelu(fea_up32)

        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        results = {'last_lr_fea': last_lr_fea,
                   'fea_up1': last_lr_fea,
                   'fea_up2': fea_up2,
                   'fea_up4': fea_up4,
                   'fea_up8': fea_up8,
                   'fea_up16': fea_up16,
                   'fea_up32': fea_up32,
                   'out': out}

        fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
        if fea_up0_en:
            results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
        fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
        if fea_upn1_en:
            results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)

        if get_steps:
            for k, v in block_results.items():
                results[k] = v
            return results
        else:
            return out
Пример #4
0
    def rrdbPreprocessing(self, lr):
        rrdbResults = self.RRDB(lr, get_steps=True)
        block_idxs = opt_get(
            self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
        if len(block_idxs) > 0:
            concat = torch.cat(
                [rrdbResults["block_{}".format(idx)] for idx in block_idxs],
                dim=1)

            if opt_get(self.opt,
                       ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
                keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
                if 'fea_up0' in rrdbResults.keys():
                    keys.append('fea_up0')
                if 'fea_up-1' in rrdbResults.keys():
                    keys.append('fea_up-1')
                if 'fea_up-2' in rrdbResults.keys():
                    keys.append('fea_up-2')
                if self.opt['scale'] >= 8:
                    keys.append('fea_up8')
                if self.opt['scale'] == 16:
                    keys.append('fea_up16')
                for k in keys:
                    h = rrdbResults[k].shape[2]
                    w = rrdbResults[k].shape[3]
                    rrdbResults[k] = torch.cat(
                        [rrdbResults[k],
                         F.interpolate(concat, (h, w))], dim=1)
        return rrdbResults
    def __init__(self, in_channels, opt):
        super().__init__()
        self.need_features = True
        self.in_channels = in_channels
        self.in_channels_rrdb = 320
        self.kernel_hidden = 1
        self.affine_eps = 0.0001
        self.n_hidden_layers = 1
        hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
        self.hidden_channels = 64 if hidden_channels is None else hidden_channels

        self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'],  0.0001)

        self.channels_for_nn = self.in_channels // 2
        self.channels_for_co = self.in_channels - self.channels_for_nn

        if self.channels_for_nn is None:
            self.channels_for_nn = self.in_channels // 2

        self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
                              out_channels=self.channels_for_co * 2,
                              hidden_channels=self.hidden_channels,
                              kernel_hidden=self.kernel_hidden,
                              n_hidden_layers=self.n_hidden_layers)

        self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
                                out_channels=self.in_channels * 2,
                                hidden_channels=self.hidden_channels,
                                kernel_hidden=self.kernel_hidden,
                                n_hidden_layers=self.n_hidden_layers)
Пример #6
0
    def __init__(self,
                 in_nc,
                 out_nc,
                 nf,
                 nb,
                 gc=32,
                 scale=4,
                 K=None,
                 opt=None,
                 step=None):
        super(SRFlowNet, self).__init__()

        self.opt = opt
        self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
                            None else opt_get(opt, ['datasets', 'train', 'quant'])
        self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
        hidden_channels = opt_get(opt,
                                  ['network_G', 'flow', 'hidden_channels'])
        hidden_channels = hidden_channels or 64
        self.RRDB_training = True  # Default is true

        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        set_RRDB_to_train = False
        if set_RRDB_to_train:
            self.set_rrdb_training(True)
        self.hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'])
        self.flowUpsamplerNet = \
            FlowUpsamplerNet((self.hr_size, self.hr_size, in_nc), hidden_channels, K,
                             flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
        self.i = 0
Пример #7
0
    def get_z(self,
              heat,
              seed=None,
              batch_size=1,
              lr_shape=None,
              y_label=None):
        if y_label is None:
            pass
        if seed: torch.manual_seed(seed)
        if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
            C = self.netG.module.flowUpsamplerNet.C
            H = int(self.opt['scale'] * lr_shape[2] //
                    self.netG.module.flowUpsamplerNet.scaleH)
            W = int(self.opt['scale'] * lr_shape[3] //
                    self.netG.module.flowUpsamplerNet.scaleW)

            size = (batch_size, C, H, W)
            if heat == 0:
                z = torch.zeros(size)
            else:
                z = torch.normal(mean=0, std=heat, size=size)
        else:
            L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
            fac = 2**(L - 3)
            z_size = int(self.lr_size // (2**(L - 3)))
            z = torch.normal(mean=0,
                             std=heat,
                             size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                   z_size))
        return z.to(self.device)
Пример #8
0
    def get_random_z(self,
                     heat,
                     seed=None,
                     batch_size=1,
                     lr_shape=None,
                     device='cuda'):
        if seed: torch.manual_seed(seed)
        if opt_get(self.opt,
                   ['networks', 'generator', 'flow', 'split', 'enable']):
            C = self.flowUpsamplerNet.C
            H = int(self.flow_scale * lr_shape[0] //
                    (self.flowUpsamplerNet.scaleH * self.flow_scale /
                     self.RRDB.scale))
            W = int(self.flow_scale * lr_shape[1] //
                    (self.flowUpsamplerNet.scaleW * self.flow_scale /
                     self.RRDB.scale))

            size = (batch_size, C, H, W)
            if heat == 0:
                z = torch.zeros(size)
            else:
                z = torch.normal(mean=0, std=heat, size=size)
        else:
            L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3
            fac = 2**(L - 3)
            z_size = int(self.lr_size // (2**(L - 3)))
            z = torch.normal(mean=0,
                             std=heat,
                             size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                   z_size))
        return z.to(device)
Пример #9
0
def forward_pass(model, data, output_dir, opt):
    alteration_suffix = util.opt_get(opt, ['name'], '')
    denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
    with torch.no_grad():
        model.feed_data(data, 0, need_GT=need_GT)
        model.test()

    visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
    visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0])
    fea_loss = 0
    psnr_loss = 0
    for i in range(visuals.shape[0]):
        img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i]
        img_name = osp.splitext(osp.basename(img_path))[0]

        sr_img = util.tensor2img(visuals[i])  # uint8

        # save images
        suffix = alteration_suffix
        if suffix:
            save_img_path = osp.join(output_dir, img_name + suffix + '.png')
        else:
            save_img_path = osp.join(output_dir, img_name + '.png')

        if need_GT:
            psnr_sr = util.tensor2img(visuals[i])
            psnr_gt = util.tensor2img(data['hq'][i])
            psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)

        util.save_img(sr_img, save_img_path)
    return fea_loss, psnr_loss
Пример #10
0
 def __init__(self, opt, path):
     self.path = path.path
     self.tiles, _ = util.get_image_paths('img', self.path)
     self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(
         opt, ['needs_metadata'], False)
     self.need_ref = opt_get(opt, ['need_ref'], False)
     if 'ignore_first' in opt.keys():
         self.tiles = self.tiles[opt['ignore_first']:]
    def __init__(self, in_channels, opt):
        super().__init__()
        self.need_features = True
        self.in_channels = in_channels
        self.in_channels_rrdb = 320
        self.kernel_hidden = 3
        self.affine_eps = 0.0001
        self.model_name = opt_get(opt, ['model'])
        if self.model_name == "SRFlow-DA-D":
            self.n_hidden_layers = 36
        else:
            self.n_hidden_layers = 4
        hidden_channels = opt_get(opt, [
            'network_G', 'flow', 'CondAffineSeparatedAndCond',
            'hidden_channels'
        ])
        self.hidden_channels = 64 if hidden_channels is None else hidden_channels

        self.affine_eps = opt_get(
            opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'],
            0.0001)

        self.channels_for_nn = self.in_channels // 2
        self.channels_for_co = self.in_channels - self.channels_for_nn

        if self.channels_for_nn is None:
            self.channels_for_nn = self.in_channels // 2

        if self.model_name == "SRFlow-DA" or self.model_name == "SRFlow-DA-S":
            self.fAffine = self.F(in_channels=self.channels_for_nn +
                                  self.in_channels_rrdb,
                                  out_channels=self.channels_for_co * 2,
                                  hidden_channels=self.hidden_channels,
                                  kernel_hidden=self.kernel_hidden,
                                  n_hidden_layers=self.n_hidden_layers)

            self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
                                    out_channels=self.in_channels * 2,
                                    hidden_channels=self.hidden_channels,
                                    kernel_hidden=self.kernel_hidden,
                                    n_hidden_layers=self.n_hidden_layers)

        else:
            self.fAffine1, self.fAffine2, self.fAffine3 = self.F(
                in_channels=self.channels_for_nn + self.in_channels_rrdb,
                out_channels=self.channels_for_co * 2,
                hidden_channels=self.hidden_channels,
                kernel_hidden=self.kernel_hidden,
                n_hidden_layers=self.n_hidden_layers)

            self.fFeatures1, self.fFeatures2, self.fFeatures3 = self.F(
                in_channels=self.in_channels_rrdb,
                out_channels=self.in_channels * 2,
                hidden_channels=self.hidden_channels,
                kernel_hidden=self.kernel_hidden,
                n_hidden_layers=self.n_hidden_layers)
Пример #12
0
 def fix_image(img):
     if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
         img = img.unsqueeze(dim=1)
         # Normalize so spectrogram is easier to view.
         img = (img - img.mean()) / img.std()
     if img.shape[1] > 3:
         img = img[:, :3, :, :]
     if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
         img = (img + 1) / 2
     if opt_get(self.opt, ['logger', 'reverse_imagenet_norm'], False):
         img = denormalize(img)
     return img
Пример #13
0
    def optimize_parameters(self, step):

        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        # if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
        #         and not self.netG.module.RRDB_training:
        if train_RRDB_delay is not None and step > int(
                train_RRDB_delay * self.opt['train']['niter']):
            if self.netG.module.set_rrdb_training(True):
                self.add_optimizer_and_scheduler_RRDB(self.opt['train'])
                # if step % 100 == 0:
                print("set RRDB trainable")

        # self.print_rrdb_state()

        # add GT noise
        add_gt_noise = opt_get(self.opt, ['train', 'add_gt_noise'], True)

        self.netG.train()
        self.log_dict = OrderedDict()
        self.optimizer_G.zero_grad()

        losses = {}
        weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
        weight_fl = 1 if weight_fl is None else weight_fl
        if weight_fl > 0:
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False,
                                         add_gt_noise=add_gt_noise)
            nll_loss = torch.mean(nll)
            losses['nll_loss'] = nll_loss * weight_fl

        weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
        if weight_l1 > 0:
            z = self.get_z(heat=0,
                           seed=None,
                           batch_size=self.var_L.shape[0],
                           lr_shape=self.var_L.shape)
            sr, logdet = self.netG(lr=self.var_L,
                                   z=z,
                                   eps_std=0,
                                   reverse=True,
                                   reverse_with_grad=True)
            l1_loss = (sr - self.real_H).abs().mean()
            losses['l1_loss'] = l1_loss * weight_l1

        total_loss = sum(losses.values())
        total_loss.backward()
        self.optimizer_G.step()

        mean = total_loss.item()
        return mean
Пример #14
0
    def optimize_parameters(self, step):

        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
                and not self.netG.module.RRDB_training:
            if self.netG.module.set_rrdb_training(True):
                self.add_optimizer_and_scheduler_RRDB(self.opt['train'])

        # self.print_rrdb_state()

        self.netG.train()
        self.log_dict = OrderedDict()
        self.optimizer_G.zero_grad()

        losses = {}
        weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
        weight_fl = 1 if weight_fl is None else weight_fl
        if weight_fl > 0:
            #print('self.var_L: ', self.var_L, self.var_L.shape)
            #print('self.real_H: ', self.real_H, self.real_H.shape)
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False)
            nll_loss = torch.mean(nll)
            losses['nll_loss'] = nll_loss * weight_fl
            #print('nll_loss: ', nll_loss)

        weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
        if weight_l1 > 0:
            z = self.get_z(heat=0,
                           seed=None,
                           batch_size=self.var_L.shape[0],
                           lr_shape=self.var_L.shape)
            sr, logdet = self.netG(lr=self.var_L,
                                   z=z,
                                   eps_std=0,
                                   reverse=True,
                                   reverse_with_grad=True)
            l1_loss = (sr - self.real_H).abs().mean()
            losses['l1_loss'] = l1_loss * weight_l1
            #print('l1_loss: ', l1_loss)

        total_loss = sum(losses.values())
        #print('total_loss: ', total_loss)
        # total_loss:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
        # ERROR: RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 11)

        total_loss.backward()
        self.optimizer_G.step()

        mean = total_loss.item()
        return mean
Пример #15
0
def register_stylegan2_discriminator(opt_net, opt):
    attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
    disc = StyleGan2Discriminator(image_size=opt_net['image_size'],
                                  input_filters=opt_net['in_nc'],
                                  attn_layers=attn,
                                  do_checkpointing=opt_get(
                                      opt_net, ['do_checkpointing'], False),
                                  quantize=opt_get(opt_net, ['quantize'],
                                                   False))
    return StyleGan2Augmentor(disc,
                              opt_net['image_size'],
                              types=opt_net['augmentation_types'],
                              prob=opt_net['augmentation_probability'])
Пример #16
0
 def __init__(self, model, opt_eval, env):
     super().__init__(model, opt_eval, env, uses_all_ddp=False)
     self.batches_per_eval = opt_eval['batches_per_eval']
     self.batch_sz = opt_eval['batch_size']
     self.im_sz = opt_eval['image_size']
     self.fid_real_samples = opt_eval['real_fid_path']
     self.gen_output_index = opt_eval[
         'gen_index'] if 'gen_index' in opt_eval.keys() else 0
     self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise')
     self.latent_dim = opt_get(opt_eval, ['latent_dim'],
                               512)  # Not needed if using 'imgnoise' input.
     self.image_norm_range = tuple(
         opt_get(env['opt'], ['image_normalization_range'], [0, 1]))
Пример #17
0
    def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
        if lr_enc is None:
            lr_enc = self.rrdbPreprocessing(lr)

        logdet = torch.zeros_like(gt[:, 0, 0, 0])
        pixels = thops.pixels(gt)

        z = gt

        if add_gt_noise:
            # Setup
            noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
            if noiseQuant:
                z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
            logdet = logdet + float(-np.log(self.quant) * pixels)

        # Encode
        epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
                                              y_onehot=y_onehot)

        objective = logdet.clone()

        if isinstance(epses, (list, tuple)):
            z = epses[-1]
        else:
            z = epses

        objective = objective + flow.GaussianDiag.logp(None, None, z)

        nll = (-objective) / float(np.log(2.) * pixels)

        if isinstance(epses, list):
            return epses, nll, logdet
        return z, nll, logdet
Пример #18
0
    def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
        z = epses.pop() if isinstance(epses, list) else z

        fl_fea = z
        # debug.imwrite("fl_fea", fl_fea)
        bypasses = {}
        level_conditionals = {}
        if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
            for level in range(self.L + 1):
                level_conditionals[level] = rrdbResults[self.levelToName[level]]

        for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
            size = shape[2]
            level = int(np.log(self.hr_size / size) / np.log(2))
            # size = fl_fea.shape[2]
            # level = int(np.log(160 / size) / np.log(2))

            if isinstance(layer, Split2d):
                fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
                                                              rrdbResults[self.levelToName[level]], logdet=logdet,
                                                              y_onehot=y_onehot)
            elif isinstance(layer, FlowStep):
                fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
            else:
                fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
            
        # print(rrdbResults['fea_up1'].shape)
        sr = fl_fea 
        # filters = torch.randn(1,320,3,3, requires_grad=True).cuda()
        # sr = fl_fea + F.conv2d(rrdbResults['fea_up1'], filters, padding=1)

        assert sr.shape[1] == 1
        return sr, logdet
Пример #19
0
def create_dataloader(dataset,
                      dataset_opt,
                      opt=None,
                      sampler=None,
                      collate_fn=None):
    phase = dataset_opt['phase']
    if phase == 'train':
        if opt_get(opt, ['dist'], False):
            world_size = torch.distributed.get_world_size()
            num_workers = dataset_opt['n_workers']
            assert dataset_opt['batch_size'] % world_size == 0
            batch_size = dataset_opt['batch_size'] // world_size
            shuffle = False
        else:
            num_workers = dataset_opt['n_workers']
            batch_size = dataset_opt['batch_size']
            shuffle = True
        return torch.utils.data.DataLoader(dataset,
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           num_workers=num_workers,
                                           sampler=sampler,
                                           drop_last=True,
                                           pin_memory=True,
                                           collate_fn=collate_fn)
    else:
        batch_size = dataset_opt['batch_size'] or 1
        return torch.utils.data.DataLoader(dataset,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           num_workers=0,
                                           pin_memory=True,
                                           collate_fn=collate_fn)
Пример #20
0
def register_u_resnet50_2(opt_net, opt):
    model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
    if opt_get(opt_net, ['use_pretrained_base'], False):
        state_dict = load_state_dict_from_url(
            'https://download.pytorch.org/models/resnet50-19c8e357.pth',
            progress=True)
        model.load_state_dict(state_dict, strict=False)
    return model
Пример #21
0
    def arch_split(self, H, W, L, levels, opt, opt_get):
        correct_splits = opt_get(
            opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
        correction = 0 if correct_splits else 1
        if opt_get(opt, ['network_G', 'flow', 'split', 'enable'
                         ]) and L < levels - correction:
            logs_eps = opt_get(opt,
                               ['network_G', 'flow', 'split', 'logs_eps']) or 0
            consume_ratio = opt_get(
                opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
            position_name = get_position_name(H, self.opt['scale'])
            position = position_name if opt_get(
                opt, ['network_G', 'flow', 'split', 'conditional']) else None
            cond_channels = opt_get(
                opt, ['network_G', 'flow', 'split', 'cond_channels'])
            cond_channels = 0 if cond_channels is None else cond_channels

            t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')

            if t == 'Split2d':
                split = models.modules.Split.Split2d(
                    num_channels=self.C,
                    logs_eps=logs_eps,
                    position=position,
                    cond_channels=cond_channels,
                    consume_ratio=consume_ratio,
                    opt=opt)
            self.layers.append(split)
            self.output_shapes.append([-1, split.num_channels_pass, H, W])
            self.C = split.num_channels_pass
Пример #22
0
def register_styledsr_discriminator(opt_net, opt):
    attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
    disc = StyleSrDiscriminator(
        image_size=opt_net['image_size'],
        input_filters=opt_net['in_nc'],
        attn_layers=attn,
        do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
        quantize=opt_get(opt_net, ['quantize'], False),
        mlp=opt_get(opt_net, ['mlp_head'], True),
        transfer_mode=opt_get(opt_net, ['transfer_mode'], False))
    if 'use_partial_pretrained' in opt_net.keys():
        disc.configure_partial_training(
            opt_net['bypass_blocks'], opt_net['partial_training_blocks'],
            opt_net['intermediate_blocks_frozen_until'])
    return DiscAugmentor(disc,
                         opt_net['image_size'],
                         types=opt_net['augmentation_types'],
                         prob=opt_net['augmentation_probability'])
Пример #23
0
def load_model(conf_path):
    opt = option.parse(conf_path, is_train=False)
    opt['gpu_ids'] = None
    opt = option.dict_to_nonedict(opt)
    model = create_model(opt)

    model_path = opt_get(opt, ['model_path'], None)
    model.load_network(load_path=model_path, network=model.netG)
    return model, opt
Пример #24
0
    def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
        super().__init__()

        self.num_channels_consume = int(round(num_channels * consume_ratio))
        self.num_channels_pass = num_channels - self.num_channels_consume

        self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
                                out_channels=self.num_channels_consume * 2)
        self.logs_eps = logs_eps
        self.position = position
        self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
Пример #25
0
 def __init__(self, opt, env):
     super().__init__(opt, env)
     self.real = opt['real']
     self.fake = opt['fake']
     self.discriminator = opt['discriminator']
     self.for_gen = opt['gen_loss']
     self.gp_frequency = opt['gradient_penalty_frequency']
     self.noise = opt['noise'] if 'noise' in opt.keys() else 0
     self.logistic = opt_get(
         opt, ['logistic'], False
     )  # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used.
Пример #26
0
 def __init__(self, opt):
     super().__init__()
     self.wrapped_dataset = create_dataset(opt['dataset'])
     augmentations = [
         RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
         augs.RandomGrayscale(p=0.2),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
     ]
     self.aug = nn.Sequential(*augmentations)
     self.rrc = RandomSharedRegionCrop(opt['latent_multiple'],
                                       opt_get(opt, ['jitter_range'], 0))
 def __init__(self, opt, env):
     super().__init__(opt, env)
     use_ddim = opt_get(opt, ['use_ddim'], False)
     self.generator = opt['generator']
     self.output_batch_size = opt['output_batch_size']
     self.output_scale_factor = opt['output_scale_factor']
     self.undo_n1_to_1 = opt_get(
         opt, ['undo_n1_to_1'], False
     )  # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
     opt['diffusion_args']['betas'] = get_named_beta_schedule(
         **opt['beta_schedule'])
     if use_ddim:
         spacing = "ddim" + str(opt['respaced_timestep_spacing'])
     else:
         spacing = [
             opt_get(opt, ['respaced_timestep_spacing'],
                     opt['beta_schedule']['num_diffusion_timesteps'])
         ]
     opt['diffusion_args']['use_timesteps'] = space_timesteps(
         opt['beta_schedule']['num_diffusion_timesteps'], spacing)
     self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
     self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop
     self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
     self.use_ema_model = opt_get(opt, ['use_ema'], False)
     self.noise_style = opt_get(opt, ['noise_type'],
                                'random')  # 'zero', 'fixed' or 'random'
Пример #28
0
 def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
     if seed: torch.manual_seed(seed)
     if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
         C = self.netG.module.flowUpsamplerNet.C
         # H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH)
         # W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW)
         scale = opt_get(self.opt, ['network_G', 'flow', 'L'])
         H, W = lr_shape[2] // (2**int(scale)), lr_shape[3] // (2**
                                                                int(scale))
         z = torch.normal(mean=0, std=heat,
                          size=(batch_size, C, H,
                                W)) if heat > 0 else torch.zeros(
                                    (batch_size, C, H, W))
     else:
         L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
         fac = 2**(L - 3)
         z_size = int(self.lr_size // (2**(L - 3)))
         z = torch.normal(mean=0,
                          std=heat,
                          size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                z_size))
     return z
Пример #29
0
    def __init__(self, opt, step):
        super(SRFlowModel, self).__init__(opt)
        self.opt = opt

        self.heats = opt['val']['heats']
        self.n_sample = opt['val']['n_sample']
        self.hr_size = opt_get(opt,
                               ['datasets', 'train', 'center_crop_hr_size'])
        self.hr_size = 160 if self.hr_size is None else self.hr_size
        self.lr_size = self.hr_size // opt['scale']

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_Flow(opt, step).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()

        if opt_get(opt, ['path', 'resume_state'], 1) is not None:
            self.load()
        else:
            print(
                "WARNING: skipping initial loading, due to resume_state None")

        if self.is_train:
            self.netG.train()

            self.init_optimizer_and_scheduler(train_opt)
            self.log_dict = OrderedDict()
 def __init__(self, opt, env):
     super().__init__(opt, env)
     self.generator = opt['generator']
     self.output_variational_bounds_key = opt['out_key_vb_loss']
     self.output_x_start_key = opt['out_key_x_start']
     opt['diffusion_args']['betas'] = get_named_beta_schedule(
         **opt['beta_schedule'])
     opt['diffusion_args']['use_timesteps'] = space_timesteps(
         opt['beta_schedule']['num_diffusion_timesteps'],
         [opt['beta_schedule']['num_diffusion_timesteps']])
     self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
     self.schedule_sampler = create_named_schedule_sampler(
         opt['sampler_type'], self.diffusion)
     self.model_input_keys = opt_get(opt, ['model_input_keys'], [])