예제 #1
0
 def train_batch(self, feature, label, learning_rate=1e-4, **kwargs):
     feature = to_list(feature)
     label = to_list(label)
     self.feed_dict.update({
         self.training_phase: True,
         self.learning_rate: learning_rate
     })
     for i in range(len(self.inputs)):
         self.feed_dict[self.inputs[i]] = feature[i]
     for i in range(len(self.label)):
         self.feed_dict[self.label[i]] = label[i]
     loss = kwargs.get('loss') or self.loss
     loss = to_list(loss)
     step = kwargs['steps']
     sess = tf.get_default_session()
     if step % self.nd_iter == 0:
         # update G-net
         sess.run(loss[0], feed_dict=self.feed_dict)
     # update D-net
     sess.run(loss[1:], feed_dict=self.feed_dict)
     loss = sess.run(list(self.train_metric.values()),
                     feed_dict=self.feed_dict)
     ret = {}
     for k, v in zip(self.train_metric, loss):
         ret[k] = v
     return ret
예제 #2
0
 def __init__(self, channels, depth=3, use_ca=False, name='CascadeRdn',
              **kwargs):
   super(CascadeRdn, self).__init__()
   self.name = name
   self.depth = to_list(depth, 2)
   self.ca = use_ca
   in_c, out_c = to_list(channels, 2)
   for i in range(self.depth[0]):
     setattr(self, f'conv11_{i}', nn.Conv2d(in_c + out_c * (i + 1), out_c, 1))
     setattr(self, f'rdn_{i}', Rdb(channels, self.depth[1], **kwargs))
     if use_ca:
       setattr(self, f'rcab_{i}', Rcab(channels))
예제 #3
0
 def _find_last_ckpt(self):
     # restore the latest checkpoint in savedir
     ckpt = tf.train.get_checkpoint_state(self.savedir)
     if ckpt and ckpt.model_checkpoint_path:
         return tf.train.latest_checkpoint(self.savedir)
     # try another way
     ckpt = to_list(self.savedir.glob('*.ckpt.index'))
     # sort as modification time
     ckpt = sorted(ckpt, key=lambda x: x.stat().st_mtime_ns)
     return self.savedir / ckpt[-1].stem if ckpt else None
예제 #4
0
 def _restore_model(self, sess):
     last_checkpoint_step = 0
     for name in self.savers:
         saver = self.savers.get(name)
         ckpt = to_list(self.savedir.glob(f'{name}*.index'))
         if ckpt:
             ckpt = sorted(ckpt, key=lambda x: x.stat().st_mtime_ns)
             ckpt = self.savedir / ckpt[-1].stem
             try:
                 saver.restore(sess, str(ckpt))
             except:
                 tf.logging.warning(
                     f'{name} of model {self.model.name} counld not be restored'
                 )
             last_checkpoint_step = self._parse_ckpt_name(ckpt)
     return last_checkpoint_step
예제 #5
0
 def __init__(self, channels, depth=3, scaling=1.0, name='Rdb', **kwargs):
   super(Rdb, self).__init__()
   self.name = name
   self.depth = depth
   self.scaling = scaling
   in_c, out_c = to_list(channels, 2)
   ks = kwargs.get('kernel_size', 3)
   stride = kwargs.get('stride', 1)
   padding = kwargs.get('padding', ks // 2)
   dilation = kwargs.get('dilation', 1)
   group = kwargs.get('group', 1)
   bias = kwargs.get('bias', True)
   for i in range(depth):
     conv = nn.Conv2d(
       in_c + out_c * i, out_c, ks, stride, padding, dilation, group, bias)
     if i < depth - 1:  # no activation after last layer
       conv = nn.Sequential(conv, nn.ReLU(True))
     setattr(self, f'conv_{i}', conv)
예제 #6
0
    def __init__(self, blocks=(4, 4), **kwargs):
        super(Crdn, self).__init__()
        self.blocks = to_list(blocks, 2)

        self.entry = nn.Sequential(nn.Conv2d(3, 32, 7, 1, 3),
                                   nn.Conv2d(32, 32, 5, 1, 2))
        self.exit = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),
                                  nn.Conv2d(32, 3, 3, 1, 1))
        self.down1 = nn.Conv2d(32, 64, 3, 2, 1)
        self.down2 = nn.Conv2d(64, 128, 3, 2, 1)
        self.up1 = Upsample([128, 64])
        self.up2 = Upsample([64, 32])
        self.cb1 = CascadeRdn(32, 3, True)
        self.cb2 = CascadeRdn(64, 3, True)
        self.cb3 = CascadeRdn(128, 3, True)
        self.cb4 = CascadeRdn(128, 3, True)
        self.cb5 = CascadeRdn(64, 3, True)
        self.cb6 = CascadeRdn(32, 3, True)
예제 #7
0
    def predict(self, files, mode='pil-image1', depth=1, **kwargs):
        r"""Predict output for frames

        Args:
            files: a list of frames as inputs
            mode: specify file format. `pil-image1` for PIL supported images, or `NV12/YV12/RGB` for raw data
            depth: specify length of sequence of images. 1 for images, >1 for videos
        """

        sess = tf.get_default_session()
        ckpt_last = self._restore_model(sess)
        files = [Path(file) for file in to_list(files)]
        data = Dataset(test=files,
                       mode=mode,
                       depth=depth,
                       modcrop=False,
                       **kwargs)
        loader = QuickLoader(1,
                             data,
                             'test',
                             self.model.scale,
                             -1,
                             crop=None,
                             **kwargs)
        it = loader.make_one_shot_iterator()
        if len(it):
            print('===================================')
            print(f'Predicting model: {self.model.name} by {ckpt_last}')
            print('===================================')
        else:
            return
        for img in tqdm.tqdm(it, 'Infer', ascii=True):
            feature, label, name = img[self.fi], img[self.li], img[-1]
            tf.logging.debug('output: ' + str(name))
            for fn in self.feature_callbacks:
                feature = fn(feature, name=name)
            outputs = self.model.test_batch(feature, None)
            for fn in self.output_callbacks:
                outputs = fn(outputs,
                             input=img[self.fi],
                             label=img[self.li],
                             mode=loader.color_format,
                             name=name)
예제 #8
0
 def __init__(self, channels, ratio=16, name='RCAB', **kwargs):
     super(Rcab, self).__init__()
     self.name = name
     self.ratio = ratio
     in_c, out_c = to_list(channels, 2)
     ks = kwargs.get('kernel_size', 3)
     padding = kwargs.get('padding', ks // 2)
     group = kwargs.get('group', 1)
     bias = kwargs.get('bias', True)
     self.c1 = nn.Sequential(
         nn.Conv2d(in_c, out_c, ks, 1, padding, 1, group, bias),
         nn.ReLU(True))
     self.c2 = nn.Conv2d(out_c, out_c, ks, 1, padding, 1, group, bias)
     self.c3 = nn.Sequential(
         nn.Conv2d(out_c, out_c // ratio, 1, groups=group, bias=bias),
         nn.ReLU(True))
     self.c4 = nn.Sequential(
         nn.Conv2d(out_c // ratio, in_c, 1, groups=group, bias=bias),
         nn.Sigmoid())
     self.pooling = nn.AdaptiveAvgPool2d(1)
예제 #9
0
 def __init__(self,
              channels,
              kernel_size,
              activation=None,
              use_bias=True,
              use_bn=False,
              use_sn=False,
              act_first=None):
     super(RB, self).__init__()
     in_c, out_c = to_list(channels, 2)
     conv1 = nn.Conv2d(in_c,
                       out_c,
                       kernel_size,
                       1,
                       kernel_size // 2,
                       bias=use_bias)
     conv2 = nn.Conv2d(out_c,
                       out_c,
                       kernel_size,
                       1,
                       kernel_size // 2,
                       bias=use_bias)
     if use_sn:
         conv1 = nn.utils.spectral_norm(conv1)
         conv2 = nn.utils.spectral_norm(conv2)
     net = [conv1, Activation(activation, in_place=True), conv2]
     if use_bn:
         net.insert(1, nn.BatchNorm2d(out_c))
         if act_first:
             net = [nn.BatchNorm2d(in_c), Activation(activation, in_place=True)] + \
                   net
         else:
             net.append(nn.BatchNorm2d(out_c))
     self.body = nn.Sequential(*net)
     if in_c != out_c:
         self.shortcut = nn.Conv2d(in_c, out_c, 1)
예제 #10
0
 def __init__(self, channels):
   super(Upsample, self).__init__()
   in_c, out_c = to_list(channels, 2)
   self.c1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
   self.c2 = nn.Conv2d(in_c, out_c, 3, 1, 1)