def _prepare_common_params(self, ode_fn, initial_state, initial_time): error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_shape = tf.nest.map_structure(ps.shape, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # Use tf.cast instead of tf.convert_to_tensor for differentiable # parameters because the tf.custom_gradient decorator converts raw floats # into tf.float32, which cannot be converted to tf.float64. initial_time = tf.cast(initial_time, real_dtype) if self._validate_args: initial_time = tf.ensure_shape(initial_time, []) rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype) atol = tf.convert_to_tensor(self._atol, dtype=real_dtype) safety_factor = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) if self._validate_args: safety_factor = tf.ensure_shape(safety_factor, []) # Convert everything to operate on a single, concatenated vector form. initial_state_vec = util.get_state_vec(initial_state) ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape) num_odes = tf.size(initial_state_vec) return util.Bunch( initial_state=initial_state, initial_time=initial_time, common_state_dtype=common_state_dtype, real_dtype=real_dtype, rtol=rtol, atol=atol, safety_factor=safety_factor, state_shape=state_shape, initial_state_vec=initial_state_vec, ode_fn_vec=ode_fn_vec, num_odes=num_odes, )
def _prepare_common_params(self, initial_state, initial_time): get_dtype = lambda x: x.dtype error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_dtypes = tf.nest.map_structure(get_dtype, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) initial_time = tf.cast(initial_time, real_dtype) return util.Bunch( initial_state=initial_state, state_dtypes=state_dtypes, real_dtype=real_dtype, initial_time=initial_time, )
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = ps.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated = x_rotated - tf.reduce_mean( x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = ps.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = ps.cast(x_len, np.float64) target_length = ps.pow(np.float64(2.), ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.))) pad_length = ps.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) tensorshape_util.set_shape(shifted_product_chopped, chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = ps.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - ps.range(0., max_lags + 1.) denominator = ps.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def averaged_sum_squares(input_tensor): num_elements_cast = tf.cast(num_elements, dtype=dtype_util.real_dtype( input_tensor.dtype)) return tf.reduce_sum(abs_square(input_tensor)) / num_elements_cast
def grad_fn(*dresults, **kwargs): """Adjoint sensitivity method to compute gradients.""" dresults = tf.nest.pack_sequence_as(results, dresults) dstates = dresults.states # The signature grad_fn(*dresults, variables=None) is not valid Python 2 # so use kwargs instead. variables = kwargs.pop('variables', []) assert not kwargs # This assert should never fail. # TODO(b/138304303): Support complex types. with tf.name_scope('{}Gradients'.format(self._name)): get_dtype = lambda x: x.dtype def error_if_complex(dtype): if dtype.is_complex: raise NotImplementedError( 'The adjoint sensitivity method does ' 'not support complex dtypes.') state_dtypes = tf.nest.map_structure( get_dtype, initial_state) tf.nest.map_structure(error_if_complex, state_dtypes) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # We add initial_time to ensure that we know where to stop. result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # First two components correspond to reverse and adjoint states. # the last component is adjoint state for variables. terminal_augmented_state = tuple([ rk_util.nest_constant(initial_state, 0.0), rk_util.nest_constant(initial_state, 0.0), tuple( rk_util.nest_constant(variable, 0.0) for variable in variables) ]) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) # TensorArray shape should not include time dimension, hence shape[1:] result_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(results.states) ] result_state_arrays = tf.nest.pack_sequence_as( results.states, result_state_arrays) dresult_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(dstates) ] dresult_state_arrays = tf.nest.pack_sequence_as( results.states, dresult_state_arrays) def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 3 components `(state, adjoint_state, vars)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. vars: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables`. Adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _ = augmented_state with tf.GradientTape() as tape: tape.watch(variables) tape.watch(state) derivatives = ode_fn(time, state) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n( tf.nest.flatten(adjoint_dot_derivatives))) adjoint_ode, adjoint_variables_ode = tape.gradient( adjoint_dot_derivatives, (state, tuple(variables)), unconnected_gradients=tf.UnconnectedGradients.ZERO) return negative_derivatives, adjoint_ode, adjoint_variables_ode def reverse_to_result_time(n, augmented_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state, adjoint_variable_state = augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) initial_augmented_state = (initial_state, initial_adjoint_state, adjoint_variable_state) # TODO(b/138304303): Allow the user to specify the Hessian of # `ode_fn` so that we can get the Jacobian of the adjoint system. # TODO(b/143624114): Support higher order derivatives. augmented_results = self._solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return n - 1, result_state, status _, augmented_state, _ = tf.while_loop( lambda n, _, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (num_result_times - 1, terminal_augmented_state, 0), back_prop=False) _, adjoint_state, adjoint_variables = augmented_state return adjoint_state, list(adjoint_variables)
def vjp_bwd(results_constants, dresults, variables=()): """Adjoint sensitivity method to compute gradients.""" results, constants = results_constants adjoint_solver = self._make_adjoint_solver_fn() dstates = dresults.states # TODO(b/138304303): Support complex types. with tf.name_scope('{}Gradients'.format(self._name)): get_dtype = lambda x: x.dtype def error_if_complex(dtype): if dtype_util.is_complex(dtype): raise NotImplementedError( 'The adjoint sensitivity method does ' 'not support complex dtypes.') state_dtypes = tf.nest.map_structure(get_dtype, initial_state) tf.nest.map_structure(error_if_complex, state_dtypes) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # We add initial_time to ensure that we know where to stop. result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # First two components correspond to reverse and adjoint states. # the last two component is adjoint state for variables and constants. terminal_augmented_state = tuple([ rk_util.nest_constant(initial_state, 0.0), rk_util.nest_constant(initial_state, 0.0), tuple( rk_util.nest_constant(variable, 0.0) for variable in variables), rk_util.nest_constant(constants, 0.0), ]) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) # TensorArray shape should not include time dimension, hence shape[1:] result_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, clear_after_read=False, element_shape=component.shape[1:]).unstack(component) for component in tf.nest.flatten(results.states) ] result_state_arrays = tf.nest.pack_sequence_as( results.states, result_state_arrays) dresult_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, clear_after_read=False, element_shape=component.shape[1:]).unstack(component) for component in tf.nest.flatten(dstates) ] dresult_state_arrays = tf.nest.pack_sequence_as( results.states, dresult_state_arrays) def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 4 components `(state, adjoint_state, vars, constants)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of the adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. variables: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables` variable outside this function. constants: represent the solution of the adjoint equation for constant gradients. Has the same structure and shape as `constants` variable outside this function. The adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at a previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` and constant theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. adjoint_constants_ode: Structure of `Tensor`s equal to backwards time derivative of the `constants` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _, _ = augmented_state # TODO(b/152464477): Doesn't work reliably in TF1. def grad_fn(state, variables, constants): del variables # We compute these gradients via the GradientTape # capturing them. derivatives = ode_fn(time, state, **constants) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n(tf.nest.flatten(adjoint_dot_derivatives))) return adjoint_dot_derivatives, negative_derivatives values = (state, tuple(variables), constants) ((_, negative_derivatives), gradients) = tfp_gradient.value_and_gradient( grad_fn, values, has_aux=True, use_gradient_tape=True) (adjoint_ode, adjoint_variables_ode, adjoint_constants_ode) = tf.nest.map_structure( lambda v, g: tf.zeros_like(v) if g is None else g, values, tuple(gradients)) return (negative_derivatives, adjoint_ode, adjoint_variables_ode, adjoint_constants_ode) def make_augmented_state(n, prev_augmented_state): """Constructs the augmented state for step `n`.""" (_, adjoint_state, adjoint_variable_state, adjoint_constant_state) = prev_augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1, ) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1, ) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) augmented_state = ( initial_state, initial_adjoint_state, adjoint_variable_state, adjoint_constant_state, ) return augmented_state def reverse_to_result_time(n, augmented_state, solver_internal_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) initial_augmented_state = make_augmented_state( n, augmented_state) # TODO(b/138304303): Allow the user to specify the Hessian of # `ode_fn` so that we can get the Jacobian of the adjoint system. # TODO(b/143624114): Support higher order derivatives. solver_internal_state = ( adjoint_solver. _adjust_solver_internal_state_for_state_jump( # pylint: disable=protected-access ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, previous_solver_internal_state= solver_internal_state, previous_state=augmented_state, )) augmented_results = adjoint_solver.solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims, previous_solver_internal_state=solver_internal_state, ) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return (n - 1, result_state, augmented_results.solver_internal_state, status) initial_n = num_result_times - 1 solver_internal_state = adjoint_solver._initialize_solver_internal_state( # pylint: disable=protected-access ode_fn=augmented_ode_fn, initial_time=result_time_array.read(initial_n), initial_state=make_augmented_state( initial_n, terminal_augmented_state), ) _, augmented_state, _, _ = tf.while_loop( lambda n, _as, _sis, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (initial_n, terminal_augmented_state, solver_internal_state, 0), back_prop=False, ) (_, adjoint_state, adjoint_variables, adjoint_constants) = augmented_state if variables: return (adjoint_state, adjoint_constants), list(adjoint_variables) else: return adjoint_state, adjoint_constants
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # This function is comprised of the following sequential stages: # (1) Make static assertions. # (2) Initialize variables. # (3) Make non-static assertions. # (4) Solve up to final time. # (5) Return `Results` object. # # The stages can be found in the code by searching for (n) where n=1..5. # # By static vs. non-static assertions (see stages 1 and 3), we mean # assertions that can be made before the graph is run vs. those that can # only be made at run time. The latter are constructed as a list of # tf.Assert operations by the function `assert_ops` (see below). # # If `solution_times` is specified as a `Tensor`, stage 4 consists of three # nested loops, which can be conceptually understood as follows: # ``` # current_time, current_state = initial_time, initial_state # order, step_size = 1, first_step_size # for solution_time in solution_times: # while current_time < solution_time: # while True: # next_time = current_time + step_size # next_state, error = ( # solve_nonlinear_equation_to_get_approximate_state_at_next_time( # current_time, current_state, next_time, order)) # if error < tolerance: # current_time, current_state = next_time, next_state # order, step_size = ( # maybe_update_order_and_step_size(order, step_size)) # break # else: # step_size = decrease_step_size(step_size) # ``` # The outermost loop advances the solver to the next `solution_time` (see # `advance_to_solution_time`). The middle loop advances the solver by a # small timestep (see `step`). The innermost loop determines the size of # that timestep (see `maybe_step`). # # If `solution_times` is specified as # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped # and `solution_time` in the middle loop is replaced by `final_time`. def assert_ops(): """Creates a list of assert operations.""" if not self._validate_args: return [] assert_ops = [] if previous_solver_internal_state is not None: assert_initial_state_matches_previous_solver_internal_state = ( tf.debugging.assert_near( tf.norm( initial_state_vec - previous_solver_internal_state. backward_differences[0], np.inf), 0., message= '`previous_solver_internal_state` does not match ' '`initial_state`.')) assert_ops.append( assert_initial_state_matches_previous_solver_internal_state ) if solution_times_chosen_by_solver: assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) else: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative( solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if max_num_steps is not None: assert_ops.append( util.assert_positive(max_num_steps, 'max_num_steps')) if max_num_newton_iters is not None: assert_ops.append( util.assert_positive(max_num_newton_iters, 'max_num_newton_iters')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive(min_step_size_factor, 'min_step_size_factor'), util.assert_positive(max_step_size_factor, 'max_step_size_factor'), tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [ '`max_order` must be between 1 and {}.'.format( bdf_util.MAX_ORDER) ]), util.assert_positive(newton_tol_factor, 'newton_tol_factor'), util.assert_positive(newton_step_size_factor, 'newton_step_size_factor'), ] return assert_ops def advance_to_solution_time(n, diagnostics, iterand, solver_internal_state, state_vec_array, time_array): """Takes multiple steps to advance time to `solution_times[n]`.""" def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) nth_solution_time = solution_time_array.read(n) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop(step_cond, step, [ nth_solution_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) state_vec_array = state_vec_array.write( n, solver_internal_state.backward_differences[0]) time_array = time_array.write(n, nth_solution_time) return (n + 1, diagnostics, iterand, solver_internal_state, state_vec_array, time_array) def step(next_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array): """Takes a single step.""" distance_to_next_time = next_time - iterand.time overstepped = iterand.new_step_size > distance_to_next_time iterand = iterand._replace(new_step_size=tf1.where( overstepped, distance_to_next_time, iterand.new_step_size), should_update_step_size=overstepped | iterand.should_update_step_size) if not self._evaluate_jacobian_lazily: diagnostics = diagnostics._replace( num_jacobian_evaluations=diagnostics. num_jacobian_evaluations + 1) iterand = iterand._replace(jacobian_mat=jacobian_fn_mat( iterand.time, solver_internal_state.backward_differences[0]), jacobian_is_up_to_date=True) def maybe_step_cond(accepted, diagnostics, *_): return tf.logical_not(accepted) & tf.equal( diagnostics.status, 0) _, diagnostics, iterand, solver_internal_state = tf.while_loop( maybe_step_cond, maybe_step, [False, diagnostics, iterand, solver_internal_state]) if solution_times_chosen_by_solver: state_vec_array = state_vec_array.write( state_vec_array.size(), solver_internal_state.backward_differences[0]) time_array = time_array.write(time_array.size(), iterand.time) return (next_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array) def maybe_step(accepted, diagnostics, iterand, solver_internal_state): """Takes a single step only if the outcome has a low enough error.""" [ num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status ] = diagnostics [ jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper ] = iterand [backward_differences, order, step_size] = solver_internal_state if max_num_steps is not None: status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0) backward_differences = tf1.where( should_update_step_size, bdf_util.interpolate_backward_differences( backward_differences, order, new_step_size / step_size), backward_differences) step_size = tf1.where(should_update_step_size, new_step_size, step_size) should_update_factorization = should_update_step_size num_steps_same_size = tf1.where(should_update_step_size, 0, num_steps_same_size) def update_factorization(): return bdf_util.newton_qr( jacobian_mat, newton_coefficients_array.read(order), step_size) if self._evaluate_jacobian_lazily: def update_jacobian_and_factorization(): new_jacobian_mat = jacobian_fn_mat(time, backward_differences[0]) new_unitary, new_upper = update_factorization() return [ new_jacobian_mat, True, num_jacobian_evaluations + 1, new_unitary, new_upper ] def maybe_update_factorization(): new_unitary, new_upper = tf.cond( should_update_factorization, update_factorization, lambda: [unitary, upper]) return [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, new_unitary, new_upper ] [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, unitary, upper ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization, maybe_update_factorization) else: unitary, upper = update_factorization() num_matrix_factorizations += 1 tol = atol + rtol * tf.abs(backward_differences[0]) newton_tol = newton_tol_factor * tf.norm(tol) [ newton_converged, next_backward_difference, next_state_vec, newton_num_iters ] = bdf_util.newton(backward_differences, max_num_newton_iters, newton_coefficients_array.read(order), ode_fn_vec, order, step_size, time, newton_tol, unitary, upper) num_steps += 1 num_ode_fn_evaluations += newton_num_iters # If Newton's method failed and the Jacobian was up to date, decrease the # step size. newton_failed = tf.logical_not(newton_converged) should_update_step_size = newton_failed & jacobian_is_up_to_date new_step_size = step_size * tf1.where(should_update_step_size, newton_step_size_factor, 1.) # If Newton's method failed and the Jacobian was NOT up to date, update # the Jacobian. should_update_jacobian = newton_failed & tf.logical_not( jacobian_is_up_to_date) error_ratio = tf1.where( newton_converged, bdf_util.error_ratio(next_backward_difference, error_coefficients_array.read(order), tol), np.nan) accepted = error_ratio < 1. converged_and_rejected = newton_converged & tf.logical_not( accepted) # If Newton's method converged but the solution was NOT accepted, decrease # the step size. new_step_size = tf1.where( converged_and_rejected, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = should_update_step_size | converged_and_rejected # If Newton's method converged and the solution was accepted, update the # matrix of backward differences. time = tf1.where(accepted, time + step_size, time) backward_differences = tf1.where( accepted, bdf_util.update_backward_differences(backward_differences, next_backward_difference, next_state_vec, order), backward_differences) jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not( accepted) num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1, num_steps_same_size) # Order and step size are only updated if we have taken strictly more than # order + 1 steps of the same size. This is to prevent the order from # being throttled. should_update_order_and_step_size = accepted & (num_steps_same_size > order + 1) backward_differences_array = tf.TensorArray( backward_differences.dtype, size=bdf_util.MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.get_shape()).unstack( backward_differences) new_order = order new_error_ratio = error_ratio for offset in [-1, +1]: proposed_order = tf.clip_by_value(order + offset, 1, max_order) proposed_error_ratio = bdf_util.error_ratio( backward_differences_array.read(proposed_order + 1), error_coefficients_array.read(proposed_order), tol) proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio new_order = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_order, new_order) new_error_ratio = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_error_ratio, new_error_ratio) order = new_order error_ratio = new_error_ratio new_step_size = tf1.where( should_update_order_and_step_size, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = (should_update_step_size | should_update_order_and_step_size) diagnostics = _BDFDiagnostics(num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status) iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper) solver_internal_state = _BDFSolverInternalState( backward_differences, order, step_size) return accepted, diagnostics, iterand, solver_internal_state # (1) Make static assertions. # TODO(b/138304296): Support specifying Jacobian sparsity patterns. if jacobian_sparsity is not None: raise NotImplementedError( 'The BDF solver does not support specifying ' 'Jacobian sparsity patterns.') if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'The BDF solver does not support batching.') solution_times_chosen_by_solver = (isinstance(solution_times, base.ChosenBySolver)) with tf.name_scope(self._name): # (2) Convert to tensors. error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_shape = tf.nest.map_structure(tf.shape, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) if jacobian_fn is None and common_state_dtype.is_complex: raise NotImplementedError( 'The BDF solver does not support automatic ' 'Jacobian computations for complex dtypes.') # Convert everything to operate on a single, concatenated vector form. initial_state_vec = util.get_state_vec(initial_state) ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape) jacobian_fn_mat = util.get_jacobian_fn_mat( jacobian_fn, ode_fn_vec, state_shape, use_pfor=self._use_pfor_to_compute_jacobian, dtype=common_state_dtype, ) num_odes = tf.size(initial_state_vec) # Use tf.cast instead of tf.convert_to_tensor for differentiable # parameters because the tf.custom_gradient decorator converts raw floats # into tf.float32, which cannot be converted to tf.float64. initial_time = tf.cast(initial_time, real_dtype) num_solution_times = 0 if solution_times_chosen_by_solver: final_time = tf.cast(solution_times.final_time, real_dtype) else: solution_times = tf.cast(solution_times, real_dtype) num_solution_times = tf.size(solution_times) solution_time_array = tf.TensorArray( solution_times.dtype, size=num_solution_times, element_shape=[]).unstack(solution_times) util.error_if_not_vector(solution_times, 'solution_times') rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype) atol = tf.convert_to_tensor(self._atol, dtype=real_dtype) safety_factor = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) min_step_size_factor = tf.convert_to_tensor( self._min_step_size_factor, dtype=real_dtype) max_step_size_factor = tf.convert_to_tensor( self._max_step_size_factor, dtype=real_dtype) max_num_steps = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32) max_num_newton_iters = self._max_num_newton_iters if max_num_newton_iters is not None: max_num_newton_iters = tf.convert_to_tensor( max_num_newton_iters, dtype=tf.int32) newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor, dtype=real_dtype) newton_step_size_factor = tf.convert_to_tensor( self._newton_step_size_factor, dtype=real_dtype) bdf_coefficients = tf.cast( tf.concat([[0.], tf.convert_to_tensor(self._bdf_coefficients, dtype=real_dtype)], 0), common_state_dtype) util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients') if self._validate_args: initial_time = tf.ensure_shape(initial_time, []) if solution_times_chosen_by_solver: final_time = tf.ensure_shape(final_time, []) safety_factor = tf.ensure_shape(safety_factor, []) min_step_size_factor = tf.ensure_shape(min_step_size_factor, []) max_step_size_factor = tf.ensure_shape(max_step_size_factor, []) if max_num_steps is not None: max_num_steps = tf.ensure_shape(max_num_steps, []) max_order = tf.ensure_shape(max_order, []) if max_num_newton_iters is not None: max_num_newton_iters = tf.ensure_shape( max_num_newton_iters, []) newton_tol_factor = tf.ensure_shape(newton_tol_factor, []) newton_step_size_factor = tf.ensure_shape( newton_step_size_factor, []) bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6]) newton_coefficients = 1. / ( (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS) newton_coefficients_array = tf.TensorArray( newton_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(newton_coefficients) error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / ( bdf_util.ORDERS + 1) error_coefficients_array = tf.TensorArray( error_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(error_coefficients) first_step_size = self._first_step_size if first_step_size is None: first_step_size = bdf_util.first_step_size( atol, error_coefficients_array.read(1), initial_state_vec, initial_time, ode_fn_vec, rtol, safety_factor) elif previous_solver_internal_state is not None: tf.logging.warn( '`first_step_size` is ignored since' '`previous_solver_internal_state` was specified.') first_step_size = tf.convert_to_tensor(first_step_size, dtype=real_dtype) if self._validate_args: first_step_size = tf.ensure_shape(first_step_size, []) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: first_order_backward_difference = ode_fn_vec( initial_time, initial_state_vec) * tf.cast( first_step_size, common_state_dtype) backward_differences = tf.concat([ initial_state_vec[tf.newaxis, :], first_order_backward_difference[tf.newaxis, :], tf.zeros(tf.stack([bdf_util.MAX_ORDER + 1, num_odes]), dtype=common_state_dtype), ], 0) solver_internal_state = _BDFSolverInternalState( backward_differences=backward_differences, order=1, step_size=first_step_size) state_vec_array = tf.TensorArray( common_state_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=initial_state_vec.get_shape()) time_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=tf.TensorShape([])) diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0, num_matrix_factorizations=0, num_ode_fn_evaluations=0, status=0) iterand = _BDFIterand( jacobian_mat=tf.zeros([num_odes, num_odes], dtype=common_state_dtype), jacobian_is_up_to_date=False, new_step_size=solver_internal_state.step_size, num_steps=0, num_steps_same_size=0, should_update_jacobian=True, should_update_step_size=False, time=initial_time, unitary=tf.zeros([num_odes, num_odes], dtype=common_state_dtype), upper=tf.zeros([num_odes, num_odes], dtype=common_state_dtype)) # (3) Make non-static assertions. with tf.control_dependencies(assert_ops()): # (4) Solve up to final time. if solution_times_chosen_by_solver: def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop(step_cond, step, [ final_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) else: def advance_to_solution_time_cond(n, diagnostics, *_): return (n < num_solution_times) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop( advance_to_solution_time_cond, advance_to_solution_time, [ 0, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) # (6) Return `Results` object. states = util.get_state_from_vec(state_vec_array.stack(), state_shape) times = time_array.stack() if not solution_times_chosen_by_solver: times.set_shape(solution_times.get_shape()) tf.nest.map_structure( lambda s, ini_s: s.set_shape( solution_times.get_shape( # pylint: disable=g-long-lambda ).concatenate(ini_s.shape)), states, initial_state) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # Static assertions del jacobian_fn, jacobian_sparsity # not used by DormandPrince if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'For homogeneous batching use `batch_ndims=0`.') solution_times_by_solver = isinstance(solution_times, base.ChosenBySolver) with tf.name_scope(self._name): # (2) Convert to tensors, determined dtypes. get_dtype = lambda x: x.dtype error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_dtypes = tf.nest.map_structure(get_dtype, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) initial_time = tf.cast(initial_time, real_dtype) max_num_steps = self._max_num_steps max_ode_fn_evals = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP step_size = tf.convert_to_tensor(self._first_step_size, dtype=real_dtype) rtol = tf.convert_to_tensor(tf.cast(self._rtol, real_dtype)) atol = tf.convert_to_tensor(tf.cast(self._atol, real_dtype)) safety = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) # Use i(d)factor notation for increasing and decreasing factors. ifactor, dfactor = self._max_step_size_factor, self._min_step_size_factor ifactor = tf.convert_to_tensor(ifactor, dtype=real_dtype) dfactor = tf.convert_to_tensor(dfactor, dtype=real_dtype) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: initial_derivative = ode_fn(initial_time, initial_state) initial_derivative = tf.nest.map_structure( tf.convert_to_tensor, initial_derivative) solver_internal_state = _RungeKuttaSolverInternalState( current_state=initial_state, current_derivative=initial_derivative, last_step_start=initial_time, current_time=initial_time, step_size=step_size, interpolating_coefficients=[initial_state] * self.ORDER) num_solution_times = 0 if solution_times_by_solver: final_time = tf.cast(solution_times.final_time, real_dtype) times_array = tf.TensorArray(real_dtype, size=num_solution_times, dynamic_size=True, element_shape=tf.TensorShape([])) else: solution_times = tf.cast(solution_times, real_dtype) util.error_if_not_vector(solution_times, 'solution_times') num_solution_times = tf.size(solution_times) times_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=False, element_shape=[]).unstack(solution_times) solutions_arrays = [ tf.TensorArray(dtype=component_dtype, size=num_solution_times, dynamic_size=solution_times_by_solver) for component_dtype in tf.nest.flatten(state_dtypes) ] solutions_arrays = tf.nest.pack_sequence_as( initial_state, solutions_arrays) rk_step = functools.partial(self._step, max_ode_fn_evals=max_ode_fn_evals, ode_fn=ode_fn, atol=atol, rtol=rtol, safety=safety, ifactor=ifactor, dfactor=dfactor) advance_to_solution_time = functools.partial( _advance_to_solution_time, times_array=solution_times, step_fn=rk_step, validate_args=self._validate_args) assert_ops = self._assert_ops( ode_fn=ode_fn, initial_time=initial_time, initial_state=initial_state, solution_times=solution_times, previous_solver_state=previous_solver_internal_state, rtol=rtol, atol=atol, first_step_size=step_size, safety_factor=safety, min_step_size_factor=ifactor, max_step_size_factor=dfactor, max_num_steps=max_num_steps, solution_times_by_solver=solution_times_by_solver) with tf.control_dependencies(assert_ops): ode_evals_by_now = 1 if self._validate_args else 0 ode_evals_by_now += 1 if solver_internal_state is None else 0 diagnostics = _DopriDiagnostics( num_ode_fn_evaluations=ode_evals_by_now, num_jacobian_evaluations=0, num_matrix_factorizations=0, status=0) if solution_times_by_solver: r = _dense_solutions_to_final_time( final_time=final_time, solver_state=solver_internal_state, diagnostics=diagnostics, step_fn=rk_step, ode_fn=ode_fn, times_array=times_array, solutions_arrays=solutions_arrays, validate_args=self._validate_args) solver_internal_state, diagnostics, times_array, solutions_arrays = r else: def iterate_cond(time_id, *_): return time_id < num_solution_times [_, solver_internal_state, diagnostics, solutions_arrays ] = tf.while_loop(iterate_cond, advance_to_solution_time, [ 0, solver_internal_state, diagnostics, solutions_arrays ], back_prop=False) times = times_array.stack() stack_components = lambda x: x.stack() states = tf.nest.map_structure(stack_components, solutions_arrays) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)