예제 #1
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    # Short-circuiting incoming lists and tuples here avoids both
    # Tensor packing / unpacking and numpy 1.20.+ pickiness about
    # np.array(tuple of Tensor).
    if isinstance(arg, (tuple, list)):
        if len(arg) == n:
            return tuple(arg)
        if len(arg) == 1:
            return (arg[0], ) * n

    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
예제 #2
0
 def _inverse(self, y):
     ndims = prefer_static.rank(y)
     indices = prefer_static.reshape(prefer_static.add(self.axis, ndims),
                                     shape=[-1, 1])
     num_left, num_right = prefer_static.unstack(self.paddings,
                                                 num=2,
                                                 axis=-1)
     x = tf.slice(y,
                  begin=prefer_static.tensor_scatter_nd_update(
                      prefer_static.zeros(ndims, dtype=tf.int32), indices,
                      num_left),
                  size=prefer_static.tensor_scatter_nd_sub(
                      prefer_static.shape(y), indices,
                      num_left + num_right))
     if not self.validate_args:
         return x
     assertions = [
         assert_util.assert_equal(
             self._forward(x),
             y,
             message=('Argument `y` to `inverse` was not padded with '
                      '`constant_values`.')),
     ]
     with tf.control_dependencies(assertions):
         return tf.identity(x)
예제 #3
0
                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs
예제 #4
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
예제 #5
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype,
                                         infer_shape=False,
                                         size=1,
                                         dynamic_size=True)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                if padding == 'VALID':
                    out_height = fh + sh * (xh - 1)
                    out_width = fw + sw * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * sh
                    out_width = xw * sw

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
예제 #6
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
예제 #7
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
예제 #8
0
    def op(x, kernel):
        input_dtype = dtype_util.common_dtype([x, kernel],
                                              dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

        batch_shape, event_shape = ps.split(ps.shape(x),
                                            num_or_size_splits=[-1, 3])
        xh, xw, c_in = ps.unstack(event_shape, num=3)
        fh, fw = filter_shape

        assertions = _maybe_validate_input_shapes(ps.shape(kernel),
                                                  channels_in=c_in,
                                                  filter_height=fh,
                                                  filter_width=fw,
                                                  validate_args=validate_args)

        with tf.control_dependencies(assertions):
            if tf.get_static_value(ps.rank(kernel)) == 2:
                flat_x = tf.reshape(x,
                                    shape=ps.concat([[-1], event_shape],
                                                    axis=0))
                flat_y = tf.nn.conv2d(x,
                                      filters=tf.reshape(
                                          kernel, shape=[fh, fw, c_in, -1]),
                                      strides=strides,
                                      padding=padding,
                                      data_format='NHWC',
                                      dilations=dilations)
                output_shape = ps.shape(flat_y)[-3:]
                return tf.reshape(flat_y,
                                  shape=ps.concat([batch_shape, output_shape],
                                                  axis=0))

            pad_values = [
                _get_conv_padding(xdim,
                                  filter_dim=k,
                                  stride=s,
                                  dilation=d,
                                  padding=padding)
                for (xdim, k, s,
                     d) in zip((xh, xw), filter_shape, strides, dilations)
            ]

            idx, shape = im2row_index(
                (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in),
                block_shape=filter_shape,
                slice_step=strides,
                dilations=dilations,
                dtype=dtype)

            if padding == 'SAME':
                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x = tf.pad(x, paddings=paddings, constant_values=0)

            flat_shape = ps.pad(batch_shape,
                                paddings=[[0, 1]],
                                constant_values=-1)
            flat_x = tf.gather(tf.reshape(x, shape=flat_shape),
                               indices=idx,
                               axis=-1)
            im_x = tf.reshape(flat_x,
                              shape=ps.concat([batch_shape, shape], axis=0))
            return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])