Example #1
0
        def loop_body(i_, event_ind):
            i = i_ // strides
            j = i_ % strides

            i_ind = ps.range(i * fw,
                             ps.maximum(i, fh) * fw,
                             delta=strides * fw,
                             dtype=dtype)
            j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype)

            nc = cartesian_add([i_ind, j_ind])
            ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

            k = ps.reshape(cartesian_add([
                ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
                ps.range(ps.shape(nc)[1], dtype=dtype)
            ]),
                           shape=[-1])
            last_j = strides - (fw - j - 1) % strides - 1
            last_i = strides - (fh - i - 1) % strides - 1
            kernel_ind = ps.stack(
                [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
            event_ind = ps.tensor_scatter_nd_update(event_ind, ind[...,
                                                                   tf.newaxis],
                                                    kernel_ind)

            return i_ + 1, event_ind
Example #2
0
    def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None):
        if event_shape is None:
            event_shape = self.event_shape_tensor()

        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)

        # Continuing the example from `_augment_sample_shape`, suppose we have:
        #   - sample shape of `[n]`,
        #   - underlying distribution batch shape of `[2, 1]`,
        #   - final broadcast batch shape of `[4, 2, 3]`.
        # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
        # ultimately want to have shape `[n, 4, 2, 3] + event_shape`.

        # First, we reshape to expand out the batch elements:
        # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
        # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
        # `[4, 1, 3]` is the shape of the elements being added by broadcasting.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)
        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)
        x_with_doubled_batch = tf.reshape(
            x,
            ps.concat([
                sample_shape,
                ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp,
                event_shape
            ],
                      axis=0))

        # Next, construct the permutation that interleaves the batch dimensions,
        # resulting in samples with shape
        # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
        # Note that each interleaved pair of batch dimensions contains exactly one
        # dim of size `1` and one of size `>= 1`.
        sample_ndims = ps.rank_from_shape(sample_shape)
        x_with_interleaved_batch = tf.transpose(
            x_with_doubled_batch,
            perm=ps.concat([
                ps.range(sample_ndims), sample_ndims + ps.reshape(
                    ps.stack([
                        ps.range(batch_rank),
                        ps.range(batch_rank) + batch_rank
                    ],
                             axis=-1), [-1]), sample_ndims + 2 * batch_rank +
                ps.range(ps.rank_from_shape(event_shape))
            ],
                           axis=0))

        # Final reshape to remove the spurious `1` dimensions.
        return tf.reshape(
            x_with_interleaved_batch,
            ps.concat([sample_shape, batch_shape, event_shape], axis=0))
Example #3
0
 def rank_to_has_batch_dimensions(cls, rank: TensorflowTreeTopology):
     event_ndims = cls.get_event_ndims()
     batch_ndims = tf.nest.map_structure(
         lambda elem_rank, elem_event_ndims: elem_rank - elem_event_ndims,
         rank,
         event_ndims,
     )
     has_batch_dims_array = ps.stack(tf.nest.flatten(batch_ndims)) > 0
     return ps.reduce_any(has_batch_dims_array)
Example #4
0
 def _split_sample(self, x):
   result_batch_shape = self._calculate_batch_shape()
   sample_shape_size = (ps.rank(x)
                        - ps.shape(result_batch_shape)[0]
                        - ps.rank(self.event_shape))
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shapes = ps.stack(all_batch_shapes, axis=0)
   all_compose_shapes = ps.gather(original_shapes, self._axis, axis=1)
   x_split = tf.split(x, all_compose_shapes, axis=sample_shape_size+self._axis)
   return sample_shape_size, x_split
Example #5
0
 def _calculate_batch_shape(self):
   """Computes fully defined batch shape for the new distribution."""
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shape = ps.stack(all_batch_shapes, axis=0)
   index_mask = ps.cast(
       ps.one_hot(self._axis, ps.shape(original_shape)[1]),
       dtype=tf.bool)
   new_concat_dim = ps.cast(
       ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32)
   return ps.where(index_mask, new_concat_dim,
                   ps.reduce_max(original_shape, axis=0))
Example #6
0
def interpolate_backward_differences(backward_differences, order,
                                     step_size_ratio):
    """Updates backward differences when a change in the step size occurs."""
    state_dtype = backward_differences.dtype
    interpolation_matrix_ = interpolation_matrix(state_dtype, order,
                                                 step_size_ratio)
    interpolation_matrix_unit_step_size_ratio = interpolation_matrix(
        state_dtype, order, 1.)
    interpolated_backward_differences_orders_one_to_five = tf.matmul(
        interpolation_matrix_unit_step_size_ratio,
        tf.matmul(interpolation_matrix_,
                  backward_differences[1:MAX_ORDER + 1]))
    interpolated_backward_differences = tf.concat([
        tf.gather(backward_differences, [0]),
        interpolated_backward_differences_orders_one_to_five,
        ps.zeros(ps.stack([2, ps.shape(backward_differences)[1]]),
                 dtype=state_dtype),
    ], 0)
    return interpolated_backward_differences
Example #7
0
    def _sample_n(self, n, seed=None):
        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)
        n_batch = ps.reduce_prod(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)
        underlying_n_batch = ps.reduce_prod(underlying_batch_shape)

        # Left pad underlying shape with any necessary ones.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)

        # Determine how many underlying samples to produce.
        n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch)
        samps = self.distribution.sample([n, n_bcast_samples], seed=seed)

        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)

        event_shape = self.event_shape_tensor()
        event_rank = ps.rank_from_shape(event_shape)
        shp = ps.concat([[n],
                         ps.where(is_dim_bcast, batch_shape, 1),
                         underlying_bcast_shp, event_shape],
                        axis=0)
        # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
        samps = tf.reshape(samps, shp)
        # Interleave broadcast and underlying axis indices for transpose.
        interleaved_batch_axes = ps.reshape(
            ps.stack([ps.range(batch_rank),
                      ps.range(batch_rank) + batch_rank],
                     axis=-1), [-1]) + 1

        event_axes = ps.range(event_rank) + (1 + 2 * batch_rank)
        perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0)
        samps = tf.transpose(samps, perm=perm)
        # Finally, reshape to the fully-broadcast batch shape.
        return tf.reshape(samps,
                          ps.concat([[n], batch_shape, event_shape], axis=0))
Example #8
0
    def _initialize_solver_internal_state(
        self,
        ode_fn,
        initial_time,
        initial_state,
    ):
        p = self._prepare_common_params(
            ode_fn=ode_fn,
            initial_state=initial_state,
            initial_time=initial_time,
        )

        first_step_size = self._first_step_size
        if first_step_size is None:
            _, error_coefficients = self._prepare_coefficients(
                p.common_state_dtype)
            first_step_size = bdf_util.first_step_size(
                p.atol, error_coefficients[1], p.initial_state_vec,
                p.initial_time, p.ode_fn_vec, p.rtol, p.safety_factor)
        first_step_size = tf.convert_to_tensor(first_step_size,
                                               dtype=p.real_dtype)
        if self._validate_args:
            first_step_size = tf.ensure_shape(first_step_size, [])

        first_order_backward_difference = p.ode_fn_vec(
            p.initial_time, p.initial_state_vec) * tf.cast(
                first_step_size, p.common_state_dtype)
        backward_differences = tf.concat(
            [
                p.initial_state_vec[tf.newaxis, :],
                first_order_backward_difference[tf.newaxis, :],
                tf.zeros(ps.stack([bdf_util.MAX_ORDER + 1, p.num_odes]),
                         dtype=p.common_state_dtype),
            ],
            axis=0,
        )
        return _BDFSolverInternalState(
            backward_differences=backward_differences,
            order=tf.ones([], tf.int32),
            step_size=first_step_size)
Example #9
0
 def _event_shape_tensor(self):
     dimension = self._scale.domain_dimension_tensor()
     return ps.stack([dimension, dimension])
Example #10
0
 def _flatten_row(jacobian_row, state_shape_part):
     state_size = ps.reduce_prod(state_shape_part)
     jacobian_row_mats = tf.nest.map_structure(
         lambda j: tf.reshape(j, ps.stack([state_size, -1], axis=0)),
         jacobian_row)
     return tf.concat(tf.nest.flatten(jacobian_row_mats), axis=-1)