Esempio n. 1
0
 def _log_abs_determinant(self):
     log_det = tfp_math.reduce_kahan_sum(tf.math.log(tf.math.abs(
         self._diag)),
                                         axis=[-1]).total
     if dtype_util.is_complex(self.dtype):
         log_det = tf.cast(log_det, dtype=self.dtype)
     return log_det
Esempio n. 2
0
def error_if_not_real_or_complex(tensor, identifier):
  """Raise a `TypeError` if the `Tensor` is neither real nor complex."""
  if not (dtype_util.is_floating(tensor.dtype) or
          dtype_util.is_complex(tensor.dtype)):
    raise TypeError(
        '`{}` must have a floating point or complex floating point dtype.'
        .format(identifier))
Esempio n. 3
0
def abs_square(x):
    """Returns the squared value of `tf.abs(x)` for real and complex dtypes."""
    if dtype_util.is_complex(x.dtype):
        return tf.math.square(tf.math.real(x)) + tf.math.square(
            tf.math.imag(x))
    else:
        return tf.math.square(x)
Esempio n. 4
0
 def _one_part(tensor):
   tensor = tf.convert_to_tensor(tensor)
   if dtype_util.is_floating(tensor.dtype) or dtype_util.is_complex(
       tensor.dtype):
     return tf.stop_gradient(tensor)
   else:
     return tensor
Esempio n. 5
0
 def _inverse(self, y):
     x = y
     if self.shift is not None:
         x = x - self.shift
     if self._is_only_identity_multiplier:
         s = (tf.math.conj(self._scale) if self.adjoint
              and dtype_util.is_complex(self._scale.dtype) else self._scale)
         return x / s
     # Solve fails if the op is singular so we may safely skip this assertion.
     x = self.scale.solvevec(x, adjoint=self.adjoint)
     return x
Esempio n. 6
0
 def _forward(self, x):
     y = x
     if self._is_only_identity_multiplier:
         s = (tf.math.conj(self._scale) if self.adjoint
              and dtype_util.is_complex(self._scale.dtype) else self._scale)
         y = y * s
         if self.shift is not None:
             return y + self.shift
         return y
     with tf.control_dependencies(
             self._maybe_check_scale() if self.validate_args else []):
         y = self.scale.matvec(y, adjoint=self.adjoint)
     if self.shift is not None:
         y = y + self.shift
     return y
Esempio n. 7
0
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        mixture_distribution, components_distribution = (
            self._get_distributions_with_broadcast_batch_shape())
        x = components_distribution.sample(  # [n, B, k, E]
            n, seed=components_seed)

        event_ndims = ps.rank_from_shape(self.event_shape_tensor,
                                         self.event_shape)
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        num_components = ps.dimension_size(x, idx=-1 - event_ndims)

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mix_sample = mixture_distribution.sample(n, seed=mix_seed)  # [n, B]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]

        # Pad `mask` to [n, B, k, [1]*e].
        batch_ndims = ps.rank(x) - event_ndims - 1
        mask_batch_ndims = ps.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = ps.shape(mask)
        target_shape = ps.concat([
            mask_shape[:-1],
            ps.ones([pad_ndims], dtype=tf.int32),
            mask_shape[-1:],
            ps.ones([event_ndims], dtype=tf.int32),
        ],
                                 axis=0)
        mask = tf.reshape(mask, shape=target_shape)

        if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype):
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            ret = self._reparameterize_sample(
                ret, event_shape=components_distribution.event_shape_tensor())

        return ret
Esempio n. 8
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = ps.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = ps.rank(x) - event_ndims - 1
        mask_batch_ndims = ps.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = ps.shape(mask)
        target_shape = ps.concat([
            mask_shape[:-1],
            ps.ones([pad_ndims], dtype=tf.int32),
            mask_shape[-1:],
            ps.ones([event_ndims], dtype=tf.int32),
        ],
                                 axis=0)
        mask = tf.reshape(mask, shape=target_shape)

        if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype):
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
Esempio n. 10
0
 def error_if_complex(dtype):
     if dtype_util.is_complex(dtype):
         raise NotImplementedError(
             'The adjoint sensitivity method does '
             'not support complex dtypes.')
Esempio n. 11
0
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # This function is comprised of the following sequential stages:
        # (1) Make static assertions.
        # (2) Initialize variables.
        # (3) Make non-static assertions.
        # (4) Solve up to final time.
        # (5) Return `Results` object.
        #
        # The stages can be found in the code by searching for (n) where n=1..5.
        #
        # By static vs. non-static assertions (see stages 1 and 3), we mean
        # assertions that can be made before the graph is run vs. those that can
        # only be made at run time. The latter are constructed as a list of
        # tf.Assert operations by the function `assert_ops` (see below).
        #
        # If `solution_times` is specified as a `Tensor`, stage 4 consists of three
        # nested loops, which can be conceptually understood as follows:
        # ```
        # current_time, current_state = initial_time, initial_state
        # order, step_size = 1, first_step_size
        # for solution_time in solution_times:
        #   while current_time < solution_time:
        #     while True:
        #       next_time = current_time + step_size
        #       next_state, error = (
        #           solve_nonlinear_equation_to_get_approximate_state_at_next_time(
        #           current_time, current_state, next_time, order))
        #       if error < tolerance:
        #         current_time, current_state = next_time, next_state
        #         order, step_size = (
        #           maybe_update_order_and_step_size(order, step_size))
        #         break
        #       else:
        #         step_size = decrease_step_size(step_size)
        # ```
        # The outermost loop advances the solver to the next `solution_time` (see
        # `advance_to_solution_time`). The middle loop advances the solver by a
        # small timestep (see `step`). The innermost loop determines the size of
        # that timestep (see `maybe_step`).
        #
        # If `solution_times` is specified as
        # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped
        # and `solution_time` in the middle loop is replaced by `final_time`.

        def advance_to_solution_time(n, diagnostics, iterand,
                                     solver_internal_state, state_vec_array,
                                     time_array):
            """Takes multiple steps to advance time to `solution_times[n]`."""
            def step_cond(next_time, diagnostics, iterand, *_):
                return (iterand.time < next_time) & (tf.equal(
                    diagnostics.status, 0))

            nth_solution_time = solution_time_array.read(n)
            [
                _, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ] = tf.while_loop(step_cond, step, [
                nth_solution_time, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ])
            state_vec_array = state_vec_array.write(
                n, solver_internal_state.backward_differences[0])
            time_array = time_array.write(n, nth_solution_time)
            return (n + 1, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def step(next_time, diagnostics, iterand, solver_internal_state,
                 state_vec_array, time_array):
            """Takes a single step."""
            distance_to_next_time = next_time - iterand.time
            overstepped = iterand.new_step_size > distance_to_next_time
            iterand = iterand._replace(new_step_size=tf1.where(
                overstepped, distance_to_next_time, iterand.new_step_size),
                                       should_update_step_size=overstepped
                                       | iterand.should_update_step_size)

            if not self._evaluate_jacobian_lazily:
                diagnostics = diagnostics._replace(
                    num_jacobian_evaluations=diagnostics.
                    num_jacobian_evaluations + 1)
                iterand = iterand._replace(jacobian_mat=jacobian_fn_mat(
                    iterand.time,
                    solver_internal_state.backward_differences[0]),
                                           jacobian_is_up_to_date=True)

            def maybe_step_cond(accepted, diagnostics, *_):
                return tf.logical_not(accepted) & tf.equal(
                    diagnostics.status, 0)

            _, diagnostics, iterand, solver_internal_state = tf.while_loop(
                maybe_step_cond, maybe_step,
                [False, diagnostics, iterand, solver_internal_state])

            if solution_times_chosen_by_solver:
                state_vec_array = state_vec_array.write(
                    state_vec_array.size(),
                    solver_internal_state.backward_differences[0])
                time_array = time_array.write(time_array.size(), iterand.time)

            return (next_time, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            [backward_differences, order, step_size] = solver_internal_state

            if max_num_steps is not None:
                status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf1.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf1.where(should_update_step_size, new_step_size,
                                  step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf1.where(should_update_step_size, 0,
                                            num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian_mat, newton_coefficients_array.read(order),
                    step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian_mat = jacobian_fn_mat(time,
                                                       backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian_mat, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian_mat, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian_mat, jacobian_is_up_to_date,
                    num_jacobian_evaluations, unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = p.atol + p.rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state_vec,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                p.ode_fn_vec, order, step_size, time,
                                newton_tol, unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf1.where(should_update_step_size,
                                                  newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf1.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf1.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    p.safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf1.where(accepted, time + step_size, time)
            backward_differences = tf1.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state_vec, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                            num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.shape).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf1.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    p.safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, step_size)
            return accepted, diagnostics, iterand, solver_internal_state

        # (1) Make static assertions.
        # TODO(b/138304296): Support specifying Jacobian sparsity patterns.
        if jacobian_sparsity is not None:
            raise NotImplementedError(
                'The BDF solver does not support specifying '
                'Jacobian sparsity patterns.')
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'The BDF solver does not support batching.')
        solution_times_chosen_by_solver = (isinstance(solution_times,
                                                      base.ChosenBySolver))

        with tf.name_scope(self._name):

            # (2) Convert to tensors.
            p = self._prepare_common_params(
                ode_fn=ode_fn,
                initial_state=initial_state,
                initial_time=initial_time,
            )

            if jacobian_fn is None and dtype_util.is_complex(
                    p.common_state_dtype):
                raise NotImplementedError(
                    'The BDF solver does not support automatic '
                    'Jacobian computations for complex dtypes.')

            # Convert everything to operate on a single, concatenated vector form.
            jacobian_fn_mat = util.get_jacobian_fn_mat(
                jacobian_fn,
                p.ode_fn_vec,
                p.state_shape,
                dtype=p.common_state_dtype,
            )

            num_solution_times = 0
            if solution_times_chosen_by_solver:
                final_time = tf.cast(solution_times.final_time, p.real_dtype)
            else:
                solution_times = tf.cast(solution_times, p.real_dtype)
                final_time = tf.reduce_max(solution_times)
                num_solution_times = tf.size(solution_times)
                solution_time_array = tf.TensorArray(
                    solution_times.dtype,
                    size=num_solution_times,
                    element_shape=[]).unstack(solution_times)
                util.error_if_not_vector(solution_times, 'solution_times')
            min_step_size_factor = tf.convert_to_tensor(
                self._min_step_size_factor, dtype=p.real_dtype)
            max_step_size_factor = tf.convert_to_tensor(
                self._max_step_size_factor, dtype=p.real_dtype)
            max_num_steps = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
            max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32)
            max_num_newton_iters = self._max_num_newton_iters
            if max_num_newton_iters is not None:
                max_num_newton_iters = tf.convert_to_tensor(
                    max_num_newton_iters, dtype=tf.int32)
            newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor,
                                                     dtype=p.real_dtype)
            newton_step_size_factor = tf.convert_to_tensor(
                self._newton_step_size_factor, dtype=p.real_dtype)
            newton_coefficients, error_coefficients = self._prepare_coefficients(
                p.common_state_dtype)
            if self._validate_args:
                final_time = tf.ensure_shape(final_time, [])
                min_step_size_factor = tf.ensure_shape(min_step_size_factor,
                                                       [])
                max_step_size_factor = tf.ensure_shape(max_step_size_factor,
                                                       [])
                if max_num_steps is not None:
                    max_num_steps = tf.ensure_shape(max_num_steps, [])
                max_order = tf.ensure_shape(max_order, [])
                if max_num_newton_iters is not None:
                    max_num_newton_iters = tf.ensure_shape(
                        max_num_newton_iters, [])
                newton_tol_factor = tf.ensure_shape(newton_tol_factor, [])
                newton_step_size_factor = tf.ensure_shape(
                    newton_step_size_factor, [])
            newton_coefficients_array = tf.TensorArray(
                newton_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(newton_coefficients)
            error_coefficients_array = tf.TensorArray(
                error_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(error_coefficients)
            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                solver_internal_state = self._initialize_solver_internal_state(
                    ode_fn=ode_fn,
                    initial_state=initial_state,
                    initial_time=initial_time,
                )
            state_vec_array = tf.TensorArray(
                p.common_state_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=p.initial_state_vec.shape)
            time_array = tf.TensorArray(
                p.real_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=tf.TensorShape([]))
            diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0,
                                          num_matrix_factorizations=0,
                                          num_ode_fn_evaluations=0,
                                          status=0)
            iterand = _BDFIterand(
                jacobian_mat=tf.zeros([p.num_odes, p.num_odes],
                                      dtype=p.common_state_dtype),
                jacobian_is_up_to_date=False,
                new_step_size=solver_internal_state.step_size,
                num_steps=0,
                num_steps_same_size=0,
                should_update_jacobian=True,
                should_update_step_size=False,
                time=p.initial_time,
                unitary=tf.zeros([p.num_odes, p.num_odes],
                                 dtype=p.common_state_dtype),
                upper=tf.zeros([p.num_odes, p.num_odes],
                               dtype=p.common_state_dtype),
            )

            # (3) Make non-static assertions.
            with tf.control_dependencies(
                    self._assert_ops(
                        previous_solver_internal_state=
                        previous_solver_internal_state,
                        initial_state_vec=p.initial_state_vec,
                        final_time=final_time,
                        initial_time=p.initial_time,
                        solution_times=solution_times,
                        max_num_steps=max_num_steps,
                        max_num_newton_iters=max_num_newton_iters,
                        atol=p.atol,
                        rtol=p.rtol,
                        first_step_size=solver_internal_state.step_size,
                        safety_factor=p.safety_factor,
                        min_step_size_factor=min_step_size_factor,
                        max_step_size_factor=max_step_size_factor,
                        max_order=max_order,
                        newton_tol_factor=newton_tol_factor,
                        newton_step_size_factor=newton_step_size_factor,
                        solution_times_chosen_by_solver=
                        solution_times_chosen_by_solver,
                    )):

                # (4) Solve up to final time.
                if solution_times_chosen_by_solver:

                    def step_cond(next_time, diagnostics, iterand, *_):
                        return (iterand.time < next_time) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(step_cond, step, [
                        final_time, diagnostics, iterand,
                        solver_internal_state, state_vec_array, time_array
                    ])

                else:

                    def advance_to_solution_time_cond(n, diagnostics, *_):
                        return (n < num_solution_times) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(
                        advance_to_solution_time_cond,
                        advance_to_solution_time, [
                            0, diagnostics, iterand, solver_internal_state,
                            state_vec_array, time_array
                        ])

                # (6) Return `Results` object.
                states = util.get_state_from_vec(state_vec_array.stack(),
                                                 p.state_shape)
                times = time_array.stack()
                if not solution_times_chosen_by_solver:
                    tensorshape_util.set_shape(times, solution_times.shape)
                    tf.nest.map_structure(
                        lambda s, ini_s: tensorshape_util.set_shape(  # pylint: disable=g-long-lambda
                            s,
                            tensorshape_util.concatenate(
                                solution_times.shape, ini_s.shape)),
                        states,
                        p.initial_state)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)