def _prepare_common_params(self, ode_fn, initial_state, initial_time): error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_shape = tf.nest.map_structure(ps.shape, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # Use tf.cast instead of tf.convert_to_tensor for differentiable # parameters because the tf.custom_gradient decorator converts raw floats # into tf.float32, which cannot be converted to tf.float64. initial_time = tf.cast(initial_time, real_dtype) if self._validate_args: initial_time = tf.ensure_shape(initial_time, []) rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype) atol = tf.convert_to_tensor(self._atol, dtype=real_dtype) safety_factor = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) if self._validate_args: safety_factor = tf.ensure_shape(safety_factor, []) # Convert everything to operate on a single, concatenated vector form. initial_state_vec = util.get_state_vec(initial_state) ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape) num_odes = tf.size(initial_state_vec) return util.Bunch( initial_state=initial_state, initial_time=initial_time, common_state_dtype=common_state_dtype, real_dtype=real_dtype, rtol=rtol, atol=atol, safety_factor=safety_factor, state_shape=state_shape, initial_state_vec=initial_state_vec, ode_fn_vec=ode_fn_vec, num_odes=num_odes, )
def _prepare_common_params(self, initial_state, initial_time): get_dtype = lambda x: x.dtype error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_dtypes = tf.nest.map_structure(get_dtype, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) initial_time = tf.cast(initial_time, real_dtype) return util.Bunch( initial_state=initial_state, state_dtypes=state_dtypes, real_dtype=real_dtype, initial_time=initial_time, )