def test_batching(self, input_batch_shape, kernel_batch_shape):
    input_shape = (12, 12, 2)
    filter_shape = (3, 3)
    channels_out = 4
    strides = 2
    dilations = (1, 1)
    padding = 'SAME'

    x, k = _make_input_and_kernel(
        self.make_input,
        input_batch_shape=input_batch_shape,
        input_shape=input_shape,
        kernel_batch_shape=kernel_batch_shape,
        filter_shape=filter_shape,
        channels_out=channels_out,
        dtype=self.dtype)

    conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations)
    y_batched = conv_fn(x, k)

    broadcast_batch_shape = ps.broadcast_shape(
        input_batch_shape, kernel_batch_shape)
    broadcasted_input = tf.broadcast_to(
        x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0))
    broadcasted_kernel = tf.broadcast_to(
        k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0))

    flat_y = tf.reshape(
        y_batched,
        shape=ps.pad(
            ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1))
    flat_x = tf.reshape(
        broadcasted_input,
        shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1))
    flat_tf_kernel = tf.einsum(
        '...ij->...ji',
        tf.reshape(
            broadcasted_kernel,
            shape=ps.concat(
                [(-1,), filter_shape, (input_shape[-1], channels_out)],
                axis=0)))

    rank = 2
    output_shape, strides_ = convolution_util._get_output_shape(
        rank=rank, strides=(strides,) * rank, padding=padding,
        dilations=dilations, input_shape=input_shape, output_size=channels_out,
        filter_shape=filter_shape)

    y_expected = tf.vectorized_map(
        lambda args: tf.nn.conv2d_transpose(  # pylint: disable=g-long-lambda
            args[0][tf.newaxis],
            args[1],
            output_shape=ps.concat([[1], output_shape], axis=0),
            strides=strides_,
            padding=padding),
        elems=(flat_x, flat_tf_kernel))

    [y_actual_, y_expected_] = self.evaluate(
        [flat_y, tf.squeeze(y_expected, axis=1)])
    self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
  def test_works_like_conv2d_transpose(
      self, input_shape, filter_shape, channels_out, strides, padding,
      dilations):

    strides_tuple = strides
    if not self.unequal_strides_ok:
      if strides[0] != strides[1]:
        # Skip this test case if the method does not support unequal strides.
        return
      else:
        strides = strides[0]

    x, k = _make_input_and_kernel(
        self.make_input,
        input_batch_shape=[],
        input_shape=input_shape,
        # Use singleton kernel_batch_shape to avoid the short circuit to
        # `conv2d_transpose`.
        kernel_batch_shape=[1],
        filter_shape=filter_shape,
        channels_out=channels_out,
        dtype=self.dtype)
    conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations)
    y_actual = conv_fn(x, k)
    output_shape, strides_ = convolution_util._get_output_shape(
        rank=2, strides=strides_tuple, padding=padding, dilations=dilations,
        input_shape=input_shape, output_size=channels_out,
        filter_shape=filter_shape)

    tf_kernel = tf.transpose(
        tf.reshape(k, ps.concat(
            [filter_shape, [input_shape[-1], channels_out]], axis=0)),
        perm=[0, 1, 3, 2])
    # conv2d_transpose does not support dilations > 1; use Keras instead.
    if any(d > 1 for d in dilations):
      keras_convt = tf.keras.layers.Conv2DTranspose(
          filters=channels_out,
          kernel_size=filter_shape,
          strides=strides,
          padding=padding,
          dilation_rate=dilations,
          use_bias=False)
      _ = keras_convt(x)  # build kernel
      keras_convt.kernel = tf_kernel
      y_expected = keras_convt(x)
    else:
      y_expected = tf.nn.conv2d_transpose(
          x, tf_kernel, output_shape=output_shape,
          strides=strides_, padding=padding, dilations=dilations)

    [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual])
    self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)