Esempio n. 1
0
 def test_nest_constant(self, dtype):
     ndtype = dtype_util.as_numpy_dtype(dtype)
     input_structure = (np.ones(4, dtype=ndtype), (np.eye(3, dtype=ndtype),
                                                   np.zeros(4,
                                                            dtype=ndtype)))
     ones_like_structure = rk_util.nest_constant(input_structure)
     tf.nest.assert_same_structure(input_structure, ones_like_structure)
     flat_ones_like_structure = tf.nest.flatten(ones_like_structure)
     for component in flat_ones_like_structure:
         self.assertAllClose(component, tf.ones(shape=component.shape))
Esempio n. 2
0
            def grad_fn(*dresults, **kwargs):
                """Adjoint sensitivity method to compute gradients."""
                dresults = tf.nest.pack_sequence_as(results, dresults)
                dstates = dresults.states
                # The signature grad_fn(*dresults, variables=None) is not valid Python 2
                # so use kwargs instead.
                variables = kwargs.pop('variables', [])
                assert not kwargs  # This assert should never fail.
                # TODO(b/138304303): Support complex types.
                with tf.name_scope('{}Gradients'.format(self._name)):
                    get_dtype = lambda x: x.dtype

                    def error_if_complex(dtype):
                        if dtype.is_complex:
                            raise NotImplementedError(
                                'The adjoint sensitivity method does '
                                'not support complex dtypes.')

                    state_dtypes = tf.nest.map_structure(
                        get_dtype, initial_state)
                    tf.nest.map_structure(error_if_complex, state_dtypes)
                    common_state_dtype = dtype_util.common_dtype(initial_state)
                    real_dtype = dtype_util.real_dtype(common_state_dtype)

                    # We add initial_time to ensure that we know where to stop.
                    result_times = tf.concat(
                        [[tf.cast(initial_time, real_dtype)], results.times],
                        0)
                    num_result_times = tf.size(result_times)

                    # First two components correspond to reverse and adjoint states.
                    # the last component is adjoint state for variables.
                    terminal_augmented_state = tuple([
                        rk_util.nest_constant(initial_state, 0.0),
                        rk_util.nest_constant(initial_state, 0.0),
                        tuple(
                            rk_util.nest_constant(variable, 0.0)
                            for variable in variables)
                    ])

                    # The XLA compiler does not compile code which slices/indexes using
                    # integer `Tensor`s. `TensorArray`s are used to get around this.
                    result_time_array = tf.TensorArray(
                        results.times.dtype,
                        clear_after_read=False,
                        size=num_result_times,
                        element_shape=[]).unstack(result_times)

                    # TensorArray shape should not include time dimension, hence shape[1:]
                    result_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(results.states)
                    ]
                    result_state_arrays = tf.nest.pack_sequence_as(
                        results.states, result_state_arrays)
                    dresult_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(dstates)
                    ]
                    dresult_state_arrays = tf.nest.pack_sequence_as(
                        results.states, dresult_state_arrays)

                    def augmented_ode_fn(backward_time, augmented_state):
                        """Dynamics function for the augmented system.

            Describes a differential equation that evolves the augmented state
            backwards in time to compute gradients using the adjoint method.
            Augmented state consists of 3 components `(state, adjoint_state,
            vars)` all evaluated at time `backward_time`:

            state: represents the solution of user provided `ode_fn`. The
              structure coincides with the `initial_state`.
            adjoint_state: represents the solution of adjoint sensitivity
              differential equation as discussed below. Has the same structure
              and shape as `state`.
            vars: represent the solution of the adjoint equation for variable
              gradients. Represented as a `Tuple(Tensor, ...)` with as many
              tensors as there are `variables`.

            Adjoint sensitivity equation describes the gradient of the solution
            with respect to the value of the solution at previous time t. Its
            dynamics are given by
            d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
            Which is computed as:
            d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
            d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
            where in the last line we moved adj(t)_j under derivative by
            removing gradient from it.

            Adjoint equation for the gradient with respect to every
            `tf.Variable` theta follows:
            d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
            = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

            Args:
              backward_time: Floating `Tensor` representing current time.
              augmented_state: `Tuple(state, adjoint_state, variable_grads)`

            Returns:
              negative_derivatives: Structure of `Tensor`s equal to backwards
                time derivative of the `state` componnent.
              adjoint_ode: Structure of `Tensor`s equal to backwards time
                derivative of the `adjoint_state` component.
              adjoint_variables_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `vars` component.
            """
                        # The negative signs disappears after the change of variables.
                        # The ODE solver cannot handle the case initial_time > final_time
                        # and hence a change of variables backward_time = -time is used.
                        time = -backward_time
                        state, adjoint_state, _ = augmented_state

                        with tf.GradientTape() as tape:
                            tape.watch(variables)
                            tape.watch(state)
                            derivatives = ode_fn(time, state)
                            adjoint_no_grad = tf.nest.map_structure(
                                tf.stop_gradient, adjoint_state)
                            negative_derivatives = rk_util.weighted_sum(
                                [-1.0], [derivatives])

                            def dot_prod(tensor_a, tensor_b):
                                return tf.reduce_sum(tensor_a * tensor_b)

                            # See docstring for details.
                            adjoint_dot_derivatives = tf.nest.map_structure(
                                dot_prod, adjoint_no_grad, derivatives)
                            adjoint_dot_derivatives = tf.squeeze(
                                tf.add_n(
                                    tf.nest.flatten(adjoint_dot_derivatives)))

                        adjoint_ode, adjoint_variables_ode = tape.gradient(
                            adjoint_dot_derivatives, (state, tuple(variables)),
                            unconnected_gradients=tf.UnconnectedGradients.ZERO)
                        return negative_derivatives, adjoint_ode, adjoint_variables_ode

                    def reverse_to_result_time(n, augmented_state, _):
                        """Integrates the augmented system backwards in time."""
                        lower_bound_of_integration = result_time_array.read(n)
                        upper_bound_of_integration = result_time_array.read(n -
                                                                            1)
                        _, adjoint_state, adjoint_variable_state = augmented_state
                        initial_state = _read_solution_components(
                            result_state_arrays, input_state_structure, n - 1)
                        initial_adjoint = _read_solution_components(
                            dresult_state_arrays, input_state_structure, n - 1)
                        initial_adjoint_state = rk_util.weighted_sum(
                            [1.0, 1.0], [adjoint_state, initial_adjoint])
                        initial_augmented_state = (initial_state,
                                                   initial_adjoint_state,
                                                   adjoint_variable_state)
                        # TODO(b/138304303): Allow the user to specify the Hessian of
                        # `ode_fn` so that we can get the Jacobian of the adjoint system.
                        # TODO(b/143624114): Support higher order derivatives.
                        augmented_results = self._solve(
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            solution_times=[-upper_bound_of_integration],
                            batch_ndims=batch_ndims)
                        # Results added an extra time dim of size 1, squeeze it.
                        select_result = lambda x: tf.squeeze(x, [0])
                        result_state = augmented_results.states
                        result_state = tf.nest.map_structure(
                            select_result, result_state)
                        status = augmented_results.diagnostics.status
                        return n - 1, result_state, status

                    _, augmented_state, _ = tf.while_loop(
                        lambda n, _, status: (n >= 1) & tf.equal(status, 0),
                        reverse_to_result_time,
                        (num_result_times - 1, terminal_augmented_state, 0),
                        back_prop=False)
                    _, adjoint_state, adjoint_variables = augmented_state
                    return adjoint_state, list(adjoint_variables)
Esempio n. 3
0
        def vjp_bwd(results_constants, dresults, variables=()):
            """Adjoint sensitivity method to compute gradients."""
            results, constants = results_constants
            adjoint_solver = self._make_adjoint_solver_fn()
            dstates = dresults.states
            # TODO(b/138304303): Support complex types.
            with tf.name_scope('{}Gradients'.format(self._name)):
                get_dtype = lambda x: x.dtype

                def error_if_complex(dtype):
                    if dtype_util.is_complex(dtype):
                        raise NotImplementedError(
                            'The adjoint sensitivity method does '
                            'not support complex dtypes.')

                state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
                tf.nest.map_structure(error_if_complex, state_dtypes)
                common_state_dtype = dtype_util.common_dtype(initial_state)
                real_dtype = dtype_util.real_dtype(common_state_dtype)

                # We add initial_time to ensure that we know where to stop.
                result_times = tf.concat(
                    [[tf.cast(initial_time, real_dtype)], results.times], 0)
                num_result_times = tf.size(result_times)

                # First two components correspond to reverse and adjoint states.
                # the last two component is adjoint state for variables and constants.
                terminal_augmented_state = tuple([
                    rk_util.nest_constant(initial_state, 0.0),
                    rk_util.nest_constant(initial_state, 0.0),
                    tuple(
                        rk_util.nest_constant(variable, 0.0)
                        for variable in variables),
                    rk_util.nest_constant(constants, 0.0),
                ])

                # The XLA compiler does not compile code which slices/indexes using
                # integer `Tensor`s. `TensorArray`s are used to get around this.
                result_time_array = tf.TensorArray(
                    results.times.dtype,
                    clear_after_read=False,
                    size=num_result_times,
                    element_shape=[]).unstack(result_times)

                # TensorArray shape should not include time dimension, hence shape[1:]
                result_state_arrays = [
                    tf.TensorArray(  # pylint: disable=g-complex-comprehension
                        dtype=component.dtype,
                        size=num_result_times - 1,
                        clear_after_read=False,
                        element_shape=component.shape[1:]).unstack(component)
                    for component in tf.nest.flatten(results.states)
                ]
                result_state_arrays = tf.nest.pack_sequence_as(
                    results.states, result_state_arrays)
                dresult_state_arrays = [
                    tf.TensorArray(  # pylint: disable=g-complex-comprehension
                        dtype=component.dtype,
                        size=num_result_times - 1,
                        clear_after_read=False,
                        element_shape=component.shape[1:]).unstack(component)
                    for component in tf.nest.flatten(dstates)
                ]
                dresult_state_arrays = tf.nest.pack_sequence_as(
                    results.states, dresult_state_arrays)

                def augmented_ode_fn(backward_time, augmented_state):
                    """Dynamics function for the augmented system.

          Describes a differential equation that evolves the augmented state
          backwards in time to compute gradients using the adjoint method.
          Augmented state consists of 4 components `(state, adjoint_state,
          vars, constants)` all evaluated at time `backward_time`:

          state: represents the solution of user provided `ode_fn`. The
            structure coincides with the `initial_state`.
          adjoint_state: represents the solution of the adjoint sensitivity
            differential equation as discussed below. Has the same structure
            and shape as `state`.
          variables: represent the solution of the adjoint equation for
            variable gradients. Represented as a `Tuple(Tensor, ...)` with as
            many tensors as there are `variables` variable outside this
            function.
          constants: represent the solution of the adjoint equation for
            constant gradients. Has the same structure and shape as
            `constants` variable outside this function.

          The adjoint sensitivity equation describes the gradient of the
          solution with respect to the value of the solution at a previous
          time t. Its dynamics are given by
          d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
          Which is computed as:
          d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
          d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
          where in the last line we moved adj(t)_j under derivative by
          removing gradient from it.

          Adjoint equation for the gradient with respect to every
          `tf.Variable` and constant theta follows:
          d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
          = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

          Args:
            backward_time: Floating `Tensor` representing current time.
            augmented_state: `Tuple(state, adjoint_state, variable_grads)`

          Returns:
            negative_derivatives: Structure of `Tensor`s equal to backwards
              time derivative of the `state` componnent.
            adjoint_ode: Structure of `Tensor`s equal to backwards time
              derivative of the `adjoint_state` component.
            adjoint_variables_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `vars` component.
            adjoint_constants_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `constants` component.
          """
                    # The negative signs disappears after the change of variables.
                    # The ODE solver cannot handle the case initial_time > final_time
                    # and hence a change of variables backward_time = -time is used.
                    time = -backward_time
                    state, adjoint_state, _, _ = augmented_state

                    # TODO(b/152464477): Doesn't work reliably in TF1.
                    def grad_fn(state, variables, constants):
                        del variables  # We compute these gradients via the GradientTape
                        # capturing them.
                        derivatives = ode_fn(time, state, **constants)
                        adjoint_no_grad = tf.nest.map_structure(
                            tf.stop_gradient, adjoint_state)
                        negative_derivatives = rk_util.weighted_sum(
                            [-1.0], [derivatives])

                        def dot_prod(tensor_a, tensor_b):
                            return tf.reduce_sum(tensor_a * tensor_b)

                        # See docstring for details.
                        adjoint_dot_derivatives = tf.nest.map_structure(
                            dot_prod, adjoint_no_grad, derivatives)
                        adjoint_dot_derivatives = tf.squeeze(
                            tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
                        return adjoint_dot_derivatives, negative_derivatives

                    values = (state, tuple(variables), constants)
                    ((_, negative_derivatives),
                     gradients) = tfp_gradient.value_and_gradient(
                         grad_fn, values, has_aux=True, use_gradient_tape=True)

                    (adjoint_ode, adjoint_variables_ode,
                     adjoint_constants_ode) = tf.nest.map_structure(
                         lambda v, g: tf.zeros_like(v)
                         if g is None else g, values, tuple(gradients))
                    return (negative_derivatives, adjoint_ode,
                            adjoint_variables_ode, adjoint_constants_ode)

                def make_augmented_state(n, prev_augmented_state):
                    """Constructs the augmented state for step `n`."""
                    (_, adjoint_state, adjoint_variable_state,
                     adjoint_constant_state) = prev_augmented_state
                    initial_state = _read_solution_components(
                        result_state_arrays,
                        input_state_structure,
                        n - 1,
                    )
                    initial_adjoint = _read_solution_components(
                        dresult_state_arrays,
                        input_state_structure,
                        n - 1,
                    )
                    initial_adjoint_state = rk_util.weighted_sum(
                        [1.0, 1.0], [adjoint_state, initial_adjoint])
                    augmented_state = (
                        initial_state,
                        initial_adjoint_state,
                        adjoint_variable_state,
                        adjoint_constant_state,
                    )
                    return augmented_state

                def reverse_to_result_time(n, augmented_state,
                                           solver_internal_state, _):
                    """Integrates the augmented system backwards in time."""
                    lower_bound_of_integration = result_time_array.read(n)
                    upper_bound_of_integration = result_time_array.read(n - 1)
                    initial_augmented_state = make_augmented_state(
                        n, augmented_state)
                    # TODO(b/138304303): Allow the user to specify the Hessian of
                    # `ode_fn` so that we can get the Jacobian of the adjoint system.
                    # TODO(b/143624114): Support higher order derivatives.
                    solver_internal_state = (
                        adjoint_solver.
                        _adjust_solver_internal_state_for_state_jump(  # pylint: disable=protected-access
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            previous_solver_internal_state=
                            solver_internal_state,
                            previous_state=augmented_state,
                        ))
                    augmented_results = adjoint_solver.solve(
                        ode_fn=augmented_ode_fn,
                        initial_time=-lower_bound_of_integration,
                        initial_state=initial_augmented_state,
                        solution_times=[-upper_bound_of_integration],
                        batch_ndims=batch_ndims,
                        previous_solver_internal_state=solver_internal_state,
                    )
                    # Results added an extra time dim of size 1, squeeze it.
                    select_result = lambda x: tf.squeeze(x, [0])
                    result_state = augmented_results.states
                    result_state = tf.nest.map_structure(
                        select_result, result_state)
                    status = augmented_results.diagnostics.status
                    return (n - 1, result_state,
                            augmented_results.solver_internal_state, status)

                initial_n = num_result_times - 1
                solver_internal_state = adjoint_solver._initialize_solver_internal_state(  # pylint: disable=protected-access
                    ode_fn=augmented_ode_fn,
                    initial_time=result_time_array.read(initial_n),
                    initial_state=make_augmented_state(
                        initial_n, terminal_augmented_state),
                )

                _, augmented_state, _, _ = tf.while_loop(
                    lambda n, _as, _sis, status:
                    (n >= 1) & tf.equal(status, 0),
                    reverse_to_result_time,
                    (initial_n, terminal_augmented_state,
                     solver_internal_state, 0),
                    back_prop=False,
                )
                (_, adjoint_state, adjoint_variables,
                 adjoint_constants) = augmented_state

                if variables:
                    return (adjoint_state,
                            adjoint_constants), list(adjoint_variables)
                else:
                    return adjoint_state, adjoint_constants
Esempio n. 4
0
    def _step(self, solver_state, diagnostics, max_ode_fn_evals, ode_fn, atol,
              rtol, safety, ifactor, dfactor):
        """Take an adaptive Runge-Kutta step.

    Args:
      solver_state: `_DopriSolverInternalState` - solver internal state.
      diagnostics: `_DopriDiagnostics` - info on the current `_solve` call.
      max_ode_fn_evals: Integer `Tensor` specifying the maximum number of ode_fn
        evaluations.
      ode_fn: Callable(t, y) -> dy_dt.
      atol: Absolute tolerance allowed, see `_solve` method for details.
      rtol: Relative tolerance allowed, see `_solve` method for details.
      safety: Safety factor, see `_solve` method for details.
      ifactor: Maximum factor by which the step size can increase, see `_solve`
        method for details.
      dfactor: Minimum factor by which the step size can decrease, see `_solve`
        method for details.

    Returns:
      solver_state: `_RungeKuttaSolverInternalState` holding new solver state.
        Note that the step might not advance the time if the error tolerance
        criterias were not met. In this case step_size is decreased.
      diagnostics: `_DopriDiagnostics` holding diagnostic values after RK step.
    """
        y0, f0, _, t0, dt, interp_coeff = solver_state
        assertion_ops = []
        # TODO(dkochkov) Profile performance impact of `control_dependencies` here.
        with tf.name_scope('assertions'):
            if self._max_num_steps is not None:
                check_max_num_steps = tf.debugging.assert_less(
                    diagnostics.num_ode_fn_evaluations, max_ode_fn_evals,
                    'max_num_steps exceeded')
                assertion_ops.append(check_max_num_steps)

            if self._validate_args:
                check_underflow = tf.debugging.assert_greater(
                    t0 + dt, t0, 'underflow in dt')
                assert_finite = functools.partial(
                    tf.debugging.assert_all_finite,
                    message='non-finite values in solution')
                check_numerics = tf.nest.map_structure(assert_finite, y0)
                assertion_ops.append(check_underflow)
                assertion_ops.append(check_numerics)

        with tf.control_dependencies(assertion_ops):
            y1, f1, y1_error, k = rk_util.runge_kutta_step(
                ode_fn, y0, f0, t0, dt, _TABLEAU)

        with tf.name_scope('error_ratio'):
            # We use the same criteria for accepting step as in scipy.
            abs_y0 = tf.nest.map_structure(tf.abs, y0)
            abs_y1 = tf.nest.map_structure(tf.abs, y1)
            max_y_vals = tf.nest.map_structure(tf.math.maximum, abs_y0, abs_y1)
            ones_nest = rk_util.nest_constant(abs_y0)
            error_tol = rk_util.weighted_sum([atol, rtol],
                                             [ones_nest, max_y_vals])

            def scale_errors(error_component, tol_scale):
                abs_square_error = rk_util.abs_square(error_component)
                abs_square_tol_scale = rk_util.abs_square(tol_scale)
                return tf.divide(abs_square_error, abs_square_tol_scale)

            scaled_errors = tf.nest.map_structure(scale_errors, y1_error,
                                                  error_tol)
            error_ratio = rk_util.nest_rms_norm(scaled_errors)
            accept_step = error_ratio <= 1

        with tf.name_scope('update/state'):
            y_next = rk_util.nest_where(accept_step, y1, y0)
            f_next = rk_util.nest_where(accept_step, f1, f0)
            t_next = tf.where(accept_step, t0 + dt, t0)

            new_coefficients = rk_util.rk_fourth_order_interpolation_coefficients(
                y0, y1, k, dt, _TABLEAU)
            interp_coeff = rk_util.nest_where(accept_step, new_coefficients,
                                              interp_coeff)

            dt_next = util.next_step_size(dt, self.ORDER, error_ratio, safety,
                                          dfactor, ifactor)
            solver_state = _RungeKuttaSolverInternalState(
                y_next, f_next, t0, t_next, dt_next, interp_coeff)

        diagnostics = diagnostics._replace(
            num_ode_fn_evaluations=diagnostics.num_ode_fn_evaluations +
            self.ODE_FN_EVALS_PER_STEP)

        return solver_state, diagnostics