Beispiel #1
0
def assert_shape_equal(shape_a, shape_b):
    """Asserts that shape_a and shape_b are equal.

  If the shapes are static, raises a ValueError when the shapes
  mismatch.

  If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
  mismatch.

  Args:
    shape_a: a list containing shape of the first tensor.
    shape_b: a list containing shape of the second tensor.

  Returns:
    Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
    when the shapes are dynamic.

  Raises:
    ValueError: When shapes are both static and unequal.
  """
    if (all(isinstance(dim, int) for dim in shape_a)
            and all(isinstance(dim, int) for dim in shape_b)):
        if shape_a != shape_b:
            raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
        else:
            return tf.no_op()
    else:
        return tf.assert_equal(shape_a, shape_b)
def _setup_mcmc(model, n_chains, *, init_position=None, seed=None, **pins):
    """Construct bijector and transforms needed for windowed MCMC.

  This pins the initial model, constructs a bijector that unconstrains and
  flattens each dimension and adds a leading batch shape of `n_chains`,
  initializes a point in the unconstrained space, and constructs a transformed
  log probability using the bijector.

  Note that we must manually construct this target log probability instead of
  using a transformed transition kernel because the TTK assumes the shape
  in is the same as the shape out.

  Args:
    model: `tfd.JointDistribution`
      The model to sample from.
    n_chains: int
      Number of chains (independent examples) to run.
    init_position: Optional
      Structure of tensors at which to initialize sampling. Should have the
      same shape and structure as
      `model.experimental_pin(**pins).sample(n_chains)`.
    seed: A seed for reproducible sampling.
    **pins:
      Values passed to `model.experimental_pin`.


  Returns:
    target_log_prob_fn: Callable on the transformed space.
    initial_transformed_position: `tf.Tensor`, sampled from a uniform (-2, 2).
    bijector: `tfb.Bijector` instance, which unconstrains and flattens.
  """
    pinned_model = model.experimental_pin(**pins) if pins else model
    bijector = _get_flat_unconstraining_bijector(pinned_model)

    if init_position is None:
        raw_init_dist = initialization.init_near_unconstrained_zero(
            pinned_model)
        init_position = initialization.retry_init(
            raw_init_dist.sample,
            target_fn=pinned_model.unnormalized_log_prob,
            sample_shape=[n_chains],
            seed=seed)
    else:
        tf.nest.map_structure(lambda x, y: tf.assert_equal(x.shape, y.shape),
                              pinned_model.sample_unpinned(n_chains),
                              init_position)

    initial_transformed_position = tf.nest.map_structure(
        tf.identity, bijector.forward(init_position))

    def target_log_prob_fn(*args):
        lp = pinned_model.unnormalized_log_prob(bijector.inverse(args))
        tensorshape_util.set_shape(lp, [n_chains])
        ldj = bijector.inverse_log_det_jacobian(
            args, event_ndims=[1 for _ in initial_transformed_position])
        return lp + ldj

    return target_log_prob_fn, initial_transformed_position, bijector
Beispiel #3
0
  def update_state_across_models(self, activations1, activations2):
    """Accumulate minibatch HSIC values from different models.

    Args:
      activations1: A list of activations for all layers in model 1.
      activations2: A list of activations for all layers in model 2.
    """
    tf.assert_equal(
        tf.shape(self.hsic_accumulator)[0], len(activations1),
        'Number of activation vectors does not match num_layers.')
    tf.assert_equal(
        tf.shape(self.hsic_accumulator)[1], len(activations2),
        'Number of activation vectors does not match num_layers.')
    layer_grams1 = [self._generate_gram_matrix(x) for x in activations1]
    layer_grams1 = tf.stack(layer_grams1, 0)  #(n_layers, n_examples ** 2)
    layer_grams2 = [self._generate_gram_matrix(x) for x in activations2]
    layer_grams2 = tf.stack(layer_grams2, 0)
    self.hsic_accumulator.assign_add(
        tf.matmul(layer_grams1, layer_grams2, transpose_b=True))
    self.hsic_accumulator_model1.assign_add(
        tf.einsum('ij,ij->i', layer_grams1, layer_grams1))
    self.hsic_accumulator_model2.assign_add(
        tf.einsum('ij,ij->i', layer_grams2, layer_grams2))
Beispiel #4
0
    def get_saved_values(self,
                         attr_name,
                         broadcast_to_input_shape=False,
                         unit_mask=None):
        """Returns the saved values of the most recent forward pass.

    All of 'mean', 'l2norm', 'mrs' and 'rs' have the same shape and here we
    define common getter operation for them.
    Args:
      attr_name: str, 'mean', 'l2norm', 'mrs' or 'rs'.
      broadcast_to_input_shape: bool, if True the values are broadcast to the
        input shape.
      unit_mask: Tensor, same shape as `self._<attr_name>` and it is multiplied
        with the saved tensor before broadcast operation.

    Returns:
      Tensor or None: None if there is no saved value exists.
    Raises:
      ValueError: when the `attr_name` is not valid.
    """
        if attr_name not in TaylorScorer._saved_values_set:
            raise ValueError('attr_name: %s is not valid. ' % attr_name)
        attr_name = '_' + attr_name
        if getattr(self, attr_name) is None:
            return None

        val, _ = getattr(self, attr_name)
        # TODO maybe get rid of this part. It doesn't belong here.
        if unit_mask is None:
            possibly_masked_mean = val
        else:
            tf.assert_equal(val.shape, unit_mask.shape)
            possibly_masked_mean = tf.multiply(val, unit_mask)
        if broadcast_to_input_shape:
            return tf.broadcast_to(possibly_masked_mean, self.xshape)
        else:
            return possibly_masked_mean
Beispiel #5
0
    def _sample_n(self, n, seed=None, conditional_input=None, training=False):
        """Samples from the distribution, with optional conditional input.
        Args:
          n: `int`, number of samples desired.
          seed: `int`, seed for RNG. Setting a random seed enforces reproducability
            of the samples between sessions (not within a single session).
          conditional_input: `Tensor` on which to condition the distribution (e.g.
            class labels), or `None`.
          training: `bool` or `None`. If `bool`, it controls the dropout layer,
            where `True` implies dropout is active. If `None`, it defers to Keras'
            handling of train/eval status.
        Returns:
          samples: a `Tensor` of shape `[n, height, width, num_channels]`.
        """
        if conditional_input is not None:
            conditional_input = tf.convert_to_tensor(conditional_input,
                                                     dtype=self.dtype)
            conditional_event_rank = tensorshape_util.rank(
                self.conditional_shape)
            conditional_input_shape = prefer_static.shape(conditional_input)
            conditional_sample_rank = prefer_static.rank(
                conditional_input) - conditional_event_rank

            # If `conditional_input` has no sample dimensions, prepend a sample
            # dimension
            if conditional_sample_rank == 0:
                conditional_input = conditional_input[tf.newaxis, ...]
                conditional_sample_rank = 1

            # Assert that the conditional event shape in the `PixelCnnNetwork` is the
            # same as that implied by `conditional_input`.
            conditional_event_shape = conditional_input_shape[
                conditional_sample_rank:]
            with tf.control_dependencies([
                    tf.assert_equal(self.conditional_shape,
                                    conditional_event_shape)
            ]):
                conditional_sample_shape = conditional_input_shape[:
                                                                   conditional_sample_rank]
                repeat = n // prefer_static.reduce_prod(
                    conditional_sample_shape)
                h = tf.reshape(
                    conditional_input,
                    prefer_static.concat([(-1, ), self.conditional_shape],
                                         axis=0))
                h = tf.tile(
                    h,
                    prefer_static.pad([repeat],
                                      paddings=[[0, conditional_event_rank]],
                                      constant_values=1))

        samples_0 = tf.random.uniform(prefer_static.concat(
            [(n, ), self.event_shape], axis=0),
                                      minval=-1.,
                                      maxval=1.,
                                      dtype=self.dtype,
                                      seed=seed)
        inputs = samples_0 if conditional_input is None else [samples_0, h]
        params_0 = self.network(inputs, training=training)
        samples_0 = self._sample_channels(*params_0, seed=seed)

        image_height, image_width, _ = tensorshape_util.as_list(
            self.event_shape)

        def loop_body(index, samples):
            """Loop for iterative pixel sampling.
            Args:
            index: 0D `Tensor` of type `int32`. Index of the current pixel.
            samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
              pixel `[index]`, with dimensions `[batch_size, height, width,
              num_channels]`.
            Returns:
            samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
              and including pixel `[index]`, with dimensions `[batch_size, height,
              width, num_channels]`.
            """
            inputs = samples if conditional_input is None else [samples, h]
            params = self.network(inputs, training=training)
            samples_new = self._sample_channels(*params, seed=seed)

            # Update the current pixel
            samples = tf.transpose(samples, [1, 2, 3, 0])
            samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
            row, col = index // image_width, index % image_width
            updates = samples_new[row, col, ...][tf.newaxis, ...]
            samples = tf.tensor_scatter_nd_update(samples, [[row, col]],
                                                  updates)
            samples = tf.transpose(samples, [3, 0, 1, 2])

            return index + 1, samples

        index0 = tf.zeros([], dtype=tf.int32)

        # Construct the while loop for sampling
        total_pixels = image_height * image_width
        loop_cond = lambda ind, _: tf.less(ind, total_pixels)  # noqa: E731
        init_vars = (index0, samples_0)
        _, samples = tf.while_loop(loop_cond,
                                   loop_body,
                                   init_vars,
                                   parallel_iterations=1)

        transformed_samples = (self._low + 0.5 * (self._high - self._low) *
                               (samples + 1.))
        return tf.round(transformed_samples)
        def collater_fn(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
            """Collater function for mention classification task. See BaseTask."""

            new_batch = {}

            # Sample mentions uniformly across batch
            mention_mask = tf.reshape(batch['mention_mask'],
                                      [n_candidate_mentions])
            sample_scores = tf.random.uniform(
                shape=[n_candidate_mentions]) * tf.cast(
                    mention_mask, tf.float32)

            mention_target_indices = tf.reshape(
                batch['mention_target_indices'], [bsz])

            # We want to make sure that the target mentions always have a priority
            # when we sample `max_batch_mentions` out of all available mentions.
            # Additionally, we want these target mentions to be in the same order as
            # their samples. In other words, we want the first sampled mention to be
            # target mention from the first sample, the second sampled mention to be
            # tagret mention from the second sample, etc.

            # Positions of target mentions in the flat array
            mention_target_indices_flat = (tf.cast(
                tf.range(bsz) * max_mentions_per_sample,
                mention_target_indices.dtype) + mention_target_indices)
            # These extra score makes sure that target mentions have a priority and
            # will be sampled in the correct order.
            mention_target_extra_score_flat = tf.cast(
                tf.reverse(tf.range(bsz) + 1, axis=[0]), tf.float32)
            # The model assumes that there is only ONE target mention per sample.
            # Moreover,we want to select them according to the order of samples:
            # target mention from sample 0, target mention from sample 1, ..., etc.
            sample_scores = tf.tensor_scatter_nd_add(
                sample_scores, tf.expand_dims(mention_target_indices_flat, 1),
                mention_target_extra_score_flat)

            sampled_indices = tf.math.top_k(sample_scores,
                                            max_batch_mentions,
                                            sorted=True).indices

            # Double-check target mentions were selected correctly.
            assert_op = tf.assert_equal(
                sampled_indices[:bsz],
                tf.cast(mention_target_indices_flat, sampled_indices.dtype))

            with tf.control_dependencies([assert_op]):
                mention_mask = tf.gather(mention_mask, sampled_indices)
            dtype = batch['mention_start_positions'].dtype
            mention_start_positions = tf.gather(
                tf.reshape(batch['mention_start_positions'],
                           [n_candidate_mentions]), sampled_indices)
            mention_end_positions = tf.gather(
                tf.reshape(batch['mention_end_positions'],
                           [n_candidate_mentions]), sampled_indices)

            mention_batch_positions = tf.gather(
                tf.repeat(tf.range(bsz, dtype=dtype), max_mentions_per_sample),
                sampled_indices)

            new_batch['text_ids'] = batch['text_ids']
            new_batch['text_mask'] = batch['text_mask']
            new_batch['classifier_target'] = tf.reshape(
                batch['target'], [bsz, config.max_num_labels_per_sample])
            new_batch['classifier_target_mask'] = tf.reshape(
                batch['target_mask'], [bsz, config.max_num_labels_per_sample])

            new_batch['mention_mask'] = mention_mask
            new_batch['mention_start_positions'] = mention_start_positions
            new_batch['mention_end_positions'] = mention_end_positions
            new_batch['mention_batch_positions'] = mention_batch_positions
            new_batch['mention_target_indices'] = tf.range(bsz, dtype=dtype)

            if config.get('max_length_with_entity_tokens') is not None:
                batch_with_entity_tokens = mention_preprocess_utils.add_entity_tokens(
                    text_ids=new_batch['text_ids'],
                    text_mask=new_batch['text_mask'],
                    mention_mask=new_batch['mention_mask'],
                    mention_batch_positions=new_batch[
                        'mention_batch_positions'],
                    mention_start_positions=new_batch[
                        'mention_start_positions'],
                    mention_end_positions=new_batch['mention_end_positions'],
                    new_length=config.max_length_with_entity_tokens,
                )
                # Update `text_ids`, `text_mask`, `mention_mask`, `mention_*_positions`
                new_batch.update(batch_with_entity_tokens)
                # Update `max_length`
                max_length = config.max_length_with_entity_tokens
            else:
                max_length = encoder_config.max_length

            new_batch['mention_target_batch_positions'] = tf.gather(
                new_batch['mention_batch_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_start_positions'] = tf.gather(
                new_batch['mention_start_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_end_positions'] = tf.gather(
                new_batch['mention_end_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_weights'] = tf.ones(bsz)

            # Fake IDs -- some encoders (ReadTwice) need them
            new_batch['mention_target_ids'] = tf.zeros(bsz)

            new_batch['segment_ids'] = tf.zeros_like(new_batch['text_ids'])

            position_ids = tf.expand_dims(tf.range(max_length, dtype=dtype),
                                          axis=0)
            new_batch['position_ids'] = tf.tile(position_ids, (bsz, 1))

            return new_batch
Beispiel #7
0
    def _sample_paths(self,
                      times,
                      num_samples,
                      random_type,
                      skip,
                      seed,
                      normal_draws=None,
                      times_grid=None,
                      validate_args=False):
        """Returns a sample of paths from the process."""
        # Note: all the notations below are the same as in [1].
        num_requested_times = tf.shape(times)[0]
        params = [self._mean_reversion, self._volatility]
        if self._corr_matrix is not None:
            params = params + [self._corr_matrix]
        times, keep_mask = _prepare_grid(times, times_grid, *params)
        # Add zeros as a starting location
        dt = times[1:] - times[:-1]
        if dt.shape.is_fully_defined():
            steps_num = dt.shape.as_list()[-1]
        else:
            steps_num = tf.shape(dt)[-1]
            # TODO(b/148133811): Re-enable Sobol test when TF 2.2 is released.
            if random_type == random.RandomType.SOBOL:
                raise ValueError(
                    'Sobol sequence for Euler sampling is temporarily '
                    'unsupported when `time_step` or `times` have a '
                    'non-constant value')
        if normal_draws is None:
            # In order to use low-discrepancy random_type we need to generate the
            # sequence of independent random normals upfront. We also precompute
            # random numbers for stateless random type in order to ensure independent
            # samples for multiple function calls whith different seeds.
            if random_type in (random.RandomType.SOBOL,
                               random.RandomType.HALTON,
                               random.RandomType.HALTON_RANDOMIZED,
                               random.RandomType.STATELESS,
                               random.RandomType.STATELESS_ANTITHETIC):
                normal_draws = utils.generate_mc_normal_draws(
                    num_normal_draws=self._dim,
                    num_time_steps=steps_num,
                    num_sample_paths=num_samples,
                    random_type=random_type,
                    seed=seed,
                    dtype=self._dtype,
                    skip=skip)
            else:
                normal_draws = None
        else:
            if validate_args:
                draws_times = tf.shape(normal_draws)[0]
                asserts = tf.assert_equal(
                    draws_times,
                    tf.shape(times)[0] - 1,  # We have added `0` to `times`
                    message='`tf.shape(normal_draws)[1]` should be equal to the '
                    'number of all `times` plus the number of all jumps of '
                    'the piecewise constant parameters.')
                with tf.compat.v1.control_dependencies([asserts]):
                    normal_draws = tf.identity(normal_draws)
        # The below is OK because we support exact discretization with piecewise
        # constant mr and vol.
        mean_reversion = self._mean_reversion(times)
        volatility = self._volatility(times)
        if self._corr_matrix is not None:
            corr_matrix = _get_parameters(times + tf.math.reduce_min(dt) / 2,
                                          self._corr_matrix)[0]
            corr_matrix_root = tf.linalg.cholesky(corr_matrix)
        else:
            corr_matrix_root = None

        exp_x_t = self._conditional_mean_x(times, mean_reversion, volatility)
        var_x_t = self._conditional_variance_x(times, mean_reversion,
                                               volatility)
        if self._dim == 1:
            mean_reversion = tf.expand_dims(mean_reversion, axis=0)

        cond_fn = lambda i, *args: i < tf.size(dt)

        def body_fn(i, written_count, current_x, rate_paths):
            """Simulate hull-white process to the next time point."""
            if normal_draws is None:
                normals = random.mv_normal_sample(
                    (num_samples, ),
                    mean=tf.zeros((self._dim, ), dtype=mean_reversion.dtype),
                    random_type=random_type,
                    seed=seed)
            else:
                normals = normal_draws[i]

            if corr_matrix_root is not None:
                normals = tf.linalg.matvec(corr_matrix_root[i], normals)
            vol_x_t = tf.math.sqrt(tf.nn.relu(tf.transpose(var_x_t)[i]))
            # If numerically `vol_x_t == 0`, the gradient of `vol_x_t` becomes `NaN`.
            # To prevent this, we explicitly set `vol_x_t` to zero tensor at zero
            # values so that the gradient is set to zero at this values.
            vol_x_t = tf.where(vol_x_t > 0.0, vol_x_t, 0.0)
            next_x = (
                tf.math.exp(-tf.transpose(mean_reversion)[i + 1] * dt[i]) *
                current_x + tf.transpose(exp_x_t)[i] + vol_x_t * normals)
            f_0_t = self._instant_forward_rate_fn(times[i + 1])

            # Update `rate_paths`
            rate_paths = utils.maybe_update_along_axis(
                tensor=rate_paths,
                do_update=keep_mask[i + 1],
                ind=written_count,
                axis=1,
                new_tensor=tf.expand_dims(next_x, axis=1) + f_0_t)
            written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
            return (i + 1, written_count, next_x, rate_paths)

        rate_paths = tf.zeros((num_samples, num_requested_times, self._dim),
                              dtype=self._dtype)
        # Include initial state, if necessary
        f0_t = self._instant_forward_rate_fn(times[0])
        rate_paths = utils.maybe_update_along_axis(tensor=rate_paths,
                                                   do_update=keep_mask[0],
                                                   ind=0,
                                                   axis=1,
                                                   new_tensor=f0_t)
        written_count = tf.cast(keep_mask[0], dtype=tf.int32)
        initial_x = tf.zeros((num_samples, self._dim), dtype=self._dtype)
        # TODO(b/157232803): Use tf.cumsum instead?
        _, _, _, rate_paths = tf.while_loop(
            cond_fn, body_fn, (0, written_count, initial_x, rate_paths))

        return rate_paths
Beispiel #8
0
def compute_alignment_loss(embs,
                           batch_size,
                           steps=None,
                           seq_lens=None,
                           stochastic_matching=False,
                           normalize_embeddings=False,
                           loss_type='classification',
                           similarity_type='l2',
                           num_cycles=20,
                           cycle_length=2,
                           temperature=0.1,
                           label_smoothing=0.1,
                           variance_lambda=0.001,
                           huber_delta=0.1,
                           normalize_indices=True):
    """Computes alignment loss between sequences of embeddings.

  This function is a wrapper around different variants of the alignment loss
  described deterministic_alignment.py and stochastic_alignment.py files. The
  structure of the library is as follows:
  i) loss_fns.py - Defines the different loss functions.
  ii) deterministic_alignment.py - Performs the alignment between sequences by
  deterministically sampling all steps of the sequences.
  iii) stochastic_alignment.py - Performs the alignment between sequences by
  stochasticallty sub-sampling a fixed number of steps from the sequences.

  There are four major hparams that need to be tuned while applying the loss:
  i) Should the loss be applied with L2 normalization on the embeddings or
  without it?
  ii) Should we perform stochastic alignment of sequences? This means should we
  use all the steps of the embedding or only choose a random subset for
  alignment?
  iii) Should we apply cycle-consistency constraints using a classification loss
  or a regression loss? (Section 3 in paper)
  iv) Should the similarity metric be based on an L2 distance or cosine
  similarity?

  Other hparams that can be used to control how hard/soft we want the alignment
  between different sequences to be:
  i) temperature (all losses)
  ii) label_smoothing (classification)
  iii) variance_lambda (regression_mse_var)
  iv) huber_delta (regression_huber)
  Each of these params are used in their respective loss types (in brackets) and
  allow the application of the cycle-consistency constraints in a controllable
  manner but they do so in very different ways. Please refer to paper for more
  details.

  The default hparams work well for frame embeddings of videos of humans
  performing actions. Other datasets might need different values of hparams.


  Args:
    embs: Tensor, sequential embeddings of the shape [N, T, D] where N is the
      batch size, T is the number of timesteps in the sequence, D is the size of
      the embeddings.
    batch_size: Integer, Size of the batch.
    steps: Tensor, step indices/frame indices of the embeddings of the shape
      [N, T] where N is the batch size, T is the number of the timesteps.
      If this is set to None, then we assume that the sampling was done in a
      uniform way and use tf.range(num_steps) as the steps.
    seq_lens: Tensor, Lengths of the sequences from which the sampling was done.
      This can provide additional information to the alignment loss. This is
      different from num_steps which is just the number of steps that have been
      sampled from the entire sequence.
    stochastic_matching: Boolean, Should the used for matching be sampled
      stochastically or deterministically? Deterministic is better for TPU.
      Stochastic is better for adding more randomness to the training process
      and handling long sequences.
    normalize_embeddings: Boolean, Should the embeddings be normalized or not?
      Default is to use raw embeddings. Be careful if you are normalizing the
      embeddings before calling this function.
    loss_type: String, This specifies the kind of loss function to use.
      Currently supported loss functions: classification, regression_mse,
      regression_mse_var, regression_huber.
    similarity_type: String, Currently supported similarity metrics: l2, cosine.
    num_cycles: Integer, number of cycles to match while aligning
      stochastically.  Only used in the stochastic version.
    cycle_length: Integer, Lengths of the cycle to use for matching. Only used
      in the stochastic version. By default, this is set to 2.
    temperature: Float, temperature scaling used to scale the similarity
      distributions calculated using the softmax function.
    label_smoothing: Float, Label smoothing argument used in
      tf.keras.losses.categorical_crossentropy function and described in this
      paper https://arxiv.org/pdf/1701.06548.pdf.
    variance_lambda: Float, Weight of the variance of the similarity
      predictions while cycling back. If this is high then the low variance
      similarities are preferred by the loss while making this term low results
      in high variance of the similarities (more uniform/random matching).
    huber_delta: float, Huber delta described in tf.keras.losses.huber_loss.
    normalize_indices: Boolean, If True, normalizes indices by sequence lengths.
      Useful for ensuring numerical instabilities doesn't arise as sequence
      indices can be large numbers.

  Returns:
    loss: Tensor, Scalar loss tensor that imposes the chosen variant of the
      cycle-consistency loss.
  """

    ##############################################################################
    # Checking inputs and setting defaults.
    ##############################################################################

    # Get the number of timestemps in the sequence embeddings.
    num_steps = tf.shape(embs)[1]

    # If steps has not been provided assume sampling has been done uniformly.
    if steps is None:
        steps = tf.tile(tf.expand_dims(tf.range(num_steps), axis=0),
                        [batch_size, 1])

    # If seq_lens has not been provided assume is equal to the size of the
    # time axis in the emebeddings.
    if seq_lens is None:
        seq_lens = tf.tile(tf.expand_dims(num_steps, 0), [batch_size])

    if not tf.executing_eagerly():
        # Check if batch size embs is consistent with provided batch size.
        with tf.control_dependencies(
            [tf.assert_equal(batch_size,
                             tf.shape(embs)[0])]):
            embs = tf.identity(embs)
        # Check if number of timesteps in embs is consistent with provided steps.
        with tf.control_dependencies([
                tf.assert_equal(num_steps,
                                tf.shape(steps)[1]),
                tf.assert_equal(batch_size,
                                tf.shape(steps)[0])
        ]):
            steps = tf.identity(steps)
    else:
        tf.assert_equal(batch_size, tf.shape(steps)[0])
        tf.assert_equal(num_steps, tf.shape(steps)[1])
        tf.assert_equal(batch_size, tf.shape(embs)[0])

    ##############################################################################
    # Perform alignment and return loss.
    ##############################################################################

    if normalize_embeddings:
        embs = tf.nn.l2_normalize(embs, axis=-1)

    if stochastic_matching:
        loss = compute_stochastic_alignment_loss(
            embs=embs,
            steps=steps,
            seq_lens=seq_lens,
            num_steps=num_steps,
            batch_size=batch_size,
            loss_type=loss_type,
            similarity_type=similarity_type,
            num_cycles=num_cycles,
            cycle_length=cycle_length,
            temperature=temperature,
            label_smoothing=label_smoothing,
            variance_lambda=variance_lambda,
            huber_delta=huber_delta,
            normalize_indices=normalize_indices)
    else:
        loss = compute_deterministic_alignment_loss(
            embs=embs,
            steps=steps,
            seq_lens=seq_lens,
            num_steps=num_steps,
            batch_size=batch_size,
            loss_type=loss_type,
            similarity_type=similarity_type,
            temperature=temperature,
            label_smoothing=label_smoothing,
            variance_lambda=variance_lambda,
            huber_delta=huber_delta,
            normalize_indices=normalize_indices)

    return loss
Beispiel #9
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init:
      axis_ = tf.get_static_value(self._axis)
      if axis_ is not None and axis_ < 0:
        raise ValueError('Axis should be positive, %d was given' % axis_)
      if axis_ is None:
        assertions.append(tf.assert_greater_equal(axis_, 0))

      all_event_shapes = [d.event_shape for d in self._distributions]
      if all(tensorshape_util.is_fully_defined(event_shape)
             for event_shape in all_event_shapes):
        if all_event_shapes[1:] != all_event_shapes[:-1]:
          raise ValueError('Distributions must have the same `event_shape`;'
                           'found: {}' % all_event_shapes)

      all_batch_shapes = [d.batch_shape for d in self._distributions]
      if all(tensorshape_util.is_fully_defined(batch_shape)
             for batch_shape in all_batch_shapes):
        batch_shape = all_batch_shapes[0].as_list()
        batch_shape[self._axis] = 1
        for b in all_batch_shapes[1:]:
          b = b.as_list()
          if len(batch_shape) != len(b):
            raise ValueError('Incompatible batch shape % s with %s' %
                             (batch_shape, b))
          b[self._axis] = 1
          tf.broadcast_static_shape(
              tensorshape_util.constant_value_as_shape(batch_shape),
              tensorshape_util.constant_value_as_shape(b))

    if not self.validate_args:
      return []

    if self.validate_args:
      # Validate that event shapes all match.
      all_event_shapes = [d.event_shape for d in self._distributions]
      if not all(tensorshape_util.is_fully_defined(event_shape)
                 for event_shape in all_event_shapes):
        all_event_shape_tensors = [d.event_shape_tensor() for
                                   d in self._distributions]
        def _get_shapes(static_shape, dynamic_shape):
          if tensorshape_util.is_fully_defined(static_shape):
            return static_shape
          else:
            return dynamic_shape
        event_shapes = tf.nest.map_structure(_get_shapes,
                                             all_event_shapes,
                                             all_event_shape_tensors)
        event_shapes = tf.nest.flatten(event_shapes)
        assertions.extend(
            assert_util.assert_equal(
                e1, e2, message='Distributions should have same event shapes.')
            for e1, e2 in zip(event_shapes[1:], event_shapes[:-1]))

      # Validate that batch shapes are broadcastable and concatenable along
      # the specified axis.
      if not all(tensorshape_util.is_fully_defined(d.batch_shape)
                 for d in self._distributions):
        for i, d in enumerate(self._distributions[:-1]):
          assertions.append(tf.assert_equal(
              tf.size(d.batch_shape_tensor()),
              tf.size(self._distributions[i+1].batch_shape_tensor())))

        batch_shape_tensors = [
            ps.tensor_scatter_nd_update(d.batch_shape_tensor(), updates=1,
                                        indices=[self._axis])
            for d in self._distributions
        ]
        assertions.append(
            functools.reduce(tf.broadcast_dynamic_shape,
                             batch_shape_tensors[1:],
                             batch_shape_tensors[:-1]))
    return assertions
def sample(dim: int,
           drift_fn: Callable[..., types.RealTensor],
           volatility_fn: Callable[..., types.RealTensor],
           times: types.RealTensor,
           time_step: Optional[types.RealTensor] = None,
           num_time_steps: Optional[types.IntTensor] = None,
           num_samples: types.IntTensor = 1,
           initial_state: Optional[types.RealTensor] = None,
           random_type: Optional[random.RandomType] = None,
           seed: Optional[types.IntTensor] = None,
           swap_memory: bool = True,
           skip: types.IntTensor = 0,
           precompute_normal_draws: bool = True,
           times_grid: Optional[types.RealTensor] = None,
           normal_draws: Optional[types.RealTensor] = None,
           watch_params: Optional[List[types.RealTensor]] = None,
           validate_args: bool = False,
           tolerance: Optional[types.RealTensor] = None,
           dtype: Optional[tf.DType] = None,
           name: Optional[str] = None) -> types.RealTensor:
    """Returns a sample paths from the process using Euler method.

  For an Ito process,

  ```
    dX = a(t, X_t) dt + b(t, X_t) dW_t
    X(t=0) = x0
  ```
  with given drift `a` and volatility `b` functions Euler method generates a
  sequence {X_n} as

  ```
  X_{n+1} = X_n + a(t_n, X_n) dt + b(t_n, X_n) (N(0, t_{n+1}) - N(0, t_n)),
  X_0 = x0
  ```
  where `dt = t_{n+1} - t_n` and `N` is a sample from the Normal distribution.
  See [1] for details.

  #### Example
  Sampling from 2-dimensional Ito process of the form:

  ```none
  dX_1 = mu_1 * sqrt(t) dt + s11 * dW_1 + s12 * dW_2
  dX_2 = mu_2 * sqrt(t) dt + s21 * dW_1 + s22 * dW_2
  ```

  ```python
  import tensorflow as tf
  import tf_quant_finance as tff

  import numpy as np

  mu = np.array([0.2, 0.7])
  s = np.array([[0.3, 0.1], [0.1, 0.3]])
  num_samples = 10000
  dim = 2
  dtype = tf.float64

  # Define drift and volatility functions
  def drift_fn(t, x):
    return mu * tf.sqrt(t) * tf.ones([num_samples, dim], dtype=dtype)

  def vol_fn(t, x):
    return s * tf.ones([num_samples, dim, dim], dtype=dtype)

  # Set starting location
  x0 = np.array([0.1, -1.1])
  # Sample `num_samples` paths at specified `times` using Euler scheme.
  times = [0.1, 1.0, 2.0]
  paths = tff.models.euler_sampling.sample(
            dim=dim,
            drift_fn=drift_fn,
            volatility_fn=vol_fn,
            times=times,
            num_samples=num_samples,
            initial_state=x0,
            time_step=0.01,
            seed=42,
            dtype=dtype)
  # Expected: paths.shape = [10000, 3, 2]
  ```

  #### References
  [1]: Wikipedia. Euler-Maruyama method:
  https://en.wikipedia.org/wiki/Euler-Maruyama_method

  Args:
    dim: Python int greater than or equal to 1. The dimension of the Ito
      Process.
    drift_fn: A Python callable to compute the drift of the process. The
      callable should accept two real `Tensor` arguments of the same dtype.
      The first argument is the scalar time t, the second argument is the
      value of Ito process X - tensor of shape
      `batch_shape + [num_samples, dim]`. `batch_shape` is the shape of the
      independent stochastic processes being modelled and is inferred from the
      initial state `x0`.
      The result is value of drift a(t, X). The return value of the callable
      is a real `Tensor` of the same dtype as the input arguments and of shape
      `batch_shape + [num_samples, dim]`.
    volatility_fn: A Python callable to compute the volatility of the process.
      The callable should accept two real `Tensor` arguments of the same dtype
      and shape `times_shape`. The first argument is the scalar time t, the
      second argument is the value of Ito process X - tensor of shape
      `batch_shape + [num_samples, dim]`. The result is value of drift b(t, X).
      The return value of the callable is a real `Tensor` of the same dtype as
      the input arguments and of shape `batch_shape + [num_samples, dim, dim]`.
    times: Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: An optional scalar real `Tensor` - maximal distance between
      points in grid in Euler schema.
      Either this or `num_time_steps` should be supplied.
      Default value: `None`.
    num_time_steps: An optional Scalar integer `Tensor` - a total number of time
      steps performed by the algorithm. The maximal distance betwen points in
      grid is bounded by `times[-1] / (num_time_steps - times.shape[0])`.
      Either this or `time_step` should be supplied.
      Default value: `None`.
    num_samples: Positive scalar `int`. The number of paths to draw.
      Default value: 1.
    initial_state: `Tensor` of shape broadcastable with
      `batch_shape + [num_samples, dim]`. The initial state of the process.
      `batch_shape` represents the shape of the independent batches of the
      stochastic process. Note that `batch_shape` is inferred from
      the `initial_state` and hence when sampling is requested for a batch of
      stochastic processes, the shape of `initial_state` should be at least
      `batch_shape + [1, 1]`.
      Default value: None which maps to a zero initial state.
    random_type: Enum value of `RandomType`. The type of (quasi)-random
      number generator to use to generate the paths.
      Default value: None which maps to the standard pseudo-random numbers.
    seed: Seed for the random number generator. The seed is
      only relevant if `random_type` is one of
      `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
        STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
      `HALTON_RANDOMIZED` the seed should be a Python integer. For
      `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
      `Tensor` of shape `[2]`.
      Default value: `None` which means no seed is set.
    swap_memory: A Python bool. Whether GPU-CPU memory swap is enabled for this
      op. See an equivalent flag in `tf.while_loop` documentation for more
      details. Useful when computing a gradient of the op since `tf.while_loop`
      is used to propagate stochastic process in time.
      Default value: True.
    skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
      Halton sequence to skip. Used only when `random_type` is 'SOBOL',
      'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
      Default value: `0`.
    precompute_normal_draws: Python bool. Indicates whether the noise increments
      `N(0, t_{n+1}) - N(0, t_n)` are precomputed. For `HALTON` and `SOBOL`
      random types the increments are always precomputed. While the resulting
      graph consumes more memory, the performance gains might be significant.
      Default value: `True`.
    times_grid: An optional rank 1 `Tensor` representing time discretization
      grid. If `times` are not on the grid, then the nearest points from the
      grid are used. When supplied, `num_time_steps` and `time_step` are
      ignored.
      Default value: `None`, which means that times grid is computed using
      `time_step` and `num_time_steps`.
    normal_draws: A `Tensor` of shape broadcastable with
      `batch_shape + [num_samples, num_time_points, dim]` and the same
      `dtype` as `times`. Represents random normal draws to compute increments
      `N(0, t_{n+1}) - N(0, t_n)`. When supplied, `num_samples` argument is
      ignored and the first dimensions of `normal_draws` is used instead.
      Default value: `None` which means that the draws are generated by the
      algorithm. By default normal_draws for each model in the batch are
      independent.
    watch_params: An optional list of zero-dimensional `Tensor`s of the same
      `dtype` as `initial_state`. If provided, specifies `Tensor`s with respect
      to which the differentiation of the sampling function will happen.
      A more efficient algorithm is used when `watch_params` are specified.
      Note the the function becomes differentiable onlhy wrt to these `Tensor`s
      and the `initial_state`. The gradient wrt any other `Tensor` is set to be
      zero.
    validate_args: Python `bool`. When `True` performs multiple checks:
      * That `times`  are increasing with the minimum increments of the
        specified tolerance.
      * If `normal_draws` are supplied, checks that `normal_draws.shape[1]` is
      equal to `num_time_steps` that is either supplied as an argument or
      computed from `time_step`.
      When `False` invalid dimension may silently render incorrect outputs.
      Default value: `False`.
    tolerance: A non-negative scalar `Tensor` specifying the minimum tolerance
      for discernible times on the time grid. Times that are closer than the
      tolerance are perceived to be the same.
      Default value: `None` which maps to `1-e6` if the for single precision
        `dtype` and `1e-10` for double precision `dtype`.
    dtype: `tf.Dtype`. If supplied the dtype for the input and output `Tensor`s.
      Default value: None which means that the dtype implied by `times` is
      used.
    name: Python string. The name to give this op.
      Default value: `None` which maps to `euler_sample`.

  Returns:
   A real `Tensor` of shape batch_shape_process + [num_samples, k, n] where `k`
     is the size of the `times`, `n` is the dimension of the process.

  Raises:
    ValueError:
      (a) When `times_grid` is not supplied, and neither `num_time_steps` nor
        `time_step` are supplied or if both are supplied.
      (b) If `normal_draws` is supplied and `dim` is mismatched.
    tf.errors.InvalidArgumentError: If `normal_draws` is supplied and
      `num_time_steps` is mismatched.
  """
    name = name or 'euler_sample'
    with tf.name_scope(name):
        times = tf.convert_to_tensor(times, dtype=dtype)
        if dtype is None:
            dtype = times.dtype
        asserts = []
        if tolerance is None:
            tolerance = 1e-10 if dtype == tf.float64 else 1e-6
        tolerance = tf.convert_to_tensor(tolerance, dtype=dtype)
        if validate_args:
            asserts.append(
                tf.assert_greater(
                    times[1:],
                    times[:-1] + tolerance,
                    message='`times` increments should be greater '
                    'than tolerance {0}'.format(tolerance)))
        if initial_state is None:
            initial_state = tf.zeros(dim, dtype=dtype)
        initial_state = tf.convert_to_tensor(initial_state,
                                             dtype=dtype,
                                             name='initial_state')
        batch_shape = tff_utils.get_shape(initial_state)[:-2]
        num_requested_times = tff_utils.get_shape(times)[0]
        # Create a time grid for the Euler scheme.
        if num_time_steps is not None and time_step is not None:
            raise ValueError(
                'When `times_grid` is not supplied only one of either '
                '`num_time_steps` or `time_step` should be defined but not both.'
            )
        if times_grid is None:
            if time_step is None:
                if num_time_steps is None:
                    raise ValueError(
                        'When `times_grid` is not supplied, either `num_time_steps` '
                        'or `time_step` should be defined.')
                num_time_steps = tf.convert_to_tensor(num_time_steps,
                                                      dtype=tf.int32,
                                                      name='num_time_steps')
                time_step = times[-1] / tf.cast(num_time_steps, dtype=dtype)
            else:
                time_step = tf.convert_to_tensor(time_step,
                                                 dtype=dtype,
                                                 name='time_step')
        else:
            times_grid = tf.convert_to_tensor(times_grid,
                                              dtype=dtype,
                                              name='times_grid')
            if validate_args:
                asserts.append(
                    tf.assert_greater(
                        times_grid[1:],
                        times_grid[:-1] + tolerance,
                        message='`times_grid` increments should be greater '
                        'than tolerance {0}'.format(tolerance)))
        times, keep_mask, time_indices = utils.prepare_grid(
            times=times,
            time_step=time_step,
            num_time_steps=num_time_steps,
            times_grid=times_grid,
            tolerance=tolerance,
            dtype=dtype)

        if normal_draws is not None:
            normal_draws = tf.convert_to_tensor(normal_draws,
                                                dtype=dtype,
                                                name='normal_draws')
            # Shape [num_time_points] + batch_shape + [num_samples, dim]
            normal_draws_rank = normal_draws.shape.rank
            perm = tf.concat(
                [[normal_draws_rank - 2],
                 tf.range(normal_draws_rank - 2), [normal_draws_rank - 1]],
                axis=0)
            normal_draws = tf.transpose(normal_draws, perm=perm)
            num_samples = tf.shape(normal_draws)[-2]
            draws_dim = normal_draws.shape[-1]
            if dim != draws_dim:
                raise ValueError(
                    '`dim` should be equal to `normal_draws.shape[2]` but are '
                    '{0} and {1} respectively'.format(dim, draws_dim))
            if validate_args:
                draws_times = tff_utils.get_shape(normal_draws)[0]
                asserts.append(
                    tf.assert_equal(
                        draws_times,
                        tf.shape(keep_mask)[0] - 1,
                        message='`num_time_steps` should be equal to '
                        '`tf.shape(normal_draws)[1]`'))
        if validate_args:
            with tf.control_dependencies(asserts):
                times = tf.identity(times)
        if watch_params is not None:
            watch_params = [
                tf.convert_to_tensor(param, dtype=dtype)
                for param in watch_params
            ]
        return _sample(dim=dim,
                       batch_shape=batch_shape,
                       drift_fn=drift_fn,
                       volatility_fn=volatility_fn,
                       times=times,
                       keep_mask=keep_mask,
                       num_requested_times=num_requested_times,
                       num_samples=num_samples,
                       initial_state=initial_state,
                       random_type=random_type,
                       seed=seed,
                       swap_memory=swap_memory,
                       skip=skip,
                       precompute_normal_draws=precompute_normal_draws,
                       normal_draws=normal_draws,
                       watch_params=watch_params,
                       time_indices=time_indices,
                       dtype=dtype)
Beispiel #11
0
def sample(dim,
           drift_fn,
           volatility_fn,
           times,
           time_step=None,
           num_time_steps=None,
           num_samples=1,
           initial_state=None,
           random_type=None,
           seed=None,
           swap_memory=True,
           skip=0,
           precompute_normal_draws=True,
           times_grid=None,
           normal_draws=None,
           watch_params=None,
           validate_args=False,
           dtype=None,
           name=None):
  """Returns a sample paths from the process using Euler method.

  For an Ito process,

  ```
    dX = a(t, X_t) dt + b(t, X_t) dW_t
  ```
  with given drift `a` and volatility `b` functions Euler method generates a
  sequence {X_n} as

  ```
  X_{n+1} = X_n + a(t_n, X_n) dt + b(t_n, X_n) (N(0, t_{n+1}) - N(0, t_n)),
  ```
  where `dt = t_{n+1} - t_n` and `N` is a sample from the Normal distribution.
  See [1] for details.

  #### References
  [1]: Wikipedia. Euler-Maruyama method:
  https://en.wikipedia.org/wiki/Euler-Maruyama_method

  Args:
    dim: Python int greater than or equal to 1. The dimension of the Ito
      Process.
    drift_fn: A Python callable to compute the drift of the process. The
      callable should accept two real `Tensor` arguments of the same dtype.
      The first argument is the scalar time t, the second argument is the
      value of Ito process X - tensor of shape `batch_shape + [dim]`.
      The result is value of drift a(t, X). The return value of the callable
      is a real `Tensor` of the same dtype as the input arguments and of shape
      `batch_shape + [dim]`.
    volatility_fn: A Python callable to compute the volatility of the process.
      The callable should accept two real `Tensor` arguments of the same dtype
      and shape `times_shape`. The first argument is the scalar time t, the
      second argument is the value of Ito process X - tensor of shape
      `batch_shape + [dim]`. The result is value of drift b(t, X). The return
      value of the callable is a real `Tensor` of the same dtype as the input
      arguments and of shape `batch_shape + [dim, dim]`.
    times: Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: An optional scalar real `Tensor` - maximal distance between
      points in grid in Euler schema.
      Either this or `num_time_steps` should be supplied.
      Default value: `None`.
    num_time_steps: An optional Scalar integer `Tensor` - a total number of time
      steps performed by the algorithm. The maximal distance betwen points in
      grid is bounded by `times[-1] / (num_time_steps - times.shape[0])`.
      Either this or `time_step` should be supplied.
      Default value: `None`.
    num_samples: Positive scalar `int`. The number of paths to draw.
      Default value: 1.
    initial_state: `Tensor` of shape `[dim]`. The initial state of the
      process.
      Default value: None which maps to a zero initial state.
    random_type: Enum value of `RandomType`. The type of (quasi)-random
      number generator to use to generate the paths.
      Default value: None which maps to the standard pseudo-random numbers.
    seed: Seed for the random number generator. The seed is
      only relevant if `random_type` is one of
      `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
        STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
      `HALTON_RANDOMIZED` the seed should be a Python integer. For
      `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
      `Tensor` of shape `[2]`.
      Default value: `None` which means no seed is set.
    swap_memory: A Python bool. Whether GPU-CPU memory swap is enabled for this
      op. See an equivalent flag in `tf.while_loop` documentation for more
      details. Useful when computing a gradient of the op since `tf.while_loop`
      is used to propagate stochastic process in time.
      Default value: True.
    skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
      Halton sequence to skip. Used only when `random_type` is 'SOBOL',
      'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
      Default value: `0`.
    precompute_normal_draws: Python bool. Indicates whether the noise increments
      `N(0, t_{n+1}) - N(0, t_n)` are precomputed. For `HALTON` and `SOBOL`
      random types the increments are always precomputed. While the resulting
      graph consumes more memory, the performance gains might be significant.
      Default value: `True`.
    times_grid: An optional rank 1 `Tensor` representing time discretization
      grid. If `times` are not on the grid, then the nearest points from the
      grid are used. When supplied, `num_time_steps` and `time_step` are
      ignored.
      Default value: `None`, which means that times grid is computed using
      `time_step` and `num_time_steps`.
    normal_draws: A `Tensor` of shape `[num_samples, num_time_points, dim]`
      and the same `dtype` as `times`. Represents random normal draws to compute
      increments `N(0, t_{n+1}) - N(0, t_n)`. When supplied, `num_samples`
      argument is ignored and the first dimensions of `normal_draws` is used
      instead.
      Default value: `None` which means that the draws are generated by the
      algorithm.
    watch_params: An optional list of zero-dimensional `Tensor`s of the same
      `dtype` as `initial_state`. If provided, specifies `Tensor`s with respect
      to which the differentiation of the sampling function will happen.
      A more efficient algorithm is used when `watch_params` are specified.
      Note the the function becomes differentiable onlhy wrt to these `Tensor`s
      and the `initial_state`. The gradient wrt any other `Tensor` is set to be
      zero.
    validate_args: Python `bool`. When `True` and `normal_draws` are supplied,
      checks that `tf.shape(normal_draws)[1]` is equal to `num_time_steps` that
      is either supplied as an argument or computed from `time_step`.
      When `False` invalid dimension may silently render incorrect outputs.
      Default value: `False`.
    dtype: `tf.Dtype`. If supplied the dtype for the input and output `Tensor`s.
      Default value: None which means that the dtype implied by `times` is
      used.
    name: Python string. The name to give this op.
      Default value: `None` which maps to `euler_sample`.

  Returns:
   A real `Tensor` of shape [num_samples, k, n] where `k` is the size of the
      `times`, `n` is the dimension of the process.

  Raises:
    ValueError:
      (a) When `times_grid` is not supplied, and neither `num_time_steps` nor
        `time_step` are supplied or if both are supplied.
      (b) If `normal_draws` is supplied and `dim` is mismatched.
    tf.errors.InvalidArgumentError: If `normal_draws` is supplied and
      `num_time_steps` is mismatched.
  """
  name = name or 'euler_sample'
  with tf.name_scope(name):
    times = tf.convert_to_tensor(times, dtype=dtype)
    if dtype is None:
      dtype = times.dtype
    if initial_state is None:
      initial_state = tf.zeros(dim, dtype=dtype)
    initial_state = tf.convert_to_tensor(initial_state, dtype=dtype,
                                         name='initial_state')
    num_requested_times = tf.shape(times)[0]
    # Create a time grid for the Euler scheme.
    if num_time_steps is not None and time_step is not None:
      raise ValueError(
          'When `times_grid` is not supplied only one of either '
          '`num_time_steps` or `time_step` should be defined but not both.')
    if times_grid is None:
      if time_step is None:
        if num_time_steps is None:
          raise ValueError(
              'When `times_grid` is not supplied, either `num_time_steps` '
              'or `time_step` should be defined.')
        num_time_steps = tf.convert_to_tensor(
            num_time_steps, dtype=tf.int32, name='num_time_steps')
        time_step = times[-1] / tf.cast(num_time_steps, dtype=dtype)
      else:
        time_step = tf.convert_to_tensor(time_step, dtype=dtype,
                                         name='time_step')
    else:
      times_grid = tf.convert_to_tensor(times_grid, dtype=dtype,
                                        name='times_grid')
    times, keep_mask, time_indices = utils.prepare_grid(
        times=times,
        time_step=time_step,
        num_time_steps=num_time_steps,
        times_grid=times_grid,
        dtype=dtype)
    if normal_draws is not None:
      normal_draws = tf.convert_to_tensor(normal_draws, dtype=dtype,
                                          name='normal_draws')
      # Shape [num_time_points, num_samples, dim]
      normal_draws = tf.transpose(normal_draws, [1, 0, 2])
      num_samples = tf.shape(normal_draws)[1]
      draws_dim = normal_draws.shape[2]
      if dim != draws_dim:
        raise ValueError(
            '`dim` should be equal to `normal_draws.shape[2]` but are '
            '{0} and {1} respectively'.format(dim, draws_dim))
      if validate_args:
        draws_times = tf.shape(normal_draws)[0]
        asserts = tf.assert_equal(
            draws_times, tf.shape(keep_mask)[0] - 1,
            message='`num_time_steps` should be equal to '
                    '`tf.shape(normal_draws)[1]`')
        with tf.compat.v1.control_dependencies([asserts]):
          normal_draws = tf.identity(normal_draws)
    if watch_params is not None:
      watch_params = [tf.convert_to_tensor(param, dtype=dtype)
                      for param in watch_params]
    return _sample(
        dim=dim,
        drift_fn=drift_fn,
        volatility_fn=volatility_fn,
        times=times,
        keep_mask=keep_mask,
        num_requested_times=num_requested_times,
        num_samples=num_samples,
        initial_state=initial_state,
        random_type=random_type,
        seed=seed,
        swap_memory=swap_memory,
        skip=skip,
        precompute_normal_draws=precompute_normal_draws,
        normal_draws=normal_draws,
        watch_params=watch_params,
        time_indices=time_indices,
        dtype=dtype)