Exemplo n.º 1
0
def trace(state, fn, num_steps, parallel_iterations=10):
    """TF implementation of `trace` operator, without the calling convention."""
    if tf.config.experimental_functions_run_eagerly() or tf.executing_eagerly(
    ):
        state, first_untraced, first_traced = fn(state)
        arrays = tf.nest.map_structure(
            lambda v: tf.TensorArray(  # pylint: disable=g-long-lambda
                v.dtype,
                size=num_steps,
                element_shape=v.shape).write(0, v),
            first_traced)
        start_idx = 1
    else:
        # We need the shapes and dtypes of the outputs of `fn` function to create
        # the `TensorArray`s etc., we can get it by pre-compiling the wrapper
        # function.
        input_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)
        fn, (_, untraced_spec, traced_spec) = _eval_shape(fn, input_spec)

        arrays = tf.nest.map_structure(
            lambda spec: tf.TensorArray(  # pylint: disable=g-long-lambda
                spec.dtype,
                size=num_steps,
                element_shape=spec.shape),
            traced_spec)
        first_untraced = tf.nest.map_structure(
            lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec)
        start_idx = 0

    def body(i, state, _, arrays):
        state, untraced, traced = fn(state)
        arrays = tf.nest.map_structure(lambda a, e: a.write(i, e), arrays,
                                       traced)
        return i + 1, state, untraced, arrays

    def cond(i, *_):
        return i < num_steps

    _, state, untraced, arrays = tf.while_loop(
        cond=cond,
        body=body,
        loop_vars=(start_idx, state, first_untraced, arrays),
        parallel_iterations=parallel_iterations,
    )

    traced = tf.nest.map_structure(lambda a: a.stack(), arrays)

    static_length = tf.get_static_value(num_steps)

    def _merge_static_length(x):
        x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
        return x

    traced = tf.nest.map_structure(_merge_static_length, traced)

    return state, untraced, traced
Exemplo n.º 2
0
def _initialize_accumulated_quantities(observations, num_timesteps):
    """Initialize arrays passed through the filter loop."""
    initial_arrays = [
        tf.nest.map_structure(
            lambda x: tf.TensorArray(dtype=x.dtype, size=num_timesteps),
            observations) for _ in range(7)
    ]
    initial_arrays.append(
        tf.nest.map_structure(
            lambda _: tf.TensorArray(dtype=tf.int32, size=num_timesteps),
            observations))
    return KalmanFilterState(*initial_arrays)
Exemplo n.º 3
0
def _initialize_loop_variables(initial_step_results, num_timesteps, trace_fn,
                               step_indices_to_trace):
    """Initialize arrays and other quantities passed through the filter loop."""

    # Create arrays to store traced values (particles, likelihoods, etc).
    num_steps_to_trace = (num_timesteps if step_indices_to_trace is None else
                          ps.size0(step_indices_to_trace))
    traced_results = trace_fn(initial_step_results)
    trace_arrays = tf.nest.map_structure(
        lambda x: tf.TensorArray(dtype=x.dtype, size=num_steps_to_trace),
        traced_results)
    # If we are supposed to trace at step 0, write the traced values.
    num_steps_traced, trace_arrays = ps.cond(
        (True if step_indices_to_trace is None else ps.equal(
            step_indices_to_trace[0], 0)),
        lambda: (
            1,  # pylint: disable=g-long-lambda
            tf.nest.map_structure(lambda ta, x: ta.write(0, x), trace_arrays,
                                  traced_results)),
        lambda: (0, trace_arrays))

    return ParticleFilterLoopVariables(
        step=1,
        previous_step_results=initial_step_results,
        accumulated_traced_results=trace_arrays,
        num_steps_traced=num_steps_traced)
Exemplo n.º 4
0
 def call(self, inputs):
     samples = tf.TensorArray(dtype=tf.float32, size=tf.shape(inputs)[0])
     i = 0
     for sample in inputs:
         samples = samples.write(i, tf.square(sample))
         i += 1
     return samples.stack()
Exemplo n.º 5
0
def _sample(dim, drift_fn, volatility_fn, times, time_step, keep_mask,
            times_shape, num_samples, initial_state, random_type, seed,
            swap_memory, skip, dtype):
    """Returns a sample of paths from the process using Euler method."""
    dt = times[1:] - times[:-1]
    sqrt_dt = tf.sqrt(dt)
    current_state = initial_state + tf.zeros([num_samples, dim],
                                             dtype=initial_state.dtype)
    if dt.shape.is_fully_defined():
        steps_num = dt.shape.as_list()[-1]
    else:
        steps_num = tf.shape(dt)[-1]
        # TODO(b/148133811): Re-enable Sobol test when TF 2.2 is released.
        if random_type == random.RandomType.SOBOL:
            raise ValueError(
                'Sobol sequence for Euler sampling is temporarily '
                'unsupported when `time_step` or `times` have a '
                'non-constant value')
    # In order to use low-discrepancy random_type we need to generate the sequence
    # of independent random normals upfront.
    if random_type in (random.RandomType.SOBOL, random.RandomType.HALTON,
                       random.RandomType.HALTON_RANDOMIZED):
        normal_draws = utils.generate_mc_normal_draws(
            num_normal_draws=dim,
            num_time_steps=steps_num,
            num_sample_paths=num_samples,
            random_type=random_type,
            dtype=dtype,
            seed=seed,
            skip=skip)
        wiener_mean = None
    else:
        # If pseudo or anthithetic sampling is used, proceed with random sampling
        # at each step.
        wiener_mean = tf.zeros((dim, ), dtype=dtype, name='wiener_mean')
        normal_draws = None
    cond_fn = lambda i, *args: i < steps_num

    # Maximum number iterations is passed to the while loop below. It improves
    # performance of the while loop on a GPU and is needed for XLA-compilation
    # comptatiblity.
    def step_fn(i, written_count, current_state, result):
        return _euler_step(i, written_count, current_state, result, drift_fn,
                           volatility_fn, wiener_mean, num_samples, times, dt,
                           sqrt_dt, keep_mask, random_type, seed, normal_draws)

    maximum_iterations = (tf.cast(1. / time_step, dtype=tf.int32) +
                          tf.size(times))
    result = tf.TensorArray(dtype=dtype, size=times_shape[-1])
    _, _, _, result = tf.while_loop(cond_fn,
                                    step_fn, (0, 0, current_state, result),
                                    maximum_iterations=maximum_iterations,
                                    swap_memory=swap_memory)
    result = tf.transpose(result.stack(), (1, 0, 2))
    # Shape of `rate_paths` is dynamic in `times` dimension because of
    # `TensorArray`. In order to make the shape static, use `set_shape` method.
    # TODO(b/148854825): Consider removing TensorArray to make all shapes static.
    result.set_shape(current_state.shape[:1] + times_shape +
                     current_state.shape[-1:])
    return result
Exemplo n.º 6
0
def SanitizedAutoCorrelationMean(x,
                                 axis,
                                 reduce_axis,
                                 max_lags=None,
                                 **kwargs):
    shape_arr = np.array(list(x.shape))
    axes = list(sorted(set(range(len(shape_arr))) - set([reduce_axis])))
    mean_shape = shape_arr[axes]
    if max_lags is not None:
        mean_shape[axis] = max_lags + 1
    mean_state = fun_mc.running_mean_init(mean_shape, x.dtype)
    new_order = list(range(len(shape_arr)))
    new_order[0] = new_order[reduce_axis]
    new_order[reduce_axis] = 0
    x = tf.transpose(x, new_order)
    x_arr = tf.TensorArray(x.dtype, x.shape[0]).unstack(x)
    mean_state, _ = fun_mc.trace(
        state=mean_state,
        fn=lambda state: fun_mc.running_mean_step(  # pylint: disable=g-long-lambda
            state,
            SanitizedAutoCorrelation(x_arr.read(state.num_points),
                                     axis,
                                     max_lags=max_lags,
                                     **kwargs)),
        num_steps=x.shape[0],
        trace_fn=lambda *_: ())
    return mean_state.mean
Exemplo n.º 7
0
def _initialize_loop_variables(initial_step_results,
                               num_steps_state_history_to_pass, num_timesteps):
    """Initialize arrays and other quantities passed through the filter loop."""

    # Create arrays to store particles, indices, and likelihoods, and write
    # their initial values.
    step_results_arrays = tf.nest.map_structure(
        lambda x: tf.TensorArray(dtype=x.dtype, size=num_timesteps).write(
            0, x), initial_step_results)

    # Because `while_loop` requires Tensor values, we'll represent the lack of
    # state history by a static-shape empty Tensor.
    # This can be detected elsewhere by branching on
    #  `tf.is_tensor(state_history) and state_history.shape[0] == 0`.
    state_history = tf.zeros([0])
    if num_steps_state_history_to_pass:
        # Repeat the initial state, so that `state_history` always has length
        # `num_steps_state_history_to_pass`.
        state_history = tf.nest.map_structure(
            lambda x: tf.broadcast_to(  # pylint: disable=g-long-lambda
                x[tf.newaxis, ...],
                prefer_static.concat([[num_steps_state_history_to_pass],
                                      prefer_static.shape(x)],
                                     axis=0)),
            initial_step_results.particles)

    return ParticleFilterLoopVariables(
        step=1,
        previous_step_results=initial_step_results,
        accumulated_step_results=step_results_arrays,
        state_history=state_history)
Exemplo n.º 8
0
def nested_for_loops(m):
    l = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=())
    for i in m:
        s = 0
        for j in i:
            s = s * 10 + j
        l = l.write(l.size(), s)
    return l.stack()
 def _map_body(trace_state):
     if not tf.is_tensor(trace_state):
         trace_state = tf.convert_to_tensor(trace_state)
     return tf.TensorArray(dtype=trace_state.dtype,
                           size=size,
                           dynamic_size=dynamic_size,
                           element_shape=trace_state.shape,
                           clear_after_read=False)
Exemplo n.º 10
0
def _initialize_arrays(initial_values, num_steps):
    """Construct a structure of `TraceArray`s from initial values."""
    trace_arrays = tf.nest.map_structure(
        lambda t: tf.TensorArray(  # pylint: disable=g-long-lambda
            dtype=t.dtype,
            size=num_steps,  # Initial size.
            clear_after_read=False,  # Allow reading->tiling final value.
            element_shape=t.shape),
        initial_values)
    return tf.nest.map_structure(lambda ta, t: ta.write(0, t), trace_arrays,
                                 initial_values)
Exemplo n.º 11
0
def nested_while_loops(n1, n2):
    i = 0
    l = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=())
    while i < n1:
        j = 0
        s = 0
        while j < n2:
            s = s * 10 + i * j
            j += 1
        l = l.write(i, s)
        i += 1
    return l.stack()
Exemplo n.º 12
0
 def _init(shape_and_dtype):
   if USE_TENSORARRAY:
     return [
         tf.TensorArray(dtype=d,  # pylint: disable=g-complex-comprehension
                        size=self.max_tree_depth + 1,
                        element_shape=s,
                        clear_after_read=False)
         for (s, d) in shape_and_dtype]
   else:
     return [
         tf.zeros(  # pylint: disable=g-complex-comprehension
             tf.TensorShape([self.max_tree_depth + 1]).concatenate(s),
             dtype=d)
         for (s, d) in shape_and_dtype]
Exemplo n.º 13
0
def _initialize_arrays(initial_values, num_steps, truncate_at_convergence):
    """Construct a structure of `TraceArray`s from initial values."""
    num_steps_ = tf.get_static_value(tf.convert_to_tensor(num_steps))
    size_is_dynamic = (num_steps_ is None or truncate_at_convergence)
    trace_arrays = tf.nest.map_structure(
        lambda t: tf.TensorArray(  # pylint: disable=g-long-lambda
            dtype=t.dtype,
            size=1 if size_is_dynamic else num_steps_,  # Initial size.
            dynamic_size=size_is_dynamic,
            clear_after_read=False,  # Allow reading->tiling final value.
            element_shape=t.shape),
        initial_values)
    return tf.nest.map_structure(lambda ta, t: ta.write(0, t), trace_arrays,
                                 initial_values)
Exemplo n.º 14
0
def update_backward_differences(backward_differences, next_backward_difference,
                                next_state_vec, order):
    """Returns the backward differences for the next time."""
    backward_differences_array = tf.TensorArray(
        backward_differences.dtype,
        size=MAX_ORDER + 3,
        clear_after_read=False,
        element_shape=next_backward_difference.shape).unstack(
            backward_differences)
    new_backward_differences_array = tf.TensorArray(
        backward_differences.dtype,
        size=MAX_ORDER + 3,
        clear_after_read=False,
        element_shape=next_backward_difference.shape)
    new_backward_differences_array = new_backward_differences_array.write(
        order + 2,
        next_backward_difference - backward_differences_array.read(order + 1))
    new_backward_differences_array = new_backward_differences_array.write(
        order + 1, next_backward_difference)

    def body(k, new_backward_differences_array_):
        new_backward_differences_array_k = (
            new_backward_differences_array_.read(k + 1) +
            backward_differences_array.read(k))
        new_backward_differences_array_ = new_backward_differences_array_.write(
            k, new_backward_differences_array_k)
        return k - 1, new_backward_differences_array_

    _, new_backward_differences_array = tf.while_loop(
        lambda k, new_backward_differences_array: k > 0, body,
        [order, new_backward_differences_array])
    new_backward_differences_array = new_backward_differences_array.write(
        0, next_state_vec)
    new_backward_differences = new_backward_differences_array.stack()
    tensorshape_util.set_shape(new_backward_differences,
                               tf.TensorShape([MAX_ORDER + 3, None]))
    return new_backward_differences
def _initialize_accumulated_quantities(
    initial_kalman_filter_state,
    observations,
    num_timesteps,
):
  """Initialize quantities to accumulate, specifying dtype/shape."""
  initial_arrays = []
  for x in initial_kalman_filter_state:
    # pylint: disable=cell-var-from-loop
    initial_arrays.append(
        tf.nest.map_structure(
            lambda _: tf.TensorArray(  # pylint: disable=g-long-lambda
                dtype=x.dtype, element_shape=x.shape, size=num_timesteps),
            observations))
    # pylint: enable=cell-var-from-loop
  return linear_gaussian_ssm.KalmanFilterState(*initial_arrays)
Exemplo n.º 16
0
    def _sample_paths(self, times, grid_step, keep_mask, times_size,
                      num_samples, initial_state, random_type, seed,
                      swap_memory):
        """Returns a sample of paths from the process."""
        dt = times[1:] - times[:-1]
        sqrt_dt = tf.sqrt(dt)
        current_state = initial_state + tf.zeros(
            [num_samples, self.dim()], dtype=initial_state.dtype)
        steps_num = tf.shape(dt)[-1]
        wiener_mean = tf.zeros((self.dim(), 1), dtype=self._dtype)

        cond_fn = lambda i, *args: i < steps_num

        def step_fn(i, written_count, current_state, result):
            """Performs one step of Euler scheme."""
            current_time = times[i + 1]
            dw = random_ops.mv_normal_sample((num_samples, ),
                                             mean=wiener_mean,
                                             random_type=random_type,
                                             seed=seed)
            dw = dw * sqrt_dt[i]
            dt_inc = dt[i] * self.drift_fn()(current_time, current_state)  # pylint: disable=not-callable
            dw_inc = tf.squeeze(
                tf.matmul(self.volatility_fn()(current_time, current_state),
                          dw), -1)  # pylint: disable=not-callable
            next_state = current_state + dt_inc + dw_inc

            # Keep only states for times, requested by user.
            result = tf.cond(keep_mask[i + 1],
                             (lambda: result.write(written_count, next_state)),
                             (lambda: result))
            written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
            return (i + 1, written_count, next_state, result)

        # Maximum number iterations is passed to the while loop below. It improves
        # performance of the while loop on a GPU and is needed for XLA-compilation
        # comptatiblity
        maximum_iterations = (tf.cast(1. / grid_step, dtype=tf.int32) +
                              tf.size(times))
        result = tf.TensorArray(dtype=self._dtype, size=times_size)
        _, _, _, result = tf.compat.v1.while_loop(
            cond_fn,
            step_fn, (0, 0, current_state, result),
            maximum_iterations=maximum_iterations,
            swap_memory=swap_memory)

        return tf.transpose(result.stack(), (1, 0, 2))
Exemplo n.º 17
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               parallel_iterations=10,
               name=None):
    """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
    with tf.name_scope(name or 'trace_scan'), tf1.variable_scope(
            tf1.get_variable_scope()) as vs:
        if vs.caching_device is None and not tf.executing_eagerly():
            vs.set_caching_device(lambda op: op.device)

        initial_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name='initial_state'),
            initial_state)
        elems = tf.convert_to_tensor(elems, name='elems')

        length = prefer_static.size0(elems)
        static_length = length if prefer_static.is_numpy(length) else None

        # This is an TensorArray in part because of XLA, which had trouble with
        # non-statically known indices. I.e. elems[i] errored, but
        # elems_array.read(i) worked.
        elems_array = tf.TensorArray(elems.dtype,
                                     size=length,
                                     element_shape=elems.shape[1:])
        elems_array = elems_array.unstack(elems)

        trace_arrays = tf.nest.map_structure(
            lambda x: tf.TensorArray(
                x.dtype, size=length, element_shape=x.shape),
            trace_fn(initial_state))

        def _body(i, state, trace_arrays):
            state = loop_fn(state, elems_array.read(i))
            trace_arrays = tf.nest.pack_sequence_as(trace_arrays, [
                a.write(i, v) for a, v in zip(tf.nest.flatten(trace_arrays),
                                              tf.nest.flatten(trace_fn(state)))
            ])
            return i + 1, state, trace_arrays

        _, final_state, trace_arrays = tf.while_loop(
            cond=lambda i, *args: i < length,
            body=_body,
            loop_vars=(0, initial_state, trace_arrays),
            parallel_iterations=parallel_iterations)

        stacked_trace = tf.nest.map_structure(lambda x: x.stack(),
                                              trace_arrays)

        # Restore the static length if we know it.
        def _merge_static_length(x):
            x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
            return x

        stacked_trace = tf.nest.map_structure(_merge_static_length,
                                              stacked_trace)
        return final_state, stacked_trace
Exemplo n.º 18
0
            def grad_fn(*dresults, **kwargs):
                """Adjoint sensitivity method to compute gradients."""
                dresults = tf.nest.pack_sequence_as(results, dresults)
                dstates = dresults.states
                # The signature grad_fn(*dresults, variables=None) is not valid Python 2
                # so use kwargs instead.
                variables = kwargs.pop('variables', [])
                assert not kwargs  # This assert should never fail.
                # TODO(b/138304303): Support complex types.
                with tf.name_scope('{}Gradients'.format(self._name)):
                    get_dtype = lambda x: x.dtype

                    def error_if_complex(dtype):
                        if dtype.is_complex:
                            raise NotImplementedError(
                                'The adjoint sensitivity method does '
                                'not support complex dtypes.')

                    state_dtypes = tf.nest.map_structure(
                        get_dtype, initial_state)
                    tf.nest.map_structure(error_if_complex, state_dtypes)
                    common_state_dtype = dtype_util.common_dtype(initial_state)
                    real_dtype = dtype_util.real_dtype(common_state_dtype)

                    # We add initial_time to ensure that we know where to stop.
                    result_times = tf.concat(
                        [[tf.cast(initial_time, real_dtype)], results.times],
                        0)
                    num_result_times = tf.size(result_times)

                    # First two components correspond to reverse and adjoint states.
                    # the last component is adjoint state for variables.
                    terminal_augmented_state = tuple([
                        rk_util.nest_constant(initial_state, 0.0),
                        rk_util.nest_constant(initial_state, 0.0),
                        tuple(
                            rk_util.nest_constant(variable, 0.0)
                            for variable in variables)
                    ])

                    # The XLA compiler does not compile code which slices/indexes using
                    # integer `Tensor`s. `TensorArray`s are used to get around this.
                    result_time_array = tf.TensorArray(
                        results.times.dtype,
                        clear_after_read=False,
                        size=num_result_times,
                        element_shape=[]).unstack(result_times)

                    # TensorArray shape should not include time dimension, hence shape[1:]
                    result_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(results.states)
                    ]
                    result_state_arrays = tf.nest.pack_sequence_as(
                        results.states, result_state_arrays)
                    dresult_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(dstates)
                    ]
                    dresult_state_arrays = tf.nest.pack_sequence_as(
                        results.states, dresult_state_arrays)

                    def augmented_ode_fn(backward_time, augmented_state):
                        """Dynamics function for the augmented system.

            Describes a differential equation that evolves the augmented state
            backwards in time to compute gradients using the adjoint method.
            Augmented state consists of 3 components `(state, adjoint_state,
            vars)` all evaluated at time `backward_time`:

            state: represents the solution of user provided `ode_fn`. The
              structure coincides with the `initial_state`.
            adjoint_state: represents the solution of adjoint sensitivity
              differential equation as discussed below. Has the same structure
              and shape as `state`.
            vars: represent the solution of the adjoint equation for variable
              gradients. Represented as a `Tuple(Tensor, ...)` with as many
              tensors as there are `variables`.

            Adjoint sensitivity equation describes the gradient of the solution
            with respect to the value of the solution at previous time t. Its
            dynamics are given by
            d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
            Which is computed as:
            d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
            d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
            where in the last line we moved adj(t)_j under derivative by
            removing gradient from it.

            Adjoint equation for the gradient with respect to every
            `tf.Variable` theta follows:
            d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
            = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

            Args:
              backward_time: Floating `Tensor` representing current time.
              augmented_state: `Tuple(state, adjoint_state, variable_grads)`

            Returns:
              negative_derivatives: Structure of `Tensor`s equal to backwards
                time derivative of the `state` componnent.
              adjoint_ode: Structure of `Tensor`s equal to backwards time
                derivative of the `adjoint_state` component.
              adjoint_variables_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `vars` component.
            """
                        # The negative signs disappears after the change of variables.
                        # The ODE solver cannot handle the case initial_time > final_time
                        # and hence a change of variables backward_time = -time is used.
                        time = -backward_time
                        state, adjoint_state, _ = augmented_state

                        with tf.GradientTape() as tape:
                            tape.watch(variables)
                            tape.watch(state)
                            derivatives = ode_fn(time, state)
                            adjoint_no_grad = tf.nest.map_structure(
                                tf.stop_gradient, adjoint_state)
                            negative_derivatives = rk_util.weighted_sum(
                                [-1.0], [derivatives])

                            def dot_prod(tensor_a, tensor_b):
                                return tf.reduce_sum(tensor_a * tensor_b)

                            # See docstring for details.
                            adjoint_dot_derivatives = tf.nest.map_structure(
                                dot_prod, adjoint_no_grad, derivatives)
                            adjoint_dot_derivatives = tf.squeeze(
                                tf.add_n(
                                    tf.nest.flatten(adjoint_dot_derivatives)))

                        adjoint_ode, adjoint_variables_ode = tape.gradient(
                            adjoint_dot_derivatives, (state, tuple(variables)),
                            unconnected_gradients=tf.UnconnectedGradients.ZERO)
                        return negative_derivatives, adjoint_ode, adjoint_variables_ode

                    def reverse_to_result_time(n, augmented_state, _):
                        """Integrates the augmented system backwards in time."""
                        lower_bound_of_integration = result_time_array.read(n)
                        upper_bound_of_integration = result_time_array.read(n -
                                                                            1)
                        _, adjoint_state, adjoint_variable_state = augmented_state
                        initial_state = _read_solution_components(
                            result_state_arrays, input_state_structure, n - 1)
                        initial_adjoint = _read_solution_components(
                            dresult_state_arrays, input_state_structure, n - 1)
                        initial_adjoint_state = rk_util.weighted_sum(
                            [1.0, 1.0], [adjoint_state, initial_adjoint])
                        initial_augmented_state = (initial_state,
                                                   initial_adjoint_state,
                                                   adjoint_variable_state)
                        # TODO(b/138304303): Allow the user to specify the Hessian of
                        # `ode_fn` so that we can get the Jacobian of the adjoint system.
                        # TODO(b/143624114): Support higher order derivatives.
                        augmented_results = self._solve(
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            solution_times=[-upper_bound_of_integration],
                            batch_ndims=batch_ndims)
                        # Results added an extra time dim of size 1, squeeze it.
                        select_result = lambda x: tf.squeeze(x, [0])
                        result_state = augmented_results.states
                        result_state = tf.nest.map_structure(
                            select_result, result_state)
                        status = augmented_results.diagnostics.status
                        return n - 1, result_state, status

                    _, augmented_state, _ = tf.while_loop(
                        lambda n, _, status: (n >= 1) & tf.equal(status, 0),
                        reverse_to_result_time,
                        (num_result_times - 1, terminal_augmented_state, 0),
                        back_prop=False)
                    _, adjoint_state, adjoint_variables = augmented_state
                    return adjoint_state, list(adjoint_variables)
Exemplo n.º 19
0
def diag_jacobian(xs,
                  ys=None,
                  sample_shape=None,
                  fn=None,
                  parallel_iterations=10,
                  name=None):
    """Computes diagonal of the Jacobian matrix of `ys=fn(xs)` wrt `xs`.

  If `ys` is a tensor or a list of tensors of the form `(ys_1, .., ys_n)` and
  `xs` is of the form `(xs_1, .., xs_n)`, the function `jacobians_diag`
  computes the diagonal of the Jacobian matrix, i.e., the partial derivatives
  `(dys_1/dxs_1,.., dys_n/dxs_n`). For definition details, see
  https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant

  #### Example

  ##### Diagonal Hessian of the log-density of a 3D Gaussian distribution

  In this example we sample from a standard univariate normal
  distribution using MALA with `step_size` equal to 0.75.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  import numpy as np

  tfd = tfp.distributions

  dtype = np.float32
  with tf.Session(graph=tf.Graph()) as sess:
    true_mean = dtype([0, 0, 0])
    true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]])
    chol = tf.linalg.cholesky(true_cov)
    target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

    # Assume that the state is passed as a list of tensors `x` and `y`.
    # Then the target function is defined as follows:
    def target_fn(x, y):
      # Stack the input tensors together
      z = tf.concat([x, y], axis=-1) - true_mean
      return target.log_prob(z)

    sample_shape = [3, 5]
    state = [tf.ones(sample_shape + [2], dtype=dtype),
             tf.ones(sample_shape + [1], dtype=dtype)]
    fn_val, grads = tfp.math.value_and_gradient(target_fn, state)

    # We can either pass the `sample_shape` of the `state` or not, which impacts
    # computational speed of `diag_jacobian`
    _, diag_jacobian_shape_passed = diag_jacobian(
        xs=state, ys=grads, sample_shape=tf.shape(fn_val))
    _, diag_jacobian_shape_none = diag_jacobian(
        xs=state, ys=grads)

    diag_jacobian_shape_passed_ = sess.run(diag_jacobian_shape_passed)
    diag_jacobian_shape_none_ = sess.run(diag_jacobian_shape_none)

  print('hessian computed through `diag_jacobian`, sample_shape passed: ',
        np.concatenate(diag_jacobian_shape_passed_, -1))
  print('hessian computed through `diag_jacobian`, sample_shape skipped',
        np.concatenate(diag_jacobian_shape_none_, -1))

  ```

  Args:
    xs: `Tensor` or a python `list` of `Tensors` of real-like dtypes and shapes
      `sample_shape` + `event_shape_i`, where `event_shape_i` can be different
      for different tensors.
    ys: `Tensor` or a python `list` of `Tensors` of the same dtype as `xs`. Must
        broadcast with the shape of `xs`. Can be omitted if `fn` is provided.
    sample_shape: A common `sample_shape` of the input tensors of `xs`. If not,
      provided, assumed to be `[1]`, which may result in a slow performance of
      `jacobians_diag`.
    fn: Python callable that takes `xs` as an argument (or `*xs`, if it is a
      list) and returns `ys`. Might be skipped if `ys` is provided and
      `tf.enable_eager_execution()` is disabled.
    parallel_iterations: `int` that specifies the allowed number of coordinates
      of the input tensor `xs`, for which the partial derivatives `dys_i/dxs_i`
      can be computed in parallel.
    name: Python `str` name prefixed to `Ops` created by this function.
      Default value: `None` (i.e., "diag_jacobian").

  Returns:
    ys: a list, which coincides with the input `ys`, when provided.
      If the input `ys` is None, `fn(*xs)` gets computed and returned as a list.
    jacobians_diag_res: a `Tensor` or a Python list of `Tensor`s of the same
      dtypes and shapes as the input `xs`. This is the diagonal of the Jacobian
      of ys wrt xs.

  Raises:
    ValueError: if lists `xs` and `ys` have different length or both `ys` and
      `fn` are `None`, or `fn` is None in the eager execution mode.
  """
    with tf.name_scope(name or 'jacobians_diag'):
        if sample_shape is None:
            sample_shape = [1]
        # Output Jacobian diagonal
        jacobians_diag_res = []
        # Convert input `xs` to a list
        xs = list(xs) if _is_list_like(xs) else [xs]
        xs = [tf.convert_to_tensor(x) for x in xs]
        if not tf.executing_eagerly():
            if ys is None:
                if fn is None:
                    raise ValueError('Both `ys` and `fn` can not be `None`')
                else:
                    ys = fn(*xs)
            # Convert ys to a list
            ys = list(ys) if _is_list_like(ys) else [ys]
            if len(xs) != len(ys):
                raise ValueError('`xs` and `ys` should have the same length')
            for y, x in zip(ys, xs):
                # Broadcast `y` to the shape of `x`.
                y_ = y + tf.zeros_like(x)
                # Change `event_shape` to one-dimension
                y_ = tf.reshape(y, tf.concat([sample_shape, [-1]], -1))

                # Declare an iterator and tensor array loop variables for the gradients.
                n = tf.size(x) / tf.cast(tf.reduce_prod(sample_shape),
                                         dtype=tf.int32)
                n = tf.cast(n, dtype=tf.int32)
                loop_vars = [0, tf.TensorArray(x.dtype, n)]

                def loop_body(j):
                    """Loop function to compute gradients of the each direction."""
                    # Gradient along direction `j`.
                    res = tf.gradients(ys=y_[..., j], xs=x)[0]  # pylint: disable=cell-var-from-loop
                    if res is None:
                        # Return zero, if the gradient is `None`.
                        res = tf.zeros(tf.concat([sample_shape, [1]], -1),
                                       dtype=x.dtype)  # pylint: disable=cell-var-from-loop
                    else:
                        # Reshape `event_shape` to 1D
                        res = tf.reshape(res,
                                         tf.concat([sample_shape, [-1]], -1))
                        # Add artificial dimension for the case of zero shape input tensor
                        res = res[tf.newaxis, ..., j]
                    return res  # pylint: disable=cell-var-from-loop

                # Iterate over all elements of the gradient and compute second order
                # derivatives.
                _, jacobian_diag_res = tf.while_loop(
                    cond=lambda j, _: j < n,  # pylint: disable=cell-var-from-loop
                    body=lambda j, result:
                    (j + 1, result.write(j, loop_body(j))),
                    loop_vars=loop_vars,
                    parallel_iterations=parallel_iterations)

                shape_x = ps.shape(x)
                # Stack gradients together and move flattened `event_shape` to the
                # zero position
                reshaped_jacobian_diag = tf.transpose(
                    a=jacobian_diag_res.stack())
                # Reshape to the original tensor
                reshaped_jacobian_diag = tf.reshape(reshaped_jacobian_diag,
                                                    shape_x)
                jacobians_diag_res.append(reshaped_jacobian_diag)

        else:
            if fn is None:
                raise ValueError(
                    '`fn` can not be `None` when eager execution is '
                    'enabled')
            if ys is None:
                ys = fn(*xs)

            def fn_slice(i, j):
                """Broadcast y[i], flatten event shape of y[i], return y[i][..., j]."""
                def fn_broadcast(*state):
                    res = fn(*state)
                    res = list(res) if _is_list_like(res) else [res]
                    if len(res) != len(state):
                        res *= len(state)
                    res = [
                        tf.reshape(r + tf.zeros_like(s),
                                   ps.concat([sample_shape, [-1]], -1))
                        for r, s in zip(res, state)
                    ]
                    return res

                # Expand dimensions before returning in order to support 0D input `xs`
                return lambda *state: tf.expand_dims(
                    fn_broadcast(*state)[i], 0)[..., j]

            def make_loop_body(i, x):
                """Loop function to compute gradients of the each direction."""
                def _fn(j, result):
                    res = value_and_gradient(fn_slice(i, j), xs)[1][i]
                    if res is None:
                        res = tf.zeros(sample_shape, dtype=x.dtype)
                    else:
                        res = tf.reshape(res,
                                         ps.concat([sample_shape, [-1]], -1))
                        res = res[..., j]
                    return j + 1, result.write(j, res)

                return _fn

            for i, x in enumerate(xs):
                # Declare an iterator and tensor array loop variables for the gradients.
                n = ps.size(x) / ps.cast(ps.reduce_prod(sample_shape),
                                         dtype=tf.int32)
                n = ps.cast(n, dtype=tf.int32)
                loop_vars = (0,
                             tf.TensorArray(x.dtype,
                                            n,
                                            element_shape=sample_shape))

                # Iterate over all elements of the gradient and compute second order
                # derivatives.
                _, jacobian_diag_res = tf.while_loop(
                    cond=lambda j, _: j < n,
                    body=make_loop_body(i, x),
                    loop_vars=loop_vars,
                    parallel_iterations=parallel_iterations)

                shape_x = ps.shape(x)
                # Stack gradients together and move flattened `event_shape` to the
                # zero position
                reshaped_jacobian_diag = tf.transpose(
                    jacobian_diag_res.stack())
                # Reshape to the original tensor
                reshaped_jacobian_diag = tf.reshape(reshaped_jacobian_diag,
                                                    shape_x)
                jacobians_diag_res.append(reshaped_jacobian_diag)

    return ys, jacobians_diag_res
    def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
        """Internal function to map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.

    Args:
      x: a single example
    Returns:
      a tf.data.Dataset
    """
        keys = list(feature_lengths)
        partial = empty_example.copy()
        first_key, *_ = keys
        dynamic_batch_size = tf.shape(x[first_key])[0]
        outputs = {}
        for k in keys:
            outputs[k] = tf.TensorArray(tf.int32,
                                        size=0,
                                        dynamic_size=True,
                                        element_shape=[feature_lengths[k]])
            outputs[k + "_positions"] = tf.TensorArray(
                tf.int32,
                size=0,
                dynamic_size=True,
                element_shape=[feature_lengths[k]])

        for i in tf.range(0, dynamic_batch_size):
            tf.autograph.experimental.set_loop_options(shape_invariants=[(
                partial, {k: tf.TensorShape([None])
                          for k in keys_etc}
            ), (outputs, {k: tf.TensorShape(None)
                          for k in keys_etc})])

            can_append = True
            one_example = {}
            for k in keys:
                val = tf.cast(x[k][i], tf.int32)
                val = val[:tf.
                          reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
                one_example[k] = val
            for k in keys:
                can_append = tf.logical_and(
                    can_append,
                    tf.less_equal(
                        tf.size(partial[k]) + tf.size(one_example[k]),
                        feature_lengths[k]))

            if not can_append:
                partial, outputs = _write_packed_example(partial, outputs)

            new_partial = {}
            for k in keys:
                new_seq = one_example[k][:feature_lengths[k]]
                new_seq_len = tf.size(new_seq)
                new_partial[k] = tf.concat([partial[k], new_seq], 0)
                new_partial[k + "_positions"] = tf.concat([
                    partial[k + "_positions"],
                    tf.range(new_seq_len, dtype=tf.int32)
                ], 0)
            partial = new_partial

        partial, outputs = _write_packed_example(partial, outputs)
        packed = {k: outputs[k].stack() for k in keys_etc}
        for k in keys:
            packed[k + "_segment_ids"] = (tf.cumsum(
                tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32),
                axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32))
        return packed
Exemplo n.º 21
0
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # Static assertions
        del jacobian_fn, jacobian_sparsity  # not used by DormandPrince
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'For homogeneous batching use `batch_ndims=0`.')
        solution_times_by_solver = isinstance(solution_times,
                                              base.ChosenBySolver)

        with tf.name_scope(self._name):
            # (2) Convert to tensors, determine dtypes.
            p = self._prepare_common_params(initial_state, initial_time)

            max_num_steps = self._max_num_steps
            max_ode_fn_evals = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
                max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP
            step_size = tf.convert_to_tensor(self._step_size,
                                             dtype=p.real_dtype)
            rtol = tf.convert_to_tensor(tf.cast(self._rtol, p.real_dtype))
            atol = tf.convert_to_tensor(tf.cast(self._atol, p.real_dtype))
            # Use i(d)factor notation for increasing and decreasing factors.
            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                solver_internal_state = self._initialize_solver_internal_state(
                    ode_fn=ode_fn,
                    initial_state=p.initial_state,
                    initial_time=p.initial_time,
                )

            num_solution_times = 0
            if solution_times_by_solver:
                final_time = tf.cast(solution_times.final_time, p.real_dtype)
                times_array = tf.TensorArray(p.real_dtype,
                                             size=num_solution_times,
                                             dynamic_size=True,
                                             element_shape=tf.TensorShape([]))
            else:
                solution_times = tf.cast(solution_times, p.real_dtype)
                util.error_if_not_vector(solution_times, 'solution_times')
                num_solution_times = tf.size(solution_times)
                times_array = tf.TensorArray(
                    p.real_dtype,
                    size=num_solution_times,
                    dynamic_size=False,
                    element_shape=[]).unstack(solution_times)

            solutions_arrays = [
                tf.TensorArray(dtype=component_dtype,
                               size=num_solution_times,
                               dynamic_size=solution_times_by_solver)
                for component_dtype in tf.nest.flatten(p.state_dtypes)
            ]
            solutions_arrays = tf.nest.pack_sequence_as(
                p.initial_state, solutions_arrays)

            rk_step = functools.partial(self._step,
                                        max_ode_fn_evals=max_ode_fn_evals,
                                        ode_fn=ode_fn,
                                        atol=atol,
                                        rtol=rtol,
                                        safety=safety,
                                        ifactor=ifactor,
                                        dfactor=dfactor)
            advance_to_solution_time = functools.partial(
                _advance_to_solution_time,
                times_array=solution_times,
                step_fn=rk_step,
                validate_args=self._validate_args)

            assert_ops = self._assert_ops(
                ode_fn=ode_fn,
                initial_time=p.initial_time,
                initial_state=p.initial_state,
                solution_times=solution_times,
                previous_solver_state=previous_solver_internal_state,
                rtol=rtol,
                atol=atol,
                first_step_size=step_size,
                safety_factor=safety,
                min_step_size_factor=ifactor,
                max_step_size_factor=dfactor,
                max_num_steps=max_num_steps,
                solution_times_by_solver=solution_times_by_solver)
            with tf.control_dependencies(assert_ops):
                ode_evals_by_now = 1 if self._validate_args else 0
                ode_evals_by_now += 1 if solver_internal_state is None else 0
                diagnostics = _DopriDiagnostics(
                    num_ode_fn_evaluations=ode_evals_by_now,
                    num_jacobian_evaluations=0,
                    num_matrix_factorizations=0,
                    status=0)

                if solution_times_by_solver:
                    r = _dense_solutions_to_final_time(
                        final_time=final_time,
                        solver_state=solver_internal_state,
                        diagnostics=diagnostics,
                        step_fn=rk_step,
                        ode_fn=ode_fn,
                        times_array=times_array,
                        solutions_arrays=solutions_arrays,
                        validate_args=self._validate_args)
                    solver_internal_state, diagnostics, times_array, solutions_arrays = r
                else:

                    def iterate_cond(time_id, *_):
                        return time_id < num_solution_times

                    [
                        _, solver_internal_state, diagnostics, solutions_arrays
                    ] = tf.while_loop(iterate_cond, advance_to_solution_time, [
                        0, solver_internal_state, diagnostics, solutions_arrays
                    ])

                times = times_array.stack()
                stack_components = lambda x: x.stack()
                states = tf.nest.map_structure(stack_components,
                                               solutions_arrays)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)
Exemplo n.º 22
0
        def vjp_bwd(results_constants, dresults, variables=()):
            """Adjoint sensitivity method to compute gradients."""
            results, constants = results_constants
            adjoint_solver = self._make_adjoint_solver_fn()
            dstates = dresults.states
            # TODO(b/138304303): Support complex types.
            with tf.name_scope('{}Gradients'.format(self._name)):
                get_dtype = lambda x: x.dtype

                def error_if_complex(dtype):
                    if dtype_util.is_complex(dtype):
                        raise NotImplementedError(
                            'The adjoint sensitivity method does '
                            'not support complex dtypes.')

                state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
                tf.nest.map_structure(error_if_complex, state_dtypes)
                common_state_dtype = dtype_util.common_dtype(initial_state)
                real_dtype = dtype_util.real_dtype(common_state_dtype)

                # We add initial_time to ensure that we know where to stop.
                result_times = tf.concat(
                    [[tf.cast(initial_time, real_dtype)], results.times], 0)
                num_result_times = tf.size(result_times)

                # First two components correspond to reverse and adjoint states.
                # the last two component is adjoint state for variables and constants.
                terminal_augmented_state = tuple([
                    rk_util.nest_constant(initial_state, 0.0),
                    rk_util.nest_constant(initial_state, 0.0),
                    tuple(
                        rk_util.nest_constant(variable, 0.0)
                        for variable in variables),
                    rk_util.nest_constant(constants, 0.0),
                ])

                # The XLA compiler does not compile code which slices/indexes using
                # integer `Tensor`s. `TensorArray`s are used to get around this.
                result_time_array = tf.TensorArray(
                    results.times.dtype,
                    clear_after_read=False,
                    size=num_result_times,
                    element_shape=[]).unstack(result_times)

                # TensorArray shape should not include time dimension, hence shape[1:]
                result_state_arrays = [
                    tf.TensorArray(  # pylint: disable=g-complex-comprehension
                        dtype=component.dtype,
                        size=num_result_times - 1,
                        clear_after_read=False,
                        element_shape=component.shape[1:]).unstack(component)
                    for component in tf.nest.flatten(results.states)
                ]
                result_state_arrays = tf.nest.pack_sequence_as(
                    results.states, result_state_arrays)
                dresult_state_arrays = [
                    tf.TensorArray(  # pylint: disable=g-complex-comprehension
                        dtype=component.dtype,
                        size=num_result_times - 1,
                        clear_after_read=False,
                        element_shape=component.shape[1:]).unstack(component)
                    for component in tf.nest.flatten(dstates)
                ]
                dresult_state_arrays = tf.nest.pack_sequence_as(
                    results.states, dresult_state_arrays)

                def augmented_ode_fn(backward_time, augmented_state):
                    """Dynamics function for the augmented system.

          Describes a differential equation that evolves the augmented state
          backwards in time to compute gradients using the adjoint method.
          Augmented state consists of 4 components `(state, adjoint_state,
          vars, constants)` all evaluated at time `backward_time`:

          state: represents the solution of user provided `ode_fn`. The
            structure coincides with the `initial_state`.
          adjoint_state: represents the solution of the adjoint sensitivity
            differential equation as discussed below. Has the same structure
            and shape as `state`.
          variables: represent the solution of the adjoint equation for
            variable gradients. Represented as a `Tuple(Tensor, ...)` with as
            many tensors as there are `variables` variable outside this
            function.
          constants: represent the solution of the adjoint equation for
            constant gradients. Has the same structure and shape as
            `constants` variable outside this function.

          The adjoint sensitivity equation describes the gradient of the
          solution with respect to the value of the solution at a previous
          time t. Its dynamics are given by
          d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
          Which is computed as:
          d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
          d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
          where in the last line we moved adj(t)_j under derivative by
          removing gradient from it.

          Adjoint equation for the gradient with respect to every
          `tf.Variable` and constant theta follows:
          d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
          = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

          Args:
            backward_time: Floating `Tensor` representing current time.
            augmented_state: `Tuple(state, adjoint_state, variable_grads)`

          Returns:
            negative_derivatives: Structure of `Tensor`s equal to backwards
              time derivative of the `state` componnent.
            adjoint_ode: Structure of `Tensor`s equal to backwards time
              derivative of the `adjoint_state` component.
            adjoint_variables_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `vars` component.
            adjoint_constants_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `constants` component.
          """
                    # The negative signs disappears after the change of variables.
                    # The ODE solver cannot handle the case initial_time > final_time
                    # and hence a change of variables backward_time = -time is used.
                    time = -backward_time
                    state, adjoint_state, _, _ = augmented_state

                    # TODO(b/152464477): Doesn't work reliably in TF1.
                    def grad_fn(state, variables, constants):
                        del variables  # We compute these gradients via the GradientTape
                        # capturing them.
                        derivatives = ode_fn(time, state, **constants)
                        adjoint_no_grad = tf.nest.map_structure(
                            tf.stop_gradient, adjoint_state)
                        negative_derivatives = rk_util.weighted_sum(
                            [-1.0], [derivatives])

                        def dot_prod(tensor_a, tensor_b):
                            return tf.reduce_sum(tensor_a * tensor_b)

                        # See docstring for details.
                        adjoint_dot_derivatives = tf.nest.map_structure(
                            dot_prod, adjoint_no_grad, derivatives)
                        adjoint_dot_derivatives = tf.squeeze(
                            tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
                        return adjoint_dot_derivatives, negative_derivatives

                    values = (state, tuple(variables), constants)
                    ((_, negative_derivatives),
                     gradients) = tfp_gradient.value_and_gradient(
                         grad_fn, values, has_aux=True, use_gradient_tape=True)

                    (adjoint_ode, adjoint_variables_ode,
                     adjoint_constants_ode) = tf.nest.map_structure(
                         lambda v, g: tf.zeros_like(v)
                         if g is None else g, values, tuple(gradients))
                    return (negative_derivatives, adjoint_ode,
                            adjoint_variables_ode, adjoint_constants_ode)

                def make_augmented_state(n, prev_augmented_state):
                    """Constructs the augmented state for step `n`."""
                    (_, adjoint_state, adjoint_variable_state,
                     adjoint_constant_state) = prev_augmented_state
                    initial_state = _read_solution_components(
                        result_state_arrays,
                        input_state_structure,
                        n - 1,
                    )
                    initial_adjoint = _read_solution_components(
                        dresult_state_arrays,
                        input_state_structure,
                        n - 1,
                    )
                    initial_adjoint_state = rk_util.weighted_sum(
                        [1.0, 1.0], [adjoint_state, initial_adjoint])
                    augmented_state = (
                        initial_state,
                        initial_adjoint_state,
                        adjoint_variable_state,
                        adjoint_constant_state,
                    )
                    return augmented_state

                def reverse_to_result_time(n, augmented_state,
                                           solver_internal_state, _):
                    """Integrates the augmented system backwards in time."""
                    lower_bound_of_integration = result_time_array.read(n)
                    upper_bound_of_integration = result_time_array.read(n - 1)
                    initial_augmented_state = make_augmented_state(
                        n, augmented_state)
                    # TODO(b/138304303): Allow the user to specify the Hessian of
                    # `ode_fn` so that we can get the Jacobian of the adjoint system.
                    # TODO(b/143624114): Support higher order derivatives.
                    solver_internal_state = (
                        adjoint_solver.
                        _adjust_solver_internal_state_for_state_jump(  # pylint: disable=protected-access
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            previous_solver_internal_state=
                            solver_internal_state,
                            previous_state=augmented_state,
                        ))
                    augmented_results = adjoint_solver.solve(
                        ode_fn=augmented_ode_fn,
                        initial_time=-lower_bound_of_integration,
                        initial_state=initial_augmented_state,
                        solution_times=[-upper_bound_of_integration],
                        batch_ndims=batch_ndims,
                        previous_solver_internal_state=solver_internal_state,
                    )
                    # Results added an extra time dim of size 1, squeeze it.
                    select_result = lambda x: tf.squeeze(x, [0])
                    result_state = augmented_results.states
                    result_state = tf.nest.map_structure(
                        select_result, result_state)
                    status = augmented_results.diagnostics.status
                    return (n - 1, result_state,
                            augmented_results.solver_internal_state, status)

                initial_n = num_result_times - 1
                solver_internal_state = adjoint_solver._initialize_solver_internal_state(  # pylint: disable=protected-access
                    ode_fn=augmented_ode_fn,
                    initial_time=result_time_array.read(initial_n),
                    initial_state=make_augmented_state(
                        initial_n, terminal_augmented_state),
                )

                _, augmented_state, _, _ = tf.while_loop(
                    lambda n, _as, _sis, status:
                    (n >= 1) & tf.equal(status, 0),
                    reverse_to_result_time,
                    (initial_n, terminal_augmented_state,
                     solver_internal_state, 0),
                    back_prop=False,
                )
                (_, adjoint_state, adjoint_variables,
                 adjoint_constants) = augmented_state

                if variables:
                    return (adjoint_state,
                            adjoint_constants), list(adjoint_variables)
                else:
                    return adjoint_state, adjoint_constants
Exemplo n.º 23
0
def scan(f, init, xs, length=None, reverse=False):
    """Scan a function over leading array axes while carrying along state.

  See the docstring of `jax.lax.scan`
  (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for
  details.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output. Note that
      the input and output carry must have the same dtype.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value. This value must have the same structure as
      the first element of the pair returned by ``f``.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.
    length: optional integer specifying the number of loop iterations, which
      must agree with the sizes of leading axes of the arrays in ``xs`` (but can
      be used to perform scans where no input ``xs`` are needed).
    reverse: optional boolean specifying whether to run the scan iteration
      forward (the default) or in reverse, equivalent to reversing the leading
      axes of the arrays in both ``xs`` and in ``ys``.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
    init, xs = tf.nest.map_structure(
        lambda x: tf_np.asarray(x) if x is not None else None, (init, xs))
    init, xs = _np_to_tf((init, xs))

    def get_length(x):
        if x is None:
            return None
        if x.shape.rank == 0:
            raise ValueError(
                "Some array in `xs` doesn't have a leading dimension")
        return x.shape[0]

    lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs))
    for l in lengths:
        if l is not None:
            if length is None:
                length = l
            elif length != l:
                raise ValueError(
                    "There are two different leading-dimension lengths: "
                    f"{length} and {l}")
    if length is None:
        raise ValueError(
            "Can't determine length. Please set the `length` argument.")
    xs_ta = tf.nest.map_structure(
        lambda t: (
            tf.TensorArray(t.dtype, size=0, dynamic_size=True).unstack(t)  # pylint: disable=g-long-lambda
            if t is not None else None),
        xs)

    def body(i, carry, ys_ta):
        if reverse:
            i_ = length - 1 - i
        else:
            i_ = i
        xs = tf.nest.map_structure(
            lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta)
        carry, ys = _np_to_tf(f(*_tf_to_np((carry, xs))))
        ys_ta = tf.nest.map_structure(
            lambda y_ta, y: (y_ta.write(i_, y)
                             if y is not None else y_ta), ys_ta, ys)
        i = i + 1
        return i, carry, ys_ta

    xs_spec = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype)
        if t is not None else None, xs)
    _, ys_spec = eval_on_shapes(f)(init, xs_spec)
    # ys_ta can't contain None because tf.while_loop doesn't allow None in
    # loop_vars.
    ys_ta = tf.nest.map_structure(
        lambda y: tf.TensorArray(
            y.dtype if y is not None else tf.float32,
            size=0,  # pylint: disable=g-long-lambda
            dynamic_size=True),
        ys_spec)
    _, carry, ys_ta = tf.while_loop(lambda i, *_: i < length, body,
                                    (0, init, ys_ta))

    def _stack(a, spec):
        if spec is None:
            return None
        a = a.stack()
        a.set_shape((length, ) + a.shape[1:])
        return a

    ys = tf.nest.map_structure(_stack, ys_ta, ys_spec)
    return _tf_to_np((carry, ys))
Exemplo n.º 24
0
def _sample_multinomial_as_iterated_binomial(num_samples, num_classes, probs,
                                             num_trials, dtype, seed):
    """Sample a multinomial by drawing one binomial sample per class.

  The batch shape is given by broadcasting num_trials with
  remove_last_dimension(probs).

  The loop over binomial samples is a `tf.while_loop`, thus supporting a dynamic
  number of classes.

  Args:
    num_samples: Singleton integer Tensor: number of multinomial samples to
      draw.
    num_classes: Singleton integer Tensor: number of classes.
    probs: Floating Tensor with last dimension `num_classes`, of normalized
      probabilities per class.
    num_trials: Tensor of number of categorical trials each multinomial consists
      of.  num_trials[..., tf.newaxis] must broadcast with probs.
    dtype: dtype at which to emit samples.
    seed: Random seed.

  Returns:
    samples: Tensor of given dtype and shape [num_samples] + batch_shape +
      [num_classes].
  """
    with tf.name_scope('draw_sample'):
        # `convert_to_tensor(num_classes) here to avoid unstacking inside
        # `split_seed`.  We can't take advantage of the Python-list code path anyway
        # because the index at which we will take the seed is a Tensor.
        seeds = samplers.split_seed(seed,
                                    n=tf.convert_to_tensor(num_classes),
                                    salt='multinomial_draw_sample')

        def fn(i, num_trials, consumed_prob, accum):
            """Sample the counts for one class using binomial."""
            probs_here = tf.gather(probs, i, axis=-1)
            binomial_probs = tf.clip_by_value(
                probs_here / (1. - consumed_prob), 0, 1)
            seed_here = tf.gather(seeds, i, axis=0)
            binom = binomial.Binomial(total_count=num_trials,
                                      probs=binomial_probs)
            # Not passing `num_samples` to `binom.sample`, as it's is already in
            # `num_trials.shape`.
            sample = binom.sample(seed=seed_here)
            accum = accum.write(i, tf.cast(sample, dtype=dtype))
            return i + 1, num_trials - sample, consumed_prob + probs_here, accum

        num_trials = tf.cast(num_trials, probs.dtype)
        # Pre-broadcast with probs
        num_trials += tf.zeros_like(probs[..., 0])
        # Pre-enlarge for different output samples
        num_trials = _replicate_along_left(num_trials, num_samples)
        i = tf.constant(0)
        consumed_prob = tf.zeros_like(probs[..., 0])
        accum = tf.TensorArray(dtype,
                               size=num_classes,
                               element_shape=num_trials.shape)
        _, num_trials_left, _, accum = tf.while_loop(
            cond=lambda index, _0, _1, _2: tf.less(index, num_classes - 1),
            body=fn,
            loop_vars=(i, num_trials, consumed_prob, accum))
        # Force the last iteration to put all the trials into the last bucket,
        # because probs[..., -1] / (1. - consumed_prob) might numerically not be 1.
        # Also saves one iteration around the while_loop and one run of the binomial
        # sampler.
        accum = accum.write(num_classes - 1,
                            tf.cast(num_trials_left, dtype=dtype))
        # This stop_gradient is necessary to prevent spurious zero gradients coming
        # from b/138796859, and a spurious gradient through num_trials_left.
        results = tf.stop_gradient(accum.stack())
        return distribution_util.move_dimension(results, 0, -1)
Exemplo n.º 25
0
    def map_fn(x):
        """Internal function to flat_map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.
    Args:
      x: a single example
    Returns:
      a tf.data.Dataset
    """
        partial = empty_example.copy()
        i = tf.zeros([], dtype=tf.int32)
        dynamic_batch_size = tf.shape(x[keys[0]])[0]
        outputs = {}
        for k in keys:
            outputs[k] = tf.TensorArray(tf.int32,
                                        size=0,
                                        dynamic_size=True,
                                        element_shape=[length[k]])
            outputs[k + '_position'] = tf.TensorArray(
                tf.int32, size=0, dynamic_size=True, element_shape=[length[k]])

        def cond_fn(i, partial, outputs):
            del partial, outputs
            return i < dynamic_batch_size

        def body_fn(i, partial, outputs):
            """Body function for while_loop.

      Args:
        i: integer scalar
        partial: dictionary of Tensor (partially-constructed example)
        outputs: dictionary of TensorArray
      Returns:
        A triple containing the new values of the inputs.
      """
            can_append = True
            one_example = {}
            for k in keys:
                val = tf.cast(x[k][i], tf.int32)
                val = val[:tf.
                          reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
                one_example[k] = val
            for k in keys:
                can_append = tf.logical_and(
                    can_append,
                    tf.less_equal(
                        tf.size(partial[k]) + tf.size(one_example[k]),
                        length[k]))

            def false_fn():
                return write_packed_example(partial, outputs)

            def true_fn():
                return partial, outputs

            partial, outputs = tf.cond(can_append, true_fn, false_fn)
            new_partial = {}
            for k in keys:
                new_seq = one_example[k][:length[k]]
                new_seq_len = tf.size(new_seq)
                new_partial[k] = tf.concat([partial[k], new_seq], 0)
                new_partial[k + '_position'] = tf.concat([
                    partial[k + '_position'],
                    tf.range(new_seq_len, dtype=tf.int32)
                ], 0)
            partial = new_partial
            return i + 1, partial, outputs

        i, partial, outputs = \
            tf.while_loop(
                cond_fn, body_fn, (i, partial, outputs),
                shape_invariants=(
                    tf.TensorShape([]),
                    {k: tf.TensorShape([None]) for k in keys_etc},
                    {k: tf.TensorShape(None) for k in keys_etc},
                )
            )
        partial, outputs = write_packed_example(partial, outputs)
        packed = {k: outputs[k].stack() for k in keys_etc}
        for k in keys:
            packed[k + '_segmentation'] = (tf.cumsum(
                tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32),
                axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32))
        return packed
Exemplo n.º 26
0
 def slice_first_element_with_from_tensor_high_rank(self, t):
     ta = tf.TensorArray(dtype=tf.float32,
                         size=STATIC_SIZE,
                         element_shape=[STATIC_SIZE])
     ta = ta.unstack(t)
     return ta.read(0)
Exemplo n.º 27
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        start_trajectory_seed, loop_seed = samplers.split_seed(seed)

        with tf.name_scope(self.name + '.one_step'):
            state_structure = current_state
            current_state = tf.nest.flatten(current_state)
            if (tf.nest.is_nested(state_structure)
                    and (not mcmc_util.is_list_like(state_structure)
                         or len(current_state) != len(state_structure))):
                # TODO(b/170865194): Support dictionaries and other non-list-like state.
                raise TypeError(
                    'NUTS does not currently support nested or '
                    'non-list-like state structures (saw: {}).'.format(
                        state_structure))

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob,
                                                seed=start_trajectory_seed)

            def _copy(v):
                return v * ps.ones(ps.pad(
                    [2], paddings=[[0, ps.rank(v)]], constant_values=1),
                                   dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                momentum_sum=init_momentum,
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=len(self._write_instruction),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(tf.int32,
                                              size=len(self._read_instruction),
                                              clear_after_read=False).unstack(
                                                  self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
                    (iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree)),
                body=lambda iter_, seed, state, metastate: self.
                _loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate, seed),
                loop_vars=(tf.zeros([], dtype=tf.int32,
                                    name='iter'), loop_seed,
                           initial_step_state, initial_step_metastate),
                parallel_iterations=self.parallel_iterations,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                leapfrogs_taken=(new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy,
                seed=seed,
            )

            result_state = tf.nest.pack_sequence_as(
                state_structure, new_step_metastate.candidate_state.state)
            return result_state, kernel_results
Exemplo n.º 28
0
 def concat_with_tensorlist_stack(self, a, b):
     ta = tf.TensorArray(dtype=tf.float32, size=2, element_shape=[])
     ta = ta.write(0, a)
     ta = ta.write(1, b)
     return ta.stack()
Exemplo n.º 29
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(self.name + '.one_step'):
            unwrap_state_list = not tf.nest.is_nested(current_state)
            if unwrap_state_list:
                current_state = [current_state]

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob)

            def _copy(v):
                return v * prefer_static.ones(prefer_static.pad(
                    [2],
                    paddings=[[0, prefer_static.rank(v)]],
                    constant_values=1),
                                              dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                momentum_sum=init_momentum,
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=len(self._write_instruction),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(tf.int32,
                                              size=len(self._read_instruction),
                                              clear_after_read=False).unstack(
                                                  self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, state, metastate: (  # pylint: disable=g-long-lambda
                    (iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree)),
                body=lambda iter_, state, metastate: self.loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate),
                loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'),
                           initial_step_state, initial_step_metastate),
                parallel_iterations=self.parallel_iterations,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                leapfrogs_taken=(new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy)

            result_state = new_step_metastate.candidate_state.state
            if unwrap_state_list:
                result_state = result_state[0]

            return result_state, kernel_results
Exemplo n.º 30
0
    def _sample_paths(self, times, times_size, current_log_spot, current_var,
                      num_samples, random_type, keep_mask, seed, tolerance):
        """Returns a sample of paths from the process."""
        # Note: all the notations below are the same as in [1].

        # Add zeros as a starting location
        dt = times[1:] - times[:-1]
        kappa, theta, epsilon, rho = _get_parameters(  # pylint: disable=unbalanced-tuple-unpacking
            times + tf.reduce_min(dt) / 2, self._kappa, self._theta,
            self._epsilon, self._rho)
        cond_fn = lambda i, *args: i < tf.size(dt)

        def body_fn(i, written_count, current_var, current_log_spot, vol_paths,
                    log_spot_paths):
            """Simulate Heston process to the next time point."""
            time_step = dt[i]

            def _next_vol_fn():
                return _update_variance(i, kappa[i], theta[i], epsilon[i],
                                        rho[i], current_var, time_step,
                                        num_samples, random_type, seed)

            # Do not update variance if `time_step > tolerance`
            next_vol = tf.cond(
                time_step > tolerance,
                lambda: _next_vol_fn(),  # pylint: disable=unnecessary-lambda
                lambda: current_var)

            def _next_log_spot_fn():
                return _update_log_spot(i, kappa[i], theta[i], epsilon[i],
                                        rho[i], current_var, next_vol,
                                        current_log_spot, time_step,
                                        num_samples, random_type, seed)

            # Do not update state if `time_step > tolerance`
            next_log_spot = tf.cond(
                time_step > tolerance,
                lambda: _next_log_spot_fn(),  # pylint: disable=unnecessary-lambda
                lambda: current_log_spot)

            vol_paths = tf.cond(
                keep_mask[i + 1],
                lambda: vol_paths.write(written_count, next_vol),
                lambda: vol_paths)
            log_spot_paths = tf.cond(
                keep_mask[i + 1],
                lambda: log_spot_paths.write(written_count, next_log_spot),
                lambda: log_spot_paths)
            written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
            return (i + 1, written_count, next_vol, next_log_spot, vol_paths,
                    log_spot_paths)

        log_spot_paths = tf.TensorArray(dtype=self._dtype, size=times_size)
        vol_paths = tf.TensorArray(dtype=self._dtype, size=times_size)
        _, _, _, _, vol_paths, log_spot_paths = tf.compat.v2.while_loop(
            cond_fn, body_fn,
            (0, 0, current_var, current_log_spot, vol_paths, log_spot_paths))
        return tf.stack([
            tf.transpose(log_spot_paths.stack()),
            tf.transpose(vol_paths.stack())
        ], -1)