def main(argv):
  del argv

  root = tf.train.Checkpoint()
  # Create a cell and attach to our trackable.
  root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)

  # Wrap the rnn_cell.__call__ function and assign to next_state.
  root.next_state = tf.function(root.rnn_cell.__call__, autograph=False)

  # Wrap the rnn_cell.get_initial_function using a decorator and assign to an
  # attribute with the same name.
  @tf.function(input_signature=[tf.TensorSpec([None, None], tf.float32)])
  def get_initial_state(tensor):
    return root.rnn_cell.get_initial_state(tensor, None, None)

  root.get_initial_state = get_initial_state

  # Construct an initial_state, then call next_state explicitly to trigger a
  # trace for serialization (we need an explicit call, because next_state has
  # not been annotated with an input_signature).
  initial_state = root.get_initial_state(
      tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
  root.next_state(
      tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
      initial_state)

  tf.saved_model.save(root, FLAGS.export_dir)
Beispiel #2
0
def main(argv):
  del argv
  cell = tf.saved_model.load(FLAGS.model_dir)

  initial_state = cell.get_initial_state(
      tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))

  cell.next_state(
      tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
      initial_state)
def main(argv):
  del argv

  sentences = [
      "<S> sentence <E>", "<S> second sentence <E>", "<S> third sentence<E>"
  ]

  model = tf.saved_model.load(FLAGS.model_dir)
  model.train(tf.constant(sentences))
  decoded = model.decode_greedy(
      sequence_length=10, first_word=tf.constant("<S>"))
  _ = [d.numpy() for d in decoded]
def main(argv):
  del argv

  sentences = ["<S> hello there <E>", "<S> how are you doing today <E>"]
  vocab = [
      "<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today"
  ]

  module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128)

  for _ in range(100):
    _ = module.train(tf.constant(sentences))

  # We have to call this function explicitly if we want it exported, because it
  # has no input_signature in the @tf.function decorator.
  decoded = module.decode_greedy(
      sequence_length=10, first_word=tf.constant("<S>"))
  _ = [d.numpy() for d in decoded]

  tf.saved_model.save(module, FLAGS.export_dir)
def wrap_keras_model_for_export(model, batch_input_shape,
                                set_hparams, default_hparams):
  """Wraps `model` for saving and loading as SavedModel."""
  if default_hparams is None: default_hparams = {}
  hparam_keys = list(default_hparams.keys())
  hparam_defaults = tuple(default_hparams.values())
  # The goal is to save a function with this argspec...
  argspec = tf_inspect.FullArgSpec(
      args=(['inputs', 'training'] + hparam_keys),
      defaults=((False,) + hparam_defaults),
      varargs=None, varkw=None,
      kwonlyargs=[], kwonlydefaults=None,
      annotations={})
  # ...and this behavior:
  def call_fn(inputs, training, *args):
    if FLAGS.export_print_hparams:
      args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s='
                                            % (training, hparam_keys[i]))
              for i in range(len(args))]
    kwargs = dict(zip(hparam_keys, args))
    if kwargs: set_hparams(model, **kwargs)
    return model(inputs, training=training)
  # We cannot spell out `args` in def statement for call_fn, but since
  # tf.function uses tf_inspect, we can use tf_decorator to wrap it with
  # the desired argspec.
  def wrapped(*args, **kwargs):  # TODO(arnoegw): Can we use call_fn itself?
    return call_fn(*args, **kwargs)
  traced_call_fn = tf.function(autograph=False)(
      tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
  # Now we need to trigger traces for
  # - training set to Python values True or False (hence two traces),
  # - tensor inputs of the expected nesting, shape and dtype,
  # - tensor-valued kwargs for hparams, with caller-side defaults.
  # Tracing with partially determined shapes requires an input signature,
  # so we initiate tracing from a helper function with only tensor inputs.
  @tf.function(autograph=False)
  def trigger_traces(inputs, **kwargs):
    return tuple(traced_call_fn(inputs, training=training, **kwargs)
                 for training in (True, False))
  inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32)
  hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value))
                  for name, value in default_hparams.items()}
  _ = trigger_traces.get_concrete_function(inputs_spec, **hparams_spec)

  # Assemble the output object.
  obj = tf.train.Checkpoint()
  obj.__call__ = traced_call_fn
  obj.trainable_variables = model.trainable_variables
  obj.variables = model.trainable_variables + model.non_trainable_variables
  obj.regularization_losses = [_get_traced_loss(model, i)
                               for i in range(len(model.losses))]
  return obj
  def _tokenize(self, sentences):
    # Perform a minimalistic text preprocessing by removing punctuation and
    # splitting on spaces.
    normalized_sentences = tf.strings.regex_replace(
        input=sentences, pattern=r"\pP", rewrite="")
    sparse_tokens = tf.string_split(normalized_sentences, " ")

    # Deal with a corner case: there is one empty sentence.
    sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
    # Deal with a corner case: all sentences are empty.
    sparse_tokens = tf.sparse.reset_shape(sparse_tokens)

    return (sparse_tokens.indices, sparse_tokens.values,
            sparse_tokens.dense_shape)
Beispiel #7
0
 def testNotReparameterized(self):
     p = tf.constant([0.2, 0.6])
     _, grad_p = tfp.math.value_and_gradient(
         lambda x: tfd.Bernoulli(probs=x, validate_args=True).sample(100),
         p)
     self.assertIsNone(grad_p)
 def loss():
   rep_id = (tf.distribute.get_replica_context().replica_id_in_sync_group)
   # The last element of last replica's gradient is NaN.
   return tf.cond(
       tf.equal(rep_id, 0), lambda: var * 2.,
       lambda: var * tf.constant([1., float('NaN')]))
Beispiel #9
0
def _random_gamma_cpu(
    shape, concentration, rate=None, log_rate=None, seed=None, log_space=False):
  """Sample using *fast* `tf.random.stateless_gamma`."""
  bad_concentration = (concentration <= 0.) | tf.math.is_nan(concentration)
  safe_concentration = tf.where(
      bad_concentration,
      dtype_util.as_numpy_dtype(concentration.dtype)(100.), concentration)

  if rate is None:
    if log_rate is None:
      rate = tf.ones([], concentration.dtype)
      log_rate = tf.zeros([], concentration.dtype)
    else:
      rate = tf.math.exp(log_rate)

  bad_rate = (rate <= 0.) | tf.math.is_nan(rate)

  if log_space:
    # The underlying gamma sampler uses a recurrence for conc < 1.  When
    # a ~ gamma(conc + 1) and x ~ uniform(0, 1), we have
    #   b = a * x ** (1/conc) ~ gamma(conc)
    # Given that we want log(b) anyway, it's more accurate to just ask the
    # sampler for a (by passing conc + 1 to it in the first place) and
    # do the correction in log-space below.
    orig_safe_concentration = safe_concentration
    safe_concentration = tf.where(
        orig_safe_concentration < 1,
        orig_safe_concentration + 1.,
        orig_safe_concentration)
    seed, conc_fix_seed = samplers.split_seed(seed)
    log_rate = tf.math.log(rate) if log_rate is None else log_rate
    rate = tf.ones_like(log_rate)  # Do the division later in log-space.

  safe_rate = tf.where(
      bad_rate,
      dtype_util.as_numpy_dtype(concentration.dtype)(100.), rate)
  samples = tf.random.stateless_gamma(
      shape=shape, seed=seed, alpha=safe_concentration,
      beta=safe_rate, dtype=concentration.dtype)

  if log_space:
    # Apply the concentration < 1 recurrence here, in log-space.
    samples = tf.math.log(samples)
    conc_fix_unif = samplers.uniform(  # in [0, 1)
        shape, dtype=samples.dtype, seed=conc_fix_seed)

    conc_lt_one_fix = tf.where(
        orig_safe_concentration < 1,
        # Why do we use log1p(-x)? x is in [0, 1) and log(0) = -inf, is bad.
        # x ~ U(0,1) => 1-x ~ U(0,1)
        # But at the boundary, 1-x in (0, 1]. Good.
        # So we can take log(unif(0,1)) safely as log(1-unif(0,1)).
        # log1p(-0) = 0, and log1p(-almost_one) = -not_quite_inf. Good.
        tf.math.log1p(-conc_fix_unif) / orig_safe_concentration,
        tf.zeros((), dtype=samples.dtype))
    samples += (conc_lt_one_fix - log_rate)

  # 0 rate is infinite scale, which implies a +inf sample.
  # `if log_space` clobbered the `rate` variable with 1 a score lines ago.
  return tf.where(
      (log_rate <= -np.inf if log_space else tf.equal(rate, 0.)),
      tf.constant(np.inf, dtype=concentration.dtype),
      tf.where(
          bad_rate | bad_concentration,
          dtype_util.as_numpy_dtype(concentration.dtype)(np.nan), samples))
Beispiel #10
0
def moving_mean_variance_zero_debiased(moving_mean,
                                       moving_variance=None,
                                       zero_debias_count=None,
                                       decay=0.99,
                                       name=None):
    """Compute zero debiased versions of `moving_mean` and `moving_variance`.

  Since `moving_*` variables initialized with `0`s will be biased (toward `0`),
  this function rescales the `moving_mean` and `moving_variance` by the factor
  `1 - decay**zero_debias_count`, i.e., such that the `moving_mean` is unbiased.
  For more details, see [Kingma (2014)][1].

  Args:
    moving_mean: `float`-like `tf.Variable` representing the exponentially
      weighted moving mean. Same shape as `moving_variance` and `value`. This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
    moving_variance: `float`-like `tf.Variable` representing the exponentially
      weighted moving variance. Same shape as `moving_mean` and `value`.  This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
      Default value: `None` (i.e., no moving variance is computed).
    zero_debias_count: `int`-like `tf.Variable` representing the number of times
      this function has been called on streaming input (*not* the number of
      reduced values used in this functions computation). When not `None` (the
      default) the returned values for `moving_mean` and `moving_variance` are
      "zero debiased", i.e., corrected for their presumed all zeros
      intialization. Note: the `tf.Variable`s `moving_mean` and
      `moving_variance` *always* store the unbiased calculation, regardless of
      setting this argument. To obtain unbiased calculations from these
      `tf.Variable`s, see `tfp.stats.moving_mean_variance_zero_debiased`.
      Default value: `None` (i.e., no zero debiasing calculation is made).
    decay: A `float`-like `Tensor` representing the moving mean decay. Typically
      close to `1.`, e.g., `0.99`.
      Default value: `0.99`.
    name: Python `str` prepended to op names created by this function.
      Default value: `None` (i.e., 'moving_mean_variance_zero_debiased').

  Returns:
    moving_mean: The zero debiased exponentially weighted moving mean.
    moving_variance: The zero debiased exponentially weighted moving variance.

  Raises:
    TypeError: if `moving_mean` does not have float type `dtype`.
    TypeError: if `moving_mean`, `moving_variance`, `decay` have different
      `base_dtype`.

  #### References

  [1]: Diederik P. Kingma, Jimmy Ba. Adam: A Method for Stochastic Optimization.
        _arXiv preprint arXiv:1412.6980_, 2014.
       https://arxiv.org/abs/1412.6980
  """
    with tf.name_scope(name or 'zero_debias_count'):
        if zero_debias_count is None:
            raise ValueError()
        base_dtype = dtype_util.base_dtype(moving_mean.dtype)
        if not dtype_util.is_floating(base_dtype):
            raise TypeError(
                'Argument `moving_mean` is not float type (saw {}).'.format(
                    dtype_util.name(moving_mean.dtype)))
        t = tf.cast(zero_debias_count, dtype=base_dtype)
        # Could have used:
        #   bias_correction = -tf.math.expm1(t * tf.math.log(decay))
        # however since we expect decay to be nearly 1, we don't expect this to bear
        # a significant improvement, yet would incur higher computational cost.
        t = tf.where(t > 0., t, tf.constant(np.inf, base_dtype))
        bias_correction = 1. - decay**t
        unbiased_mean = moving_mean / bias_correction
        if moving_variance is None:
            return unbiased_mean
        if base_dtype != dtype_util.base_dtype(moving_variance.dtype):
            raise TypeError(
                'Arguments `moving_mean` and `moving_variance` do not '
                'have same base `dtype` (saw {}, {}).'.format(
                    dtype_util.name(moving_mean.dtype),
                    dtype_util.name(moving_variance.dtype)))
        unbiased_variance = moving_variance / bias_correction
        return unbiased_mean, unbiased_variance
    def testTrainAndServe(self, use_adapt, load_under_strategy):

        with self.coordinator.strategy.scope():

            feature_ps, label_ps = self.define_kpls_for_training(use_adapt)

            def dataset_fn():
                def feature_and_label_gen():
                    while True:
                        features = random.sample(FEATURE_VOCAB, 3)
                        label = ["yes"] if "avenger" in features else ["no"]
                        yield {"features": features, "label": label}

                # The dataset will be created on the coordinator.
                raw_dataset = tf.data.Dataset.from_generator(
                    feature_and_label_gen,
                    output_signature={
                        "features": tf.TensorSpec([3], tf.string),
                        "label": tf.TensorSpec([1], tf.string)
                    }).shuffle(100).batch(32)

                train_dataset = raw_dataset.map(lambda x: (  # pylint: disable=g-long-lambda
                    {
                        "features": feature_ps(x["features"])
                    }, label_ps(x["label"])))
                return train_dataset

            # Create the model. The input needs to be compatible with KPLs.
            model_input = tf.keras.Input(shape=(3, ),
                                         dtype=tf.int64,
                                         name="model_input")

            # input_dim includes a mask token and an oov token.
            emb_output = tf.keras.layers.Embedding(
                input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)(model_input)
            emb_output = tf.reduce_mean(emb_output, axis=1)
            dense_output = tf.keras.layers.Dense(
                units=1, activation="sigmoid")(emb_output)
            model = tf.keras.Model({"features": model_input}, dense_output)

            optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
            accuracy = tf.keras.metrics.Accuracy()

        @tf.function
        def worker_fn(iterator):
            def replica_fn(iterator):
                batch_data, labels = next(iterator)
                with tf.GradientTape() as tape:
                    pred = model(batch_data, training=True)
                    loss = tf.nn.compute_average_loss(
                        tf.keras.losses.BinaryCrossentropy(
                            reduction=tf.keras.losses.Reduction.NONE)(labels,
                                                                      pred))
                    gradients = tape.gradient(loss, model.trainable_variables)

                optimizer.apply_gradients(
                    zip(gradients, model.trainable_variables))

                actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
                accuracy.update_state(labels, actual_pred)

            self.coordinator.strategy.run(replica_fn, args=(iterator, ))

        distributed_dataset = self.coordinator.create_per_worker_dataset(
            dataset_fn)
        distributed_iterator = iter(distributed_dataset)
        for _ in range(4):
            accuracy.reset_state()
            for _ in range(7):
                self.coordinator.schedule(worker_fn,
                                          args=(distributed_iterator, ))
            self.coordinator.join()
        self.assertGreater(accuracy.result().numpy(), 0.5)

        # Create a saved model.
        model.feature_ps = feature_ps
        model.label_ps = label_ps
        model.label_inverse_lookup_layer = self.define_reverse_lookup_layer()

        def create_serving_signature(model):
            @tf.function
            def serve_fn(raw_features):
                raw_features = tf.expand_dims(raw_features, axis=0)
                transformed_features = model.feature_ps(raw_features)
                outputs = model(transformed_features)
                outputs = tf.squeeze(outputs, axis=0)
                outputs = tf.cast(tf.greater(outputs, 0.5), tf.int64)
                decoded_outputs = model.label_inverse_lookup_layer(outputs)
                return tf.squeeze(decoded_outputs, axis=0)

            # serving does NOT have batch dimension
            return serve_fn.get_concrete_function(
                tf.TensorSpec(shape=(3), dtype=tf.string, name="example"))

        serving_fn = create_serving_signature(model)

        saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        model.save(saved_model_dir, signatures={"serving_default": serving_fn})

        if load_under_strategy:
            with self.coordinator.strategy.scope():

                loaded_serving_fn = tf.keras.models.load_model(
                    saved_model_dir).signatures["serving_default"]

            outputs = []
            for _ in range(7):
                outputs.append(
                    self.coordinator.schedule(
                        loaded_serving_fn,
                        args=(tf.constant(["avenger", "ironman",
                                           "avenger"]), )))
            self.coordinator.join()
            for prediction0 in outputs:
                self.assertIn(prediction0._get_values()["output_0"],
                              ("yes", "no"))
        else:
            loaded_serving_fn = tf.keras.models.load_model(
                saved_model_dir).signatures["serving_default"]

            # check the result w/ and w/o avenger.
            prediction0 = loaded_serving_fn(
                tf.constant(["avenger", "ironman", "avenger"]))["output_0"]
            self.assertIn(prediction0, ("yes", "no"))

            prediction1 = loaded_serving_fn(
                tf.constant(["ironman", "ironman", "unkonwn"]))["output_0"]
            self.assertIn(prediction1, ("yes", "no"))
def run_customized_training_loop(
    # pylint: disable=invalid-name
    _sentinel=None,
    # pylint: enable=invalid-name
    strategy=None,
    model_fn=None,
    loss_fn=None,
    model_dir=None,
    train_input_fn=None,
    steps_per_epoch=None,
    steps_per_loop=1,
    epochs=1,
    eval_input_fn=None,
    eval_steps=None,
    steps_between_eval=None,
    steps_before_eval_start=None,
    stop_threshold=None,
    metric_fn=None,
    init_checkpoint=None,
    custom_callbacks=None,
    run_eagerly=False,
    sub_model_export_name=None,
    explicit_allreduce=False,
    device_warmup=False,
    synthetic_train_input_fn=None,
    pre_allreduce_callbacks=None,
    post_allreduce_callbacks=None,
    allreduce_bytes_per_pack=0,
    enable_checkpoint_and_summary=False,
    num_accumulation_steps=1,
    stop_steps=None):
  """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      steps_between_eval: Number of steps between evals
      steps_before_eval_start: Number of steps to skip before starting eval
      stop_threshold: Stop threshold for MLPerf once accuracy achieved
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        methods are invoked during training.
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.
      sub_model_export_name: If not None, will export `sub_model` returned by
        `model_fn` into checkpoint files. The name of intermediate checkpoint
        file is {sub_model_export_name}_step_{step}.ckpt and the last
        checkpint's name is {sub_model_export_name}.ckpt;
        if None, `sub_model` will not be exported as checkpoint.
      explicit_allreduce: Whether to explicitly perform gradient allreduce,
        instead of relying on implicit allreduce in optimizer.apply_gradients().
        default is False. For now, if training using FP16 mixed precision,
        explicit allreduce will aggregate gradients in FP16 format. For TPU and
        GPU training using FP32, explicit allreduce will aggregate gradients in
        FP32 format.
      device_warmup: Whether or not to enable device warmup. This
        runs the training and eval loop on synthetic data to pre-compile XLA
        and TF tracing before accessing data.
      synthetic_train_input_fn: Function that returns synthetic training
        dataset. This is used in device warmup.
      pre_allreduce_callbacks: A list of callback functions that takes gradients
        and model variables pairs as input, manipulate them, and returns a new
        gradients and model variables paris. The callback functions will be
        invoked in the list order and before gradients are allreduced.
        Default is no callbacks. Only used when explicit_allreduce=True.
      post_allreduce_callbacks: A list of callback functions that takes
        gradients and model variables pairs as input, manipulate them, and
        returns a new gradients and model variables paris. The callback
        functions will be invoked in the list order and right before gradients
        are applied to variables for updates. Default is no callbacks. Only used
        when explicit_allreduce=True.
      allreduce_bytes_per_pack: A non-negative integer. Breaks collective
        operations into packs of certain size. If it's zero, all gradients are
        in one pack.
      enable_checkpoint_and_summary: Whether to save checkpoint and summary.
      stop_steps: The number of steps to run before stopping the training loop.

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
        (4) sub_model_checkpoint_name is specified, but `sub_model` returned
        by `model_fn` is None.
  """

  if _sentinel is not None:
    raise ValueError('only call `run_customized_training_loop()` '
                     'with named arguments.')

  required_arguments = [
      strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
  ]
  if [arg for arg in required_arguments if arg is None]:
    raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                     '`steps_per_loop` and `steps_per_epoch` are required '
                     'parameters.')

  if steps_between_eval % steps_per_loop != 0:
    raise ValueError('steps_between_eval should be multiple of steps_per_loop.')

  if steps_per_loop > steps_per_epoch:
    logging.error(
        'steps_per_loop: %d is specified to be greater than '
        ' steps_per_epoch: %d, we will use steps_per_epoch as'
        ' steps_per_loop.', steps_per_loop, steps_per_epoch)
    steps_per_loop = steps_per_epoch
  assert tf.executing_eagerly()

  if run_eagerly:
    if steps_per_loop > 1:
      raise ValueError(
          'steps_per_loop is used for performance optimization. When you want '
          'to run eagerly, you cannot leverage graph mode loop.')
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      raise ValueError(
          'TPUStrategy should not run eagerly as it heavily replies on graph'
          ' optimization for the distributed system.')

  if eval_input_fn and (eval_steps is None):
    raise ValueError(
        '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
        'is not none.')
  if device_warmup and (synthetic_train_input_fn is None):
    raise ValueError('`synthetic_train_input_fn` is required when '
                     'device_warmup is enabled.')

  if metric_fn and not callable(metric_fn):
    raise ValueError(
        'if `metric_fn` is specified, metric_fn must be a callable.')

  if stop_steps:
    total_training_steps = stop_steps
  else:
    total_training_steps = steps_per_epoch * epochs

  if stop_steps and stop_steps > steps_per_epoch * epochs:
    raise ValueError('`stop_steps` should not be greater than '
                     '`num_train_steps_per_epoch` * `num_epochs`.')

  # To reduce unnecessary send/receive input pipeline operation, we place input
  # pipeline ops in worker task.
  train_iterator = _get_input_iterator(train_input_fn, strategy)

  with distribution_utils.get_strategy_scope(strategy):
    # To correctly place the model weights on accelerators,
    # model and optimizer should be created in scope.
    model, sub_model, sub_pretrain_model = model_fn()
    if not hasattr(model, 'optimizer'):
      raise ValueError('User should set optimizer attribute to model '
                       'inside `model_fn`.')
    if sub_model_export_name and sub_model is None:
      raise ValueError('sub_model_export_name is specified as %s, but '
                       'sub_model is None.' % sub_model_export_name)

    optimizer = model.optimizer

    train_loss_metric = tf.keras.metrics.Mean(
        'training_loss', dtype=tf.float32)
    if eval_input_fn:
      eval_metric_num = tf.keras.metrics.Sum('masked_lm_num', dtype=tf.float32)
      eval_metric_denom = tf.keras.metrics.Sum(
          'masked_lm_denom', dtype=tf.float32)

    # If evaluation is required, make a copy of metric as it will be used by
    # both train and evaluation.
    train_metrics = [
        tf.keras.metrics.Mean('masked_lm_accuracy', dtype=tf.float32)
    ]

    # Create summary writers
    summary_dir = os.path.join(model_dir, 'summaries')
    if enable_checkpoint_and_summary:
      eval_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'eval'))
    else:
      eval_summary_writer = tf.summary.create_noop_writer()
    if steps_per_loop >= _MIN_SUMMARY_STEPS and enable_checkpoint_and_summary:
      # Only writes summary when the stats are collected sufficiently over
      # enough steps.
      train_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'train'))
    else:
      train_summary_writer = tf.summary.create_noop_writer()

    # Collects training variables.
    training_vars = model.trainable_variables

    @tf.function(experimental_compile=True)
    def _compiled_local_step(inputs, labels, training_vars, accum_vars):
      """Replicated training step."""
      with tf.GradientTape() as tape:
        model_outputs, metric_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
      if isinstance(optimizer,
                    tf.keras.mixed_precision.experimental.LossScaleOptimizer):
        with tape:
          scaled_loss = optimizer.get_scaled_loss(loss)
        scaled_grads = tape.gradient(scaled_loss, training_vars)
        grads = optimizer.get_unscaled_gradients(scaled_grads)
      else:
        grads = tape.gradient(loss, training_vars)
      (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

      if accum_vars is None:
        return grads, loss, model_outputs, metric_outputs
      else:
        new_accum_vars = []
        for i, grad in enumerate(grads):
          new_accum_vars.append(
              accum_vars[i] +
              tf.math.scalar_mul(1.0 / num_accumulation_steps, grad))
        return new_accum_vars, loss, model_outputs, metric_outputs

    def get_input_slice(input_dict, idx):
      split_input = {}
      for key in input_dict:
        split_input[key] = input_dict[key][idx]
      return split_input

    def _replicated_step(inputs):
      """Replicated training step."""
      inputs, labels = inputs
      if explicit_allreduce:
        # TODO(b/155523821): Fix OOM issue so we use experimental_compile with
        # multi-worker mirrored strategy.
        with tf.GradientTape() as tape:
          model_outputs, metric_outputs = model(inputs, training=True)
          loss = loss_fn(labels, model_outputs)

        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     training_vars,
                                                     pre_allreduce_callbacks,
                                                     post_allreduce_callbacks,
                                                     allreduce_bytes_per_pack)
      else:
        if num_accumulation_steps > 1:
          accum_vars = [
              tf.zeros_like(tvar, dtype=tf.float32) for tvar in training_vars
          ]
          for key in inputs:
            inputs[key] = tf.split(inputs[key], num_accumulation_steps)

          split_labels = tf.split(labels, num_accumulation_steps)
          for local_step in range(num_accumulation_steps):
            accum_vars, loss, model_outputs, metric_outputs = _compiled_local_step(
                get_input_slice(inputs, local_step), split_labels[local_step],
                training_vars, accum_vars)

          optimizer.apply_gradients(zip(accum_vars, training_vars))
        else:
          grads, loss, model_outputs, metric_outputs = _compiled_local_step(
              inputs, labels, training_vars, None)
          optimizer.apply_gradients(zip(grads, training_vars))
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(loss)
      for metric in train_metrics:
        metric.update_state(metric_outputs['masked_lm_accuracy'])

    @tf.function
    def train_steps(iterator, steps):
      """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      if not isinstance(steps, tf.Tensor):
        raise ValueError('steps should be an Tensor. Python object may cause '
                         'retracing.')

      for _ in tf.range(steps):
        strategy.run(_replicated_step, args=(next(iterator),))

    def train_single_step(iterator):
      """Performs a distributed training step.

      Args:
        iterator: the distributed iterator of training datasets.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      strategy.run(_replicated_step, args=(next(iterator),))

    def test_step(iterator):
      """Calculates evaluation metrics on distributed devices."""

      def _test_step_fn(inputs):
        """Replicated accuracy calculation."""

        inputs, labels = inputs
        model_outputs, metric_outputs = model(inputs, training=False)
        eval_metric_num.update_state(metric_outputs['masked_lm_num'])
        eval_metric_denom.update_state(metric_outputs['masked_lm_denom'])
      strategy.run(_test_step_fn, args=(next(iterator),))

    if not run_eagerly:
      train_single_step = tf.function(train_single_step)
      test_step = tf.function(test_step)

    def _run_evaluation(current_training_step, test_iterator):
      """Runs validation steps and aggregate metrics."""
      mlperf_epoch_num = int(current_training_step / steps_between_eval)
      mlp_log.mlperf_print(
          'eval_start', None, metadata={'epoch_num': mlperf_epoch_num})
      for _ in range(eval_steps):
        test_step(test_iterator)
      mlp_log.mlperf_print(
          'eval_stop', None, metadata={'epoch_num': mlperf_epoch_num})

      with eval_summary_writer.as_default():
        masked_lm_accuracy = (
            _float_metric_value(eval_metric_num) /
            _float_metric_value(eval_metric_denom))
        logging.info('Step: [%d] Validation %s = %f', current_training_step,
                     'masked_lm_accuracy', masked_lm_accuracy)
        tf.summary.scalar(
            'masked_lm_accuracy',
            masked_lm_accuracy,
            step=current_training_step)
        mlp_log.mlperf_print(
            'eval_accuracy',
            masked_lm_accuracy,
            metadata={'epoch_num': mlperf_epoch_num})
        eval_summary_writer.flush()
      return masked_lm_accuracy

    def _run_callbacks_on_batch_begin(batch):
      """Runs custom callbacks at the start of every step."""
      # While BERT pretraining does not have epochs,
      # to make the logging consistent with other mlperf models,
      # in all the mlp_log, epochs are steps.
      mlp_log.mlperf_print(
          'block_start',
          None,
          metadata={
              'first_epoch_num': int(batch),
              'epoch_count': int(steps_per_loop),
          })
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_begin(batch)

    def _run_callbacks_on_batch_end(batch, logs):
      """Runs custom callbacks at the end of every step."""
      mlp_log.mlperf_print(
          'block_stop', None, metadata={
              'first_epoch_num': int(batch),
          })
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_end(batch, logs)

    # Training loop starts here.
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    sub_model_checkpoint = tf.train.Checkpoint(
        model=sub_model) if sub_model_export_name else None

    # TODO: commenting this out, as we always load from a initial checkpoint
    # latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    # if latest_checkpoint_file:
    #   logging.info(
    #       'Checkpoint file %s found and restoring from '
    #       'checkpoint', latest_checkpoint_file)
    #   checkpoint.restore(latest_checkpoint_file)
    #   logging.info('Loading from checkpoint file completed')

    current_step = optimizer.iterations.numpy()
    checkpoint_name = 'ctl_step_{step}.ckpt'
    checkpoint_save_dir = model_dir if enable_checkpoint_and_summary else None

    if init_checkpoint:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'initial checkpoint for core model.', init_checkpoint)
      checkpoint = tf.train.Checkpoint(model=sub_pretrain_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
      logging.info('Loading from checkpoint file completed')

    if device_warmup:
      synthetic_train_iterator = _get_input_iterator(synthetic_train_input_fn,
                                                     strategy)
      logging.info('Running device warmup for 1 step.')
      train_steps(synthetic_train_iterator, tf.constant(1, dtype=tf.int32))
      # Reset the global step.
      tf.keras.backend.set_value(optimizer.iterations, 0)
      current_step = optimizer.iterations.numpy()

    masked_lm_accuracy = 0
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)

    while current_step < total_training_steps:
      # Training loss/metric are taking average over steps inside micro
      # training loop. We reset the their values before each round.
      train_loss_metric.reset_states()
      for metric in train_metrics + model.metrics:
        metric.reset_states()

      _run_callbacks_on_batch_begin(current_step)
      # Runs several steps in the host while loop.
      steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

      train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
      train_loss = _float_metric_value(train_loss_metric)
      _run_callbacks_on_batch_end(current_step, {'loss': train_loss})
      current_step += steps

      # Updates training logging.
      training_status = 'Train Step: %d/%d  / loss = %s' % (
          current_step, total_training_steps, train_loss)

      with train_summary_writer.as_default():
        tf.summary.scalar(
            train_loss_metric.name, train_loss, step=current_step)
        for metric in train_metrics + model.metrics:
          metric_value = _float_metric_value(metric)
          training_status += '  %s = %f' % (metric.name, metric_value)
          tf.summary.scalar(metric.name, metric_value, step=current_step)
        train_summary_writer.flush()
      logging.info(training_status)

      # Saves model checkpoints and run validation steps at every epoch end.
      if current_step % steps_per_epoch == 0:
        # To avoid repeated model saving, we do not save after the last
        # step of training.
        if current_step < total_training_steps:
          _save_checkpoint(checkpoint, checkpoint_save_dir,
                           checkpoint_name.format(step=current_step))
          if sub_model_export_name:
            _save_checkpoint(
                sub_model_checkpoint, checkpoint_save_dir,
                '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
      if eval_input_fn and (current_step % (steps_between_eval) == 0) and (
          current_step >= steps_before_eval_start):
        logging.info('Running evaluation after step: %s.', current_step)
        masked_lm_accuracy = _run_evaluation(
            current_step, _get_input_iterator(eval_input_fn, strategy))
        if masked_lm_accuracy >= stop_threshold:
          mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
          break

        # Re-initialize evaluation metric.
        eval_metric_num.reset_states()
        eval_metric_denom.reset_states()

    if masked_lm_accuracy < stop_threshold:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'aborted'})

    _save_checkpoint(checkpoint, checkpoint_save_dir,
                     checkpoint_name.format(step=current_step))
    if sub_model_export_name:
      _save_checkpoint(sub_model_checkpoint, checkpoint_save_dir,
                       '%s.ckpt' % sub_model_export_name)

    if enable_checkpoint_and_summary:
      training_summary = {
          'total_training_steps': total_training_steps,
          'train_loss': _float_metric_value(train_loss_metric),
      }
      if train_metrics:
        # TODO(hongkuny): Cleans up summary reporting in text.
        training_summary['last_train_metrics'] = _float_metric_value(
            train_metrics[0])
        #training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])

      write_txt_summary(training_summary, summary_dir)

    return model, masked_lm_accuracy, current_step
Beispiel #13
0
 def _event_shape_tensor(self):
   return tf.constant([], dtype=tf.int32)
 def test_shift_right_by_one(self):
     x = tf.constant([3, 8, 1, 0, 0])
     shifted = feature_converters._shift_right_by_one(x)
     expected = [0, 3, 8, 1, 0]
     actual = self.evaluate(shifted)
     self.assertAllEqual(actual, expected)
 def test_shift_right_by_one_nonzero_last_position(self):
     x = tf.constant([3, 8, 8, 9, 4])
     shifted = feature_converters._shift_right_by_one(x)
     expected = [0, 3, 8, 8, 9]
     actual = self.evaluate(shifted)
     self.assertAllEqual(actual, expected)
 def testParamShapes(self):
     desired_shape = [10, 3, 4]
     self._testParamShapes(desired_shape)
     self._testParamShapes(tf.constant(desired_shape))
 def test_non_padding_position(self):
     x = tf.constant([3, 8, 5, 0, 0, 2, 0])
     non_padding_position = feature_converters.non_padding_position(x)
     expected = [1, 1, 1, 0, 0, 1, 0]
     actual = self.evaluate(non_padding_position)
     self.assertAllEqual(actual, expected)
Beispiel #18
0
 def _event_shape_tensor(self):
     return tf.constant([self.dimension, self.dimension], dtype=tf.int32)
def interpolate(x: types.RealTensor,
                x_data: types.RealTensor,
                y_data: types.RealTensor,
                left_slope: types.RealTensor = None,
                right_slope: types.RealTensor = None,
                validate_args: bool = False,
                optimize_for_tpu: bool = False,
                dtype: tf.DType = None,
                name: str = None):
  """Performs linear interpolation for supplied points.

  Given a set of knots whose x- and y- coordinates are in `x_data` and `y_data`,
  this function returns y-values for x-coordinates in `x` via piecewise
  linear interpolation.

  `x_data` must be non decreasing, but `y_data` don't need to be because we do
  not require the function approximated by these knots to be monotonic.

  #### Examples

  ```python
  import tf_quant_finance as tff
  x = [-10, -1, 1, 3, 6, 7, 8, 15, 18, 25, 30, 35]
  x_data = [-1, 2, 6, 8, 18, 30.0]
  y_data = [10, -1, -5, 7, 9, 20]

  tff.math.interpolation.linear.interpolate(x, x_data, y_data,
                                            dtype=tf.float64)
  # Expected: [ 10, 10, 2.66666667, -2, -5, 1, 7, 8.4, 9, 15.41666667, 20, 20]
  ```

  Args:
    x: x-coordinates for which we need to get interpolation. A N-D
      `Tensor` of real dtype. First N-1 dimensions represent batching
      dimensions.
    x_data: x coordinates. A N-D `Tensor` of real dtype. Should be sorted
      in non decreasing order. First N-1 dimensions represent batching
      dimensions.
    y_data: y coordinates. A N-D `Tensor` of real dtype. Should have the
      compatible shape as `x_data`. First N-1 dimensions represent batching
      dimensions.
    left_slope: The slope to use for extrapolation with x-coordinate smaller
      than the min `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None`, which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the leftmost `y_data`.
    right_slope: The slope to use for extrapolation with x-coordinate greater
      than the max `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None` which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the rightmost `y_data`.
    validate_args: Python `bool` that indicates whether the function performs
      the check if the shapes of `x_data` and `y_data` are equal and that the
      elements in `x_data` are non decreasing. If this value is set to `False`
      and the elements in `x_data` are not increasing, the result of linear
      interpolation may be wrong.
      Default value: `False`.
    optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot
      encoding to lookup indices of `x` in `x_data`. This significantly
      improves performance of the algorithm on a TPU device but may slow down
      performance on the CPU.
      Default value: `False`.
    dtype: Optional tf.dtype for `x`, x_data`, `y_data`, `left_slope` and
      `right_slope`.
      Default value: `None` which means that the `dtype` inferred from
        `x`.
    name: Python str. The name prefixed to the ops created by this function.
      Default value: `None` which maps to 'linear_interpolation'.

  Returns:
    A N-D `Tensor` of real dtype corresponding to the x-values in `x`.
  """
  name = name or 'linear_interpolate'
  with tf.name_scope(name):
    x = tf.convert_to_tensor(x, dtype=dtype, name='x')
    dtype = dtype or x.dtype
    x_data = tf.convert_to_tensor(x_data, dtype=dtype, name='x_data')
    y_data = tf.convert_to_tensor(y_data, dtype=dtype, name='y_data')
    # Try broadcast batch_shapes
    x, x_data, y_data = tff_utils.broadcast_common_batch_shape(
        x, x_data, y_data)

    # Rank of the inputs is known
    batch_rank = x.shape.rank - 1
    if batch_rank == 0:
      x = tf.expand_dims(x, 0)
      x_data = tf.expand_dims(x_data, 0)
      y_data = tf.expand_dims(y_data, 0)

    if left_slope is None:
      left_slope = tf.constant(0.0, dtype=x.dtype, name='left_slope')
    else:
      left_slope = tf.convert_to_tensor(left_slope, dtype=dtype,
                                        name='left_slope')
    if right_slope is None:
      right_slope = tf.constant(0.0, dtype=x.dtype, name='right_slope')
    else:
      right_slope = tf.convert_to_tensor(right_slope, dtype=dtype,
                                         name='right_slope')
    control_deps = []
    if validate_args:
      # Check that `x_data` elements is non-decreasing
      diffs = x_data[..., 1:] - x_data[..., :-1]
      assertion = tf.debugging.assert_greater_equal(
          diffs,
          tf.zeros_like(diffs),
          message='x_data is not sorted in non-decreasing order.')
      control_deps.append(assertion)
      # Check that the shapes of `x_data` and `y_data` are equal
      control_deps.append(
          tf.compat.v1.assert_equal(tff_utils.get_shape(x_data),
                                    tff_utils.get_shape(y_data)))

    with tf.control_dependencies(control_deps):
      # Get upper bound indices for `x`.
      upper_indices = tf.searchsorted(x_data, x, side='left', out_type=tf.int32)
      x_data_size = tff_utils.get_shape(x_data)[-1]
      at_min = tf.equal(upper_indices, 0)
      at_max = tf.equal(upper_indices, x_data_size)
      # Create tensors in order to be used by `tf.where`.
      # `values_min` are extrapolated values for x-coordinates less than or
      # equal to `x_data[..., 0]`.
      # `values_max` are extrapolated values for x-coordinates greater than
      # `x_data[..., -1]`.

      values_min = tf.expand_dims(y_data[..., 0], -1) + left_slope * (
          x - tf.broadcast_to(
              tf.expand_dims(x_data[..., 0], -1),
              shape=tff_utils.get_shape(x)))
      values_max = tf.expand_dims(y_data[..., -1], -1) + right_slope * (
          x - tf.broadcast_to(
              tf.expand_dims(x_data[..., -1], -1),
              shape=tff_utils.get_shape(x)))

      # `tf.where` evaluates all branches, need to cap indices to ensure it
      # won't go out of bounds.
      lower_encoding = tf.math.maximum(upper_indices - 1, 0)
      upper_encoding = tf.math.minimum(upper_indices, x_data_size - 1)
      # Prepare indices for `tf.gather` or `tf.one_hot`
      # TODO(b/156720909): Extract get_slice logic into a common utilities
      # module for cubic and linear interpolation
      if optimize_for_tpu:
        lower_encoding = tf.one_hot(lower_encoding, x_data_size,
                                    dtype=dtype)
        upper_encoding = tf.one_hot(upper_encoding, x_data_size,
                                    dtype=dtype)
      def get_slice(x, encoding):
        if optimize_for_tpu:
          return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) * encoding,
                                    axis=-1)
        else:
          return tf.gather(x, encoding, axis=-1, batch_dims=x.shape.rank - 1)
      x_data_lower = get_slice(x_data, lower_encoding)
      x_data_upper = get_slice(x_data, upper_encoding)
      y_data_lower = get_slice(y_data, lower_encoding)
      y_data_upper = get_slice(y_data, upper_encoding)

      # Nan in unselected branches could propagate through gradient calculation,
      # hence we need to clip the values to ensure no nan would occur. In this
      # case we need to ensure there is no division by zero.
      x_data_diff = x_data_upper - x_data_lower
      floor_x_diff = tf.where(at_min | at_max, x_data_diff + 1, x_data_diff)
      interpolated = y_data_lower + (x - x_data_lower) * (
          tf.math.divide_no_nan(y_data_upper - y_data_lower, floor_x_diff))

      interpolated = tf.where(at_min, values_min, interpolated)
      interpolated = tf.where(at_max, values_max, interpolated)
      if batch_rank > 0:
        return interpolated
      else:
        return tf.squeeze(interpolated, 0)
Beispiel #20
0
def make_bernoulli(batch_shape, dtype=tf.int32):
    p = np.random.uniform(size=list(batch_shape))
    p = tf.constant(p, dtype=tf.float32)
    return tfd.Bernoulli(probs=p, dtype=dtype, validate_args=True)
Beispiel #21
0
 def test_false_for_base_case(self):
     self.assertFalse(util.is_namedtuple_like(tuple([1, 2])))
     self.assertFalse(util.is_namedtuple_like(list([3., 4.])))
     self.assertFalse(util.is_namedtuple_like(dict(a=5, b=6)))
     self.assertFalse(util.is_namedtuple_like(tf.constant(1.)))
     self.assertFalse(util.is_namedtuple_like(np.int32()))
 def test_autoregressive_inputs_unpacked(self):
     x = tf.constant([3, 8, 9, 5, 1, 0, 0])
     autoreg_inputs = feature_converters.autoregressive_inputs(x)
     actual = self.evaluate(autoreg_inputs)
     expected = [0, 3, 8, 9, 5, 1, 0]
     self.assertAllEqual(actual, expected)
Beispiel #23
0
 def testGradientWorksDespiteBijectorCaching(self):
     x = tf.constant(2.)
     fn_result, grads = util.maybe_call_fn_and_grads(
         lambda x_: tfd.LogNormal(loc=0., scale=1.).log_prob(x_), x)
     self.assertAllEqual(False, fn_result is None)
     self.assertAllEqual([False], [g is None for g in grads])
 def test_autoregressive_inputs_different_dtypes(self):
     x = tf.constant([3, 8, 1, 9, 1, 5, 4, 1, 0, 0])
     sequence_id = tf.constant([1, 1, 1, 2, 2, 3, 3, 3, 0, 0], tf.int32)
     autoreg_inputs = feature_converters.autoregressive_inputs(
         x, sequence_id=sequence_id, output_dtype=tf.int64)
     self.assertEqual(autoreg_inputs.dtype, tf.int64)
    def test_TimeDistributed_with_mimo(self):
        dense_1 = keras.layers.Dense(8)
        dense_2 = keras.layers.Dense(16)

        class TestLayer(keras.layers.Layer):
            def __init__(self):
                super().__init__()
                self.dense_1 = dense_1
                self.dense_2 = dense_2

            def call(self, inputs):
                return self.dense_1(inputs[0]), self.dense_2(inputs[1])

            def compute_output_shape(self, input_shape):
                output_shape_1 = self.dense_1.compute_output_shape(
                    input_shape[0]
                )
                output_shape_2 = self.dense_2.compute_output_shape(
                    input_shape[1]
                )
                return output_shape_1, output_shape_2

        np.random.seed(100)
        layer = TestLayer()

        data_1 = tf.constant(
            [
                [[[1.0], [1.0]], [[2.0], [2.0]]],
                [[[4.0], [4.0]], [[5.0], [5.0]]],
                [[[7.0], [7.0]], [[8.0], [8.0]]],
            ]
        )

        data_2 = tf.constant(
            [
                [[[1.0], [1.0]], [[2.0], [2.0]]],
                [[[4.0], [4.0]], [[5.0], [5.0]]],
                [[[7.0], [7.0]], [[8.0], [8.0]]],
            ]
        )

        x1 = keras.Input(shape=(None, 2, 1), dtype="float32")
        x2 = keras.Input(shape=(None, 2, 1), dtype="float32")
        y1, y2 = keras.layers.TimeDistributed(layer)([x1, x2])
        model_1 = keras.models.Model([x1, x2], [y1, y2])
        model_1.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        output_1 = model_1.predict((data_1, data_2), steps=1)

        y1 = dense_1(x1)
        y2 = dense_2(x2)
        model_2 = keras.models.Model([x1, x2], [y1, y2])
        output_2 = model_2.predict((data_1, data_2), steps=1)

        self.assertAllClose(output_1, output_2)

        model_1.fit(
            x=[
                np.random.random((10, 2, 2, 1)),
                np.random.random((10, 2, 2, 1)),
            ],
            y=[
                np.random.random((10, 2, 2, 8)),
                np.random.random((10, 2, 2, 16)),
            ],
            epochs=1,
            batch_size=3,
        )
 def my_fn(ex):
     inputs = ex["inputs"]
     res = ex.copy()
     res["inputs"] = tf.where(tf.greater(inputs, 15),
                              tf.constant(50, inputs.dtype), inputs)
     return res
Beispiel #27
0
 def get_results():
     start = tf.constant(start_position)
     return self.evaluate(
         tfp.optimizer.lbfgs_minimize(rastrigin,
                                      initial_position=start,
                                      tolerance=1e-5))
Beispiel #28
0
 def testLambertWGradient(self, value, expected):
     """Tests the gradient of the LambertW function on some known identities."""
     x = tf.constant(value, dtype=tf.float64)
     _, dy_dx = tfp.math.value_and_gradient(tfp.math.lambertw, x)
     self.assertAllClose(dy_dx, expected)
Beispiel #29
0
def assign_moving_mean_variance(value,
                                moving_mean,
                                moving_variance=None,
                                zero_debias_count=None,
                                decay=0.99,
                                axis=(),
                                name=None):
    """Compute one update to the exponentially weighted moving mean and variance.

  The `value` updated exponentially weighted moving `moving_mean` and
  `moving_variance` are conceptually given by the following recurrence
  relations ([Welford (1962)][1]):

  ```python
  new_mean = old_mean + (1 - decay) * (value - old_mean)
  new_var  = old_var  + (1 - decay) * (value - old_mean) * (value - new_mean)
  ```

  This function implements the above recurrences in a numerically stable manner
  and also uses the `assign_add` op to allow concurrent lockless updates to the
  supplied variables.

  For additional references see [this John D. Cook blog post][
  https://www.johndcook.com/blog/standard_deviation/]
  (whereas we use `1 - decay = 1 / k`) and
  [Finch (2009; Eq.  143)][2] (whereas we use `1 - decay = alpha`).

  Since variables that are initialized to a `0` value will be `0` biased,
  providing `zero_debias_count` triggers scaling the `moving_mean` and
  `moving_variance` by the factor of `1 - decay ** (zero_debias_count + 1)`.
  For more details, see `tfp.stats.moving_mean_variance_zero_debiased`.

  Args:
    value: `float`-like `Tensor` representing one or more streaming
      observations. When `axis` is non-empty `value ` is reduced (by mean) for
      updated `moving_mean` and `moving-variance`. Presumed to have same shape
      as `moving_mean` and `moving_variance`.
    moving_mean: `float`-like `tf.Variable` representing the exponentially
      weighted moving mean. Same shape as `moving_variance` and `value`. This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
    moving_variance: `float`-like `tf.Variable` representing the exponentially
      weighted moving variance. Same shape as `moving_mean` and `value`.  This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
      Default value: `None` (i.e., no moving variance is computed).
    zero_debias_count: `int`-like `tf.Variable` representing the number of times
      this function has been called on streaming input (*not* the number of
      reduced values used in this functions computation). When not `None` (the
      default) the returned values for `moving_mean` and `moving_variance` are
      "zero debiased", i.e., corrected for their presumed all zeros
      intialization. Note: the `tf.Variable`s `moving_mean` and
      `moving_variance` *always* store the unbiased calculation, regardless of
      setting this argument. To obtain unbiased calculations from these
      `tf.Variable`s, see `tfp.stats.moving_mean_variance_zero_debiased`.
      Default value: `None` (i.e., no zero debiasing calculation is made).
    decay: A `float`-like `Tensor` representing the moving mean decay. Typically
      close to `1.`, e.g., `0.99`.
      Default value: `0.99`.
    axis: The dimensions to reduce. If `()` (the default) no dimensions are
      reduced. If `None` all dimensions are reduced. Must be in the range
      `[-rank(value), rank(value))`.
      Default value: `()` (i.e., no reduction is made).
    name: Python `str` prepended to op names created by this function.
      Default value: `None` (i.e., 'assign_moving_mean_variance').

  Returns:
    moving_mean: The `value`-updated exponentially weighted moving mean.
      Debiased if `zero_debias_count is not None`.
    moving_variance: The `value`-updated exponentially weighted moving variance.
      Debiased if `zero_debias_count is not None`.

  Raises:
    TypeError: if `moving_mean` does not have float type `dtype`.
    TypeError: if `moving_mean`, `moving_variance`, `value`, `decay` have
      different `base_dtype`.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions
  d = tfd.MultivariateNormalTriL(
      loc=[-1., 1.],
      scale_tril=tf.linalg.cholesky([[0.75, 0.05],
                                     [0.05, 0.5]]))
  d.mean()
  # ==> [-1.,  1.]
  d.variance()
  # ==> [0.75, 0.5]
  moving_mean = tf.Variable(tf.zeros(2))
  moving_variance = tf.Variable(tf.zeros(2))
  zero_debias_count = tf.Variable(0)
  for _ in range(100):
    m, v = tfp.stats.assign_moving_mean_variance(
      value=d.sample(3),
      moving_mean=moving_mean,
      moving_variance=moving_variance,
      zero_debias_count=zero_debias_count,
      decay=0.99,
      axis=-2)
    print(m.numpy(), v.numpy())
  # ==> [-1.0334632  0.9545268] [0.8126194 0.5118788]
  # ==> [-1.0293456   0.96070296] [0.8115873  0.50947404]
  # ...
  # ==> [-1.025172  0.96351 ] [0.7142789  0.48570773]

  m1, v1 = tfp.stats.moving_mean_variance_zero_debiased(
    moving_mean,
    moving_variance,
    zero_debias_count,
    decay=0.99)
  print(m.numpy(), v.numpy())
  # ==> [-1.025172  0.96351 ] [0.7142789  0.48570773]
  assert(all(m == m1))
  assert(all(v == v1))
  ```

  #### References

  [1]  B. P. Welford. Note on a Method for Calculating Corrected Sums of
       Squares and Products. Technometrics, Vol. 4, No. 3 (Aug., 1962), p419-20.
       http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.302.7503&rep=rep1&type=pdf
       http://www.jstor.org/stable/1266577

  [2]: Tony Finch. Incremental calculation of weighted mean and variance.
       _Technical Report_, 2009.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
  """
    with tf.name_scope(name or 'assign_moving_mean_variance'):
        base_dtype = dtype_util.base_dtype(moving_mean.dtype)
        if not dtype_util.is_floating(base_dtype):
            raise TypeError(
                'Argument `moving_mean` is not float type (saw {}).'.format(
                    dtype_util.name(moving_mean.dtype)))

        value = tf.convert_to_tensor(value, dtype=base_dtype, name='value')
        decay = tf.convert_to_tensor(decay, dtype=base_dtype, name='decay')
        # Force a read of `moving_mean` as  we'll need it twice.
        old_mean = tf.convert_to_tensor(moving_mean,
                                        dtype=base_dtype,
                                        name='old_mean')

        updated_mean = moving_mean.assign_add(
            (1. - decay) * (tf.reduce_mean(value, axis=axis) - old_mean))

        if zero_debias_count is not None:
            t = tf.cast(zero_debias_count.assign_add(1), base_dtype)
            # Could have used:
            #   bias_correction = -tf.math.expm1(t * tf.math.log(decay))
            # however since we expect decay to be nearly 1, we don't expect this to
            # bear a significant improvement, yet would incur higher computational
            # cost.
            bias_correction = 1. - decay**t
            with tf.control_dependencies([updated_mean]):
                updated_mean = updated_mean / bias_correction

        if moving_variance is None:
            return updated_mean

        if base_dtype != dtype_util.base_dtype(moving_variance.dtype):
            raise TypeError(
                'Arguments `moving_mean` and `moving_variance` do not '
                'have same base `dtype` (saw {}, {}).'.format(
                    dtype_util.name(moving_mean.dtype),
                    dtype_util.name(moving_variance.dtype)))

        if zero_debias_count is not None:
            old_t = tf.where(t > 1., t - 1., tf.constant(np.inf, base_dtype))
            old_bias_correction = 1. - decay**old_t
            old_mean = old_mean / old_bias_correction

        mean_sq_diff = tf.reduce_mean(tf.math.squared_difference(
            value, old_mean),
                                      axis=axis)
        updated_variance = moving_variance.assign_add(
            (1. - decay) * (decay * mean_sq_diff - moving_variance))

        if zero_debias_count is not None:
            with tf.control_dependencies([updated_variance]):
                updated_variance = updated_variance / bias_correction

        return updated_mean, updated_variance
Beispiel #30
0
    def test_single_sharded_shared_embedding_softmax_layer(
            self, soft_cap_logits, lookup_style, scale_sqrt_depth):
        class_ids = np.random.randint(1, 50, [8, 10, 1])
        p = embedding_softmax.SingleShardSharedEmbeddingSoftmax.Params().Set(
            name='jax_softmax',
            num_classes=50,
            input_dims=40,
            soft_cap_logits=soft_cap_logits,
            lookup_style=lookup_style,
            scale_sqrt_depth=scale_sqrt_depth)
        softmax_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = softmax_layer.instantiate_variables(prng_key)
        npy_input = np.random.normal(1.5, 2.0, [8, 10, p.input_dims])
        inputs = jnp.asarray(npy_input)
        class_weights = np.random.normal(1.5, 2.0, [8, 10, 1])
        outputs = test_utils.apply(softmax_layer,
                                   initial_vars,
                                   softmax_layer.fprop,
                                   inputs,
                                   class_weights,
                                   class_ids=class_ids)
        ids = np.squeeze(class_ids, axis=-1)
        emb_lookup_outputs = test_utils.apply(softmax_layer,
                                              initial_vars,
                                              softmax_layer.emb_lookup,
                                              ids=jnp.asarray(ids))
        # Test whether tf Softmax layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = initial_vars
        tf_initial_vars.linear = py_utils.NestedMap()
        tf_initial_vars.linear.w = initial_vars.logits_ffn.linear.w
        tf_initial_vars.bias = py_utils.NestedMap()
        tf_initial_vars.bias.b = initial_vars.logits_ffn.bias.b
        tf_p = lingvo_layers.SingleShardSharedEmbeddingSoftmax.Params().Set(
            name='tf_softmax',
            num_classes=p.num_classes,
            input_dim=p.input_dims,
            vocab_size=p.num_classes,
            embedding_dim=p.input_dims,
            logits_soft_max=soft_cap_logits,
            scale_sqrt_depth=scale_sqrt_depth)
        tf_softmax_layer = tf_p.Instantiate()
        tf_output = tf_softmax_layer.FProp(tf_initial_vars,
                                           tf.constant(inputs,
                                                       dtype=tf.float32),
                                           class_weights,
                                           class_ids=class_ids)
        tf_emb_lookup_output = tf_softmax_layer.EmbLookup(tf_initial_vars,
                                                          ids=tf.constant(ids))

        # Check all entries in the NestedMap and ensure it matches TF
        np_logits = to_np(outputs.logits)
        tf_np_logits = to_np(tf_output.logits)
        self.assertAllClose(np_logits, tf_np_logits, atol=1e-6)
        for k in outputs.keys():
            self.assertAllClose(to_np(outputs[k]),
                                to_np(tf_output[k]),
                                atol=1e-6)
        np_emb_lookup_output = to_np(emb_lookup_outputs)
        tf_np_emb_lookup_output = to_np(tf_emb_lookup_output)
        self.assertAllClose(tf_np_emb_lookup_output,
                            np_emb_lookup_output,
                            atol=1e-6)
    def testEventShapes(self):
        shape_static = [5, 4, 3, 2]
        shape_dynamic = tf1.placeholder_with_default(tf.constant(shape_static),
                                                     shape=None)

        def make_bijector(perm=None, rightmost_transposed_ndims=None):
            if perm is not None:
                perm = tf.convert_to_tensor(value=perm)
                if not self.is_static:
                    perm = tf1.placeholder_with_default(perm, shape=perm.shape)
            return tfb.Transpose(
                perm, rightmost_transposed_ndims=rightmost_transposed_ndims)

        for is_shape_static, shape, shape_t in [
            (True, tf.zeros(shape_static).shape, tf.constant(shape_static)),
            (False, tf.zeros(shape_dynamic).shape, shape_dynamic)
        ]:

            # pylint: disable=cell-var-from-loop
            def event_shape(b, direction):
                shape_fn = getattr(b, '{}_event_shape'.format(direction))
                if (is_shape_static
                        and self.is_static) or tf.executing_eagerly():
                    result = shape_fn(shape)
                    self.assertTrue(tensorshape_util.is_fully_defined(result))
                    return result
                if is_shape_static:
                    self.assertEqual(len(shape), shape_fn(shape).ndims)
                else:
                    self.assertIsNone(shape_fn(shape).ndims)
                shape_tensor_fn = getattr(
                    b, '{}_event_shape_tensor'.format(direction))
                return self.evaluate(shape_tensor_fn(shape_t))

            # pylint: enable=cell-var-from-loop

            self.assertAllEqual((5, 3, 4, 2),
                                event_shape(make_bijector([1, 0, 2]),
                                            'forward'))
            self.assertAllEqual((5, 2, 4, 3),
                                event_shape(make_bijector([2, 0, 1]),
                                            'forward'))
            self.assertAllEqual(
                (5, 4, 2, 3),
                event_shape(make_bijector(rightmost_transposed_ndims=2),
                            'forward'))
            self.assertAllEqual(
                (5, 2, 3, 4),
                event_shape(make_bijector(rightmost_transposed_ndims=3),
                            'forward'))
            self.assertAllEqual((5, 3, 4, 2),
                                event_shape(make_bijector([1, 0, 2]),
                                            'inverse'))
            self.assertAllEqual((5, 3, 2, 4),
                                event_shape(make_bijector([2, 0, 1]),
                                            'inverse'))
            self.assertAllEqual(
                (5, 4, 2, 3),
                event_shape(make_bijector(rightmost_transposed_ndims=2),
                            'inverse'))
            self.assertAllEqual(
                (5, 2, 3, 4),
                event_shape(make_bijector(rightmost_transposed_ndims=3),
                            'inverse'))
Beispiel #32
0
 def test_single_sharded_softmax_layer(self, soft_cap_logits, use_class_ids,
                                       use_class_probabilities,
                                       label_smoothing_prob):
     if use_class_ids:
         class_ids = np.random.randint(0, 50, [8, 10, 1])
     else:
         class_ids = None
     if use_class_probabilities:
         class_probabilities = np.random.normal(1.5, 2.0, [8, 10, 50])
     else:
         class_probabilities = None
     p = embedding_softmax.SingleShardFullSoftmax.Params().Set(
         name='jax_softmax',
         num_classes=50,
         input_dims=40,
         soft_cap_logits=soft_cap_logits,
         label_smoothing_prob=label_smoothing_prob)
     softmax_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=1234)
     initial_vars = softmax_layer.instantiate_variables(prng_key)
     npy_input = np.random.normal(1.5, 2.0, [8, 10, p.input_dims])
     inputs = jnp.asarray(npy_input)
     class_weights = np.random.normal(1.5, 2.0, [8, 10, 1])
     if class_probabilities is not None:
         class_probabilities /= np.sum(class_probabilities,
                                       axis=-1,
                                       keepdims=True)
     logits = test_utils.apply(softmax_layer, initial_vars,
                               softmax_layer.get_logits, inputs)
     outputs = test_utils.apply(softmax_layer,
                                initial_vars,
                                softmax_layer.fprop,
                                inputs,
                                class_weights,
                                class_ids=class_ids,
                                class_probabilities=class_probabilities)
     # Test whether tf Softmax layer returns same output
     # Modify initial_vars to use TF compatible params
     tf_initial_vars = initial_vars
     tf_initial_vars.linear = py_utils.NestedMap()
     tf_initial_vars.linear.w = initial_vars.logits_ffn.linear.w
     tf_initial_vars.bias = py_utils.NestedMap()
     tf_initial_vars.bias.b = initial_vars.logits_ffn.bias.b
     tf_p = lingvo_layers.SingleShardFullSoftmax.Params().Set(
         name='tf_softmax',
         num_classes=p.num_classes,
         input_dim=p.input_dims,
         logits_soft_max=soft_cap_logits)
     tf_softmax_layer = tf_p.Instantiate()
     tf_logits = tf_softmax_layer.Logits(
         tf_initial_vars, tf.constant(inputs, dtype=tf.float32))
     if use_class_ids and label_smoothing_prob > 0:
         class_probabilities = np.zeros([8, 10, 50])
         index = np.indices([8, 10])
         class_probabilities[index[0], index[1],
                             np.squeeze(class_ids, 2)] = 1
         class_probabilities = (
             class_probabilities * (1 - label_smoothing_prob) +
             (1 - class_probabilities) * label_smoothing_prob /
             (p.num_classes - 1))
         class_ids = None
     tf_output = tf_softmax_layer.FProp(
         tf_initial_vars,
         tf.constant(inputs, dtype=tf.float32),
         class_weights,
         class_ids=class_ids,
         class_probabilities=class_probabilities)
     # Check all entries in the NestedMap and ensure it matches TF
     np_get_logits = to_np(logits)
     tf_np_get_logits = to_np(tf_logits)
     self.assertAllClose(np_get_logits, tf_np_get_logits, atol=1e-6)
     # Note: The argmax-related values are very sensitive to numerical errors.
     for k in outputs.keys():
         self.assertAllClose(to_np(outputs[k]),
                             to_np(tf_output[k]),
                             atol=1e-6)
Beispiel #33
0
  def rejection_sample(concentration):
    """Gamma rejection sampler."""
    # Note, concentration here already has a shape that is broadcast with rate.
    cast_concentration = tf.cast(concentration, internal_dtype)

    good_params_mask = (concentration > 0.)
    # When replacing NaN values, use 100. for concentration, since that leads to
    # a high-likelihood of the rejection sampler accepting on the first pass.
    safe_concentration = tf.where(good_params_mask, cast_concentration, 100.)

    modified_safe_concentration = tf.where(
        safe_concentration < 1., safe_concentration + 1., safe_concentration)

    one_third = tf.constant(1. / 3, dtype=internal_dtype)
    d = modified_safe_concentration - one_third
    c = one_third * tf.math.rsqrt(d)

    def generate_and_test_samples(seed):
      """Generate and test samples."""
      v_seed, u_seed = samplers.split_seed(seed)

      def generate_positive_v():
        """Generate positive v."""
        def _inner(seed):
          x = samplers.normal(shape, dtype=internal_dtype, seed=seed)
          # This implicitly broadcasts concentration up to sample shape.
          v = 1 + c * x
          return (x, v), v > 0.

        # Note: It should be possible to remove this 'inner' call to
        # `batched_las_vegas_algorithm` and merge the v > 0 check into the
        # overall check for a good sample. This would lead to a slightly simpler
        # implementation; it is unclear whether it would be faster. We include
        # the inner loop so this implementation is more easily comparable to
        # Ref. [1] and other implementations.
        return brs.batched_las_vegas_algorithm(_inner, v_seed)[0]

      (x, v) = generate_positive_v()
      logv = tf.math.log1p(c * x)
      x2 = x * x
      v3 = v * v * v
      logv3 = logv * 3

      u = samplers.uniform(
          shape, dtype=internal_dtype, seed=u_seed)

      # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to
      # run the check below only if it fails, in order to avoid the relatively
      # expensive logarithm calls. Our algorithm operates in batch mode: we will
      # have to compute or not compute the logarithms for the entire batch, and
      # as the batch gets larger, the odds we compute it grow. Therefore we
      # don't bother with the "cheap" check.
      good_sample_mask = tf.math.log(u) < (x2 / 2. + d * (1 - v3 + logv3))

      return logv3 if log_space else v3, good_sample_mask

    samples = brs.batched_las_vegas_algorithm(
        generate_and_test_samples, seed=generate_and_test_samples_seed)[0]

    concentration_fix_unif = samplers.uniform(  # in [0, 1)
        shape, dtype=internal_dtype, seed=concentration_fix_seed)

    if log_space:
      concentration_lt_one_fix = tf.where(
          safe_concentration < 1.,
          # Why do we use log1p(-x)? x is in [0, 1) and log(0) = -inf, is bad.
          # x ~ U(0,1) => 1-x ~ U(0,1)
          # But at the boundary, 1-x in (0, 1]. Good.
          # So we can take log(unif(0,1)) safely as log(1-unif(0,1)).
          # log1p(-0) = 0, and log1p(-almost_one) = -not_quite_inf. Good.
          tf.math.log1p(-concentration_fix_unif) / safe_concentration,
          tf.zeros((), dtype=internal_dtype))
      samples = samples + tf.math.log(d) + concentration_lt_one_fix
    else:
      concentration_lt_one_fix = tf.where(
          safe_concentration < 1.,
          tf.math.pow(concentration_fix_unif,
                      tf.math.reciprocal(safe_concentration)),
          tf.ones((), dtype=internal_dtype))
      samples = samples * d * concentration_lt_one_fix

    samples = tf.where(good_params_mask, samples, np.nan)
    output_type_samples = tf.cast(samples, output_dtype)

    return output_type_samples
Beispiel #34
0
            if 'store_parameters_in_results' in new_kernel.parameters:
                self.assertTrue(
                    new_kernel.parameters['store_parameters_in_results'])

    def testNoParameters(self):
        kernel = FakeInnerNoParameters()
        new_kernel = util.enable_store_parameters_in_results(kernel)
        self.assertIs(kernel, new_kernel)


class TensorConvertible(object):
    pass


tf.register_tensor_conversion_function(
    TensorConvertible, conversion_func=lambda *args: tf.constant(0))


class SimpleTensorWarningTest(test_util.TestCase):

    # We must defer creating the TF objects until the body of the test.
    # pylint: disable=unnecessary-lambda
    @parameterized.parameters([lambda: tf.Variable(0)],
                              [lambda: tf.Variable(0)],
                              [lambda: TensorConvertible()])
    @test_util.disable_test_for_backend(disable_numpy=True,
                                        disable_jax=True,
                                        reason='Variable/DeferredTensor')
    def testWarn(self, tensor_callable):
        tensor = tensor_callable()
        warnings.simplefilter('always')
Beispiel #35
0
def _create_corr_matrix(rho, dtype):
    """Create correlation matrix with scalar `rho`."""
    one = tf.constant(1.0, dtype=dtype)
    m1 = tf.concat([one, rho], axis=0)
    m2 = tf.concat([rho, one], axis=0)
    return tf.stack([m1, m2])
Beispiel #36
0
 def testLogBetaDtype(self, dtype):
     x = tf.constant([1., 2.], dtype=dtype)
     y = tf.constant([3., 4.], dtype=dtype)
     result = tfp_math.lbeta(x, y)
     self.assertEqual(result.dtype, dtype)