Пример #1
0
    def test_batch(self, inputs, label=None, **kwargs):
        """test one batch

    Args:
        inputs: LR images
        label: if None, return only predicted outputs;
          else return outputs along with metrics
        kwargs: for future use

    Return:
        predicted outputs, metrics if `label` is not None
    """

        feature = to_list(inputs)
        label = to_list(label)
        self.feed_dict.update({self.training_phase: False})
        for i in range(len(self.inputs)):
            self.feed_dict[self.inputs[i]] = feature[i]
        if label:
            for i in range(len(self.label)):
                self.feed_dict[self.label[i]] = label[i]
            results = tf.get_default_session().run(self.outputs +
                                                   list(self.metrics.values()),
                                                   feed_dict=self.feed_dict)
            outputs, metrics = results[:len(self.outputs
                                            )], results[len(self.outputs):]
        else:
            results = tf.get_default_session().run(self.outputs,
                                                   feed_dict=self.feed_dict)
            outputs, metrics = results, []
        ret = {}
        for k, v in zip(self.metrics, metrics):
            ret[k] = v
        return outputs, ret
Пример #2
0
    def train_batch(self, feature, label, learning_rate=1e-4, **kwargs):
        """training one batch one step.

    Args:
        feature: input tensors, LR image1 for SR use case
        label: labels, HR image1 for SR use case
        learning_rate: update step size in current calculation
        kwargs: for future use

    Return:
        the results of ops in `self.loss`
    """

        feature = to_list(feature)
        label = to_list(label)
        self.feed_dict.update({
            self.training_phase: True,
            self.learning_rate: learning_rate
        })
        for i in range(len(self.inputs)):
            self.feed_dict[self.inputs[i]] = feature[i]
        for i in range(len(self.label)):
            self.feed_dict[self.label[i]] = label[i]
        loss = kwargs.get('loss') or self.loss
        loss = to_list(loss)
        loss = tf.get_default_session().run(list(self.train_metric.values()) +
                                            loss,
                                            feed_dict=self.feed_dict)
        ret = {}
        for k, v in zip(self.train_metric, loss):
            ret[k] = v
        return ret
Пример #3
0
 def train_batch(self, feature, label, learning_rate=1e-4, **kwargs):
     feature = to_list(feature)
     label = to_list(label)
     self.feed_dict.update({
         self.training_phase: True,
         self.learning_rate: learning_rate
     })
     for i in range(len(self.inputs)):
         self.feed_dict[self.inputs[i]] = feature[i]
     for i in range(len(self.label)):
         self.feed_dict[self.label[i]] = label[i]
     loss = kwargs.get('loss') or self.loss
     loss = to_list(loss)
     step = kwargs['steps']
     sess = tf.get_default_session()
     if step % self.nd_iter == 0:
         # update G-net
         sess.run(loss[0], feed_dict=self.feed_dict)
     # update D-net
     sess.run(loss[1:], feed_dict=self.feed_dict)
     loss = sess.run(list(self.train_metric.values()),
                     feed_dict=self.feed_dict)
     ret = {}
     for k, v in zip(self.train_metric, loss):
         ret[k] = v
     return ret
Пример #4
0
 def __init__(self, layers=3, filters=64, kernel=(9, 5, 5),
              custom_upsample=False,
              name='srcnn', **kwargs):
   super(SRCNN, self).__init__(**kwargs)
   self.name = name
   self.do_up = not custom_upsample
   self.layers = layers
   self.filters = filters
   self.kernel_size = to_list(kernel)
   if len(self.kernel_size) < self.layers:
     self.kernel_size += to_list(kernel[-1],
                                 self.layers - len(self.kernel_size))
Пример #5
0
 def __init__(self,
              layers=3,
              filters=(64, 32),
              kernel=(5, 3, 3),
              name='espcn',
              **kwargs):
     super(ESPCN, self).__init__(**kwargs)
     self.name = name
     self.layers = layers
     self.filters = to_list(filters, layers - 1)
     self.kernel_size = to_list(kernel, layers)
     if len(self.kernel_size) < self.layers:
         self.kernel_size += to_list(kernel[-1],
                                     self.layers - len(self.kernel_size))
Пример #6
0
    def __init__(self, scale, channel, weight_decay=0, **kwargs):
        """Common initialize parameters

    Args:
        scale: the scale factor, can be a list of 2 integer to specify
          different stretch in width and height
        channel: input color channel
        weight_decay: decay of L2 regularization on trainable weights
    """

        self.scale = to_list(scale, repeat=2)
        self.channel = channel
        self.weight_decay = weight_decay  # weights regularization
        self.rgba = False  # deprecated
        self._trainer = VSR  # default trainer

        self.inputs = []  # hold placeholder for model inputs
        # hold some image procession for inputs (i.e. RGB->YUV, if you need)
        self.inputs_preproc = []
        self.label = []  # hold placeholder for model labels
        self.outputs = []  # hold output tensors
        self.loss = []  # this is the optimize op
        self.train_metric = {}  # metrics show at training phase
        self.metrics = {}  # metrics record in tf.summary and show at benchmark
        self.feed_dict = {}
        self.savers = {}
        self.global_steps = None
        self.training_phase = None  # only useful for bn
        self.learning_rate = None
        self.summary_op = None
        self.summary_writer = None
        self.compiled = False
        self.pre_ckpt = None
        self.unknown_args = kwargs
Пример #7
0
 def __init__(self, layers=16, filter_size=(5, 5), depth=7, name='duf',
              **kwargs):
   super(DUF, self).__init__(**kwargs)
   self.layers = layers
   self.filter_size = to_list(filter_size, 2)
   self.depth = depth
   self.name = name
Пример #8
0
 def __init__(self,
              name='drsr_v2',
              noise_config=None,
              weights=(1, 10, 1e-5),
              level=1,
              mean_shift=(0, 0, 0),
              arch=None,
              auto_shift=None,
              **kwargs):
     super(DRSR, self).__init__(**kwargs)
     self.name = name
     self.noise = Config(scale=0, offset=0, penalty=0.5, max=0, layers=7)
     if isinstance(noise_config, (dict, Config)):
         self.noise.update(**noise_config)
         self.noise.crf = np.load(self.noise.crf)
         self.noise.offset = to_list(self.noise.offset, 4)
         self.noise.offset = [x / 255 for x in self.noise.offset]
         self.noise.max /= 255
     self.weights = weights
     self.level = level
     if mean_shift is not None:
         self.norm = partial(_normalize, shift=mean_shift)
         self.denorm = partial(_denormalize, shift=mean_shift)
     self.arch = arch
     self.auto = auto_shift
     self.to_sum = []
Пример #9
0
def _make_displacement(x, patch=3, max_dis=1, stride1=1, stride2=1):
    """[B, H, W, C]->[B, H, W, V, d*d]"""
    k1, k2 = to_list(patch, 2)
    d1, d2 = to_list(max_dis, 2)
    h, w = tf.shape(x)[1], tf.shape(x)[2]
    padding = [[0, 0], [k1 // 2 + d1] * 2, [k2 // 2 + d2] * 2, [0, 0]]
    padded_x = tf.pad(x, padding)
    disp = []
    vec = []
    for i in range(0, 2 * d1 + 1, stride2):
        for j in range(0, 2 * d2 + 1, stride2):
            for k in range(k1):
                for l in range(k2):
                    vec.append(padded_x[:, i + k:i + k + h:stride1,
                                        j + l:j + l + w:stride1, :])
            disp.append(tf.concat(vec, -1))
            vec.clear()
    return tf.stack(disp, axis=-1)
Пример #10
0
def _make_vector(x, patch=3, stride=1):
    """[B, H, W, C]->[B, H, W, c*k1*k2]"""
    k1, k2 = to_list(patch, 2)
    h, w = tf.shape(x)[1], tf.shape(x)[2]
    padded_x = tf.pad(x, [[0, 0], [k1 // 2] * 2, [k2 // 2] * 2, [0, 0]])
    vec = []
    for i in range(k1):
        for j in range(k2):
            vec.append(padded_x[:, i:i + h:stride, j:j + w:stride, :])
    return tf.concat(vec, axis=-1)
Пример #11
0
 def __init__(self, layers, epsilon=1e-3, name='lapsrn', **kwargs):
     super(LapSRN, self).__init__(**kwargs)
     self.epsilon2 = epsilon**2
     self.name = name
     s0, s1 = self.scale
     if np.any(np.log2([s0, s1]) != np.round(np.log2([s0, s1]))):
         raise ValueError('Scale factor must be pow of 2.'
                          'Error: scale={},{}'.format(s0, s1))
     assert s0 == s1
     self.level = int(np.log2(s0))
     self.layers = to_list(layers, self.level)
Пример #12
0
 def _find_last_ckpt(self):
   # restore the latest checkpoint in save dir
   if self._saved is not None:
     ckpt = tf.train.get_checkpoint_state(self._saved)
     if ckpt and ckpt.model_checkpoint_path:
       return tf.train.latest_checkpoint(self._saved)
     # try another way
     ckpt = to_list(self._saved.glob('*.ckpt.index'))
     # sort as modification time
     ckpt = sorted(ckpt, key=lambda x: x.stat().st_mtime_ns)
     return self._saved / ckpt[-1].stem if ckpt else None
Пример #13
0
def bicubic_rescale(img, scale):
    """Resize image in tensorflow.

  NOTE: tf.image.resize_bicubic uses different boundary to PIL.Image.resize,
    try to use resize_area without aligned corners.
  """
    print("bicubic_rescale is deprecated. Use bicubic_resize instead.")
    with tf.name_scope('Bicubic'):
        shape = tf.shape(img)
        scale = to_list(scale, 2)
        shape_enlarge = tf.to_float(shape) * [1, *scale, 1]
        shape_enlarge = tf.to_int32(shape_enlarge)
        return tf.image.resize_bicubic(img, shape_enlarge[1:3], False)
Пример #14
0
def crop_to_batch(image, scale):
    """Crop image into `scale[0]*scale[1]` parts, and concat into batch dimension

    Args:
        image: A 4-D tensor with [N, H, W, C]
        scale: A 1-D tensor or scalar of scale factor for width and height
  """

    scale = to_list(scale, 2)
    with tf.name_scope('BatchEnhance'):
        hs = tf.split(image, scale[1], axis=1)
        image = tf.concat(hs, axis=0)
        rs = tf.split(image, scale[0], axis=2)
        return tf.concat(rs, axis=0)
Пример #15
0
def cascade_rdn(layers: Layers,
                inputs,
                depth,
                use_ca=False,
                scope=None,
                reuse=None,
                **kwargs):
    """Cascaded residual dense block.
  Args:
    layers: child class of Layers
    inputs: input tensor
    depth: an int or list of 2 ints, representing number of rdbs
    use_ca: insert channel attention layer
    scope: scope name
    reuse: reuse variables
  """

    k = kwargs.pop('kernel_size', 3)
    f = kwargs.pop('filters', 64)
    act = kwargs.pop('activation', 'relu')
    kwargs.pop('name', None)
    depth = to_list(depth, 2)
    with tf.variable_scope(scope, 'CascadeRDN', reuse=reuse):
        fl = [inputs]
        x = inputs
        for i in range(depth[0]):
            x = rdn(layers,
                    x,
                    depth[1],
                    kernel_size=k,
                    filters=f,
                    activation=act,
                    **kwargs)
            if use_ca:
                x = rcab(layers,
                         x,
                         kernel_size=k,
                         filters=f,
                         activation=act,
                         **kwargs)
            fl.append(x)
            x = tf.concat(fl, -1)
            x = layers.conv2d(x, f, 1, **kwargs)
        return x
Пример #16
0
def pixel_shift(image, scale, channel=1):
  """Efficient Sub-pixel Convolution,
    see paper: https://arxiv.org/abs/1609.05158

    Args:
        image: A 4-D tensor of [N, H, W, C*scale[0]*scale[1]]
        scale: A scalar or 1-D tensor with 2 elements, the scale factor for
          width and height respectively
        channel: specify the channel number

    Return:
        A 4-D tensor of [N, H*scale[1], W*scale[0], C]
  """

  with tf.name_scope('PixelShift'):
    r = to_list(scale, 2)
    shape = tf.shape(image)
    h, w = shape[1], shape[2]
    image = tf.reshape(image, [-1, h, w, r[1], r[0], channel])
    image = tf.transpose(image, perm=[0, 1, 3, 2, 4, 5])  # B, H, r, W, r, C
    image = tf.reshape(image, [-1, h * r[1], w * r[0], channel])
    return image
Пример #17
0
 def _restore_model(self, sess):
   last_checkpoint_step = 0
   if self.model.pre_ckpt is not None:
     _saved = Path(self.model.pre_ckpt)
   else:
     _saved = self._saved
   if _saved is None:
     return last_checkpoint_step
   for name in self.savers:
     saver = self.savers.get(name)
     ckpt = to_list(_saved.glob('{}*.index'.format(name)))
     if ckpt:
       ckpt = sorted(ckpt, key=lambda x: x.stat().st_mtime_ns)
       ckpt = _saved / ckpt[-1].stem
       try:
         saver.restore(sess, str(ckpt))
       except tf.errors.NotFoundError:
         LOG.warning(
             '{} of model {} could not be restored'.format(
                 name, self.model.name))
       last_checkpoint_step = _parse_ckpt_name(ckpt)
   return last_checkpoint_step
Пример #18
0
def correlation(f1, f2, patch, max_displacement, stride1=1, stride2=1):
    """calculate correlation between feature map "f1" and "f2".
  See "FlowNet: Learning Optical Flow with Convolutional Networks" for
  details.

  Args:
      f1: a 4-D tensor with shape [B, H, W, C]
      f2: a 4-D tensor with shape [B, H, W, C]
      patch: an integer or a list like [k1, k2], window size for comparison
      max_displacement: an integer, representing the max searching distance
      stride1: stride for patch
      stride2: stride for displacement

  Returns:
      a 4-D correlation tensor with shape [B, H, W, d*d]
  """
    channel = f1.shape[-1]
    norm = np.prod(to_list(patch, 2) + [channel])
    v1 = _make_vector(f1, patch, stride1)
    v1 = tf.expand_dims(v1, -2)
    v2 = _make_displacement(f2, patch, max_displacement, stride1, stride2)
    corr = tf.matmul(v1, v2) / tf.to_float(norm)
    return tf.squeeze(corr, axis=-2)
Пример #19
0
def shrink_mod_scale(x, scale):
    """clip each dim of x to multiple of scale"""
    return [_x - _x % _s for _x, _s in zip(x, to_list(scale, 2))]
    def upscale(self,
                image,
                method='espcn',
                scale=None,
                direct_output=True,
                **kwargs):
        """Image up-scale layer

    Upsample `image` width and height by scale factor `scale[0]` and
    `scale[1]`. Perform upsample progressively: i.e. x12:= x2->x2->x3

    Args:
        image: tensors to upsample
        method: method could be 'espcn', 'nearest' or 'deconv'
        scale: None or int or [int, int]. If None, `scale`=`self.scale`
        direct_output: output channel is the desired RGB or Grayscale, if
          False, keep the same channels as `image`
    """
        _allowed_method = ('espcn', 'nearest', 'deconv')
        assert str(method).lower() in _allowed_method
        method = str(method).lower()
        act = kwargs.get('activator')
        ki = kwargs.get('kernel_initializer', 'he_normal')
        kr = kwargs.get('kernel_regularizer', 'l2')
        use_bias = kwargs.get('use_bias', True)

        scale_x, scale_y = to_list(scale, 2) or self.scale
        features = self.channel if direct_output else image.shape.as_list()[-1]
        while scale_x > 1 or scale_y > 1:
            if scale_x % 2 == 1 or scale_y % 2 == 1:
                if method == 'espcn':
                    image = pixel_shift(
                        self.conv2d(image,
                                    features * scale_x * scale_y,
                                    3,
                                    use_bias=use_bias,
                                    kernel_initializer=ki,
                                    kernel_regularizer=kr), [scale_x, scale_y],
                        features)
                elif method == 'nearest':
                    image = pixel_shift(
                        tf.concat([image] * scale_x * scale_y, -1),
                        [scale_x, scale_y], image.shape[-1])
                elif method == 'deconv':
                    image = self.deconv2d(image,
                                          features,
                                          3,
                                          strides=[scale_y, scale_x],
                                          kernel_initializer=ki,
                                          kernel_regularizer=kr,
                                          use_bias=use_bias)
                if act:
                    image = act(image)
                break
            else:
                scale_x //= 2
                scale_y //= 2
                if method == 'espcn':
                    image = pixel_shift(
                        self.conv2d(image,
                                    features * 4,
                                    3,
                                    use_bias=use_bias,
                                    kernel_initializer=ki,
                                    kernel_regularizer=kr), [2, 2], features)
                elif method == 'nearest':
                    image = pixel_shift(tf.concat([image] * 4, -1), [2, 2],
                                        image.shape[-1])
                elif method == 'deconv':
                    image = self.deconv2d(image,
                                          features,
                                          3,
                                          strides=2,
                                          use_bias=use_bias,
                                          kernel_initializer=ki,
                                          kernel_regularizer=kr)
                if act:
                    image = act(image)
        return image
    def resblock3d(self,
                   x,
                   filters,
                   kernel_size,
                   strides=(1, 1, 1),
                   padding='same',
                   data_format='channels_last',
                   activation=None,
                   use_bias=True,
                   use_batchnorm=False,
                   kernel_initializer='he_normal',
                   kernel_regularizer='l2',
                   placement=None,
                   **kwargs):
        """make a residual block

    Args:
        x:
        filters:
        kernel_size:
        strides:
        padding:
        data_format:
        activation:
        use_bias:
        use_batchnorm:
        kernel_initializer:
        kernel_regularizer:
        placement: 'front' or 'behind', use BN layer in front of or behind
          after the 1st conv2d layer.
    """

        kwargs.update({
            'padding': padding,
            'data_format': data_format,
            'activation': activation,
            'use_bias': use_bias,
            'use_batchnorm': use_batchnorm,
            'kernel_initializer': kernel_initializer,
            'kernel_regularizer': kernel_regularizer
        })
        if placement is None:
            placement = 'behind'
        assert placement in ('front', 'behind')
        name = pop_dict_wo_keyerror(kwargs, 'name')
        reuse = pop_dict_wo_keyerror(kwargs, 'reuse')
        with tf.variable_scope(name, 'ResBlock', reuse=reuse):
            ori = x
            if placement == 'front':
                act = self._act(activation)
                if use_batchnorm:
                    x = tf.layers.batch_normalization(
                        x, training=self.training_phase)
                if act:
                    x = act(x)
            x = self.conv3d(x, filters, kernel_size, **kwargs)
            kwargs.pop('activation')
            if placement == 'front':
                kwargs.pop('use_batchnorm')
            strides = to_list(strides, 3)
            x = self.conv3d(x, filters, kernel_size, strides=strides, **kwargs)
            if ori.shape[-1] != x.shape[-1] or strides[0] > 1:
                # short cut
                ori = self.conv3d(ori,
                                  x.shape[-1],
                                  1,
                                  strides=strides,
                                  kernel_initializer=kernel_initializer)
            ori += x
        return ori