def __init__(self, config_en, is_testing=False): super(EnhancementBlueprint, self).__init__() self.net = EnhancementNetwork(config_en) self.net = self.net.to(pe.DEVICE) self.config_en = config_en self.clf = None self.qstrategy = QStrategy.MIN self.losses = EnhancementLosses(config_en, is_testing) global_config.assert_only_one('cinorm', 'cin_eb', 'cgdn') self.cin_style = None if global_config.get('cinorm', False): self.cin_style = 'cinorm' elif global_config.get('cin_eb', False): self.cin_style = 'evenbins' self.cin_q = cin_bins.Quantizer(global_config['cin_eb']) elif global_config.get('cgdn', False): self.cin_style = 'cgdn' self.cin_q = cin_bins.Quantizer(global_config['cgdn']) print('EB: self.cin_style =', self.cin_style) self.padding_fac = self.get_padding_fac() print('***' * 10) print('*** Padding by a factor', self.padding_fac) print('***' * 10)
def __init__(self, config_ms, scale, C=3, atrous_rates_str='1,2,4'): raise NotImplementedError super(AtrousProbabilityClassifier, self).__init__() K = config_ms.prob.K Kp = non_shared_get_Kp(K, C) self.atrous = StackedAtrousConvs(atrous_rates_str, config_ms.Cf, Kp, kernel_size=config_ms.kernel_size, name=str(scale)) self._repr = f'C={C}; K={K}; Kp={Kp}; rates={atrous_rates_str}' if global_config.get('usefinal1', False): print('*** Using Final') self.final = Final(config_ms, C, 1) elif global_config.get('usefinal', False): print('*** Using Final') self.final = Final(config_ms, C, 3) else: self.final = lambda x: x if global_config.get('initbias', False): K = config_ms.prob.K self.atrous.lin.bias = nn.Parameter(_init_bias(self.atrous.lin.bias.detach(), C, K)) print(self.atrous.lin.bias.requires_grad) print('Updated bias:', self.atrous.lin.bias.reshape(-1, C, K)) if global_config.get('usefinal1', False): self.final.scales_conv.bias = nn.Parameter(_init_bias(self.final.scales_conv.bias.detach(), C, K))
def read_evenly_spaced_bins(config_dl): flag = global_config.get('cin_eb', None) or global_config.get( 'cgdn', None) if flag and flag.startswith('auto'): flag_name = 'cin_eb' if global_config.get('cin_eb', None) else 'cgdn' nb = int(flag.replace('auto', '')) # creates if needed pkl_p = cin_bins.make_bin_pkl( config_dl.imgs_dir_train['compressed'], nb) print(f'Setting {flag_name} = {pkl_p}') global_config[flag_name] = pkl_p
def __init__(self, optims, initial, decay_fac, decay_interval_itr=None, decay_interval_epoch=None, epoch_len=None, warm_restart=None, warm_restart_schedule=None): super(ExponentialDecayLRSchedule, self).__init__(optims) assert_exc((decay_interval_itr is not None) ^ (decay_interval_epoch is not None), 'Need either iter or epoch') if decay_interval_epoch: assert epoch_len is not None decay_interval_itr = int(decay_interval_epoch * epoch_len) if warm_restart: warm_restart = int(warm_restart * epoch_len) self.initial = initial self.decay_fac = decay_fac self.decay_every_itr = decay_interval_itr self.warm_restart_itr = warm_restart self.warm_restart_schedule = warm_restart_schedule self.last_warm_restart = 0 self.exp_min = global_config.get('exp_min', None) if self.exp_min: print('*** Has minimal LR =', self.exp_min)
def get_log_dir(log_dir_root, rel_paths, restorer, strip_ext='.cf', global_config_values=None): if not restorer or not restorer.restore_continue: log_dir = logdir_helpers.create_unique_log_dir( rel_paths, log_dir_root, strip_ext=strip_ext, postfix=global_config_values) print('Created {}...'.format(log_dir)) return log_dir previous_log_dir = restorer.get_log_dir() if restorer.restore_continue: theoretical_log_dir = logdir_helpers.get_log_dir_name( rel_paths, strip_ext=strip_ext, postfix=global_config_values) previous_log_dir_name = logdir_helpers.log_name_from_log_dir( previous_log_dir) if theoretical_log_dir != previous_log_dir_name and not global_config.get( 'force_overwrite', False): raise ValueError( '--restore_continue given, but previous log_dir != current:\n' + f' {previous_log_dir_name}\n!= {theoretical_log_dir}') print('Using {}...'.format(previous_log_dir)) return previous_log_dir
def __init__(self, config_ms): super(MultiscaleNetwork, self).__init__() # Set for the RGB baselines self._rgb = config_ms.rgb_bicubic_baseline # if set, make sure no backprob through sub_mean # True for L3C and RGB, not for RGB Shared self._fuse_feat = config_ms.dec.skip self._show_input = global_config.get('showinp', False) # For the first scale, where input is RGB with C=3 rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_rgb_mean = edsr.MeanShift(255., rgb_mean, rgb_std) # to interval -128, 128 self.scales = config_ms.num_scales self.config_ms = config_ms # NOTES about naming: See README if not config_ms.rgb_bicubic_baseline: # Heads are used to make the code work for L3C as well as the RBG baselines. # For RGB, each encoder gets a bicubically downsampled RGB image as input, with 3 channels. # Otherwise, the encoder gets the final feature before the quantizer, with Cf channels. # The Heads map either of these to Cf channels, such that encoders always get a feature map with Cf # channels. heads = ([RGBHead(config_ms)] + [ Head(config_ms, Cin=self.get_Cin_for_scale(scale)) for scale in range(self.scales - 1) ]) nets = [Net(config_ms, scale) for scale in range(self.scales)] prob_clfs = ([AtrousProbabilityClassifier(config_ms, C=3)] + [ AtrousProbabilityClassifier(config_ms, config_ms.q.C) for _ in range(self.scales - 1) ]) else: print('*** Multiscale RGB Pyramid') # For RGB Baselines, we feed subsampled version of RGB directly to the next subsampler # (see Fig A2, A3 in appendix of paper). Thus, the heads are just identity. heads = [ pe.LambdaModule(lambda x: x, name='ID') for _ in range(self.scales) ] nets = [Net(config_ms, scale) for scale in range(self.scales)] prob_clfs = [ AtrousProbabilityClassifier(config_ms, C=3) for _ in range(self.scales) ] self.heads = nn.ModuleList(heads) self.nets = nn.ModuleList(nets) self.prob_clfs = nn.ModuleList(prob_clfs) # len == #scales self.extra_repr_str = 'scales={} / {} nets / {} ps'.format( self.scales, len(self.nets), len(self.prob_clfs))
def __init__(self, config_ms, C=3, filter_size=3): super(Final, self).__init__() raise NotImplementedError self.C = C self.K = config_ms.prob.K self.scales_conv = nn.Conv2d(self.C * self.K, self.C * self.K, filter_size, padding=filter_size//2, bias=global_config.get('initbias', False)) print('C', self.C, 'K', self.K, self)
def _maybe_auto_reg(l: NetworkOutput): if not global_config.get('s_autoreg', False): return l.means coeffs = torch.tanh(l.lambdas) # NCKHW, basically coeffs_g_r, coeffs_b_r, coeffs_b_g means_r, means_g, means_b = l.means[:, 0, ...], l.means[:, 1, ...], l.means[:, 2, ...] # each NKHW coeffs_g_r, coeffs_b_r, coeffs_b_g = coeffs[:, 0, ...], coeffs[:, 1, ...], coeffs[:, 2, ...] # each NKHW x_reg = l.means return torch.stack( (means_r, means_g + coeffs_g_r * x_reg[:, 0, ...], means_b + coeffs_b_r * x_reg[:, 0, ...] + coeffs_b_g * x_reg[:, 1, ...]), dim=1) # NCKHW again
def get_test_dataset_transform(crop): img_to_tensor_t = [ images_loader.IndexImagesDataset.to_tensor_uint8_transform() ] if global_config.get('ycbcr', False): print('Adding ->YCbCr to Testset') t = transforms.Lambda(lambda pil_img: pil_img.convert('YCbCr')) img_to_tensor_t.insert(0, t) if crop: print(f'Cropping Testset: {crop}') img_to_tensor_t.insert(0, transforms.CenterCrop(crop)) return transforms.Compose(img_to_tensor_t)
def get_ds_train(self): """ Dataset must return dicts with {'idx', 'raw'}, where 'raw' is 3HW uint8 """ if self.config_dl.is_residual_dataset: return get_residual_dataset( self.config_dl.imgs_dir_train, random_transforms=True, random_scale=self.config_dl.random_scale, crop_size=self.config_dl.crop_size, mode='both' if self.style == 'enhancement' else 'diff', discard_shitty=self.config_dl.discard_shitty_train, filter_min_size=self.config_dl.train_filter_min_size, top_only=global_config.get('top_only', None), is_training=True) else: assert self.style != 'enhancement', 'style == enhancement -> expected residual dataset' to_tensor_transform = transforms.Compose([ transforms.RandomCrop(self.config_dl.crop_size), transforms.RandomHorizontalFlip(), images_loader.IndexImagesDataset.to_tensor_uint8_transform() ]) if global_config.get('ycbcr', False): print('Adding ->YCbCr') t = transforms.Lambda(lambda pil_img: pil_img.convert('YCbCr')) to_tensor_transform.transforms.insert(-2, t) ds_syn = global_config.get('ds_syn', None) if ds_syn: ds_train = self._get_syn(ds_syn, 30 * 10000) else: ds_train = images_loader.IndexImagesDataset( images=cached_listdir_imgs( self.config_dl.imgs_dir_train, min_size=self.config_dl.crop_size, discard_shitty=self.config_dl.discard_shitty_train), to_tensor_transform=to_tensor_transform) return ds_train
def _get_jpg(self, im): tmp_dir = os.path.join('/dev/shm', str(os.getpid())) os.makedirs(tmp_dir, exist_ok=True) img_p = os.path.join(tmp_dir, 'img.jpg') quality = self.quality if global_config.get('rand_quality'): quality = random.randint(95, 99) im.save(img_p, quality=quality) bpp = os.path.getsize(img_p) * 8 / np.prod(im.size) return Image.open(img_p).convert('RGB'), bpp
def __init__(self, config_ms): super(RGBHead, self).__init__() assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID' head = [ Head(config_ms, Cin=3) ] if global_config.get('gdn', False): print('*** Adding GDN') head.append(GDN(config_ms.Cf)) self.head = nn.Sequential( # Note, this actually shifts data to ~[-1, 1] # edsr.MeanShift(0, (0., 0., 0.), (128., 128., 128.)), *head) self._repr = 'MeanShift//Head(C=3)'
def __init__(self, rgb_scale: bool, L): """ :param rgb_scale: Whether this is the loss for the RGB scale. In that case, use_coeffs=True _num_params=_NUM_PARAMS_RGB == 4, since we predict coefficients lambda. See note above. :param L: number of symbols """ super(DiscretizedMixLogisticLoss, self).__init__() self.rgb_scale = rgb_scale self.L = L self._means_oracle = global_config.get('means_oracle', None) if self._means_oracle: print('*** Means oracle,', self._means_oracle) self._self_auto_reg = global_config.get('s_autoreg', False) # Adapted bounds for our case. self.bin_width = 2 / (L - 1) # Lp = L+1 self.targets = torch.linspace(-1 - self.bin_width / 2, 1 + self.bin_width / 2, self.L + 1, dtype=torch.float32, device=pe.DEVICE) self.min_sigma = global_config.get('minsigma', -9.) self._extra_repr = ( f'DMLL:' f'L={self.L}, ' f'bin_width={self.bin_width}, min_sigma={self.min_sigma}') self._alpha = 1
def _tail(fw_=3): if global_config.get('atrous', None): print('Atrous Tail') assert 'long_sigma' in global_config assert 'long_means' in global_config return [ prob_clf.StackedAtrousConvs( atrous_rates_str='1,2,4', Cin=Cf, Cout=prob_clf.ProbClfTail.get_cout(config_en), Catrous=Cf//2, bias=False, activation=nn.LeakyReLU(inplace=True))] else: # default so far return [ pe.default_conv(Cf, Cf, fw_), nn.LeakyReLU(inplace=True), pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1), # final 1x1 ]
def iterator(self, epoch): """ :returns an iterator over tuples (itr, batch) """ skip_epochs, skip_batches = self.epochs_to_skip() if epoch < skip_epochs: print('Skipping epoch {}'.format(epoch)) return [] # nothing to iterate if epoch > skip_epochs or (epoch == skip_epochs and skip_batches == 0): # iterate like normal return enumerate(self.dl_train, epoch * len(self.dl_train)) # if we get to here, we are in the first epoch which we should not skip, so skip `skip_batches` batches it = iter(self.dl_train) for i in range(skip_batches): print('\rDropping batch {: 10d}...'.format(i), end='') if not global_config.get('drop_batches', False): # would be nice to not load images but this is hard to do as DataLoader caches Dataset's respondes, # might even be immutable? next(it) # drop batch print(' -- dropped {} batches'.format(skip_batches)) return enumerate(it, epoch * len(self.dl_train) + skip_batches)
def __init__(self, config_ms, scale, C=3): super(DeepProbabilityClassifier, self).__init__() Cf = config_ms.Cf kernel_size = 3 m_body = [ edsr.ResBlock(conv, Cf, kernel_size, act=act.make(Cf, inverse=True), res_scale=global_config.get('res_scale', 1)) for _ in range(3) ] m_body.append(conv(Cf, Cf, kernel_size)) self.body = nn.Sequential(*m_body) K = config_ms.prob.K # For RGB, generate the outputs specified by config_ms.prob.rgb_outputs # otherwise, generate means, sigmas, pis tail_outputs = (_parse_outputs_flag(config_ms.prob.rgb_outputs) if scale == 0 else RequiredOutputs(True, True, True, lambdas=False)) self.tail = ProbClfTail(Cf, C, K, outputs=tail_outputs)
def make(C, inverse): return { 'relu': lambda: nn.ReLU(True), 'lrelu': lambda: nn.LeakyReLU(inplace=True), 'GDN': lambda: gdn.GDN(C, inverse=inverse) }[global_config.get('act', 'relu')]()
def __init__(self, config_clf): super(ClassifierNetwork, self).__init__() self.config_clf = config_clf Cf = config_clf.Cf num_classes = config_clf.num_classes head = config_clf.head nB = config_clf.n_resblock norm = { 'bn': nn.BatchNorm2d, 'gdn': gdn.GDN, 'identity': lambda _: pe.IdentityModule() }[config_clf.norm] if head == 'down3': head = [ pe.default_conv(3, Cf // 4, 5, stride=2), norm(Cf // 4), nn.ReLU(inplace=True), pe.default_conv(Cf // 4, Cf // 2, 5, stride=2), norm(Cf // 2), nn.ReLU(inplace=True), pe.default_conv(Cf // 2, Cf, 5, stride=2), norm(Cf), nn.ReLU(inplace=True), ] elif head == 'down2': head = [ pe.default_conv(3, Cf // 2, 5, stride=2), norm(Cf // 4), nn.ReLU(inplace=True), pe.default_conv(Cf // 2, Cf, 5, stride=2), norm(Cf), nn.ReLU(inplace=True), ] self.head = nn.Sequential(*head) norm_cls = lambda: norm(Cf) model = [ ResBlock(pe.default_conv, Cf, kernel_size=3, act=nn.ReLU(inplace=True), norm_cls=norm_cls) for _ in range(nB) ] final_Cf = Cf if config_clf.num_res_down == 2: model.append(pe.default_conv(Cf, 2 * Cf, 5, stride=2)) norm_cls = lambda: norm(2 * Cf) model += [ ResBlock(pe.default_conv, 2 * Cf, kernel_size=3, act=nn.ReLU(inplace=True), norm_cls=norm_cls) for _ in range(nB) ] final_Cf = 2 * Cf if global_config.get('final_conv', False): model += [ pe.default_conv(final_Cf, final_Cf, 3), nn.LeakyReLU(inplace=True) ] self.model = nn.Sequential( *model, ChannelAverage(), ) if config_clf.deep_tail: tail = [ nn.Linear(final_Cf, 2 * final_Cf), nn.LeakyReLU(inplace=True), nn.Linear(final_Cf, num_classes) ] else: tail = [nn.Linear(final_Cf, num_classes)] self.tail = nn.Sequential(*tail)
def get_residual_dataset(imgs_dir, random_transforms: bool, random_scale, crop_size: int, mode: str, max_imgs=None, discard_shitty=True, filter_min_size=None, top_only=None, is_training=False, sort=False): if top_only: assert top_only < 1 multiple_ds = False if isinstance(imgs_dir, dict): assert 'raw' in imgs_dir and 'compressed' in imgs_dir, imgs_dir.keys() raw_p, compressed_p = imgs_dir['raw'], imgs_dir['compressed'] multiple_ds = isinstance(imgs_dir['raw'], list) if multiple_ds: assert len(raw_p) == len(compressed_p) elif ';' in imgs_dir: raw_p, compressed_p = imgs_dir.split(';') else: raise ValueError('Invalid imgs_dir, should be dict or string with ;, got', imgs_dir) # works fine if p_ is a list get_imgs = lambda p_: cached_listdir_imgs( p_, min_size=filter_min_size or crop_size, discard_shitty=discard_shitty) if compressed_p == 'JPG': print('*** Using JPG...') imgs = get_imgs(raw_p) return JPGDataset(imgs, random_crops=crop_size if random_transforms else None, random_flips=random_transforms, random_scale=random_scale, center_crops=crop_size if not random_transforms else None, max_imgs=max_imgs) if is_training and global_config.get('filter_imgs', False): assert not multiple_ds print('*** filtering', imgs_dir) filter_imgs = global_config['filter_imgs'] if not isinstance(filter_imgs, int): filter_imgs = 680 print(filter_imgs) get_imgs = lambda p_: cached_listdir_imgs_max(p_, max_size=filter_imgs, discard_shitty=True) raw, compressed = map(get_imgs, (raw_p, compressed_p)) if top_only: sorted_imgs = sorted((Compressor.bpp_from_compressed_file(p), p) for p in compressed.ps) top_only_imgs = sorted_imgs[-int(top_only * len(sorted_imgs)):] top_only_ps = [p for _, p in top_only_imgs] compressed = Images(top_only_ps, compressed.id + f'_top{top_only:.2f}') print(f'*** Using {len(top_only_ps)} of {len(sorted_imgs)} images only') if sort: print('Sorting...') raw = raw.sort() compressed = compressed.sort() return ResidualDataset(compressed, raw, mode=mode, random_crops=crop_size if random_transforms else None, random_flips=random_transforms, random_scale=random_scale, center_crops=crop_size if not random_transforms else None, max_imgs=max_imgs)
def get_padding_fac(self): return 2 if global_config.get('down_up', None) else 0
def __init__(self, loss_dmol_rgb: DiscretizedMixLogisticLoss): self.loss_dmol_rgb = loss_dmol_rgb self._plot_loss = global_config.get('test.plot_loss', False) self._info = global_config.get('test.info', '') self._mode = global_config.get('test.mode', 'direct') # direct, channel self._rand_init = global_config.get('test.rand_init', False) self._grid = global_config.get('test.grid', 1) self._full = global_config.get('test.full', False) self._num_iter = global_config.get('test.num_iter', 50) self._early_stop = global_config.get('test.early_stop', False) self._ignore_overhead = global_config.get('test.ignore_overhead', False) self._subsampling = global_config.get('test.subsampling', 4) self._optim_cls = global_config.get('test.optim', 'SGD') self._lr = global_config.get('test.lr', 9e-2) self._optim_params = global_config.get_as_dict( 'test.optim_params', 'dict(momentum=0.9)' if self._optim_cls != 'Adam' else 'dict()') self._summary = _Summary()
def make_res_block(_act, _use_norm=True): return edsr.ResBlock( pe.default_conv, Cf, kernel_size, act=_act, norm_cls=norm_cls if _use_norm else None, res_scale=global_config.get('res_scale', 0.1))
def __init__(self, config_p, dl_config_p, log_dir_root, log_config: LogConfig, num_workers, saver: Saver, restorer: TrainRestorer = None, sw_cls=vis.safe_summary_writer.SafeSummaryWriter): """ :param config_p: Path to the network config file, see README :param dl_config_p: Path to the dataloader config file, see README :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here. :param log_config: Instance of train.trainer.LogConfig, contains intervals. :param num_workers: Number of workers to use for DataLoading, see train.py :param saver: Saver instance to use. :param restorer: Instance of TrainRestorer, if we need to restore """ self.style = MultiscaleTrainer.get_style_from_config(config_p) self.blueprint_cls = { 'enhancement': EnhancementBlueprint, 'classifier': ClassifierBlueprint }[self.style] global_config.declare_used('filter_imgs') # Read configs # config = config for the network # config_dl = config for data loading (self.config, self.config_dl), rel_paths = ft.unzip( map(config_parser.parse, [config_p, dl_config_p])) # TODO only read by enhancement classes self.config.is_residual = self.config_dl.is_residual_dataset # Update global_config given config.global_config global_config_config_keys = global_config.add_from_str_without_overwriting( self.config.global_config) # Update config_ms depending on global_config global_config.update_config(self.config) if self.style == 'enhancement': EnhancementBlueprint.read_evenly_spaced_bins(self.config_dl) self._custom_init() # Create data loaders dl_train, self.ds_val, self.fixed_first_val = self._get_dataloaders( num_workers) # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse # during testing. self.blueprint = self.blueprint_cls(self.config) print('Network:', self.blueprint.net) # Setup optimizer optim_cls = { 'RMSprop': optim.RMSprop, 'Adam': optim.Adam, 'SGD': optim.SGD, }[self.config.optim] net = self.blueprint.net self.optim = optim_cls(net.parameters(), self.config.lr.initial, weight_decay=self.config.weight_decay) # Calculate a rough estimate for time per batch (does not take into account that CUDA is async, # but good enought to get a feeling during training). self.time_accumulator = timer.TimeAccumulator() # Restore network if requested skip_to_itr = self.maybe_restore(restorer) if skip_to_itr is not None: # i.e., we have a restorer print('Skipping to {}...'.format(skip_to_itr)) # Create LR schedule to update parameters self.lr_schedule = lr_schedule.from_spec(self.config.lr.schedule, self.config.lr.initial, [self.optim], epoch_len=len(dl_train)) # --- All nn.Modules are setup --- print('-' * 80) # create log dir and summary writer self.log_dir_root = log_dir_root global_config_values = global_config.values( ignore=global_config_config_keys) self.log_dir = Trainer.get_log_dir( log_dir_root, rel_paths, restorer, global_config_values=global_config_values) self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir) self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME) print(f'Checkpoints will be saved to {self.ckpt_dir}') saver.set_out_dir(self.ckpt_dir) if global_config.get('ds_syn', None): underlying = dl_train.dataset while not isinstance(underlying, _CheckerboardDataset): underlying = underlying.ds underlying.save_all(self.log_dir) # Create summary writer sw = sw_cls(self.log_dir) self.summarizer = vis.summarizable_module.Summarizer(sw) net.register_summarizer(self.summarizer) self.blueprint.register_summarizer(self.summarizer) # Try to write filenames somewhere try: dl_train.dataset.write_file_names_to_txt(self.log_dir) except AttributeError: raise AttributeError( f'dl_train.dataset of type {type(dl_train.dataset)} does not support ' f'write_file_names_to_txt(log_dir)!') # superclass setup super(MultiscaleTrainer, self).__init__(dl_train, [self.optim], net, sw, max_epochs=self.config_dl.max_epochs, log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)
def __init__(self, config_en): super(EnhancementNetwork, self).__init__() self.config_en = config_en Cf = config_en.Cf kernel_size = config_en.kernel_size n_resblock = config_en.n_resblock more_gdn = global_config.get('more_gdn', False) more_act = global_config.get('more_act', False) act_body = act.make(Cf, inverse=False) act_tail = act.make(Cf, inverse=False) self.head = pe.default_conv(3, Cf, 3) if more_act: self.head = nn.Sequential(self.head, act_body) self._down_up = global_config.get('down_up', None) if self._down_up: self.down = pe.default_conv(Cf, Cf, global_config.get('fw_du', 3), stride=2) if more_gdn: self.down = nn.Sequential(self.down, GDN(Cf)) if more_act: self.down = nn.Sequential(self.down, act_body) else: self.down = lambda x: x self.unet_skip_conv = None if global_config.get('unet_skip', None): self.unet_skip_conv = nn.Sequential( pe.default_conv(2*Cf, Cf, 3), nn.ReLU(inplace=True)) assert not global_config.get('learned_skip', False) # Cf_resnet = global_config.get('Cf_resnet', Cf) # print('*** Cf_resnet ==', Cf_resnet5) norm_cls = None if global_config.get('inorm', False): print('***Using Instance Norm!') norm_cls = lambda: nn.InstanceNorm2d(Cf, affine=True) if global_config.get('gdn', False): norm_cls = lambda: GDN(Cf) use_norm_for_long = not global_config.get('no_norm_final', False) if not use_norm_for_long: print('*** no norm for final') def make_res_block(_act, _use_norm=True): return edsr.ResBlock( pe.default_conv, Cf, kernel_size, act=_act, norm_cls=norm_cls if _use_norm else None, res_scale=global_config.get('res_scale', 0.1)) norm_in_body = True if global_config.get('gdn_as_nl', False): print('*** GDN as non linearity!') norm_cls = None act_body = GDN(Cf) if not global_config.get('gdnfreetail', False): act_tail = GDN(Cf) norm_in_body = False use_norm_for_long = False m_body = [ make_res_block(act_body, norm_in_body) for _ in range(n_resblock) ] m_body.append(pe.default_conv(Cf, Cf, kernel_size)) self.body = nn.Sequential(*m_body) if self._down_up: if self._down_up == 'deconv': up = ups.DeconvUp(config_en) elif self._down_up == 'nn': up = ups.ResizeConvUp(config_en) else: up = edsr.Upsampler(pe.default_conv, 2, Cf, act=False) if more_gdn: up = nn.Sequential(up, GDN(Cf, inverse=True)) if more_act: up = nn.Sequential(up, act_body) print('*** DownUp, adding', up) self.after_skip = up else: self.after_skip = lambda x: x tail_networks = {} if global_config.get('deeptails', False): raise NotImplemented # num_blocks = global_config['deeptails'] # for name in ('sigmas', 'means'): # tail_networks[name] = lambda: SequentialWithSkip( # body=nn.Sequential(*[make_res_block(nn.LeakyReLU(inplace=True)) # for _ in range(num_blocks)]), # final=pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1)) def _tail(fw_=3): if global_config.get('atrous', None): print('Atrous Tail') assert 'long_sigma' in global_config assert 'long_means' in global_config return [ prob_clf.StackedAtrousConvs( atrous_rates_str='1,2,4', Cin=Cf, Cout=prob_clf.ProbClfTail.get_cout(config_en), Catrous=Cf//2, bias=False, activation=nn.LeakyReLU(inplace=True))] else: # default so far return [ pe.default_conv(Cf, Cf, fw_), nn.LeakyReLU(inplace=True), pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1), # final 1x1 ] if global_config.get('long_sigma', False): fw_sigma = global_config.get('fw_s', 5) print('filter_width for sigma =', fw_sigma) modules = [make_res_block(act_tail, use_norm_for_long), *_tail(fw_sigma)] if global_config.get('fc2', False): print('Adding another 1x1 conv!') modules.insert(-1, pe.default_conv(Cf, Cf, 1)) tail_networks['sigmas'] = lambda: pe.FeatureMapSaverSequential( *modules, saver=None # pe.FeatureMapSaver() ) print('Did set tail_networks.sigmas') if global_config.get('long_means', False): tail_networks['means'] = lambda: pe.FeatureMapSaverSequential( make_res_block(act_tail, use_norm_for_long), *_tail() # saver=self.savers['final_sigmas'], idx=-2 ) print('Did set tail_networks.means') if global_config.get('long_pis', False): tail_networks['pis'] = lambda: pe.FeatureMapSaverSequential( make_res_block(act_tail, use_norm_for_long), pe.default_conv(Cf, Cf, 3), # no crazy smoothing nn.LeakyReLU(inplace=True), # a non linearity pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1), # final 1x1 saver=None ) print('Did set tail_networks.pis') if global_config.get('long_lambdas', False): tail_networks['lambdas'] = lambda: nn.Sequential( make_res_block(act_tail, use_norm_for_long), pe.default_conv(Cf, Cf, 3), # no crazy smoothing nn.LeakyReLU(inplace=True), # a non linearity pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1), # final 1x1 ) print('Did set tail_networks.lambdas') if global_config.get('longer_lambda', False): tail_networks['lambdas'] = lambda: nn.Sequential( pe.default_conv(Cf, Cf, 3), nn.LeakyReLU(inplace=True), pe.default_conv(Cf, prob_clf.ProbClfTail.get_cout(config_en), 1), # final 1x1 ) print('Did set tail_networks.lambdas') self.side_information_mode = False if global_config.get('side_information', False): print('*** Using side_information!') self.side_information_mode = True Ccond = global_config['side_information'] self.side_information_net = SideInformationNetwork(Ccond) self.side_information_conv = ConditionalConvolution(Cf, Cf, Ccond, kernel_size, activation=nn.LeakyReLU(inplace=True)) print('Setting tail_networks[', tail_networks.keys(), ']') self.tail = prob_clf.ProbClfTail.from_config(config_en, tail_networks=tail_networks)