def loop_body(i_, event_ind): i = i_ // strides j = i_ % strides i_ind = ps.range(i * fw, ps.maximum(i, fh) * fw, delta=strides * fw, dtype=dtype) j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype) nc = cartesian_add([i_ind, j_ind]) ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) k = ps.reshape(cartesian_add([ ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), ps.range(ps.shape(nc)[1], dtype=dtype) ]), shape=[-1]) last_j = strides - (fw - j - 1) % strides - 1 last_i = strides - (fh - i - 1) % strides - 1 kernel_ind = ps.stack( [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) event_ind = ps.tensor_scatter_nd_update(event_ind, ind[..., tf.newaxis], kernel_ind) return i_ + 1, event_ind
def _log_average_probs_process_args(logits, validate_args, sample_axis, event_axis): """Processes args for `log_average_probs`.""" rank = ps.rank(logits) if sample_axis is None or validate_args: event_axis = ps.reshape(ps.non_negative_axis(event_axis, rank), shape=[-1]) if sample_axis is None: sample_axis = ps.setdiff1d(ps.range(rank), event_axis) elif validate_args: sample_axis = ps.reshape(ps.non_negative_axis(sample_axis, rank), shape=[-1]) return sample_axis, event_axis
def _forward_event_shape_tensor(self, input_shape, is_inverse=False): ndims = ps.size(input_shape) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) extra_sizes = ps.reduce_sum(self.paddings, axis=-1) update_fn = (ps.tensor_scatter_nd_sub if is_inverse else ps.tensor_scatter_nd_add) return update_fn(ps.identity(input_shape), indices, extra_sizes)
def _sample_n(self, n, seed, **kwargs): sample_shape = ps.reshape(self.sample_shape, shape=[-1]) x = self.distribution.sample(ps.concat([[n], sample_shape], axis=0), seed=seed, **kwargs) return tf.transpose(a=x, perm=self._sampling_permutation(sample_ndims=1))
def _sample_n(self, n, seed, **kwargs): sample_shape = prefer_static.reshape(self.sample_shape, shape=[-1]) fake_sample_ndims = prefer_static.rank_from_shape(sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) perm = prefer_static.concat([ [0], prefer_static.range(1 + fake_sample_ndims, 1 + fake_sample_ndims + batch_ndims, dtype=tf.int32), prefer_static.range(1, 1 + fake_sample_ndims, dtype=tf.int32), prefer_static.range( 1 + fake_sample_ndims + batch_ndims, 1 + fake_sample_ndims + batch_ndims + event_ndims, dtype=tf.int32), ], axis=0) x = self.distribution.sample(prefer_static.concat([[n], sample_shape], axis=0), seed=seed, **kwargs) return tf.transpose(a=x, perm=perm)
def _inverse(self, y): ndims = prefer_static.rank(y) indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), shape=[-1, 1]) num_left, num_right = prefer_static.unstack(self.paddings, num=2, axis=-1) x = tf.slice(y, begin=prefer_static.tensor_scatter_nd_update( prefer_static.zeros(ndims, dtype=tf.int32), indices, num_left), size=prefer_static.tensor_scatter_nd_sub( prefer_static.shape(y), indices, num_left + num_right)) if not self.validate_args: return x assertions = [ assert_util.assert_equal( self._forward(x), y, message=('Argument `y` to `inverse` was not padded with ' '`constant_values`.')), ] with tf.control_dependencies(assertions): return tf.identity(x)
def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None): if event_shape is None: event_shape = self.event_shape_tensor() batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) # Continuing the example from `_augment_sample_shape`, suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we # ultimately want to have shape `[n, 4, 2, 3] + event_shape`. # First, we reshape to expand out the batch elements: # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`, # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and # `[4, 1, 3]` is the shape of the elements being added by broadcasting. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) x_with_doubled_batch = tf.reshape( x, ps.concat([ sample_shape, ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape ], axis=0)) # Next, construct the permutation that interleaves the batch dimensions, # resulting in samples with shape # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`. # Note that each interleaved pair of batch dimensions contains exactly one # dim of size `1` and one of size `>= 1`. sample_ndims = ps.rank_from_shape(sample_shape) x_with_interleaved_batch = tf.transpose( x_with_doubled_batch, perm=ps.concat([ ps.range(sample_ndims), sample_ndims + ps.reshape( ps.stack([ ps.range(batch_rank), ps.range(batch_rank) + batch_rank ], axis=-1), [-1]), sample_ndims + 2 * batch_rank + ps.range(ps.rank_from_shape(event_shape)) ], axis=0)) # Final reshape to remove the spurious `1` dimensions. return tf.reshape( x_with_interleaved_batch, ps.concat([sample_shape, batch_shape, event_shape], axis=0))
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): rank = tf.convert_to_tensor(rank, dtype=tf.int32) expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0) expand_shape = prefer_static.pad(prefer_static.shape(x), paddings=[[0, expand_ndims]], constant_values=1) return prefer_static.reshape(x, expand_shape)
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): expand_ndims = ps.maximum(rank - ps.rank(x), 0) expand_shape = ps.concat( [ps.shape(x), ps.ones(shape=[expand_ndims], dtype=tf.int32)], axis=0) return ps.reshape(x, expand_shape)
def _sample_direction_part(state_part, part_seed): state_part_shape = ps.shape(state_part) batch_shape = state_part_shape[:batch_rank] dimension = ps.reduce_prod(state_part_shape[batch_rank:]) return ps.reshape( random_ops.spherical_uniform(shape=batch_shape, dimension=dimension, dtype=state_part.dtype, seed=part_seed), state_part_shape)
def __init__( self, input_size, output_size, # Weights init_kernel_fn=None, # tfp.experimental.nn.initializers.glorot_uniform() init_bias_fn=None, # tf.initializers.zeros() make_kernel_bias_fn=nn_util_lib.make_kernel_bias, dtype=tf.float32, batch_shape=(), # Misc activation_fn=None, name=None): """Constructs layer. Args: input_size: ... output_size: ... init_kernel_fn: ... Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). init_bias_fn: ... Default value: `None` (i.e., `tf.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... Default value: `tf.float32`. batch_shape: ... Default value: `()`. activation_fn: ... Default value: `None`. name: ... Default value: `None` (i.e., `'Affine'`). """ batch_shape = tf.constant( [], dtype=tf.int32) if batch_shape is None else prefer_static.cast( prefer_static.reshape(batch_shape, shape=[-1]), tf.int32) batch_ndims = prefer_static.size(batch_shape) kernel_shape = prefer_static.concat([ batch_shape, [input_size, output_size]], axis=0) bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0) apply_kernel_fn = lambda x, k: tf.matmul( x[..., tf.newaxis, :], k)[..., 0, :] # pylint-disable=long-lambda kernel, bias = make_kernel_bias_fn( kernel_shape, bias_shape, init_kernel_fn, init_bias_fn, batch_ndims, batch_ndims, dtype) self._make_kernel_bias_fn = make_kernel_bias_fn # For tracking. super(Affine, self).__init__( kernel=kernel, bias=bias, apply_kernel_fn=apply_kernel_fn, activation_fn=activation_fn, dtype=dtype, name=name)
def _forward(self, x): ndims = ps.rank(x) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) return tf.pad( x, paddings=ps.tensor_scatter_nd_update( ps.zeros([ndims, 2], dtype=tf.int32), indices, self.paddings), mode=self.mode, constant_values=ps.cast(self.constant_values, dtype=x.dtype))
def _dummy_indices_like(indices): """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`.""" indices_shape = ps.shape(indices) num_particles = indices_shape[0] return tf.broadcast_to( ps.reshape( ps.range(num_particles), ps.pad([num_particles], paddings=[[0, ps.rank_from_shape(indices_shape) - 1]], constant_values=1)), indices_shape)
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): rank = tf.convert_to_tensor(rank, dtype=tf.int32) expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0) expand_shape = prefer_static.concat([ prefer_static.shape(x), prefer_static.ones(shape=[expand_ndims], dtype=tf.int32) ], axis=0) return prefer_static.reshape(x, expand_shape)
def _dummy_indices_like(indices): """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`.""" indices_shape = ps.shape(indices) num_particles = indices_shape[0] return tf.broadcast_to( ps.reshape( ps.range(num_particles), ps.concat([[num_particles], ps.ones([ps.rank_from_shape(indices_shape) - 1], dtype=np.int32)], axis=0)), indices_shape)
def _initialize(shape, dtype, batch_ndims, scale, mode, distribution, seed=None): """Samples a random `Tensor` per specified args.""" if not dtype_util.is_floating(dtype): raise TypeError('Argument `dtype` must be float type (saw: "{}").'.format( dtype)) shape = prefer_static.reshape(shape, shape=[-1]) # Ensure shape is vector. fan_in, fan_out = _compute_fans_from_shape(shape, batch_ndims) fans = _summarize_fans(fan_in, fan_out, mode, dtype) scale = prefer_static.cast(scale, dtype) return _sample_distribution(shape, scale / fans, distribution, seed, dtype)
def sample(self, sample_shape=(), seed=None, name='sample'): # pylint: disable=unused-argument return tf.zeros( ps.concat( [ # sample_shape might be a scalar ps.reshape(ps.convert_to_shape_tensor( sample_shape, tf.int32), shape=[-1]), self.batch_shape_tensor(), self.event_shape_tensor() ], axis=0))
def sample_shape(self): sample_shape = ps.reshape(self._sample_shape, shape=[-1]) shard_axis_size = sample_shape[self.shard_axis] num_devices = self.num_devices if shard_axis_size % num_devices != 0: raise ValueError('Does not shard evenly.') shard_size = shard_axis_size // num_devices sample_shape = ps.concat([ sample_shape[:self.shard_axis], [shard_size], sample_shape[self.shard_axis + 1:] ], axis=0) return sample_shape
def update_running_variance(): diags = [ variance_part.variance() for variance_part in variance_parts ] new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip( variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, # or none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note # this division is inferred from the shapes of the state part, the # log_prob, and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the # lead dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat([[ ps.reduce_prod(state_part_shape[:num_reduce_dims]) ], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape # again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) return new_variance_parts
def _finish_log_prob(self, lp, aux): (sample_ndims, extra_sample_ndims, batch_ndims) = aux # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), ps.concat([ps.ones([sample_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32)], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) # (2) Make the final reduction. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return self._sum_fn()(lp, axis=axis)
def _bcast_and_reduce_logdet(self, underlying_ldj): # Ensure ldj is fully broadcast in the sample dims, i.e. ensure ldj has # full sample shape in the sample axes, before we reduce. batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) sample_ndims = ps.rank(underlying_ldj) - extra_sample_ndims - batch_ndims bcast_ldj_shape = ps.broadcast_shape( ps.shape(underlying_ldj), ps.concat([ps.ones([sample_ndims], tf.int32), ps.ones([batch_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1])], axis=0)) ldj = tf.broadcast_to(underlying_ldj, bcast_ldj_shape) return self._sum_fn(ldj, axis=-1 - ps.range(extra_sample_ndims))
def _sample_and_log_prob(self, sample_shape, seed, **kwargs): sample_ndims = ps.rank_from_shape(sample_shape) batch_ndims = ps.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_shape = ps.reshape(self.sample_shape, shape=[-1]) extra_sample_ndims = ps.rank_from_shape(extra_sample_shape) x, lp = self.distribution.experimental_sample_and_log_prob( ps.concat([sample_shape, extra_sample_shape], axis=0), seed=seed, **kwargs) return ( tf.transpose(x, perm=self._sampling_permutation(sample_ndims)), self._finish_log_prob( lp, aux=(sample_ndims, extra_sample_ndims, batch_ndims)))
def loop_body(i, outputs): subkernel_ind = kernels_ind.read(i) fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) eh = ex_h + fh_ - 1 ew = ex_w + fw_ - 1 subkernel_ind = ps.reshape(ps.reshape( subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + ps.range(c_in), shape=[-1]) k = tf.gather(kernel, subkernel_ind, axis=-2) ind, shape = im2row_index([eh, ew, c_in], block_shape=(fh_, fw_), slice_step=(1, 1), dilations=dilations) x_i = x_pad[..., :eh, :ew, :] x_i_shape = ps.shape(x_i) flat_shape = ps.pad(x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_i, flat_shape) x_ = tf.gather(flat_x, ind, axis=-1) im_x = tf.reshape( x_, ps.concat([x_i_shape[:-3], shape], axis=0)) outputs = outputs.write( i, tf.matmul( im_x, tf.reshape( k, ps.concat([ kernel_batch, [1, fh_ * fw_ * c_in, c_out] ], axis=0)))) return i + 1, outputs
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps): """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc.""" step_indices_to_trace = tf.convert_to_tensor( step_indices_to_trace, dtype_hint=tf.int32) # Warning: breaks gradients. traced_steps_have_rank_zero = ps.equal( ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0) # Canonicalize negative step indices as positive. step_indices_to_trace = ps.where(step_indices_to_trace < 0, num_timesteps + step_indices_to_trace, step_indices_to_trace) # Canonicalize scalars as length-one vectors. return (ps.reshape(step_indices_to_trace, [ps.size(step_indices_to_trace)]), traced_steps_have_rank_zero)
def _fn(self, **kwargs): """Implements summary statistic, eg, mean, stddev, mode.""" sample_shape = ps.reshape(self.sample_shape, shape=[-1]) x = getattr(self.distribution, attr)(**kwargs) shape = ps.concat([ self.distribution.batch_shape_tensor(), ps.ones(ps.rank_from_shape(sample_shape), dtype=sample_shape.dtype), self.distribution.event_shape_tensor(), ], axis=0) x = tf.reshape(x, shape=shape) shape = ps.concat([ self.distribution.batch_shape_tensor(), sample_shape, self.distribution.event_shape_tensor(), ], axis=0) return tf.broadcast_to(x, shape)
def _sampling_permutation(self, sample_ndims): fake_sample_ndims = ps.rank_from_shape( ps.reshape(self.sample_shape, shape=[-1])) event_ndims = ps.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) batch_ndims = ps.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) return ps.concat([ ps.range(sample_ndims), ps.range(sample_ndims + fake_sample_ndims, sample_ndims + fake_sample_ndims + batch_ndims, dtype=tf.int32), ps.range(sample_ndims, sample_ndims + fake_sample_ndims, dtype=tf.int32), ps.range(sample_ndims + fake_sample_ndims + batch_ndims, sample_ndims + fake_sample_ndims + batch_ndims + event_ndims, dtype=tf.int32), ], axis=0)
def _sample_n(self, n, seed=None): batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) underlying_n_batch = ps.reduce_prod(underlying_batch_shape) # Left pad underlying shape with any necessary ones. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) # Determine how many underlying samples to produce. n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch) samps = self.distribution.sample([n, n_bcast_samples], seed=seed) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) event_shape = self.event_shape_tensor() event_rank = ps.rank_from_shape(event_shape) shp = ps.concat([[n], ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape], axis=0) # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp. samps = tf.reshape(samps, shp) # Interleave broadcast and underlying axis indices for transpose. interleaved_batch_axes = ps.reshape( ps.stack([ps.range(batch_rank), ps.range(batch_rank) + batch_rank], axis=-1), [-1]) + 1 event_axes = ps.range(event_rank) + (1 + 2 * batch_rank) perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0) samps = tf.transpose(samps, perm=perm) # Finally, reshape to the fully-broadcast batch shape. return tf.reshape(samps, ps.concat([[n], batch_shape, event_shape], axis=0))
def _log_prob(self, x, **kwargs): batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = ps.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=ps.pad(ps.shape(x), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) ndims = ps.rank(x) sample_ndims = ps.maximum(0, d) # (2) Transpose x's dims. sample_dims = ps.range(0, sample_ndims) batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = ps.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), ps.concat([ ps.ones([sample_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32) ], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) # (5) Make the final reduction in x. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'one_step')): variance_parts = previous_kernel_results.running_variance diags = [ variance_part.variance() for variance_part in variance_parts ] # Set the momentum. batch_ndims = ps.rank( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) state_parts = tf.nest.flatten(current_state) new_momentum_distribution = _make_momentum_distribution( diags, state_parts, batch_ndims) inner_results = self.momentum_distribution_setter_fn( previous_kernel_results.inner_results, new_momentum_distribution) # Step the inner kernel. inner_kwargs = {} if seed is None else dict(seed=seed) new_state, new_inner_results = self.inner_kernel.one_step( current_state, inner_results, **inner_kwargs) new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip(variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, or # none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note this # division is inferred from the shapes of the state part, the log_prob, # and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the above # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the lead # dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat( [[ps.reduce_prod(state_part_shape[:num_reduce_dims])], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) new_kernel_results = previous_kernel_results._replace( inner_results=new_inner_results, running_variance=new_variance_parts) return new_state, new_kernel_results
def __init__( self, input_size, output_size, # keras::Conv::filters # Conv specific. filter_shape, # keras::Conv::kernel_size rank=2, # keras::Conv::rank strides=1, # keras::Conv::strides padding='VALID', # keras::Conv::padding; 'CAUSAL' not implemented. # keras::Conv::data_format is not implemented dilations=1, # keras::Conv::dilation_rate # Weights init_kernel_fn=None, # tfp.experimental.nn.initializers.glorot_uniform() init_bias_fn=None, # tf.initializers.zeros() make_kernel_bias_fn=nn_util_lib.make_kernel_bias, dtype=tf.float32, batch_shape=(), # Misc activation_fn=None, name=None): """Constructs layer. Note: `data_format` is not supported since all nn layers operate on the rightmost column. If your channel dimension is not rightmost, use `tf.transpose` before calling this layer. For example, if your channel dimension is second from the left, the following code will move it rightmost: ```python inputs = tf.transpose(inputs, tf.concat([ [0], tf.range(2, tf.rank(inputs)), [1]], axis=0)) ``` Args: input_size: ... In Keras, this argument is inferred from the rightmost input shape, i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the second from the rightmost dimension of both `inputs` and `kernel`. Default value: `None`. output_size: ... In Keras, this argument is called `filters`. This argument specifies the rightmost dimension size of both `kernel` and `bias`. filter_shape: ... In Keras, this argument is called `kernel_size`. This argument specifies the leftmost `rank` dimensions' sizes of `kernel`. rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. This argument implies the number of `kernel` dimensions, i.e.`, `kernel.shape.rank == rank + 2`. In Keras, this argument has the same name and semantics. Default value: `2`. strides: An integer or tuple/list of n integers, specifying the stride length of the convolution. In Keras, this argument has the same name and semantics. Default value: `1`. padding: One of `"VALID"` or `"SAME"` (case-insensitive). In Keras, this argument has the same name and semantics (except we don't support `"CAUSAL"`). Default value: `'VALID'`. dilations: An integer or tuple/list of `rank` integers, specifying the dilation rate to use for dilated convolution. Currently, specifying any `dilations` value != 1 is incompatible with specifying any `strides` value != 1. In Keras, this argument is called `dilation_rate`. Default value: `1`. init_kernel_fn: ... Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). init_bias_fn: ... Default value: `None` (i.e., `tf.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... Default value: `tf.float32`. batch_shape: ... Default value: `()`. activation_fn: ... Default value: `None`. name: ... Default value: `None` (i.e., `'Convolution'`). """ filter_shape = prepare_tuple_argument(filter_shape, rank, arg_name='filter_shape') batch_shape = (np.array([], dtype=np.int32) if batch_shape is None else prefer_static.reshape(batch_shape, shape=[-1])) batch_ndims = prefer_static.size(batch_shape) if tf.get_static_value(batch_ndims) == 0: # In this branch, we statically know there are no batch dims. kernel_shape = filter_shape + (input_size, output_size) bias_shape = [output_size] apply_kernel_fn = _make_convolution_fn(rank, strides, padding, dilations) else: # In this branch, there are either static/dynamic batch dims or # dynamically no batch dims. kernel_shape = prefer_static.concat( [batch_shape, filter_shape, [input_size, output_size]], axis=0) bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0) apply_kernel_fn = lambda x, k: convolution_batch( # pylint: disable=g-long-lambda x, k, rank=rank, strides=strides, padding=padding, data_format='NHWBC', dilations=dilations) kernel, bias = make_kernel_bias_fn(kernel_shape, bias_shape, init_kernel_fn, init_bias_fn, batch_ndims, batch_ndims, dtype) self._make_kernel_bias_fn = make_kernel_bias_fn # For tracking. super(Convolution, self).__init__(kernel=kernel, bias=bias, apply_kernel_fn=apply_kernel_fn, dtype=dtype, activation_fn=activation_fn, name=name)