def __init__( self, input_size, output_size, # keras::Conv::filters # Conv specific. filter_shape, # keras::Conv::kernel_size rank=2, # keras::Conv::rank strides=1, # keras::Conv::strides padding='VALID', # keras::Conv::padding; 'CAUSAL' not implemented. # keras::Conv::data_format is not implemented dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() bias_initializer=None, # tf.initializers.zeros() make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias, dtype=tf.float32, index_dtype=tf.int32, batch_shape=(), # Misc activation_fn=None, validate_args=False, name=None): """Constructs layer. Note: `data_format` is not supported since all nn layers operate on the rightmost column. If your channel dimension is not rightmost, use `tf.transpose` before calling this layer. For example, if your channel dimension is second from the left, the following code will move it rightmost: ```python inputs = tf.transpose(inputs, tf.concat([ [0], tf.range(2, tf.rank(inputs)), [1]], axis=0)) ``` Args: input_size: ... In Keras, this argument is inferred from the rightmost input shape, i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the second from the rightmost dimension of both `inputs` and `kernel`. Default value: `None`. output_size: ... In Keras, this argument is called `filters`. This argument specifies the rightmost dimension size of both `kernel` and `bias`. filter_shape: ... In Keras, this argument is called `kernel_size`. This argument specifies the leftmost `rank` dimensions' sizes of `kernel`. rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. This argument implies the number of `kernel` dimensions, i.e., `kernel.shape.rank == rank + 2`. In Keras, this argument has the same name and semantics. Default value: `2`. strides: An integer or tuple/list of n integers, specifying the stride length of the convolution. In Keras, this argument has the same name and semantics. Default value: `1`. padding: One of `"VALID"` or `"SAME"` (case-insensitive). In Keras, this argument has the same name and semantics (except we don't support `"CAUSAL"`). Default value: `'VALID'`. dilations: An integer or tuple/list of `rank` integers, specifying the dilation rate to use for dilated convolution. Currently, specifying any `dilations` value != 1 is incompatible with specifying any `strides` value != 1. In Keras, this argument is called `dilation_rate`. Default value: `1`. kernel_initializer: ... Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... Default value: `None` (i.e., `tf.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... Default value: `tf.float32`. index_dtype: ... batch_shape: ... Default value: `()`. activation_fn: ... Default value: `None`. validate_args: ... name: ... Default value: `None` (i.e., `'ConvolutionV2'`). """ filter_shape = convolution_util.prepare_tuple_argument( filter_shape, rank, arg_name='filter_shape', validate_args=validate_args) batch_shape = (tf.constant([], dtype=tf.int32) if batch_shape is None else ps.cast(ps.reshape(batch_shape, shape=[-1]), tf.int32)) batch_ndims = ps.size(batch_shape) apply_kernel_fn = convolution_util.make_convolution_fn( filter_shape, rank=2, strides=strides, padding=padding, dilations=dilations, dtype=index_dtype, validate_args=validate_args) kernel_shape = ps.concat( [batch_shape, [ps.reduce_prod(filter_shape) * input_size, output_size]], axis=0) bias_shape = ps.concat( [batch_shape, tf.ones(rank), [output_size]], axis=0) kernel, bias = make_kernel_bias_fn( kernel_shape, bias_shape, kernel_initializer, bias_initializer, batch_ndims, batch_ndims, dtype) self._make_kernel_bias_fn = make_kernel_bias_fn # For tracking. super(ConvolutionV2, self).__init__( kernel=kernel, bias=bias, apply_kernel_fn=apply_kernel_fn, dtype=dtype, activation_fn=activation_fn, validate_args=validate_args, name=name)
def __init__( self, input_size, output_size, # keras::Conv::filters # Conv specific. filter_shape, # keras::Conv::kernel_size rank=2, # keras::Conv::rank strides=1, # keras::Conv::strides padding='VALID', # keras::Conv::padding; 'CAUSAL' not implemented. # keras::Conv::data_format is not implemented dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() bias_initializer=None, # tf.initializers.zeros() make_posterior_fn=kernel_bias_lib. make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib. make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, unpack_weights_fn=unpack_kernel_and_bias, dtype=tf.float32, index_dtype=tf.int32, # Misc activation_fn=None, seed=None, validate_args=False, name=None): """Constructs layer. Note: `data_format` is not supported since all nn layers operate on the rightmost column. If your channel dimension is not rightmost, use `tf.transpose` before calling this layer. For example, if your channel dimension is second from the left, the following code will move it rightmost: ```python inputs = tf.transpose(inputs, tf.concat([ [0], tf.range(2, tf.rank(inputs)), [1]], axis=0)) ``` Args: input_size: ... In Keras, this argument is inferred from the rightmost input shape, i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the second from the rightmost dimension of both `inputs` and `kernel`. Default value: `None`. output_size: ... In Keras, this argument is called `filters`. This argument specifies the rightmost dimension size of both `kernel` and `bias`. filter_shape: ... In Keras, this argument is called `kernel_size`. This argument specifies the leftmost `rank` dimensions' sizes of `kernel`. rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. This argument implies the number of `kernel` dimensions, i.e., `kernel.shape.rank == rank + 2`. In Keras, this argument has the same name and semantics. Default value: `2`. strides: An integer or tuple/list of n integers, specifying the stride length of the convolution. In Keras, this argument has the same name and semantics. Default value: `1`. padding: One of `"VALID"` or `"SAME"` (case-insensitive). In Keras, this argument has the same name and semantics (except we don't support `"CAUSAL"`). Default value: `'VALID'`. dilations: An integer or tuple/list of `rank` integers, specifying the dilation rate to use for dilated convolution. Currently, specifying any `dilations` value != 1 is incompatible with specifying any `strides` value != 1. In Keras, this argument is called `dilation_rate`. Default value: `1`. kernel_initializer: ... Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... Default value: `None` (i.e., `tf.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. make_prior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_prior_spike_and_slab`. posterior_value_fn: ... Default valye: `tfd.Distribution.sample` unpack_weights_fn: Default value: `unpack_kernel_and_bias` dtype: ... Default value: `tf.float32`. index_dtype: ... activation_fn: ... Default value: `None`. seed: ... Default value: `None` (i.e., no seed). validate_args: ... name: ... Default value: `None` (i.e., `'ConvolutionVariationalFlipoutV2'`). """ filter_shape = convolution_util.prepare_tuple_argument( filter_shape, rank, arg_name='filter_shape', validate_args=validate_args) kernel_shape = ps.concat([filter_shape, [input_size, output_size]], axis=0) self._make_posterior_fn = make_posterior_fn # For variable tracking. self._make_prior_fn = make_prior_fn # For variable tracking. batch_ndims = 0 apply_kernel_fn = convolution_util.make_convolution_fn( filter_shape, rank=2, strides=strides, padding=padding, dilations=dilations, dtype=index_dtype, validate_args=validate_args) # TODO(emilyaf): Update kernel shape and remove this. temp_apply_kernel_fn = lambda x, k: apply_kernel_fn( # pylint: disable=g-long-lambda x, tf.reshape(k, [-1, output_size])) super(ConvolutionVariationalFlipoutV2, self).__init__( posterior=make_posterior_fn(kernel_shape, [output_size], kernel_initializer, bias_initializer, batch_ndims, batch_ndims, dtype), prior=make_prior_fn(kernel_shape, [output_size], kernel_initializer, bias_initializer, batch_ndims, batch_ndims, dtype), apply_kernel_fn=temp_apply_kernel_fn, posterior_value_fn=posterior_value_fn, unpack_weights_fn=unpack_weights_fn, dtype=dtype, activation_fn=activation_fn, seed=seed, validate_args=validate_args, name=name)