def train_batch(self, feature, label, learning_rate=1e-4, **kwargs): feature = to_list(feature) label = to_list(label) self.feed_dict.update({ self.training_phase: True, self.learning_rate: learning_rate }) for i in range(len(self.inputs)): self.feed_dict[self.inputs[i]] = feature[i] for i in range(len(self.label)): self.feed_dict[self.label[i]] = label[i] loss = kwargs.get('loss') or self.loss loss = to_list(loss) step = kwargs['steps'] sess = tf.get_default_session() if step % self.nd_iter == 0: # update G-net sess.run(loss[0], feed_dict=self.feed_dict) # update D-net sess.run(loss[1:], feed_dict=self.feed_dict) loss = sess.run(list(self.train_metric.values()), feed_dict=self.feed_dict) ret = {} for k, v in zip(self.train_metric, loss): ret[k] = v return ret
def __init__(self, channels, depth=3, use_ca=False, name='CascadeRdn', **kwargs): super(CascadeRdn, self).__init__() self.name = name self.depth = to_list(depth, 2) self.ca = use_ca in_c, out_c = to_list(channels, 2) for i in range(self.depth[0]): setattr(self, f'conv11_{i}', nn.Conv2d(in_c + out_c * (i + 1), out_c, 1)) setattr(self, f'rdn_{i}', Rdb(channels, self.depth[1], **kwargs)) if use_ca: setattr(self, f'rcab_{i}', Rcab(channels))
def _find_last_ckpt(self): # restore the latest checkpoint in savedir ckpt = tf.train.get_checkpoint_state(self.savedir) if ckpt and ckpt.model_checkpoint_path: return tf.train.latest_checkpoint(self.savedir) # try another way ckpt = to_list(self.savedir.glob('*.ckpt.index')) # sort as modification time ckpt = sorted(ckpt, key=lambda x: x.stat().st_mtime_ns) return self.savedir / ckpt[-1].stem if ckpt else None
def _restore_model(self, sess): last_checkpoint_step = 0 for name in self.savers: saver = self.savers.get(name) ckpt = to_list(self.savedir.glob(f'{name}*.index')) if ckpt: ckpt = sorted(ckpt, key=lambda x: x.stat().st_mtime_ns) ckpt = self.savedir / ckpt[-1].stem try: saver.restore(sess, str(ckpt)) except: tf.logging.warning( f'{name} of model {self.model.name} counld not be restored' ) last_checkpoint_step = self._parse_ckpt_name(ckpt) return last_checkpoint_step
def __init__(self, channels, depth=3, scaling=1.0, name='Rdb', **kwargs): super(Rdb, self).__init__() self.name = name self.depth = depth self.scaling = scaling in_c, out_c = to_list(channels, 2) ks = kwargs.get('kernel_size', 3) stride = kwargs.get('stride', 1) padding = kwargs.get('padding', ks // 2) dilation = kwargs.get('dilation', 1) group = kwargs.get('group', 1) bias = kwargs.get('bias', True) for i in range(depth): conv = nn.Conv2d( in_c + out_c * i, out_c, ks, stride, padding, dilation, group, bias) if i < depth - 1: # no activation after last layer conv = nn.Sequential(conv, nn.ReLU(True)) setattr(self, f'conv_{i}', conv)
def __init__(self, blocks=(4, 4), **kwargs): super(Crdn, self).__init__() self.blocks = to_list(blocks, 2) self.entry = nn.Sequential(nn.Conv2d(3, 32, 7, 1, 3), nn.Conv2d(32, 32, 5, 1, 2)) self.exit = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), nn.Conv2d(32, 3, 3, 1, 1)) self.down1 = nn.Conv2d(32, 64, 3, 2, 1) self.down2 = nn.Conv2d(64, 128, 3, 2, 1) self.up1 = Upsample([128, 64]) self.up2 = Upsample([64, 32]) self.cb1 = CascadeRdn(32, 3, True) self.cb2 = CascadeRdn(64, 3, True) self.cb3 = CascadeRdn(128, 3, True) self.cb4 = CascadeRdn(128, 3, True) self.cb5 = CascadeRdn(64, 3, True) self.cb6 = CascadeRdn(32, 3, True)
def predict(self, files, mode='pil-image1', depth=1, **kwargs): r"""Predict output for frames Args: files: a list of frames as inputs mode: specify file format. `pil-image1` for PIL supported images, or `NV12/YV12/RGB` for raw data depth: specify length of sequence of images. 1 for images, >1 for videos """ sess = tf.get_default_session() ckpt_last = self._restore_model(sess) files = [Path(file) for file in to_list(files)] data = Dataset(test=files, mode=mode, depth=depth, modcrop=False, **kwargs) loader = QuickLoader(1, data, 'test', self.model.scale, -1, crop=None, **kwargs) it = loader.make_one_shot_iterator() if len(it): print('===================================') print(f'Predicting model: {self.model.name} by {ckpt_last}') print('===================================') else: return for img in tqdm.tqdm(it, 'Infer', ascii=True): feature, label, name = img[self.fi], img[self.li], img[-1] tf.logging.debug('output: ' + str(name)) for fn in self.feature_callbacks: feature = fn(feature, name=name) outputs = self.model.test_batch(feature, None) for fn in self.output_callbacks: outputs = fn(outputs, input=img[self.fi], label=img[self.li], mode=loader.color_format, name=name)
def __init__(self, channels, ratio=16, name='RCAB', **kwargs): super(Rcab, self).__init__() self.name = name self.ratio = ratio in_c, out_c = to_list(channels, 2) ks = kwargs.get('kernel_size', 3) padding = kwargs.get('padding', ks // 2) group = kwargs.get('group', 1) bias = kwargs.get('bias', True) self.c1 = nn.Sequential( nn.Conv2d(in_c, out_c, ks, 1, padding, 1, group, bias), nn.ReLU(True)) self.c2 = nn.Conv2d(out_c, out_c, ks, 1, padding, 1, group, bias) self.c3 = nn.Sequential( nn.Conv2d(out_c, out_c // ratio, 1, groups=group, bias=bias), nn.ReLU(True)) self.c4 = nn.Sequential( nn.Conv2d(out_c // ratio, in_c, 1, groups=group, bias=bias), nn.Sigmoid()) self.pooling = nn.AdaptiveAvgPool2d(1)
def __init__(self, channels, kernel_size, activation=None, use_bias=True, use_bn=False, use_sn=False, act_first=None): super(RB, self).__init__() in_c, out_c = to_list(channels, 2) conv1 = nn.Conv2d(in_c, out_c, kernel_size, 1, kernel_size // 2, bias=use_bias) conv2 = nn.Conv2d(out_c, out_c, kernel_size, 1, kernel_size // 2, bias=use_bias) if use_sn: conv1 = nn.utils.spectral_norm(conv1) conv2 = nn.utils.spectral_norm(conv2) net = [conv1, Activation(activation, in_place=True), conv2] if use_bn: net.insert(1, nn.BatchNorm2d(out_c)) if act_first: net = [nn.BatchNorm2d(in_c), Activation(activation, in_place=True)] + \ net else: net.append(nn.BatchNorm2d(out_c)) self.body = nn.Sequential(*net) if in_c != out_c: self.shortcut = nn.Conv2d(in_c, out_c, 1)
def __init__(self, channels): super(Upsample, self).__init__() in_c, out_c = to_list(channels, 2) self.c1 = nn.Conv2d(in_c, out_c, 3, 1, 1) self.c2 = nn.Conv2d(in_c, out_c, 3, 1, 1)