def __init__(
      self,
      input_size,
      output_size,          # keras::Conv::filters
      # Conv specific.
      filter_shape,         # keras::Conv::kernel_size
      rank=2,               # keras::Conv::rank
      strides=1,            # keras::Conv::strides
      padding='VALID',      # keras::Conv::padding; 'CAUSAL' not implemented.
                            # keras::Conv::data_format is not implemented
      dilations=1,          # keras::Conv::dilation_rate
      # Weights
      kernel_initializer=None,  # tfp.nn.initializers.glorot_uniform()
      bias_initializer=None,    # tf.initializers.zeros()
      make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias,
      dtype=tf.float32,
      index_dtype=tf.int32,
      batch_shape=(),
      # Misc
      activation_fn=None,
      validate_args=False,
      name=None):
    """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      kernel_initializer: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      bias_initializer: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      index_dtype: ...
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      validate_args: ...
      name: ...
        Default value: `None` (i.e., `'ConvolutionV2'`).
    """
    filter_shape = convolution_util.prepare_tuple_argument(
        filter_shape, rank, arg_name='filter_shape',
        validate_args=validate_args)
    batch_shape = (tf.constant([], dtype=tf.int32) if batch_shape is None
                   else ps.cast(ps.reshape(batch_shape, shape=[-1]), tf.int32))
    batch_ndims = ps.size(batch_shape)

    apply_kernel_fn = convolution_util.make_convolution_fn(
        filter_shape, rank=2, strides=strides, padding=padding,
        dilations=dilations, dtype=index_dtype, validate_args=validate_args)

    kernel_shape = ps.concat(
        [batch_shape, [ps.reduce_prod(filter_shape) * input_size, output_size]],
        axis=0)
    bias_shape = ps.concat(
        [batch_shape, tf.ones(rank), [output_size]], axis=0)
    kernel, bias = make_kernel_bias_fn(
        kernel_shape, bias_shape,
        kernel_initializer, bias_initializer,
        batch_ndims, batch_ndims,
        dtype)

    self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
    super(ConvolutionV2, self).__init__(
        kernel=kernel,
        bias=bias,
        apply_kernel_fn=apply_kernel_fn,
        dtype=dtype,
        activation_fn=activation_fn,
        validate_args=validate_args,
        name=name)
Ejemplo n.º 2
0
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            # Weights
        kernel_initializer=None,  # tfp.nn.initializers.glorot_uniform()
            bias_initializer=None,  # tf.initializers.zeros()
            make_posterior_fn=kernel_bias_lib.
        make_kernel_bias_posterior_mvn_diag,
            make_prior_fn=kernel_bias_lib.
        make_kernel_bias_prior_spike_and_slab,
            posterior_value_fn=tfd.Distribution.sample,
            unpack_weights_fn=unpack_kernel_and_bias,
            dtype=tf.float32,
            index_dtype=tf.int32,
            # Misc
            activation_fn=None,
            seed=None,
            validate_args=False,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      kernel_initializer: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      bias_initializer: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_posterior_fn: ...
        Default value:
          `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`.
      make_prior_fn: ...
        Default value:
          `tfp.experimental.nn.util.make_kernel_bias_prior_spike_and_slab`.
      posterior_value_fn: ...
        Default valye: `tfd.Distribution.sample`
      unpack_weights_fn:
        Default value: `unpack_kernel_and_bias`
      dtype: ...
        Default value: `tf.float32`.
      index_dtype: ...
      activation_fn: ...
        Default value: `None`.
      seed: ...
        Default value: `None` (i.e., no seed).
      validate_args: ...
      name: ...
        Default value: `None` (i.e.,
        `'ConvolutionVariationalFlipoutV2'`).
    """

        filter_shape = convolution_util.prepare_tuple_argument(
            filter_shape,
            rank,
            arg_name='filter_shape',
            validate_args=validate_args)

        kernel_shape = ps.concat([filter_shape, [input_size, output_size]],
                                 axis=0)
        self._make_posterior_fn = make_posterior_fn  # For variable tracking.
        self._make_prior_fn = make_prior_fn  # For variable tracking.
        batch_ndims = 0

        apply_kernel_fn = convolution_util.make_convolution_fn(
            filter_shape,
            rank=2,
            strides=strides,
            padding=padding,
            dilations=dilations,
            dtype=index_dtype,
            validate_args=validate_args)
        # TODO(emilyaf): Update kernel shape and remove this.
        temp_apply_kernel_fn = lambda x, k: apply_kernel_fn(  # pylint: disable=g-long-lambda
            x, tf.reshape(k, [-1, output_size]))
        super(ConvolutionVariationalFlipoutV2, self).__init__(
            posterior=make_posterior_fn(kernel_shape, [output_size],
                                        kernel_initializer, bias_initializer,
                                        batch_ndims, batch_ndims, dtype),
            prior=make_prior_fn(kernel_shape, [output_size],
                                kernel_initializer, bias_initializer,
                                batch_ndims, batch_ndims, dtype),
            apply_kernel_fn=temp_apply_kernel_fn,
            posterior_value_fn=posterior_value_fn,
            unpack_weights_fn=unpack_weights_fn,
            dtype=dtype,
            activation_fn=activation_fn,
            seed=seed,
            validate_args=validate_args,
            name=name)