コード例 #1
0
    def __init__(self, config_en, is_testing=False):
        super(EnhancementBlueprint, self).__init__()

        self.net = EnhancementNetwork(config_en)
        self.net = self.net.to(pe.DEVICE)
        self.config_en = config_en

        self.clf = None
        self.qstrategy = QStrategy.MIN

        self.losses = EnhancementLosses(config_en, is_testing)

        global_config.assert_only_one('cinorm', 'cin_eb', 'cgdn')

        self.cin_style = None
        if global_config.get('cinorm', False):
            self.cin_style = 'cinorm'
        elif global_config.get('cin_eb', False):
            self.cin_style = 'evenbins'
            self.cin_q = cin_bins.Quantizer(global_config['cin_eb'])
        elif global_config.get('cgdn', False):
            self.cin_style = 'cgdn'
            self.cin_q = cin_bins.Quantizer(global_config['cgdn'])
        print('EB: self.cin_style =', self.cin_style)

        self.padding_fac = self.get_padding_fac()
        print('***' * 10)
        print('*** Padding by a factor', self.padding_fac)
        print('***' * 10)
コード例 #2
0
    def __init__(self, config_ms, scale, C=3, atrous_rates_str='1,2,4'):
        raise NotImplementedError

        super(AtrousProbabilityClassifier, self).__init__()

        K = config_ms.prob.K
        Kp = non_shared_get_Kp(K, C)

        self.atrous = StackedAtrousConvs(atrous_rates_str, config_ms.Cf, Kp,
                                         kernel_size=config_ms.kernel_size,
                                         name=str(scale))
        self._repr = f'C={C}; K={K}; Kp={Kp}; rates={atrous_rates_str}'

        if global_config.get('usefinal1', False):
            print('*** Using Final')
            self.final = Final(config_ms, C, 1)
        elif global_config.get('usefinal', False):
            print('*** Using Final')
            self.final = Final(config_ms, C, 3)
        else:
            self.final = lambda x: x

        if global_config.get('initbias', False):
            K = config_ms.prob.K
            self.atrous.lin.bias = nn.Parameter(_init_bias(self.atrous.lin.bias.detach(), C, K))
            print(self.atrous.lin.bias.requires_grad)
            print('Updated bias:', self.atrous.lin.bias.reshape(-1, C, K))

            if global_config.get('usefinal1', False):
                self.final.scales_conv.bias = nn.Parameter(_init_bias(self.final.scales_conv.bias.detach(), C, K))
コード例 #3
0
 def read_evenly_spaced_bins(config_dl):
     flag = global_config.get('cin_eb', None) or global_config.get(
         'cgdn', None)
     if flag and flag.startswith('auto'):
         flag_name = 'cin_eb' if global_config.get('cin_eb',
                                                   None) else 'cgdn'
         nb = int(flag.replace('auto', ''))
         # creates if needed
         pkl_p = cin_bins.make_bin_pkl(
             config_dl.imgs_dir_train['compressed'], nb)
         print(f'Setting {flag_name} = {pkl_p}')
         global_config[flag_name] = pkl_p
コード例 #4
0
    def __init__(self,
                 optims,
                 initial,
                 decay_fac,
                 decay_interval_itr=None,
                 decay_interval_epoch=None,
                 epoch_len=None,
                 warm_restart=None,
                 warm_restart_schedule=None):
        super(ExponentialDecayLRSchedule, self).__init__(optims)
        assert_exc((decay_interval_itr is not None) ^
                   (decay_interval_epoch is not None),
                   'Need either iter or epoch')
        if decay_interval_epoch:
            assert epoch_len is not None
            decay_interval_itr = int(decay_interval_epoch * epoch_len)
            if warm_restart:
                warm_restart = int(warm_restart * epoch_len)
        self.initial = initial
        self.decay_fac = decay_fac
        self.decay_every_itr = decay_interval_itr

        self.warm_restart_itr = warm_restart
        self.warm_restart_schedule = warm_restart_schedule

        self.last_warm_restart = 0

        self.exp_min = global_config.get('exp_min', None)
        if self.exp_min:
            print('*** Has minimal LR =', self.exp_min)
コード例 #5
0
    def get_log_dir(log_dir_root,
                    rel_paths,
                    restorer,
                    strip_ext='.cf',
                    global_config_values=None):
        if not restorer or not restorer.restore_continue:
            log_dir = logdir_helpers.create_unique_log_dir(
                rel_paths,
                log_dir_root,
                strip_ext=strip_ext,
                postfix=global_config_values)
            print('Created {}...'.format(log_dir))
            return log_dir

        previous_log_dir = restorer.get_log_dir()

        if restorer.restore_continue:
            theoretical_log_dir = logdir_helpers.get_log_dir_name(
                rel_paths, strip_ext=strip_ext, postfix=global_config_values)
            previous_log_dir_name = logdir_helpers.log_name_from_log_dir(
                previous_log_dir)
            if theoretical_log_dir != previous_log_dir_name and not global_config.get(
                    'force_overwrite', False):
                raise ValueError(
                    '--restore_continue given, but previous log_dir != current:\n'
                    + f'   {previous_log_dir_name}\n!= {theoretical_log_dir}')

        print('Using {}...'.format(previous_log_dir))
        return previous_log_dir
コード例 #6
0
    def __init__(self, config_ms):
        super(MultiscaleNetwork, self).__init__()

        # Set for the RGB baselines
        self._rgb = config_ms.rgb_bicubic_baseline  # if set, make sure no backprob through sub_mean

        # True for L3C and RGB, not for RGB Shared
        self._fuse_feat = config_ms.dec.skip

        self._show_input = global_config.get('showinp', False)

        # For the first scale, where input is RGB with C=3
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_rgb_mean = edsr.MeanShift(255., rgb_mean,
                                           rgb_std)  # to interval -128, 128

        self.scales = config_ms.num_scales
        self.config_ms = config_ms

        # NOTES about naming: See README

        if not config_ms.rgb_bicubic_baseline:
            # Heads are used to make the code work for L3C as well as the RBG baselines.
            # For RGB, each encoder gets a bicubically downsampled RGB image as input, with 3 channels.
            # Otherwise, the encoder gets the final feature before the quantizer, with Cf channels.
            # The Heads map either of these to Cf channels, such that encoders always get a feature map with Cf
            # channels.
            heads = ([RGBHead(config_ms)] + [
                Head(config_ms, Cin=self.get_Cin_for_scale(scale))
                for scale in range(self.scales - 1)
            ])
            nets = [Net(config_ms, scale) for scale in range(self.scales)]
            prob_clfs = ([AtrousProbabilityClassifier(config_ms, C=3)] + [
                AtrousProbabilityClassifier(config_ms, config_ms.q.C)
                for _ in range(self.scales - 1)
            ])
        else:
            print('*** Multiscale RGB Pyramid')
            # For RGB Baselines, we feed subsampled version of RGB directly to the next subsampler
            # (see Fig A2, A3 in appendix of paper). Thus, the heads are just identity.
            heads = [
                pe.LambdaModule(lambda x: x, name='ID')
                for _ in range(self.scales)
            ]
            nets = [Net(config_ms, scale) for scale in range(self.scales)]
            prob_clfs = [
                AtrousProbabilityClassifier(config_ms, C=3)
                for _ in range(self.scales)
            ]

        self.heads = nn.ModuleList(heads)
        self.nets = nn.ModuleList(nets)
        self.prob_clfs = nn.ModuleList(prob_clfs)  # len == #scales

        self.extra_repr_str = 'scales={} / {} nets / {} ps'.format(
            self.scales, len(self.nets), len(self.prob_clfs))
コード例 #7
0
    def __init__(self, config_ms, C=3, filter_size=3):
        super(Final, self).__init__()

        raise NotImplementedError

        self.C = C
        self.K = config_ms.prob.K

        self.scales_conv = nn.Conv2d(self.C * self.K, self.C * self.K, filter_size,
                                     padding=filter_size//2, bias=global_config.get('initbias', False))
        print('C', self.C, 'K', self.K, self)
コード例 #8
0
def _maybe_auto_reg(l: NetworkOutput):
    if not global_config.get('s_autoreg', False):
        return l.means

    coeffs = torch.tanh(l.lambdas)  # NCKHW, basically coeffs_g_r, coeffs_b_r, coeffs_b_g
    means_r, means_g, means_b = l.means[:, 0, ...], l.means[:, 1, ...], l.means[:, 2, ...]  # each NKHW
    coeffs_g_r, coeffs_b_r, coeffs_b_g = coeffs[:, 0, ...], coeffs[:, 1, ...], coeffs[:, 2, ...]  # each NKHW
    x_reg = l.means
    return torch.stack(
            (means_r,
             means_g + coeffs_g_r * x_reg[:, 0, ...],
             means_b + coeffs_b_r * x_reg[:, 0, ...] + coeffs_b_g * x_reg[:, 1, ...]), dim=1)  # NCKHW again
コード例 #9
0
ファイル: shared.py プロジェクト: CrhistyanSilva/RC-PyTorch
def get_test_dataset_transform(crop):
    img_to_tensor_t = [
        images_loader.IndexImagesDataset.to_tensor_uint8_transform()
    ]
    if global_config.get('ycbcr', False):
        print('Adding ->YCbCr to Testset')
        t = transforms.Lambda(lambda pil_img: pil_img.convert('YCbCr'))
        img_to_tensor_t.insert(0, t)
    if crop:
        print(f'Cropping Testset: {crop}')
        img_to_tensor_t.insert(0, transforms.CenterCrop(crop))
    return transforms.Compose(img_to_tensor_t)
コード例 #10
0
    def get_ds_train(self):
        """
        Dataset must return dicts with {'idx', 'raw'}, where 'raw' is 3HW uint8
        """
        if self.config_dl.is_residual_dataset:
            return get_residual_dataset(
                self.config_dl.imgs_dir_train,
                random_transforms=True,
                random_scale=self.config_dl.random_scale,
                crop_size=self.config_dl.crop_size,
                mode='both' if self.style == 'enhancement' else 'diff',
                discard_shitty=self.config_dl.discard_shitty_train,
                filter_min_size=self.config_dl.train_filter_min_size,
                top_only=global_config.get('top_only', None),
                is_training=True)
        else:
            assert self.style != 'enhancement', 'style == enhancement -> expected residual dataset'

        to_tensor_transform = transforms.Compose([
            transforms.RandomCrop(self.config_dl.crop_size),
            transforms.RandomHorizontalFlip(),
            images_loader.IndexImagesDataset.to_tensor_uint8_transform()
        ])

        if global_config.get('ycbcr', False):
            print('Adding ->YCbCr')
            t = transforms.Lambda(lambda pil_img: pil_img.convert('YCbCr'))
            to_tensor_transform.transforms.insert(-2, t)

        ds_syn = global_config.get('ds_syn', None)
        if ds_syn:
            ds_train = self._get_syn(ds_syn, 30 * 10000)
        else:
            ds_train = images_loader.IndexImagesDataset(
                images=cached_listdir_imgs(
                    self.config_dl.imgs_dir_train,
                    min_size=self.config_dl.crop_size,
                    discard_shitty=self.config_dl.discard_shitty_train),
                to_tensor_transform=to_tensor_transform)
        return ds_train
コード例 #11
0
    def _get_jpg(self, im):
        tmp_dir = os.path.join('/dev/shm', str(os.getpid()))
        os.makedirs(tmp_dir, exist_ok=True)

        img_p = os.path.join(tmp_dir, 'img.jpg')
        quality = self.quality
        if global_config.get('rand_quality'):
            quality = random.randint(95, 99)
        im.save(img_p, quality=quality)

        bpp = os.path.getsize(img_p) * 8 / np.prod(im.size)

        return Image.open(img_p).convert('RGB'), bpp
コード例 #12
0
ファイル: head.py プロジェクト: CrhistyanSilva/RC-PyTorch
 def __init__(self, config_ms):
     super(RGBHead, self).__init__()
     assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID'
     head = [
         Head(config_ms, Cin=3)
     ]
     if global_config.get('gdn', False):
         print('*** Adding GDN')
         head.append(GDN(config_ms.Cf))
     self.head = nn.Sequential(
             # Note, this actually shifts data to ~[-1, 1]
             # edsr.MeanShift(0, (0., 0., 0.), (128., 128., 128.)),
             *head)
     self._repr = 'MeanShift//Head(C=3)'
コード例 #13
0
    def __init__(self, rgb_scale: bool, L):
        """
        :param rgb_scale: Whether this is the loss for the RGB scale. In that case,
            use_coeffs=True
            _num_params=_NUM_PARAMS_RGB == 4, since we predict coefficients lambda. See note above.
        :param L: number of symbols
        """
        super(DiscretizedMixLogisticLoss, self).__init__()

        self.rgb_scale = rgb_scale
        self.L = L

        self._means_oracle = global_config.get('means_oracle', None)
        if self._means_oracle:
            print('*** Means oracle,', self._means_oracle)

        self._self_auto_reg = global_config.get('s_autoreg', False)

        # Adapted bounds for our case.
        self.bin_width = 2 / (L - 1)

        # Lp = L+1
        self.targets = torch.linspace(-1 - self.bin_width / 2,
                                      1 + self.bin_width / 2,
                                      self.L + 1,
                                      dtype=torch.float32,
                                      device=pe.DEVICE)

        self.min_sigma = global_config.get('minsigma', -9.)

        self._extra_repr = (
            f'DMLL:'
            f'L={self.L}, '
            f'bin_width={self.bin_width}, min_sigma={self.min_sigma}')

        self._alpha = 1
コード例 #14
0
 def _tail(fw_=3):
     if global_config.get('atrous', None):
         print('Atrous Tail')
         assert 'long_sigma' in global_config
         assert 'long_means' in global_config
         return [
             prob_clf.StackedAtrousConvs(
                     atrous_rates_str='1,2,4',
                     Cin=Cf, Cout=prob_clf.ProbClfTail.get_cout(config_en), Catrous=Cf//2,
                     bias=False, activation=nn.LeakyReLU(inplace=True))]
     else:  # default so far
         return [
             pe.default_conv(Cf, Cf, fw_),
             nn.LeakyReLU(inplace=True),
             pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1),  # final 1x1
         ]
コード例 #15
0
 def iterator(self, epoch):
     """ :returns an iterator over tuples (itr, batch) """
     skip_epochs, skip_batches = self.epochs_to_skip()
     if epoch < skip_epochs:
         print('Skipping epoch {}'.format(epoch))
         return []  # nothing to iterate
     if epoch > skip_epochs or (epoch == skip_epochs and skip_batches
                                == 0):  # iterate like normal
         return enumerate(self.dl_train, epoch * len(self.dl_train))
     # if we get to here, we are in the first epoch which we should not skip, so skip `skip_batches` batches
     it = iter(self.dl_train)
     for i in range(skip_batches):
         print('\rDropping batch {: 10d}...'.format(i), end='')
         if not global_config.get('drop_batches', False):
             # would be nice to not load images but this is hard to do as DataLoader caches Dataset's respondes,
             # might even be immutable?
             next(it)  # drop batch
     print(' -- dropped {} batches'.format(skip_batches))
     return enumerate(it, epoch * len(self.dl_train) + skip_batches)
コード例 #16
0
    def __init__(self, config_ms, scale, C=3):
        super(DeepProbabilityClassifier, self).__init__()

        Cf = config_ms.Cf
        kernel_size = 3

        m_body = [
            edsr.ResBlock(conv, Cf, kernel_size, act=act.make(Cf, inverse=True),
                          res_scale=global_config.get('res_scale', 1))
            for _ in range(3)
        ]
        m_body.append(conv(Cf, Cf, kernel_size))

        self.body = nn.Sequential(*m_body)

        K = config_ms.prob.K

        # For RGB, generate the outputs specified by config_ms.prob.rgb_outputs
        # otherwise, generate means, sigmas, pis
        tail_outputs = (_parse_outputs_flag(config_ms.prob.rgb_outputs) if scale == 0
                        else RequiredOutputs(True, True, True, lambdas=False))

        self.tail = ProbClfTail(Cf, C, K, outputs=tail_outputs)
コード例 #17
0
ファイル: act.py プロジェクト: CrhistyanSilva/RC-PyTorch
def make(C, inverse):
    return {
        'relu': lambda: nn.ReLU(True),
        'lrelu': lambda: nn.LeakyReLU(inplace=True),
        'GDN': lambda: gdn.GDN(C, inverse=inverse)
    }[global_config.get('act', 'relu')]()
コード例 #18
0
    def __init__(self, config_clf):
        super(ClassifierNetwork, self).__init__()
        self.config_clf = config_clf
        Cf = config_clf.Cf
        num_classes = config_clf.num_classes
        head = config_clf.head
        nB = config_clf.n_resblock
        norm = {
            'bn': nn.BatchNorm2d,
            'gdn': gdn.GDN,
            'identity': lambda _: pe.IdentityModule()
        }[config_clf.norm]

        if head == 'down3':
            head = [
                pe.default_conv(3, Cf // 4, 5, stride=2),
                norm(Cf // 4),
                nn.ReLU(inplace=True),
                pe.default_conv(Cf // 4, Cf // 2, 5, stride=2),
                norm(Cf // 2),
                nn.ReLU(inplace=True),
                pe.default_conv(Cf // 2, Cf, 5, stride=2),
                norm(Cf),
                nn.ReLU(inplace=True),
            ]
        elif head == 'down2':
            head = [
                pe.default_conv(3, Cf // 2, 5, stride=2),
                norm(Cf // 4),
                nn.ReLU(inplace=True),
                pe.default_conv(Cf // 2, Cf, 5, stride=2),
                norm(Cf),
                nn.ReLU(inplace=True),
            ]

        self.head = nn.Sequential(*head)
        norm_cls = lambda: norm(Cf)

        model = [
            ResBlock(pe.default_conv,
                     Cf,
                     kernel_size=3,
                     act=nn.ReLU(inplace=True),
                     norm_cls=norm_cls) for _ in range(nB)
        ]

        final_Cf = Cf
        if config_clf.num_res_down == 2:
            model.append(pe.default_conv(Cf, 2 * Cf, 5, stride=2))
            norm_cls = lambda: norm(2 * Cf)
            model += [
                ResBlock(pe.default_conv,
                         2 * Cf,
                         kernel_size=3,
                         act=nn.ReLU(inplace=True),
                         norm_cls=norm_cls) for _ in range(nB)
            ]
            final_Cf = 2 * Cf

        if global_config.get('final_conv', False):
            model += [
                pe.default_conv(final_Cf, final_Cf, 3),
                nn.LeakyReLU(inplace=True)
            ]

        self.model = nn.Sequential(
            *model,
            ChannelAverage(),
        )

        if config_clf.deep_tail:
            tail = [
                nn.Linear(final_Cf, 2 * final_Cf),
                nn.LeakyReLU(inplace=True),
                nn.Linear(final_Cf, num_classes)
            ]
        else:
            tail = [nn.Linear(final_Cf, num_classes)]
        self.tail = nn.Sequential(*tail)
コード例 #19
0
def get_residual_dataset(imgs_dir, random_transforms: bool,
                         random_scale, crop_size: int, mode: str, max_imgs=None,
                         discard_shitty=True, filter_min_size=None, top_only=None,
                         is_training=False, sort=False):
    if top_only:
        assert top_only < 1
    multiple_ds = False
    if isinstance(imgs_dir, dict):
        assert 'raw' in imgs_dir and 'compressed' in imgs_dir, imgs_dir.keys()
        raw_p, compressed_p = imgs_dir['raw'], imgs_dir['compressed']
        multiple_ds = isinstance(imgs_dir['raw'], list)
        if multiple_ds:
            assert len(raw_p) == len(compressed_p)
    elif ';' in imgs_dir:
        raw_p, compressed_p = imgs_dir.split(';')
    else:
        raise ValueError('Invalid imgs_dir, should be dict or string with ;, got', imgs_dir)

    # works fine if p_ is a list
    get_imgs = lambda p_: cached_listdir_imgs(
            p_, min_size=filter_min_size or crop_size, discard_shitty=discard_shitty)

    if compressed_p == 'JPG':
        print('*** Using JPG...')
        imgs = get_imgs(raw_p)
        return JPGDataset(imgs,
                          random_crops=crop_size if random_transforms else None,
                          random_flips=random_transforms,
                          random_scale=random_scale,
                          center_crops=crop_size if not random_transforms else None,
                          max_imgs=max_imgs)

    if is_training and global_config.get('filter_imgs', False):
        assert not multiple_ds
        print('*** filtering', imgs_dir)
        filter_imgs = global_config['filter_imgs']
        if not isinstance(filter_imgs, int):
            filter_imgs = 680
        print(filter_imgs)
        get_imgs = lambda p_: cached_listdir_imgs_max(p_, max_size=filter_imgs, discard_shitty=True)

    raw, compressed = map(get_imgs, (raw_p, compressed_p))

    if top_only:
        sorted_imgs = sorted((Compressor.bpp_from_compressed_file(p), p) for p in compressed.ps)
        top_only_imgs = sorted_imgs[-int(top_only * len(sorted_imgs)):]
        top_only_ps = [p for _, p in top_only_imgs]
        compressed = Images(top_only_ps, compressed.id + f'_top{top_only:.2f}')
        print(f'*** Using {len(top_only_ps)} of {len(sorted_imgs)} images only')

    if sort:
        print('Sorting...')
        raw = raw.sort()
        compressed = compressed.sort()

    return ResidualDataset(compressed, raw,
                           mode=mode,
                           random_crops=crop_size if random_transforms else None,
                           random_flips=random_transforms,
                           random_scale=random_scale,
                           center_crops=crop_size if not random_transforms else None,
                           max_imgs=max_imgs)
コード例 #20
0
 def get_padding_fac(self):
     return 2 if global_config.get('down_up', None) else 0
コード例 #21
0
    def __init__(self,
                 loss_dmol_rgb: DiscretizedMixLogisticLoss):
        self.loss_dmol_rgb = loss_dmol_rgb

        self._plot_loss = global_config.get('test.plot_loss', False)

        self._info = global_config.get('test.info', '')
        self._mode = global_config.get('test.mode', 'direct')  # direct, channel

        self._rand_init = global_config.get('test.rand_init', False)
        self._grid = global_config.get('test.grid', 1)
        self._full = global_config.get('test.full', False)
        self._num_iter = global_config.get('test.num_iter', 50)
        self._early_stop = global_config.get('test.early_stop', False)
        self._ignore_overhead = global_config.get('test.ignore_overhead', False)
        self._subsampling = global_config.get('test.subsampling', 4)
        self._optim_cls = global_config.get('test.optim', 'SGD')
        self._lr = global_config.get('test.lr', 9e-2)
        self._optim_params = global_config.get_as_dict(
            'test.optim_params', 'dict(momentum=0.9)' if self._optim_cls != 'Adam' else 'dict()')

        self._summary = _Summary()
コード例 #22
0
 def make_res_block(_act, _use_norm=True):
     return edsr.ResBlock(
         pe.default_conv, Cf, kernel_size, act=_act,
         norm_cls=norm_cls if _use_norm else None,
         res_scale=global_config.get('res_scale', 0.1))
コード例 #23
0
    def __init__(self,
                 config_p,
                 dl_config_p,
                 log_dir_root,
                 log_config: LogConfig,
                 num_workers,
                 saver: Saver,
                 restorer: TrainRestorer = None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param config_p: Path to the network config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """
        self.style = MultiscaleTrainer.get_style_from_config(config_p)
        self.blueprint_cls = {
            'enhancement': EnhancementBlueprint,
            'classifier': ClassifierBlueprint
        }[self.style]

        global_config.declare_used('filter_imgs')

        # Read configs
        # config = config for the network
        # config_dl = config for data loading
        (self.config, self.config_dl), rel_paths = ft.unzip(
            map(config_parser.parse, [config_p, dl_config_p]))
        # TODO only read by enhancement classes
        self.config.is_residual = self.config_dl.is_residual_dataset

        # Update global_config given config.global_config
        global_config_config_keys = global_config.add_from_str_without_overwriting(
            self.config.global_config)
        # Update config_ms depending on global_config
        global_config.update_config(self.config)

        if self.style == 'enhancement':
            EnhancementBlueprint.read_evenly_spaced_bins(self.config_dl)

        self._custom_init()

        # Create data loaders
        dl_train, self.ds_val, self.fixed_first_val = self._get_dataloaders(
            num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = self.blueprint_cls(self.config)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {
            'RMSprop': optim.RMSprop,
            'Adam': optim.Adam,
            'SGD': optim.SGD,
        }[self.config.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(),
                               self.config.lr.initial,
                               weight_decay=self.config.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(self.config.lr.schedule,
                                                 self.config.lr.initial,
                                                 [self.optim],
                                                 epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir_root = log_dir_root
        global_config_values = global_config.values(
            ignore=global_config_config_keys)
        self.log_dir = Trainer.get_log_dir(
            log_dir_root,
            rel_paths,
            restorer,
            global_config_values=global_config_values)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)

        if global_config.get('ds_syn', None):
            underlying = dl_train.dataset
            while not isinstance(underlying, _CheckerboardDataset):
                underlying = underlying.ds
            underlying.save_all(self.log_dir)

        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)

        # Try to write filenames somewhere
        try:
            dl_train.dataset.write_file_names_to_txt(self.log_dir)
        except AttributeError:
            raise AttributeError(
                f'dl_train.dataset of type {type(dl_train.dataset)} does not support '
                f'write_file_names_to_txt(log_dir)!')

        # superclass setup
        super(MultiscaleTrainer,
              self).__init__(dl_train, [self.optim],
                             net,
                             sw,
                             max_epochs=self.config_dl.max_epochs,
                             log_config=log_config,
                             saver=saver,
                             skip_to_itr=skip_to_itr)
コード例 #24
0
    def __init__(self, config_en):
        super(EnhancementNetwork, self).__init__()
        self.config_en = config_en

        Cf = config_en.Cf
        kernel_size = config_en.kernel_size
        n_resblock = config_en.n_resblock

        more_gdn = global_config.get('more_gdn', False)
        more_act = global_config.get('more_act', False)

        act_body = act.make(Cf, inverse=False)
        act_tail = act.make(Cf, inverse=False)

        self.head = pe.default_conv(3, Cf, 3)
        if more_act:
            self.head = nn.Sequential(self.head, act_body)

        self._down_up = global_config.get('down_up', None)
        if self._down_up:
            self.down = pe.default_conv(Cf, Cf, global_config.get('fw_du', 3), stride=2)
            if more_gdn:
                self.down = nn.Sequential(self.down, GDN(Cf))
            if more_act:
                self.down = nn.Sequential(self.down, act_body)
        else:
            self.down = lambda x: x

        self.unet_skip_conv = None
        if global_config.get('unet_skip', None):
            self.unet_skip_conv = nn.Sequential(
                    pe.default_conv(2*Cf, Cf, 3),
                    nn.ReLU(inplace=True))

        assert not global_config.get('learned_skip', False)

        # Cf_resnet = global_config.get('Cf_resnet', Cf)
        # print('*** Cf_resnet ==', Cf_resnet5)

        norm_cls = None
        if global_config.get('inorm', False):
            print('***Using Instance Norm!')
            norm_cls = lambda: nn.InstanceNorm2d(Cf, affine=True)

        if global_config.get('gdn', False):
            norm_cls = lambda: GDN(Cf)

        use_norm_for_long = not global_config.get('no_norm_final', False)
        if not use_norm_for_long:
            print('*** no norm for final')

        def make_res_block(_act, _use_norm=True):
            return edsr.ResBlock(
                pe.default_conv, Cf, kernel_size, act=_act,
                norm_cls=norm_cls if _use_norm else None,
                res_scale=global_config.get('res_scale', 0.1))

        norm_in_body = True

        if global_config.get('gdn_as_nl', False):
            print('*** GDN as non linearity!')
            norm_cls = None
            act_body = GDN(Cf)
            if not global_config.get('gdnfreetail', False):
                act_tail = GDN(Cf)
            norm_in_body = False
            use_norm_for_long = False

        m_body = [
            make_res_block(act_body, norm_in_body)
            for _ in range(n_resblock)
        ]
        m_body.append(pe.default_conv(Cf, Cf, kernel_size))
        self.body = nn.Sequential(*m_body)

        if self._down_up:
            if self._down_up == 'deconv':
                up = ups.DeconvUp(config_en)
            elif self._down_up == 'nn':
                up = ups.ResizeConvUp(config_en)
            else:
                up = edsr.Upsampler(pe.default_conv, 2, Cf, act=False)

            if more_gdn:
                up = nn.Sequential(up, GDN(Cf, inverse=True))
            if more_act:
                up = nn.Sequential(up, act_body)

            print('*** DownUp, adding', up)
            self.after_skip = up
        else:
            self.after_skip = lambda x: x

        tail_networks = {}
        if global_config.get('deeptails', False):
            raise NotImplemented
            # num_blocks = global_config['deeptails']
            # for name in ('sigmas', 'means'):
            #     tail_networks[name] = lambda: SequentialWithSkip(
            #             body=nn.Sequential(*[make_res_block(nn.LeakyReLU(inplace=True))
            #                                  for _ in range(num_blocks)]),
            #             final=pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1))

        def _tail(fw_=3):
            if global_config.get('atrous', None):
                print('Atrous Tail')
                assert 'long_sigma' in global_config
                assert 'long_means' in global_config
                return [
                    prob_clf.StackedAtrousConvs(
                            atrous_rates_str='1,2,4',
                            Cin=Cf, Cout=prob_clf.ProbClfTail.get_cout(config_en), Catrous=Cf//2,
                            bias=False, activation=nn.LeakyReLU(inplace=True))]
            else:  # default so far
                return [
                    pe.default_conv(Cf, Cf, fw_),
                    nn.LeakyReLU(inplace=True),
                    pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1),  # final 1x1
                ]

        if global_config.get('long_sigma', False):
            fw_sigma = global_config.get('fw_s', 5)
            print('filter_width for sigma =', fw_sigma)
            modules = [make_res_block(act_tail, use_norm_for_long),
                       *_tail(fw_sigma)]
            if global_config.get('fc2', False):
                print('Adding another 1x1 conv!')
                modules.insert(-1, pe.default_conv(Cf, Cf, 1))
            tail_networks['sigmas'] = lambda: pe.FeatureMapSaverSequential(
                    *modules,
                    saver=None  # pe.FeatureMapSaver()
            )
            print('Did set tail_networks.sigmas')
        if global_config.get('long_means', False):
            tail_networks['means'] = lambda: pe.FeatureMapSaverSequential(
                    make_res_block(act_tail, use_norm_for_long),
                    *_tail()
                    # saver=self.savers['final_sigmas'], idx=-2
            )
            print('Did set tail_networks.means')
        if global_config.get('long_pis', False):
            tail_networks['pis'] = lambda: pe.FeatureMapSaverSequential(
                    make_res_block(act_tail, use_norm_for_long),
                    pe.default_conv(Cf, Cf, 3),  # no crazy smoothing
                    nn.LeakyReLU(inplace=True),  # a non linearity
                    pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1),  # final 1x1
                    saver=None
            )
            print('Did set tail_networks.pis')
        if global_config.get('long_lambdas', False):
            tail_networks['lambdas'] = lambda: nn.Sequential(
                    make_res_block(act_tail, use_norm_for_long),
                    pe.default_conv(Cf, Cf, 3),  # no crazy smoothing
                    nn.LeakyReLU(inplace=True),  # a non linearity
                    pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1),  # final 1x1
            )
            print('Did set tail_networks.lambdas')

        if global_config.get('longer_lambda', False):
            tail_networks['lambdas'] = lambda: nn.Sequential(
                    pe.default_conv(Cf, Cf, 3),
                    nn.LeakyReLU(inplace=True),
                    pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1),  # final 1x1
            )
            print('Did set tail_networks.lambdas')

        self.side_information_mode = False
        if global_config.get('side_information', False):
            print('*** Using side_information!')
            self.side_information_mode = True
            Ccond = global_config['side_information']
            self.side_information_net = SideInformationNetwork(Ccond)
            self.side_information_conv = ConditionalConvolution(Cf, Cf, Ccond, kernel_size,
                                                                activation=nn.LeakyReLU(inplace=True))

        print('Setting tail_networks[', tail_networks.keys(), ']')
        self.tail = prob_clf.ProbClfTail.from_config(config_en, tail_networks=tail_networks)