Beispiel #1
0
    def _prepare_common_params(self, ode_fn, initial_state, initial_time):
        error_if_wrong_dtype = functools.partial(
            util.error_if_not_real_or_complex, identifier='initial_state')

        initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                              initial_state)
        tf.nest.map_structure(error_if_wrong_dtype, initial_state)

        state_shape = tf.nest.map_structure(ps.shape, initial_state)
        common_state_dtype = dtype_util.common_dtype(initial_state)
        real_dtype = dtype_util.real_dtype(common_state_dtype)
        # Use tf.cast instead of tf.convert_to_tensor for differentiable
        # parameters because the tf.custom_gradient decorator converts raw floats
        # into tf.float32, which cannot be converted to tf.float64.
        initial_time = tf.cast(initial_time, real_dtype)
        if self._validate_args:
            initial_time = tf.ensure_shape(initial_time, [])

        rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
        atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
        safety_factor = tf.convert_to_tensor(self._safety_factor,
                                             dtype=real_dtype)

        if self._validate_args:
            safety_factor = tf.ensure_shape(safety_factor, [])

        # Convert everything to operate on a single, concatenated vector form.
        initial_state_vec = util.get_state_vec(initial_state)
        ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape)
        num_odes = tf.size(initial_state_vec)

        return util.Bunch(
            initial_state=initial_state,
            initial_time=initial_time,
            common_state_dtype=common_state_dtype,
            real_dtype=real_dtype,
            rtol=rtol,
            atol=atol,
            safety_factor=safety_factor,
            state_shape=state_shape,
            initial_state_vec=initial_state_vec,
            ode_fn_vec=ode_fn_vec,
            num_odes=num_odes,
        )
Beispiel #2
0
  def _prepare_common_params(self, initial_state, initial_time):
    get_dtype = lambda x: x.dtype
    error_if_wrong_dtype = functools.partial(
        util.error_if_not_real_or_complex, identifier='initial_state')

    initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state)
    tf.nest.map_structure(error_if_wrong_dtype, initial_state)

    state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
    common_state_dtype = dtype_util.common_dtype(initial_state)
    real_dtype = dtype_util.real_dtype(common_state_dtype)

    initial_time = tf.cast(initial_time, real_dtype)

    return util.Bunch(
        initial_state=initial_state,
        state_dtypes=state_dtypes,
        real_dtype=real_dtype,
        initial_time=initial_time,
    )
Beispiel #3
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
Beispiel #4
0
 def averaged_sum_squares(input_tensor):
     num_elements_cast = tf.cast(num_elements,
                                 dtype=dtype_util.real_dtype(
                                     input_tensor.dtype))
     return tf.reduce_sum(abs_square(input_tensor)) / num_elements_cast
Beispiel #5
0
            def grad_fn(*dresults, **kwargs):
                """Adjoint sensitivity method to compute gradients."""
                dresults = tf.nest.pack_sequence_as(results, dresults)
                dstates = dresults.states
                # The signature grad_fn(*dresults, variables=None) is not valid Python 2
                # so use kwargs instead.
                variables = kwargs.pop('variables', [])
                assert not kwargs  # This assert should never fail.
                # TODO(b/138304303): Support complex types.
                with tf.name_scope('{}Gradients'.format(self._name)):
                    get_dtype = lambda x: x.dtype

                    def error_if_complex(dtype):
                        if dtype.is_complex:
                            raise NotImplementedError(
                                'The adjoint sensitivity method does '
                                'not support complex dtypes.')

                    state_dtypes = tf.nest.map_structure(
                        get_dtype, initial_state)
                    tf.nest.map_structure(error_if_complex, state_dtypes)
                    common_state_dtype = dtype_util.common_dtype(initial_state)
                    real_dtype = dtype_util.real_dtype(common_state_dtype)

                    # We add initial_time to ensure that we know where to stop.
                    result_times = tf.concat(
                        [[tf.cast(initial_time, real_dtype)], results.times],
                        0)
                    num_result_times = tf.size(result_times)

                    # First two components correspond to reverse and adjoint states.
                    # the last component is adjoint state for variables.
                    terminal_augmented_state = tuple([
                        rk_util.nest_constant(initial_state, 0.0),
                        rk_util.nest_constant(initial_state, 0.0),
                        tuple(
                            rk_util.nest_constant(variable, 0.0)
                            for variable in variables)
                    ])

                    # The XLA compiler does not compile code which slices/indexes using
                    # integer `Tensor`s. `TensorArray`s are used to get around this.
                    result_time_array = tf.TensorArray(
                        results.times.dtype,
                        clear_after_read=False,
                        size=num_result_times,
                        element_shape=[]).unstack(result_times)

                    # TensorArray shape should not include time dimension, hence shape[1:]
                    result_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(results.states)
                    ]
                    result_state_arrays = tf.nest.pack_sequence_as(
                        results.states, result_state_arrays)
                    dresult_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(dstates)
                    ]
                    dresult_state_arrays = tf.nest.pack_sequence_as(
                        results.states, dresult_state_arrays)

                    def augmented_ode_fn(backward_time, augmented_state):
                        """Dynamics function for the augmented system.

            Describes a differential equation that evolves the augmented state
            backwards in time to compute gradients using the adjoint method.
            Augmented state consists of 3 components `(state, adjoint_state,
            vars)` all evaluated at time `backward_time`:

            state: represents the solution of user provided `ode_fn`. The
              structure coincides with the `initial_state`.
            adjoint_state: represents the solution of adjoint sensitivity
              differential equation as discussed below. Has the same structure
              and shape as `state`.
            vars: represent the solution of the adjoint equation for variable
              gradients. Represented as a `Tuple(Tensor, ...)` with as many
              tensors as there are `variables`.

            Adjoint sensitivity equation describes the gradient of the solution
            with respect to the value of the solution at previous time t. Its
            dynamics are given by
            d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
            Which is computed as:
            d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
            d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
            where in the last line we moved adj(t)_j under derivative by
            removing gradient from it.

            Adjoint equation for the gradient with respect to every
            `tf.Variable` theta follows:
            d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
            = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

            Args:
              backward_time: Floating `Tensor` representing current time.
              augmented_state: `Tuple(state, adjoint_state, variable_grads)`

            Returns:
              negative_derivatives: Structure of `Tensor`s equal to backwards
                time derivative of the `state` componnent.
              adjoint_ode: Structure of `Tensor`s equal to backwards time
                derivative of the `adjoint_state` component.
              adjoint_variables_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `vars` component.
            """
                        # The negative signs disappears after the change of variables.
                        # The ODE solver cannot handle the case initial_time > final_time
                        # and hence a change of variables backward_time = -time is used.
                        time = -backward_time
                        state, adjoint_state, _ = augmented_state

                        with tf.GradientTape() as tape:
                            tape.watch(variables)
                            tape.watch(state)
                            derivatives = ode_fn(time, state)
                            adjoint_no_grad = tf.nest.map_structure(
                                tf.stop_gradient, adjoint_state)
                            negative_derivatives = rk_util.weighted_sum(
                                [-1.0], [derivatives])

                            def dot_prod(tensor_a, tensor_b):
                                return tf.reduce_sum(tensor_a * tensor_b)

                            # See docstring for details.
                            adjoint_dot_derivatives = tf.nest.map_structure(
                                dot_prod, adjoint_no_grad, derivatives)
                            adjoint_dot_derivatives = tf.squeeze(
                                tf.add_n(
                                    tf.nest.flatten(adjoint_dot_derivatives)))

                        adjoint_ode, adjoint_variables_ode = tape.gradient(
                            adjoint_dot_derivatives, (state, tuple(variables)),
                            unconnected_gradients=tf.UnconnectedGradients.ZERO)
                        return negative_derivatives, adjoint_ode, adjoint_variables_ode

                    def reverse_to_result_time(n, augmented_state, _):
                        """Integrates the augmented system backwards in time."""
                        lower_bound_of_integration = result_time_array.read(n)
                        upper_bound_of_integration = result_time_array.read(n -
                                                                            1)
                        _, adjoint_state, adjoint_variable_state = augmented_state
                        initial_state = _read_solution_components(
                            result_state_arrays, input_state_structure, n - 1)
                        initial_adjoint = _read_solution_components(
                            dresult_state_arrays, input_state_structure, n - 1)
                        initial_adjoint_state = rk_util.weighted_sum(
                            [1.0, 1.0], [adjoint_state, initial_adjoint])
                        initial_augmented_state = (initial_state,
                                                   initial_adjoint_state,
                                                   adjoint_variable_state)
                        # TODO(b/138304303): Allow the user to specify the Hessian of
                        # `ode_fn` so that we can get the Jacobian of the adjoint system.
                        # TODO(b/143624114): Support higher order derivatives.
                        augmented_results = self._solve(
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            solution_times=[-upper_bound_of_integration],
                            batch_ndims=batch_ndims)
                        # Results added an extra time dim of size 1, squeeze it.
                        select_result = lambda x: tf.squeeze(x, [0])
                        result_state = augmented_results.states
                        result_state = tf.nest.map_structure(
                            select_result, result_state)
                        status = augmented_results.diagnostics.status
                        return n - 1, result_state, status

                    _, augmented_state, _ = tf.while_loop(
                        lambda n, _, status: (n >= 1) & tf.equal(status, 0),
                        reverse_to_result_time,
                        (num_result_times - 1, terminal_augmented_state, 0),
                        back_prop=False)
                    _, adjoint_state, adjoint_variables = augmented_state
                    return adjoint_state, list(adjoint_variables)
Beispiel #6
0
        def vjp_bwd(results_constants, dresults, variables=()):
            """Adjoint sensitivity method to compute gradients."""
            results, constants = results_constants
            adjoint_solver = self._make_adjoint_solver_fn()
            dstates = dresults.states
            # TODO(b/138304303): Support complex types.
            with tf.name_scope('{}Gradients'.format(self._name)):
                get_dtype = lambda x: x.dtype

                def error_if_complex(dtype):
                    if dtype_util.is_complex(dtype):
                        raise NotImplementedError(
                            'The adjoint sensitivity method does '
                            'not support complex dtypes.')

                state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
                tf.nest.map_structure(error_if_complex, state_dtypes)
                common_state_dtype = dtype_util.common_dtype(initial_state)
                real_dtype = dtype_util.real_dtype(common_state_dtype)

                # We add initial_time to ensure that we know where to stop.
                result_times = tf.concat(
                    [[tf.cast(initial_time, real_dtype)], results.times], 0)
                num_result_times = tf.size(result_times)

                # First two components correspond to reverse and adjoint states.
                # the last two component is adjoint state for variables and constants.
                terminal_augmented_state = tuple([
                    rk_util.nest_constant(initial_state, 0.0),
                    rk_util.nest_constant(initial_state, 0.0),
                    tuple(
                        rk_util.nest_constant(variable, 0.0)
                        for variable in variables),
                    rk_util.nest_constant(constants, 0.0),
                ])

                # The XLA compiler does not compile code which slices/indexes using
                # integer `Tensor`s. `TensorArray`s are used to get around this.
                result_time_array = tf.TensorArray(
                    results.times.dtype,
                    clear_after_read=False,
                    size=num_result_times,
                    element_shape=[]).unstack(result_times)

                # TensorArray shape should not include time dimension, hence shape[1:]
                result_state_arrays = [
                    tf.TensorArray(  # pylint: disable=g-complex-comprehension
                        dtype=component.dtype,
                        size=num_result_times - 1,
                        clear_after_read=False,
                        element_shape=component.shape[1:]).unstack(component)
                    for component in tf.nest.flatten(results.states)
                ]
                result_state_arrays = tf.nest.pack_sequence_as(
                    results.states, result_state_arrays)
                dresult_state_arrays = [
                    tf.TensorArray(  # pylint: disable=g-complex-comprehension
                        dtype=component.dtype,
                        size=num_result_times - 1,
                        clear_after_read=False,
                        element_shape=component.shape[1:]).unstack(component)
                    for component in tf.nest.flatten(dstates)
                ]
                dresult_state_arrays = tf.nest.pack_sequence_as(
                    results.states, dresult_state_arrays)

                def augmented_ode_fn(backward_time, augmented_state):
                    """Dynamics function for the augmented system.

          Describes a differential equation that evolves the augmented state
          backwards in time to compute gradients using the adjoint method.
          Augmented state consists of 4 components `(state, adjoint_state,
          vars, constants)` all evaluated at time `backward_time`:

          state: represents the solution of user provided `ode_fn`. The
            structure coincides with the `initial_state`.
          adjoint_state: represents the solution of the adjoint sensitivity
            differential equation as discussed below. Has the same structure
            and shape as `state`.
          variables: represent the solution of the adjoint equation for
            variable gradients. Represented as a `Tuple(Tensor, ...)` with as
            many tensors as there are `variables` variable outside this
            function.
          constants: represent the solution of the adjoint equation for
            constant gradients. Has the same structure and shape as
            `constants` variable outside this function.

          The adjoint sensitivity equation describes the gradient of the
          solution with respect to the value of the solution at a previous
          time t. Its dynamics are given by
          d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
          Which is computed as:
          d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
          d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
          where in the last line we moved adj(t)_j under derivative by
          removing gradient from it.

          Adjoint equation for the gradient with respect to every
          `tf.Variable` and constant theta follows:
          d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
          = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

          Args:
            backward_time: Floating `Tensor` representing current time.
            augmented_state: `Tuple(state, adjoint_state, variable_grads)`

          Returns:
            negative_derivatives: Structure of `Tensor`s equal to backwards
              time derivative of the `state` componnent.
            adjoint_ode: Structure of `Tensor`s equal to backwards time
              derivative of the `adjoint_state` component.
            adjoint_variables_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `vars` component.
            adjoint_constants_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `constants` component.
          """
                    # The negative signs disappears after the change of variables.
                    # The ODE solver cannot handle the case initial_time > final_time
                    # and hence a change of variables backward_time = -time is used.
                    time = -backward_time
                    state, adjoint_state, _, _ = augmented_state

                    # TODO(b/152464477): Doesn't work reliably in TF1.
                    def grad_fn(state, variables, constants):
                        del variables  # We compute these gradients via the GradientTape
                        # capturing them.
                        derivatives = ode_fn(time, state, **constants)
                        adjoint_no_grad = tf.nest.map_structure(
                            tf.stop_gradient, adjoint_state)
                        negative_derivatives = rk_util.weighted_sum(
                            [-1.0], [derivatives])

                        def dot_prod(tensor_a, tensor_b):
                            return tf.reduce_sum(tensor_a * tensor_b)

                        # See docstring for details.
                        adjoint_dot_derivatives = tf.nest.map_structure(
                            dot_prod, adjoint_no_grad, derivatives)
                        adjoint_dot_derivatives = tf.squeeze(
                            tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
                        return adjoint_dot_derivatives, negative_derivatives

                    values = (state, tuple(variables), constants)
                    ((_, negative_derivatives),
                     gradients) = tfp_gradient.value_and_gradient(
                         grad_fn, values, has_aux=True, use_gradient_tape=True)

                    (adjoint_ode, adjoint_variables_ode,
                     adjoint_constants_ode) = tf.nest.map_structure(
                         lambda v, g: tf.zeros_like(v)
                         if g is None else g, values, tuple(gradients))
                    return (negative_derivatives, adjoint_ode,
                            adjoint_variables_ode, adjoint_constants_ode)

                def make_augmented_state(n, prev_augmented_state):
                    """Constructs the augmented state for step `n`."""
                    (_, adjoint_state, adjoint_variable_state,
                     adjoint_constant_state) = prev_augmented_state
                    initial_state = _read_solution_components(
                        result_state_arrays,
                        input_state_structure,
                        n - 1,
                    )
                    initial_adjoint = _read_solution_components(
                        dresult_state_arrays,
                        input_state_structure,
                        n - 1,
                    )
                    initial_adjoint_state = rk_util.weighted_sum(
                        [1.0, 1.0], [adjoint_state, initial_adjoint])
                    augmented_state = (
                        initial_state,
                        initial_adjoint_state,
                        adjoint_variable_state,
                        adjoint_constant_state,
                    )
                    return augmented_state

                def reverse_to_result_time(n, augmented_state,
                                           solver_internal_state, _):
                    """Integrates the augmented system backwards in time."""
                    lower_bound_of_integration = result_time_array.read(n)
                    upper_bound_of_integration = result_time_array.read(n - 1)
                    initial_augmented_state = make_augmented_state(
                        n, augmented_state)
                    # TODO(b/138304303): Allow the user to specify the Hessian of
                    # `ode_fn` so that we can get the Jacobian of the adjoint system.
                    # TODO(b/143624114): Support higher order derivatives.
                    solver_internal_state = (
                        adjoint_solver.
                        _adjust_solver_internal_state_for_state_jump(  # pylint: disable=protected-access
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            previous_solver_internal_state=
                            solver_internal_state,
                            previous_state=augmented_state,
                        ))
                    augmented_results = adjoint_solver.solve(
                        ode_fn=augmented_ode_fn,
                        initial_time=-lower_bound_of_integration,
                        initial_state=initial_augmented_state,
                        solution_times=[-upper_bound_of_integration],
                        batch_ndims=batch_ndims,
                        previous_solver_internal_state=solver_internal_state,
                    )
                    # Results added an extra time dim of size 1, squeeze it.
                    select_result = lambda x: tf.squeeze(x, [0])
                    result_state = augmented_results.states
                    result_state = tf.nest.map_structure(
                        select_result, result_state)
                    status = augmented_results.diagnostics.status
                    return (n - 1, result_state,
                            augmented_results.solver_internal_state, status)

                initial_n = num_result_times - 1
                solver_internal_state = adjoint_solver._initialize_solver_internal_state(  # pylint: disable=protected-access
                    ode_fn=augmented_ode_fn,
                    initial_time=result_time_array.read(initial_n),
                    initial_state=make_augmented_state(
                        initial_n, terminal_augmented_state),
                )

                _, augmented_state, _, _ = tf.while_loop(
                    lambda n, _as, _sis, status:
                    (n >= 1) & tf.equal(status, 0),
                    reverse_to_result_time,
                    (initial_n, terminal_augmented_state,
                     solver_internal_state, 0),
                    back_prop=False,
                )
                (_, adjoint_state, adjoint_variables,
                 adjoint_constants) = augmented_state

                if variables:
                    return (adjoint_state,
                            adjoint_constants), list(adjoint_variables)
                else:
                    return adjoint_state, adjoint_constants
Beispiel #7
0
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # This function is comprised of the following sequential stages:
        # (1) Make static assertions.
        # (2) Initialize variables.
        # (3) Make non-static assertions.
        # (4) Solve up to final time.
        # (5) Return `Results` object.
        #
        # The stages can be found in the code by searching for (n) where n=1..5.
        #
        # By static vs. non-static assertions (see stages 1 and 3), we mean
        # assertions that can be made before the graph is run vs. those that can
        # only be made at run time. The latter are constructed as a list of
        # tf.Assert operations by the function `assert_ops` (see below).
        #
        # If `solution_times` is specified as a `Tensor`, stage 4 consists of three
        # nested loops, which can be conceptually understood as follows:
        # ```
        # current_time, current_state = initial_time, initial_state
        # order, step_size = 1, first_step_size
        # for solution_time in solution_times:
        #   while current_time < solution_time:
        #     while True:
        #       next_time = current_time + step_size
        #       next_state, error = (
        #           solve_nonlinear_equation_to_get_approximate_state_at_next_time(
        #           current_time, current_state, next_time, order))
        #       if error < tolerance:
        #         current_time, current_state = next_time, next_state
        #         order, step_size = (
        #           maybe_update_order_and_step_size(order, step_size))
        #         break
        #       else:
        #         step_size = decrease_step_size(step_size)
        # ```
        # The outermost loop advances the solver to the next `solution_time` (see
        # `advance_to_solution_time`). The middle loop advances the solver by a
        # small timestep (see `step`). The innermost loop determines the size of
        # that timestep (see `maybe_step`).
        #
        # If `solution_times` is specified as
        # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped
        # and `solution_time` in the middle loop is replaced by `final_time`.

        def assert_ops():
            """Creates a list of assert operations."""
            if not self._validate_args:
                return []
            assert_ops = []
            if previous_solver_internal_state is not None:
                assert_initial_state_matches_previous_solver_internal_state = (
                    tf.debugging.assert_near(
                        tf.norm(
                            initial_state_vec - previous_solver_internal_state.
                            backward_differences[0], np.inf),
                        0.,
                        message=
                        '`previous_solver_internal_state` does not match '
                        '`initial_state`.'))
                assert_ops.append(
                    assert_initial_state_matches_previous_solver_internal_state
                )
            if solution_times_chosen_by_solver:
                assert_ops.append(
                    util.assert_positive(final_time - initial_time,
                                         'final_time - initial_time'))
            else:
                assert_ops += [
                    util.assert_increasing(solution_times, 'solution_times'),
                    util.assert_nonnegative(
                        solution_times[0] - initial_time,
                        'solution_times[0] - initial_time'),
                ]
            if max_num_steps is not None:
                assert_ops.append(
                    util.assert_positive(max_num_steps, 'max_num_steps'))
            if max_num_newton_iters is not None:
                assert_ops.append(
                    util.assert_positive(max_num_newton_iters,
                                         'max_num_newton_iters'))
            assert_ops += [
                util.assert_positive(rtol, 'rtol'),
                util.assert_positive(atol, 'atol'),
                util.assert_positive(first_step_size, 'first_step_size'),
                util.assert_positive(safety_factor, 'safety_factor'),
                util.assert_positive(min_step_size_factor,
                                     'min_step_size_factor'),
                util.assert_positive(max_step_size_factor,
                                     'max_step_size_factor'),
                tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER),
                          [
                              '`max_order` must be between 1 and {}.'.format(
                                  bdf_util.MAX_ORDER)
                          ]),
                util.assert_positive(newton_tol_factor, 'newton_tol_factor'),
                util.assert_positive(newton_step_size_factor,
                                     'newton_step_size_factor'),
            ]
            return assert_ops

        def advance_to_solution_time(n, diagnostics, iterand,
                                     solver_internal_state, state_vec_array,
                                     time_array):
            """Takes multiple steps to advance time to `solution_times[n]`."""
            def step_cond(next_time, diagnostics, iterand, *_):
                return (iterand.time < next_time) & (tf.equal(
                    diagnostics.status, 0))

            nth_solution_time = solution_time_array.read(n)
            [
                _, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ] = tf.while_loop(step_cond, step, [
                nth_solution_time, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ])
            state_vec_array = state_vec_array.write(
                n, solver_internal_state.backward_differences[0])
            time_array = time_array.write(n, nth_solution_time)
            return (n + 1, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def step(next_time, diagnostics, iterand, solver_internal_state,
                 state_vec_array, time_array):
            """Takes a single step."""
            distance_to_next_time = next_time - iterand.time
            overstepped = iterand.new_step_size > distance_to_next_time
            iterand = iterand._replace(new_step_size=tf1.where(
                overstepped, distance_to_next_time, iterand.new_step_size),
                                       should_update_step_size=overstepped
                                       | iterand.should_update_step_size)

            if not self._evaluate_jacobian_lazily:
                diagnostics = diagnostics._replace(
                    num_jacobian_evaluations=diagnostics.
                    num_jacobian_evaluations + 1)
                iterand = iterand._replace(jacobian_mat=jacobian_fn_mat(
                    iterand.time,
                    solver_internal_state.backward_differences[0]),
                                           jacobian_is_up_to_date=True)

            def maybe_step_cond(accepted, diagnostics, *_):
                return tf.logical_not(accepted) & tf.equal(
                    diagnostics.status, 0)

            _, diagnostics, iterand, solver_internal_state = tf.while_loop(
                maybe_step_cond, maybe_step,
                [False, diagnostics, iterand, solver_internal_state])

            if solution_times_chosen_by_solver:
                state_vec_array = state_vec_array.write(
                    state_vec_array.size(),
                    solver_internal_state.backward_differences[0])
                time_array = time_array.write(time_array.size(), iterand.time)

            return (next_time, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            [backward_differences, order, step_size] = solver_internal_state

            if max_num_steps is not None:
                status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf1.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf1.where(should_update_step_size, new_step_size,
                                  step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf1.where(should_update_step_size, 0,
                                            num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian_mat, newton_coefficients_array.read(order),
                    step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian_mat = jacobian_fn_mat(time,
                                                       backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian_mat, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian_mat, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian_mat, jacobian_is_up_to_date,
                    num_jacobian_evaluations, unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = atol + rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state_vec,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                ode_fn_vec, order, step_size, time, newton_tol,
                                unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf1.where(should_update_step_size,
                                                  newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf1.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf1.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf1.where(accepted, time + step_size, time)
            backward_differences = tf1.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state_vec, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                            num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.get_shape()).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf1.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, step_size)
            return accepted, diagnostics, iterand, solver_internal_state

        # (1) Make static assertions.
        # TODO(b/138304296): Support specifying Jacobian sparsity patterns.
        if jacobian_sparsity is not None:
            raise NotImplementedError(
                'The BDF solver does not support specifying '
                'Jacobian sparsity patterns.')
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'The BDF solver does not support batching.')
        solution_times_chosen_by_solver = (isinstance(solution_times,
                                                      base.ChosenBySolver))

        with tf.name_scope(self._name):

            # (2) Convert to tensors.
            error_if_wrong_dtype = functools.partial(
                util.error_if_not_real_or_complex, identifier='initial_state')

            initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                  initial_state)
            tf.nest.map_structure(error_if_wrong_dtype, initial_state)

            state_shape = tf.nest.map_structure(tf.shape, initial_state)
            common_state_dtype = dtype_util.common_dtype(initial_state)
            real_dtype = dtype_util.real_dtype(common_state_dtype)

            if jacobian_fn is None and common_state_dtype.is_complex:
                raise NotImplementedError(
                    'The BDF solver does not support automatic '
                    'Jacobian computations for complex dtypes.')

            # Convert everything to operate on a single, concatenated vector form.
            initial_state_vec = util.get_state_vec(initial_state)
            ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape)
            jacobian_fn_mat = util.get_jacobian_fn_mat(
                jacobian_fn,
                ode_fn_vec,
                state_shape,
                use_pfor=self._use_pfor_to_compute_jacobian,
                dtype=common_state_dtype,
            )

            num_odes = tf.size(initial_state_vec)
            # Use tf.cast instead of tf.convert_to_tensor for differentiable
            # parameters because the tf.custom_gradient decorator converts raw floats
            # into tf.float32, which cannot be converted to tf.float64.
            initial_time = tf.cast(initial_time, real_dtype)
            num_solution_times = 0
            if solution_times_chosen_by_solver:
                final_time = tf.cast(solution_times.final_time, real_dtype)
            else:
                solution_times = tf.cast(solution_times, real_dtype)
                num_solution_times = tf.size(solution_times)
                solution_time_array = tf.TensorArray(
                    solution_times.dtype,
                    size=num_solution_times,
                    element_shape=[]).unstack(solution_times)
                util.error_if_not_vector(solution_times, 'solution_times')
            rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
            atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
            safety_factor = tf.convert_to_tensor(self._safety_factor,
                                                 dtype=real_dtype)
            min_step_size_factor = tf.convert_to_tensor(
                self._min_step_size_factor, dtype=real_dtype)
            max_step_size_factor = tf.convert_to_tensor(
                self._max_step_size_factor, dtype=real_dtype)
            max_num_steps = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
            max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32)
            max_num_newton_iters = self._max_num_newton_iters
            if max_num_newton_iters is not None:
                max_num_newton_iters = tf.convert_to_tensor(
                    max_num_newton_iters, dtype=tf.int32)
            newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor,
                                                     dtype=real_dtype)
            newton_step_size_factor = tf.convert_to_tensor(
                self._newton_step_size_factor, dtype=real_dtype)
            bdf_coefficients = tf.cast(
                tf.concat([[0.],
                           tf.convert_to_tensor(self._bdf_coefficients,
                                                dtype=real_dtype)], 0),
                common_state_dtype)
            util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients')
            if self._validate_args:
                initial_time = tf.ensure_shape(initial_time, [])
                if solution_times_chosen_by_solver:
                    final_time = tf.ensure_shape(final_time, [])
                safety_factor = tf.ensure_shape(safety_factor, [])
                min_step_size_factor = tf.ensure_shape(min_step_size_factor,
                                                       [])
                max_step_size_factor = tf.ensure_shape(max_step_size_factor,
                                                       [])
                if max_num_steps is not None:
                    max_num_steps = tf.ensure_shape(max_num_steps, [])
                max_order = tf.ensure_shape(max_order, [])
                if max_num_newton_iters is not None:
                    max_num_newton_iters = tf.ensure_shape(
                        max_num_newton_iters, [])
                newton_tol_factor = tf.ensure_shape(newton_tol_factor, [])
                newton_step_size_factor = tf.ensure_shape(
                    newton_step_size_factor, [])
                bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6])
            newton_coefficients = 1. / (
                (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS)
            newton_coefficients_array = tf.TensorArray(
                newton_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(newton_coefficients)
            error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / (
                bdf_util.ORDERS + 1)
            error_coefficients_array = tf.TensorArray(
                error_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(error_coefficients)
            first_step_size = self._first_step_size
            if first_step_size is None:
                first_step_size = bdf_util.first_step_size(
                    atol, error_coefficients_array.read(1), initial_state_vec,
                    initial_time, ode_fn_vec, rtol, safety_factor)
            elif previous_solver_internal_state is not None:
                tf.logging.warn(
                    '`first_step_size` is ignored since'
                    '`previous_solver_internal_state` was specified.')
            first_step_size = tf.convert_to_tensor(first_step_size,
                                                   dtype=real_dtype)
            if self._validate_args:
                first_step_size = tf.ensure_shape(first_step_size, [])
            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                first_order_backward_difference = ode_fn_vec(
                    initial_time, initial_state_vec) * tf.cast(
                        first_step_size, common_state_dtype)
                backward_differences = tf.concat([
                    initial_state_vec[tf.newaxis, :],
                    first_order_backward_difference[tf.newaxis, :],
                    tf.zeros(tf.stack([bdf_util.MAX_ORDER + 1, num_odes]),
                             dtype=common_state_dtype),
                ], 0)
                solver_internal_state = _BDFSolverInternalState(
                    backward_differences=backward_differences,
                    order=1,
                    step_size=first_step_size)
            state_vec_array = tf.TensorArray(
                common_state_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=initial_state_vec.get_shape())
            time_array = tf.TensorArray(
                real_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=tf.TensorShape([]))
            diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0,
                                          num_matrix_factorizations=0,
                                          num_ode_fn_evaluations=0,
                                          status=0)
            iterand = _BDFIterand(
                jacobian_mat=tf.zeros([num_odes, num_odes],
                                      dtype=common_state_dtype),
                jacobian_is_up_to_date=False,
                new_step_size=solver_internal_state.step_size,
                num_steps=0,
                num_steps_same_size=0,
                should_update_jacobian=True,
                should_update_step_size=False,
                time=initial_time,
                unitary=tf.zeros([num_odes, num_odes],
                                 dtype=common_state_dtype),
                upper=tf.zeros([num_odes, num_odes], dtype=common_state_dtype))

            # (3) Make non-static assertions.
            with tf.control_dependencies(assert_ops()):

                # (4) Solve up to final time.
                if solution_times_chosen_by_solver:

                    def step_cond(next_time, diagnostics, iterand, *_):
                        return (iterand.time < next_time) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(step_cond, step, [
                        final_time, diagnostics, iterand,
                        solver_internal_state, state_vec_array, time_array
                    ])

                else:

                    def advance_to_solution_time_cond(n, diagnostics, *_):
                        return (n < num_solution_times) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(
                        advance_to_solution_time_cond,
                        advance_to_solution_time, [
                            0, diagnostics, iterand, solver_internal_state,
                            state_vec_array, time_array
                        ])

                # (6) Return `Results` object.
                states = util.get_state_from_vec(state_vec_array.stack(),
                                                 state_shape)
                times = time_array.stack()
                if not solution_times_chosen_by_solver:
                    times.set_shape(solution_times.get_shape())
                    tf.nest.map_structure(
                        lambda s, ini_s: s.set_shape(
                            solution_times.get_shape(  # pylint: disable=g-long-lambda
                            ).concatenate(ini_s.shape)),
                        states,
                        initial_state)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # Static assertions
        del jacobian_fn, jacobian_sparsity  # not used by DormandPrince
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'For homogeneous batching use `batch_ndims=0`.')
        solution_times_by_solver = isinstance(solution_times,
                                              base.ChosenBySolver)

        with tf.name_scope(self._name):
            # (2) Convert to tensors, determined dtypes.
            get_dtype = lambda x: x.dtype
            error_if_wrong_dtype = functools.partial(
                util.error_if_not_real_or_complex, identifier='initial_state')

            initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                  initial_state)
            tf.nest.map_structure(error_if_wrong_dtype, initial_state)

            state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
            common_state_dtype = dtype_util.common_dtype(initial_state)
            real_dtype = dtype_util.real_dtype(common_state_dtype)

            initial_time = tf.cast(initial_time, real_dtype)
            max_num_steps = self._max_num_steps
            max_ode_fn_evals = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
                max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP
            step_size = tf.convert_to_tensor(self._first_step_size,
                                             dtype=real_dtype)
            rtol = tf.convert_to_tensor(tf.cast(self._rtol, real_dtype))
            atol = tf.convert_to_tensor(tf.cast(self._atol, real_dtype))
            safety = tf.convert_to_tensor(self._safety_factor,
                                          dtype=real_dtype)
            # Use i(d)factor notation for increasing and decreasing factors.
            ifactor, dfactor = self._max_step_size_factor, self._min_step_size_factor
            ifactor = tf.convert_to_tensor(ifactor, dtype=real_dtype)
            dfactor = tf.convert_to_tensor(dfactor, dtype=real_dtype)

            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                initial_derivative = ode_fn(initial_time, initial_state)
                initial_derivative = tf.nest.map_structure(
                    tf.convert_to_tensor, initial_derivative)
                solver_internal_state = _RungeKuttaSolverInternalState(
                    current_state=initial_state,
                    current_derivative=initial_derivative,
                    last_step_start=initial_time,
                    current_time=initial_time,
                    step_size=step_size,
                    interpolating_coefficients=[initial_state] * self.ORDER)

            num_solution_times = 0
            if solution_times_by_solver:
                final_time = tf.cast(solution_times.final_time, real_dtype)
                times_array = tf.TensorArray(real_dtype,
                                             size=num_solution_times,
                                             dynamic_size=True,
                                             element_shape=tf.TensorShape([]))
            else:
                solution_times = tf.cast(solution_times, real_dtype)
                util.error_if_not_vector(solution_times, 'solution_times')
                num_solution_times = tf.size(solution_times)
                times_array = tf.TensorArray(
                    real_dtype,
                    size=num_solution_times,
                    dynamic_size=False,
                    element_shape=[]).unstack(solution_times)

            solutions_arrays = [
                tf.TensorArray(dtype=component_dtype,
                               size=num_solution_times,
                               dynamic_size=solution_times_by_solver)
                for component_dtype in tf.nest.flatten(state_dtypes)
            ]
            solutions_arrays = tf.nest.pack_sequence_as(
                initial_state, solutions_arrays)

            rk_step = functools.partial(self._step,
                                        max_ode_fn_evals=max_ode_fn_evals,
                                        ode_fn=ode_fn,
                                        atol=atol,
                                        rtol=rtol,
                                        safety=safety,
                                        ifactor=ifactor,
                                        dfactor=dfactor)
            advance_to_solution_time = functools.partial(
                _advance_to_solution_time,
                times_array=solution_times,
                step_fn=rk_step,
                validate_args=self._validate_args)

            assert_ops = self._assert_ops(
                ode_fn=ode_fn,
                initial_time=initial_time,
                initial_state=initial_state,
                solution_times=solution_times,
                previous_solver_state=previous_solver_internal_state,
                rtol=rtol,
                atol=atol,
                first_step_size=step_size,
                safety_factor=safety,
                min_step_size_factor=ifactor,
                max_step_size_factor=dfactor,
                max_num_steps=max_num_steps,
                solution_times_by_solver=solution_times_by_solver)
            with tf.control_dependencies(assert_ops):
                ode_evals_by_now = 1 if self._validate_args else 0
                ode_evals_by_now += 1 if solver_internal_state is None else 0
                diagnostics = _DopriDiagnostics(
                    num_ode_fn_evaluations=ode_evals_by_now,
                    num_jacobian_evaluations=0,
                    num_matrix_factorizations=0,
                    status=0)

                if solution_times_by_solver:
                    r = _dense_solutions_to_final_time(
                        final_time=final_time,
                        solver_state=solver_internal_state,
                        diagnostics=diagnostics,
                        step_fn=rk_step,
                        ode_fn=ode_fn,
                        times_array=times_array,
                        solutions_arrays=solutions_arrays,
                        validate_args=self._validate_args)
                    solver_internal_state, diagnostics, times_array, solutions_arrays = r
                else:

                    def iterate_cond(time_id, *_):
                        return time_id < num_solution_times

                    [_, solver_internal_state, diagnostics, solutions_arrays
                     ] = tf.while_loop(iterate_cond,
                                       advance_to_solution_time, [
                                           0, solver_internal_state,
                                           diagnostics, solutions_arrays
                                       ],
                                       back_prop=False)

                times = times_array.stack()
                stack_components = lambda x: x.stack()
                states = tf.nest.map_structure(stack_components,
                                               solutions_arrays)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)