예제 #1
0
    def test_prepare_conv_args(self):
        [filter_shape, rank, strides, padding,
         dilations] = convolution_util.prepare_conv_args((3, 3),
                                                         rank=2,
                                                         strides=2,
                                                         padding='same',
                                                         dilations=(1, 1))

        for arg in [filter_shape, strides, dilations]:
            self.assertLen(arg, rank)

        self.assertEqual(padding, 'SAME')
def _make_convolution_fn(rank, strides, padding, dilations):
  """Helper to create tf convolution op."""
  [
      _,
      rank,
      strides,
      padding,
      dilations,
  ] = convolution_util.prepare_conv_args(1, rank, strides, padding, dilations)
  def op(x, kernel):
    dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
    x = tf.convert_to_tensor(x, dtype=dtype, name='x')
    kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')
    return tf.nn.convolution(
        x, kernel,
        strides=strides,
        padding=padding,
        data_format='NHWC',
        dilations=dilations)
  return lambda x, kernel: nn_util_lib.batchify_op(op, rank + 1, x, kernel)
def _convolution_batch_nhwbc(
    x, kernel, rank, strides, padding, dilations, name):
  """Specialization of batch conv to NHWBC data format."""
  with tf.name_scope(name or 'conv2d_nhwbc'):
    # Prepare arguments.
    [
        _,  # filter shape
        rank,
        _,  # strides
        padding,
        dilations,
    ] = convolution_util.prepare_conv_args(1, rank, strides, padding, dilations)
    strides = prepare_strides(strides, rank + 2, arg_name='strides')

    dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
    x = tf.convert_to_tensor(x, dtype=dtype, name='x')
    kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

    # Step 1: Transpose and double flatten kernel.
    # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
    kernel_shape = ps.shape(kernel)
    kernel_batch_shape, kernel_event_shape = ps.split(
        kernel_shape,
        num_or_size_splits=[-1, rank + 2])
    kernel_batch_size = ps.reduce_prod(kernel_batch_shape)
    kernel_ndims = ps.rank(kernel)
    kernel_batch_ndims = kernel_ndims - rank - 2
    perm = ps.concat([
        ps.range(kernel_batch_ndims, kernel_batch_ndims + rank),
        ps.range(0, kernel_batch_ndims),
        ps.range(kernel_batch_ndims + rank, kernel_ndims),
    ], axis=0)  # Eg, [1, 2, 0, 3, 4]
    kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
    kernel = tf.reshape(
        kernel,
        shape=ps.concat([
            kernel_event_shape[:rank],
            [kernel_batch_size * kernel_event_shape[-2],
             kernel_event_shape[-1]],
        ], axis=0))  # F + [bc, c']

    # Step 2: Double flatten x.
    # x.shape = N + D + B + [c]
    x_shape = ps.shape(x)
    [
        x_sample_shape,
        x_rank_shape,
        x_batch_shape,
        x_channel_shape,
    ] = ps.split(
        x_shape,
        num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
    x = tf.reshape(
        x,  # N + D + B + [c]
        shape=ps.concat([
            [ps.reduce_prod(x_sample_shape)],
            x_rank_shape,
            [ps.reduce_prod(x_batch_shape) *
             ps.reduce_prod(x_channel_shape)],
        ], axis=0))  # [n] + D + [bc]

    # Step 3: Apply convolution.
    y = tf.nn.depthwise_conv2d(
        x, kernel,
        strides=strides,
        padding=padding,
        data_format='NHWC',
        dilations=dilations)
    #  SAME: y.shape = [n, h,      w,      bcc']
    # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

    # Step 4: Reshape/reduce for output.
    y_shape = ps.shape(y)
    y = tf.reshape(
        y,
        shape=ps.concat([
            x_sample_shape,
            y_shape[1:-1],
            kernel_batch_shape,
            kernel_event_shape[-2:],
        ], axis=0))  # N + D' + B + [c, c']
    y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

    return y