Exemplo n.º 1
0
    def _prepare_common_params(self, ode_fn, initial_state, initial_time):
        error_if_wrong_dtype = functools.partial(
            util.error_if_not_real_or_complex, identifier='initial_state')

        initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                              initial_state)
        tf.nest.map_structure(error_if_wrong_dtype, initial_state)

        state_shape = tf.nest.map_structure(ps.shape, initial_state)
        common_state_dtype = dtype_util.common_dtype(initial_state)
        real_dtype = dtype_util.real_dtype(common_state_dtype)
        # Use tf.cast instead of tf.convert_to_tensor for differentiable
        # parameters because the tf.custom_gradient decorator converts raw floats
        # into tf.float32, which cannot be converted to tf.float64.
        initial_time = tf.cast(initial_time, real_dtype)
        if self._validate_args:
            initial_time = tf.ensure_shape(initial_time, [])

        rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
        atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
        safety_factor = tf.convert_to_tensor(self._safety_factor,
                                             dtype=real_dtype)

        if self._validate_args:
            safety_factor = tf.ensure_shape(safety_factor, [])

        # Convert everything to operate on a single, concatenated vector form.
        initial_state_vec = util.get_state_vec(initial_state)
        ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape)
        num_odes = tf.size(initial_state_vec)

        return util.Bunch(
            initial_state=initial_state,
            initial_time=initial_time,
            common_state_dtype=common_state_dtype,
            real_dtype=real_dtype,
            rtol=rtol,
            atol=atol,
            safety_factor=safety_factor,
            state_shape=state_shape,
            initial_state_vec=initial_state_vec,
            ode_fn_vec=ode_fn_vec,
            num_odes=num_odes,
        )
Exemplo n.º 2
0
  def _prepare_common_params(self, initial_state, initial_time):
    get_dtype = lambda x: x.dtype
    error_if_wrong_dtype = functools.partial(
        util.error_if_not_real_or_complex, identifier='initial_state')

    initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state)
    tf.nest.map_structure(error_if_wrong_dtype, initial_state)

    state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
    common_state_dtype = dtype_util.common_dtype(initial_state)
    real_dtype = dtype_util.real_dtype(common_state_dtype)

    initial_time = tf.cast(initial_time, real_dtype)

    return util.Bunch(
        initial_state=initial_state,
        state_dtypes=state_dtypes,
        real_dtype=real_dtype,
        initial_time=initial_time,
    )