def loop_body(i_, event_ind): i = i_ // strides j = i_ % strides i_ind = ps.range(i * fw, ps.maximum(i, fh) * fw, delta=strides * fw, dtype=dtype) j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype) nc = cartesian_add([i_ind, j_ind]) ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) k = ps.reshape(cartesian_add([ ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), ps.range(ps.shape(nc)[1], dtype=dtype) ]), shape=[-1]) last_j = strides - (fw - j - 1) % strides - 1 last_i = strides - (fh - i - 1) % strides - 1 kernel_ind = ps.stack( [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) event_ind = ps.tensor_scatter_nd_update(event_ind, ind[..., tf.newaxis], kernel_ind) return i_ + 1, event_ind
def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None): if event_shape is None: event_shape = self.event_shape_tensor() batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) # Continuing the example from `_augment_sample_shape`, suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we # ultimately want to have shape `[n, 4, 2, 3] + event_shape`. # First, we reshape to expand out the batch elements: # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`, # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and # `[4, 1, 3]` is the shape of the elements being added by broadcasting. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) x_with_doubled_batch = tf.reshape( x, ps.concat([ sample_shape, ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape ], axis=0)) # Next, construct the permutation that interleaves the batch dimensions, # resulting in samples with shape # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`. # Note that each interleaved pair of batch dimensions contains exactly one # dim of size `1` and one of size `>= 1`. sample_ndims = ps.rank_from_shape(sample_shape) x_with_interleaved_batch = tf.transpose( x_with_doubled_batch, perm=ps.concat([ ps.range(sample_ndims), sample_ndims + ps.reshape( ps.stack([ ps.range(batch_rank), ps.range(batch_rank) + batch_rank ], axis=-1), [-1]), sample_ndims + 2 * batch_rank + ps.range(ps.rank_from_shape(event_shape)) ], axis=0)) # Final reshape to remove the spurious `1` dimensions. return tf.reshape( x_with_interleaved_batch, ps.concat([sample_shape, batch_shape, event_shape], axis=0))
def rank_to_has_batch_dimensions(cls, rank: TensorflowTreeTopology): event_ndims = cls.get_event_ndims() batch_ndims = tf.nest.map_structure( lambda elem_rank, elem_event_ndims: elem_rank - elem_event_ndims, rank, event_ndims, ) has_batch_dims_array = ps.stack(tf.nest.flatten(batch_ndims)) > 0 return ps.reduce_any(has_batch_dims_array)
def _split_sample(self, x): result_batch_shape = self._calculate_batch_shape() sample_shape_size = (ps.rank(x) - ps.shape(result_batch_shape)[0] - ps.rank(self.event_shape)) all_batch_shapes = [d.batch_shape.as_list() if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in self.distributions] original_shapes = ps.stack(all_batch_shapes, axis=0) all_compose_shapes = ps.gather(original_shapes, self._axis, axis=1) x_split = tf.split(x, all_compose_shapes, axis=sample_shape_size+self._axis) return sample_shape_size, x_split
def _calculate_batch_shape(self): """Computes fully defined batch shape for the new distribution.""" all_batch_shapes = [d.batch_shape.as_list() if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in self.distributions] original_shape = ps.stack(all_batch_shapes, axis=0) index_mask = ps.cast( ps.one_hot(self._axis, ps.shape(original_shape)[1]), dtype=tf.bool) new_concat_dim = ps.cast( ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32) return ps.where(index_mask, new_concat_dim, ps.reduce_max(original_shape, axis=0))
def interpolate_backward_differences(backward_differences, order, step_size_ratio): """Updates backward differences when a change in the step size occurs.""" state_dtype = backward_differences.dtype interpolation_matrix_ = interpolation_matrix(state_dtype, order, step_size_ratio) interpolation_matrix_unit_step_size_ratio = interpolation_matrix( state_dtype, order, 1.) interpolated_backward_differences_orders_one_to_five = tf.matmul( interpolation_matrix_unit_step_size_ratio, tf.matmul(interpolation_matrix_, backward_differences[1:MAX_ORDER + 1])) interpolated_backward_differences = tf.concat([ tf.gather(backward_differences, [0]), interpolated_backward_differences_orders_one_to_five, ps.zeros(ps.stack([2, ps.shape(backward_differences)[1]]), dtype=state_dtype), ], 0) return interpolated_backward_differences
def _sample_n(self, n, seed=None): batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) underlying_n_batch = ps.reduce_prod(underlying_batch_shape) # Left pad underlying shape with any necessary ones. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) # Determine how many underlying samples to produce. n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch) samps = self.distribution.sample([n, n_bcast_samples], seed=seed) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) event_shape = self.event_shape_tensor() event_rank = ps.rank_from_shape(event_shape) shp = ps.concat([[n], ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape], axis=0) # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp. samps = tf.reshape(samps, shp) # Interleave broadcast and underlying axis indices for transpose. interleaved_batch_axes = ps.reshape( ps.stack([ps.range(batch_rank), ps.range(batch_rank) + batch_rank], axis=-1), [-1]) + 1 event_axes = ps.range(event_rank) + (1 + 2 * batch_rank) perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0) samps = tf.transpose(samps, perm=perm) # Finally, reshape to the fully-broadcast batch shape. return tf.reshape(samps, ps.concat([[n], batch_shape, event_shape], axis=0))
def _initialize_solver_internal_state( self, ode_fn, initial_time, initial_state, ): p = self._prepare_common_params( ode_fn=ode_fn, initial_state=initial_state, initial_time=initial_time, ) first_step_size = self._first_step_size if first_step_size is None: _, error_coefficients = self._prepare_coefficients( p.common_state_dtype) first_step_size = bdf_util.first_step_size( p.atol, error_coefficients[1], p.initial_state_vec, p.initial_time, p.ode_fn_vec, p.rtol, p.safety_factor) first_step_size = tf.convert_to_tensor(first_step_size, dtype=p.real_dtype) if self._validate_args: first_step_size = tf.ensure_shape(first_step_size, []) first_order_backward_difference = p.ode_fn_vec( p.initial_time, p.initial_state_vec) * tf.cast( first_step_size, p.common_state_dtype) backward_differences = tf.concat( [ p.initial_state_vec[tf.newaxis, :], first_order_backward_difference[tf.newaxis, :], tf.zeros(ps.stack([bdf_util.MAX_ORDER + 1, p.num_odes]), dtype=p.common_state_dtype), ], axis=0, ) return _BDFSolverInternalState( backward_differences=backward_differences, order=tf.ones([], tf.int32), step_size=first_step_size)
def _event_shape_tensor(self): dimension = self._scale.domain_dimension_tensor() return ps.stack([dimension, dimension])
def _flatten_row(jacobian_row, state_shape_part): state_size = ps.reduce_prod(state_shape_part) jacobian_row_mats = tf.nest.map_structure( lambda j: tf.reshape(j, ps.stack([state_size, -1], axis=0)), jacobian_row) return tf.concat(tf.nest.flatten(jacobian_row_mats), axis=-1)