Ejemplo n.º 1
0
    def test_prepare_tuple_argument(self):

        rank = 3

        # Test that scalars are processed to tuples.
        arg = convolution_util.prepare_tuple_argument(
            self.make_integer_input(2),
            n=rank,
            arg_name='arg',
            validate_args=True)
        self.assertIsInstance(arg, list)
        self.assertLen(arg, rank)

        # Test that `Tensor` args are processed correctly.
        arg = convolution_util.prepare_tuple_argument(self.make_integer_input(
            [2, 3, 4]),
                                                      n=rank,
                                                      arg_name='arg_2',
                                                      validate_args=True)
        self.assertIsInstance(arg, list)
        self.assertLen(arg, rank)

        with self.assertRaisesRegex(ValueError,
                                    'must be equal to `1` or to the rank'):
            convolution_util.prepare_tuple_argument(self.make_integer_input(
                [1, 2]),
                                                    n=rank,
                                                    arg_name='invalid_arg',
                                                    validate_args=True)
  def __init__(
      self,
      input_size,
      output_size,          # keras::Conv::filters
      # Conv specific.
      filter_shape,         # keras::Conv::kernel_size
      rank=2,               # keras::Conv::rank
      strides=1,            # keras::Conv::strides
      padding='VALID',      # keras::Conv::padding; 'CAUSAL' not implemented.
                            # keras::Conv::data_format is not implemented
      dilations=1,          # keras::Conv::dilation_rate
      # Weights
      kernel_initializer=None,  # tfp.nn.initializers.glorot_uniform()
      bias_initializer=None,    # tf.initializers.zeros()
      make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias,
      dtype=tf.float32,
      index_dtype=tf.int32,
      batch_shape=(),
      # Misc
      activation_fn=None,
      validate_args=False,
      name=None):
    """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      kernel_initializer: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      bias_initializer: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      index_dtype: ...
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      validate_args: ...
      name: ...
        Default value: `None` (i.e., `'ConvolutionV2'`).
    """
    filter_shape = convolution_util.prepare_tuple_argument(
        filter_shape, rank, arg_name='filter_shape',
        validate_args=validate_args)
    batch_shape = (tf.constant([], dtype=tf.int32) if batch_shape is None
                   else ps.cast(ps.reshape(batch_shape, shape=[-1]), tf.int32))
    batch_ndims = ps.size(batch_shape)

    apply_kernel_fn = convolution_util.make_convolution_fn(
        filter_shape, rank=2, strides=strides, padding=padding,
        dilations=dilations, dtype=index_dtype, validate_args=validate_args)

    kernel_shape = ps.concat(
        [batch_shape, [ps.reduce_prod(filter_shape) * input_size, output_size]],
        axis=0)
    bias_shape = ps.concat(
        [batch_shape, tf.ones(rank), [output_size]], axis=0)
    kernel, bias = make_kernel_bias_fn(
        kernel_shape, bias_shape,
        kernel_initializer, bias_initializer,
        batch_ndims, batch_ndims,
        dtype)

    self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
    super(ConvolutionV2, self).__init__(
        kernel=kernel,
        bias=bias,
        apply_kernel_fn=apply_kernel_fn,
        dtype=dtype,
        activation_fn=activation_fn,
        validate_args=validate_args,
        name=name)
  def __init__(
      self,
      input_size,
      output_size,          # keras::Conv::filters
      # Conv specific.
      filter_shape,         # keras::Conv::kernel_size
      rank=2,               # keras::Conv::rank
      strides=1,            # keras::Conv::strides
      padding='VALID',      # keras::Conv::padding; 'CAUSAL' not implemented.
                            # keras::Conv::data_format is not implemented
      dilations=1,          # keras::Conv::dilation_rate
      # Weights
      kernel_initializer=None,  # tfp.nn.initializers.glorot_uniform()
      bias_initializer=None,    # tf.initializers.zeros()
      make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag,
      make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab,
      posterior_value_fn=tfd.Distribution.sample,
      unpack_weights_fn=unpack_kernel_and_bias,
      dtype=tf.float32,
      # Misc
      activation_fn=None,
      seed=None,
      validate_args=False,
      name=None):
    """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      kernel_initializer: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      bias_initializer: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_posterior_fn: ...
        Default value:
          `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`.
      make_prior_fn: ...
        Default value:
          `tfp.experimental.nn.util.make_kernel_bias_prior_spike_and_slab`.
      posterior_value_fn: ...
        Default valye: `tfd.Distribution.sample`
      unpack_weights_fn:
        Default value: `unpack_kernel_and_bias`
      dtype: ...
        Default value: `tf.float32`.
      activation_fn: ...
        Default value: `None`.
      seed: ...
        Default value: `None` (i.e., no seed).
      validate_args: ...
      name: ...
        Default value: `None` (i.e.,
        `'ConvolutionVariationalFlipout'`).
    """
    filter_shape = convolution_util.prepare_tuple_argument(
        filter_shape, rank, arg_name='filter_shape',
        validate_args=validate_args)
    kernel_shape = ps.concat(
        [filter_shape, [input_size, output_size]], axis=0)
    self._make_posterior_fn = make_posterior_fn  # For variable tracking.
    self._make_prior_fn = make_prior_fn  # For variable tracking.
    batch_ndims = 0
    super(ConvolutionVariationalFlipout, self).__init__(
        posterior=make_posterior_fn(
            kernel_shape, [output_size],
            kernel_initializer, bias_initializer,
            batch_ndims, batch_ndims,
            dtype),
        prior=make_prior_fn(
            kernel_shape, [output_size],
            kernel_initializer, bias_initializer,
            batch_ndims, batch_ndims,
            dtype),
        apply_kernel_fn=_make_convolution_fn(
            rank, strides, padding, dilations),
        posterior_value_fn=posterior_value_fn,
        unpack_weights_fn=unpack_weights_fn,
        dtype=dtype,
        activation_fn=activation_fn,
        validate_args=validate_args,
        seed=seed,
        name=name)
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            output_padding=None,  # keras::ConvTranspose::output_padding
            method='auto',
            # Weights
            kernel_initializer=None,  # tfp.nn.initializers.glorot_uniform()
            bias_initializer=None,  # tf.initializers.zeros()
            make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias,
            dtype=tf.float32,
            index_dtype=tf.int32,
            # Misc
            activation_fn=None,
            validate_args=False,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      output_padding: An `int` or length-`rank` tuple/list representing the
        amount of padding along the input spatial dimensions (e.g., depth,
        height, width). A single `int` indicates the same value for all spatial
        dimensions. The amount of output padding along a given dimension must be
        lower than the stride along that same dimension.  If set to `None`
        (default), the output shape is inferred.
        In Keras, this argument has the same name and semantics.
        Default value: `None` (i.e., inferred).
      method: ...
      kernel_initializer: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      bias_initializer: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      index_dtype: ...
      activation_fn: ...
        Default value: `None`.
      validate_args: ...
      name: ...
        Default value: `None` (i.e., `'ConvolutionTranspose'`).
    """
        filter_shape = convolution_util.prepare_tuple_argument(
            filter_shape, rank, 'filter_shape', validate_args)
        kernel_shape = ps.convert_to_shape_tensor(
            [ps.reduce_prod(filter_shape) * input_size, output_size])
        batch_ndims = 0
        kernel, bias = make_kernel_bias_fn(kernel_shape, [output_size],
                                           kernel_initializer,
                                           bias_initializer, batch_ndims,
                                           batch_ndims, dtype)

        apply_kernel_fn = _get_convolution_transpose_fn(strides, method)(
            filter_shape,
            strides,
            padding,
            rank=2,
            dilations=dilations,
            dtype=index_dtype,
            validate_args=validate_args)

        super(ConvolutionTranspose,
              self).__init__(kernel=kernel,
                             bias=bias,
                             apply_kernel_fn=apply_kernel_fn,
                             dtype=dtype,
                             activation_fn=activation_fn,
                             validate_args=validate_args,
                             name=name)