コード例 #1
0
def main(argv):
  del argv

  root = tf.train.Checkpoint()
  # Create a cell and attach to our trackable.
  root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)

  # Wrap the rnn_cell.__call__ function and assign to next_state.
  root.next_state = tf.function(root.rnn_cell.__call__, autograph=False)

  # Wrap the rnn_cell.get_initial_function using a decorator and assign to an
  # attribute with the same name.
  @tf.function(input_signature=[tf.TensorSpec([None, None], tf.float32)])
  def get_initial_state(tensor):
    return root.rnn_cell.get_initial_state(tensor, None, None)

  root.get_initial_state = get_initial_state

  # Construct an initial_state, then call next_state explicitly to trigger a
  # trace for serialization (we need an explicit call, because next_state has
  # not been annotated with an input_signature).
  initial_state = root.get_initial_state(
      tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
  root.next_state(
      tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
      initial_state)

  tf.saved_model.save(root, FLAGS.export_dir)
コード例 #2
0
def wrap_keras_model_for_export(model, batch_input_shape,
                                set_hparams, default_hparams):
  """Wraps `model` for saving and loading as SavedModel."""
  if default_hparams is None: default_hparams = {}
  hparam_keys = list(default_hparams.keys())
  hparam_defaults = tuple(default_hparams.values())
  # The goal is to save a function with this argspec...
  argspec = tf_inspect.FullArgSpec(
      args=(['inputs', 'training'] + hparam_keys),
      defaults=((False,) + hparam_defaults),
      varargs=None, varkw=None,
      kwonlyargs=[], kwonlydefaults=None,
      annotations={})
  # ...and this behavior:
  def call_fn(inputs, training, *args):
    if FLAGS.export_print_hparams:
      args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s='
                                            % (training, hparam_keys[i]))
              for i in range(len(args))]
    kwargs = dict(zip(hparam_keys, args))
    if kwargs: set_hparams(model, **kwargs)
    return model(inputs, training=training)
  # We cannot spell out `args` in def statement for call_fn, but since
  # tf.function uses tf_inspect, we can use tf_decorator to wrap it with
  # the desired argspec.
  def wrapped(*args, **kwargs):  # TODO(arnoegw): Can we use call_fn itself?
    return call_fn(*args, **kwargs)
  traced_call_fn = tf.function(autograph=False)(
      tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
  # Now we need to trigger traces for
  # - training set to Python values True or False (hence two traces),
  # - tensor inputs of the expected nesting, shape and dtype,
  # - tensor-valued kwargs for hparams, with caller-side defaults.
  # Tracing with partially determined shapes requires an input signature,
  # so we initiate tracing from a helper function with only tensor inputs.
  @tf.function(autograph=False)
  def trigger_traces(inputs, **kwargs):
    return tuple(traced_call_fn(inputs, training=training, **kwargs)
                 for training in (True, False))
  inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32)
  hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value))
                  for name, value in default_hparams.items()}
  _ = trigger_traces.get_concrete_function(inputs_spec, **hparams_spec)

  # Assemble the output object.
  obj = tf.train.Checkpoint()
  obj.__call__ = traced_call_fn
  obj.trainable_variables = model.trainable_variables
  obj.variables = model.trainable_variables + model.non_trainable_variables
  obj.regularization_losses = [_get_traced_loss(model, i)
                               for i in range(len(model.losses))]
  return obj
コード例 #3
0
def _get_traced_loss(model, i):
  """Returns tf.function for model.losses[i] with a trace for zero args.

  The intended usage is
    [_get_traced_loss(model, i) for i in range(len(model.losses))]
  This is better than
    [tf.function(lambda: model.losses[i], input_signature=[]) for i ...]
  because it avoids capturing a loop index in a lambda, and removes any
  chance of deferring the trace.

  Args:
    model: a Keras Model.
    i: an integer between from 0 up to but to len(model.losses).
  """
  f = tf.function(lambda: model.losses[i])
  _ = f.get_concrete_function()
  return f
コード例 #4
0
def make_distributed_tensor(strategy, tensors):
    stacked = tf.stack(tensors, axis=0)
    fn = tf.function(lambda t: t[xla.replica_id()])
    return strategy.run(fn, args=(stacked, ))
コード例 #5
0
ファイル: trainers.py プロジェクト: zeeps31/ddsp
 def build(self, batch):
     """Build the model by running a distributed batch through it."""
     logging.info('Building the model...')
     _ = self.run(tf.function(self.model.__call__), batch)
     self.model.summary()
コード例 #6
0
def xla_compile(f):
    """Decorator for XLA compilation."""
    return tf.function(f, autograph=False, experimental_compile=True)
コード例 #7
0
 def test_creates_var_imbalanced_illegal(self, target, c, type_, exc_type,
                                         exc_regex):
     c = type_(c)
     with self.assertRaisesRegex(exc_type, exc_regex):
         tf.function(target)(c)
コード例 #8
0
ファイル: utils.py プロジェクト: jarotter/probability
def tfcompile(func=None,
              tf_function=True,
              xla_best_effort=True,
              xla_compile_all=False):
    """Centralizes TF compilation related options.

  Args:
    func: Python `callable` to wrapped with the specified TF compilation
      directives.
      Default value: `None`.
    tf_function: `bool` representing whether the resulting function should be
      `tf.function` decoreated.
      Default value: `True`.
    xla_best_effort: `bool` representing whether XLA auto-clustering compilation
      should be performed. (This argument is ignored if the function is executed
      eagerly.)
      Default value: `True`.
    xla_compile_all: `bool` representing whether XLA compilation should be
      performed. (This argument overrides both `tf_function` and
      `xla_best_effort`.
      Default value: `False`.

  Returns:
    wrapped_func: A Python `callable` with the specified compilation directives
      embedded.

  ### Example Usage

  ```python
  tfn = tfp.experimental.nn

  # Use style #1.
  @tfn.util.tfcompile(xla_compile_all=True)
  def foo(...):
       ...

  # Use style #2.
  def foo(...):
    ...
  foo = tfn.util.tfcompile(xla_compile_all=True)(foo)
  ```

  """
    # Note: xla_compile_all overrides both tf_function and xla_best_effort.
    tf_function = tf_function or xla_compile_all
    xla_best_effort = xla_best_effort and not xla_compile_all
    maybe_tf_function = (tf.function(autograph=False,
                                     experimental_compile=xla_compile_all)
                         if tf_function else _dummy_context())

    def decorator(f):
        @functools.wraps(f)
        @maybe_tf_function
        def wrapped(*args, **kwargs):
            maybe_xla_best_effort = (tf.xla.experimental.jit_scope(
                compile_ops=True) if not tf.executing_eagerly()
                                     and xla_best_effort else _dummy_context())
            with maybe_xla_best_effort:
                return f(*args, **kwargs)

        return wrapped

    if func is None:
        # This branch handles the following use case:
        #   @tfcompile(...)
        #   def foo(...):
        #      ...
        return decorator
    else:
        # This branch handles the following use case:
        #   foo = tfcompile(...)(foo)
        return decorator(func)
コード例 #9
0
 def test_nested(self):
     lop = AutoBlockDiag([AutoDiag(tf.ones([2]) * 2), AutoIdentity(1)])
     self.assertAllClose(
         tf.constant([6., 6, 3]),
         tf.function(lambda lop: lop.matvec(3. * tf.ones([3])))(lop))
コード例 #10
0
ファイル: test_util.py プロジェクト: iamyourboss/probability
def run_hmc_on_model(
    model,
    num_chains,
    num_steps,
    num_leapfrog_steps,
    step_size,
    target_accept_prob=0.9,
    seed=None,
    dtype=tf.float32,
    use_xla=False,
):
    """Runs HMC on a target.

  Args:
    model: The model to validate.
    num_chains: Number of chains to run in parallel.
    num_steps: Total number of steps to take. The first half are used to warm up
      the sampler.
    num_leapfrog_steps: Number of leapfrog steps to take.
    step_size: Step size to use.
    target_accept_prob: Target acceptance probability.
    seed: Optional seed to use. By default, `test_util.test_seed()` is used.
    dtype: DType to use for the algorithm.
    use_xla: Whether to use XLA.

  Returns:
    mcmc_results: `MCMCResults`.
  """
    step_size = tf.convert_to_tensor(step_size, dtype)

    def target_log_prob_fn(*x):
        x = tf.nest.pack_sequence_as(model.dtype, x)
        return model.unnormalized_log_prob(x)

    if seed is None:
        seed = test_util.test_seed()
    if tf.executing_eagerly():
        # TODO(b/141368747): HMC doesn't like you passing the seed in when in
        # eager mode.
        seed = None
    current_state = tf.nest.map_structure(
        lambda b, e: b(  # pylint: disable=g-long-lambda
            tf.zeros([num_chains] + list(e), dtype=dtype)),
        model.default_event_space_bijector,
        model.event_shape)

    # tfp.mcmc only works well with lists.
    current_state = tf.nest.flatten(current_state)

    hmc = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=target_log_prob_fn,
        num_leapfrog_steps=num_leapfrog_steps,
        step_size=[tf.fill(s.shape, step_size) for s in current_state],
        seed=seed)
    hmc = tfp.mcmc.TransformedTransitionKernel(
        hmc, tf.nest.flatten(model.default_event_space_bijector))
    hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(
        hmc,
        num_adaptation_steps=int(num_steps // 2 * 0.8),
        target_accept_prob=target_accept_prob)

    chain, is_accepted = tf.function(
        lambda: tfp.mcmc.sample_chain(  # pylint: disable=g-long-lambda
            current_state=current_state,
            kernel=hmc,
            num_results=num_steps // 2,
            num_burnin_steps=num_steps // 2,
            trace_fn=lambda _, pkr:  # pylint: disable=g-long-lambda
            (pkr.inner_results.inner_results.is_accepted)),
        autograph=False,
        experimental_compile=use_xla)()

    accept_rate = tf.reduce_mean(tf.cast(is_accepted, dtype))
    ess = tf.nest.map_structure(
        lambda c: tfp.mcmc.effective_sample_size(  # pylint: disable=g-long-lambda
            c,
            cross_chain_dims=1,
            filter_beyond_positive_pairs=True),
        chain)
    r_hat = tf.nest.map_structure(tfp.mcmc.potential_scale_reduction, chain)

    mcmc_results = MCMCResults(
        chain=tf.nest.pack_sequence_as(model.default_event_space_bijector,
                                       chain),
        accept_rate=accept_rate,
        ess=ess,
        r_hat=r_hat,
    )
    return mcmc_results
コード例 #11
0
    def testKernelResultsHaveCorrectThingsWhenExchangeAdjacentOnly(self):
        target = tfd.Normal(0., 1.)

        def make_kernel_fn(target_log_prob_fn, seed):
            return tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_log_prob_fn,
                seed=seed,
                step_size=0.3,
                num_leapfrog_steps=3)

        inverse_temperatures = [1., 0.99, 0.01]
        remc = tfp.mcmc.ReplicaExchangeMC(
            target_log_prob_fn=tf.function(target.log_prob, autograph=False),
            inverse_temperatures=inverse_temperatures,
            make_kernel_fn=make_kernel_fn,
            seed=_set_seed())

        num_results = 400
        num_replicas = len(inverse_temperatures)

        samples, kernel_results = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=target.sample(seed=_set_seed()),
            kernel=remc,
            num_burnin_steps=20,
            trace_fn=lambda _, results: results,
            parallel_iterations=1)  # For determinism.

        self.assertAllEqual((num_results, ), samples.shape)

        kernel_results = self.evaluate(kernel_results)

        # Boring checks of existence/shape.
        self.assertLen(kernel_results.replica_states, num_replicas)
        self.assertLen(kernel_results.replica_results, num_replicas)
        self.assertLen(kernel_results.sampled_replica_results, num_replicas)
        self.assertAllEqual((num_results, num_replicas - 1),
                            kernel_results.is_exchange_proposed.shape)
        self.assertAllEqual((num_results, num_replicas - 1),
                            kernel_results.is_exchange_accepted.shape)

        # Exciting checks of correctness!

        # Tests below assume these particular temperatures.
        self.assertAllEqual([1., 0.99, 0.01], inverse_temperatures)

        # Default exchange proposed function is to exchange every time.  There are
        # two exchanges possible, and they are mutually exclusive, so each is
        # proposed about 1/2 the time.
        self.assertAllClose(
            np.mean(kernel_results.is_exchange_proposed, axis=0),
            [0.5, 0.5],
            # Treating exchange proposals as Bernoulli(p=.5), the below tolerance is
            # set so that only about .1% of random test runs (with num_results set
            # to 400 above), the measured proposal rate will fall outside the
            # implied bounds, which are [0.418, 0.582].
            rtol=0.165)

        # P[ExchangeAccepted | ExchangeProposed]
        conditional_accept_prob = (
            np.sum(kernel_results.is_exchange_accepted, axis=0) /
            np.sum(kernel_results.is_exchange_proposed, axis=0))

        # The first exchange is between inverse temps 1 and 0.99, which are
        # basically the same distributions, so usually accepted.
        self.assertAllClose(0.99, conditional_accept_prob[0], rtol=0.1)

        # The second exchange is between 0.99 and 0.01, which are totally different,
        # so will mostly be rejected.
        self.assertLess(conditional_accept_prob[1], 0.5)
コード例 #12
0
    def testRWM2DMixNormal(self):
        """Sampling from a 2-D Mixture Normal Distribution."""
        dtype = np.float32

        # By symmetry, target has mean [0, 0]
        # Therefore, Var = E[X^2] = E[E[X^2 | c]], where c is the component.
        # Now..., for the first component,
        #   E[X1^2] =  Var[X1] + Mean[X1]^2
        #           =  0.3^2 + 1^2,
        # and similarly for the second.  As a result,
        # Var[mixture] = 1.09.
        target = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
            components_distribution=tfd.MultivariateNormalDiag(
                loc=[[-1., -1], [1., 1.]],
                scale_identity_multiplier=[0.3, 0.3]))

        inverse_temperatures = 10.**tf.linspace(0., -2., 4)
        step_sizes = tf.constant([0.3, 0.6, 1.2, 2.4])

        def make_kernel_fn(target_log_prob_fn, seed):
            kernel = tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_log_prob_fn,
                seed=seed,
                step_size=step_sizes[make_kernel_fn.idx],
                num_leapfrog_steps=2)
            make_kernel_fn.idx += 1
            return kernel

        # TODO(b/124770732): Remove this hack.
        make_kernel_fn.idx = 0

        remc = tfp.mcmc.ReplicaExchangeMC(
            target_log_prob_fn=tf.function(target.log_prob, autograph=False),
            # Verified that test fails if inverse_temperatures = [1.]
            inverse_temperatures=inverse_temperatures,
            make_kernel_fn=make_kernel_fn,
            seed=_set_seed())

        def _trace_log_accept_ratio(state, results):
            del state
            return [
                r.log_accept_ratio for r in results.sampled_replica_results
            ]

        num_results = 1000
        samples, log_accept_ratios = tfp.mcmc.sample_chain(
            num_results=num_results,
            # Start at one of the modes, in order to make mode jumping necessary
            # if we want to pass test.
            current_state=np.ones(2, dtype=dtype),
            kernel=remc,
            num_burnin_steps=500,
            trace_fn=_trace_log_accept_ratio,
            parallel_iterations=1)  # For determinism.
        self.assertAllEqual((num_results, 2), samples.shape)
        log_accept_ratios = [
            tf.reduce_mean(input_tensor=tf.exp(tf.minimum(0., lar)))
            for lar in log_accept_ratios
        ]

        sample_mean = tf.reduce_mean(input_tensor=samples, axis=0)
        sample_std = tf.sqrt(
            tf.reduce_mean(input_tensor=tf.math.squared_difference(
                samples, sample_mean),
                           axis=0))
        [sample_mean_, sample_std_, log_accept_ratios_
         ] = self.evaluate([sample_mean, sample_std, log_accept_ratios])
        tf1.logging.vlog(1, 'log_accept_ratios: %s  eager: %s',
                         log_accept_ratios_, tf.executing_eagerly())

        self.assertAllClose(sample_mean_, [0., 0.], atol=0.3, rtol=0.3)
        self.assertAllClose(sample_std_,
                            [np.sqrt(1.09), np.sqrt(1.09)],
                            atol=0.1,
                            rtol=0.1)
コード例 #13
0
def run_customized_training_loop(
    # pylint: disable=invalid-name
    _sentinel=None,
    # pylint: enable=invalid-name
    strategy=None,
    model_fn=None,
    loss_fn=None,
    model_dir=None,
    train_input_fn=None,
    steps_per_epoch=None,
    steps_per_loop=1,
    epochs=1,
    eval_input_fn=None,
    eval_steps=None,
    steps_between_eval=None,
    steps_before_eval_start=None,
    stop_threshold=None,
    metric_fn=None,
    init_checkpoint=None,
    custom_callbacks=None,
    run_eagerly=False,
    sub_model_export_name=None,
    explicit_allreduce=False,
    device_warmup=False,
    synthetic_train_input_fn=None,
    pre_allreduce_callbacks=None,
    post_allreduce_callbacks=None,
    allreduce_bytes_per_pack=0,
    enable_checkpoint_and_summary=False,
    num_accumulation_steps=1,
    stop_steps=None):
  """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      steps_between_eval: Number of steps between evals
      steps_before_eval_start: Number of steps to skip before starting eval
      stop_threshold: Stop threshold for MLPerf once accuracy achieved
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        methods are invoked during training.
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.
      sub_model_export_name: If not None, will export `sub_model` returned by
        `model_fn` into checkpoint files. The name of intermediate checkpoint
        file is {sub_model_export_name}_step_{step}.ckpt and the last
        checkpint's name is {sub_model_export_name}.ckpt;
        if None, `sub_model` will not be exported as checkpoint.
      explicit_allreduce: Whether to explicitly perform gradient allreduce,
        instead of relying on implicit allreduce in optimizer.apply_gradients().
        default is False. For now, if training using FP16 mixed precision,
        explicit allreduce will aggregate gradients in FP16 format. For TPU and
        GPU training using FP32, explicit allreduce will aggregate gradients in
        FP32 format.
      device_warmup: Whether or not to enable device warmup. This
        runs the training and eval loop on synthetic data to pre-compile XLA
        and TF tracing before accessing data.
      synthetic_train_input_fn: Function that returns synthetic training
        dataset. This is used in device warmup.
      pre_allreduce_callbacks: A list of callback functions that takes gradients
        and model variables pairs as input, manipulate them, and returns a new
        gradients and model variables paris. The callback functions will be
        invoked in the list order and before gradients are allreduced.
        Default is no callbacks. Only used when explicit_allreduce=True.
      post_allreduce_callbacks: A list of callback functions that takes
        gradients and model variables pairs as input, manipulate them, and
        returns a new gradients and model variables paris. The callback
        functions will be invoked in the list order and right before gradients
        are applied to variables for updates. Default is no callbacks. Only used
        when explicit_allreduce=True.
      allreduce_bytes_per_pack: A non-negative integer. Breaks collective
        operations into packs of certain size. If it's zero, all gradients are
        in one pack.
      enable_checkpoint_and_summary: Whether to save checkpoint and summary.
      stop_steps: The number of steps to run before stopping the training loop.

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
        (4) sub_model_checkpoint_name is specified, but `sub_model` returned
        by `model_fn` is None.
  """

  if _sentinel is not None:
    raise ValueError('only call `run_customized_training_loop()` '
                     'with named arguments.')

  required_arguments = [
      strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
  ]
  if [arg for arg in required_arguments if arg is None]:
    raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                     '`steps_per_loop` and `steps_per_epoch` are required '
                     'parameters.')

  if steps_between_eval % steps_per_loop != 0:
    raise ValueError('steps_between_eval should be multiple of steps_per_loop.')

  if steps_per_loop > steps_per_epoch:
    logging.error(
        'steps_per_loop: %d is specified to be greater than '
        ' steps_per_epoch: %d, we will use steps_per_epoch as'
        ' steps_per_loop.', steps_per_loop, steps_per_epoch)
    steps_per_loop = steps_per_epoch
  assert tf.executing_eagerly()

  if run_eagerly:
    if steps_per_loop > 1:
      raise ValueError(
          'steps_per_loop is used for performance optimization. When you want '
          'to run eagerly, you cannot leverage graph mode loop.')
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      raise ValueError(
          'TPUStrategy should not run eagerly as it heavily replies on graph'
          ' optimization for the distributed system.')

  if eval_input_fn and (eval_steps is None):
    raise ValueError(
        '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
        'is not none.')
  if device_warmup and (synthetic_train_input_fn is None):
    raise ValueError('`synthetic_train_input_fn` is required when '
                     'device_warmup is enabled.')

  if metric_fn and not callable(metric_fn):
    raise ValueError(
        'if `metric_fn` is specified, metric_fn must be a callable.')

  if stop_steps:
    total_training_steps = stop_steps
  else:
    total_training_steps = steps_per_epoch * epochs

  if stop_steps and stop_steps > steps_per_epoch * epochs:
    raise ValueError('`stop_steps` should not be greater than '
                     '`num_train_steps_per_epoch` * `num_epochs`.')

  # To reduce unnecessary send/receive input pipeline operation, we place input
  # pipeline ops in worker task.
  train_iterator = _get_input_iterator(train_input_fn, strategy)

  with distribution_utils.get_strategy_scope(strategy):
    # To correctly place the model weights on accelerators,
    # model and optimizer should be created in scope.
    model, sub_model, sub_pretrain_model = model_fn()
    if not hasattr(model, 'optimizer'):
      raise ValueError('User should set optimizer attribute to model '
                       'inside `model_fn`.')
    if sub_model_export_name and sub_model is None:
      raise ValueError('sub_model_export_name is specified as %s, but '
                       'sub_model is None.' % sub_model_export_name)

    optimizer = model.optimizer

    train_loss_metric = tf.keras.metrics.Mean(
        'training_loss', dtype=tf.float32)
    if eval_input_fn:
      eval_metric_num = tf.keras.metrics.Sum('masked_lm_num', dtype=tf.float32)
      eval_metric_denom = tf.keras.metrics.Sum(
          'masked_lm_denom', dtype=tf.float32)

    # If evaluation is required, make a copy of metric as it will be used by
    # both train and evaluation.
    train_metrics = [
        tf.keras.metrics.Mean('masked_lm_accuracy', dtype=tf.float32)
    ]

    # Create summary writers
    summary_dir = os.path.join(model_dir, 'summaries')
    if enable_checkpoint_and_summary:
      eval_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'eval'))
    else:
      eval_summary_writer = tf.summary.create_noop_writer()
    if steps_per_loop >= _MIN_SUMMARY_STEPS and enable_checkpoint_and_summary:
      # Only writes summary when the stats are collected sufficiently over
      # enough steps.
      train_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'train'))
    else:
      train_summary_writer = tf.summary.create_noop_writer()

    # Collects training variables.
    training_vars = model.trainable_variables

    @tf.function(experimental_compile=True)
    def _compiled_local_step(inputs, labels, training_vars, accum_vars):
      """Replicated training step."""
      with tf.GradientTape() as tape:
        model_outputs, metric_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
      if isinstance(optimizer,
                    tf.keras.mixed_precision.experimental.LossScaleOptimizer):
        with tape:
          scaled_loss = optimizer.get_scaled_loss(loss)
        scaled_grads = tape.gradient(scaled_loss, training_vars)
        grads = optimizer.get_unscaled_gradients(scaled_grads)
      else:
        grads = tape.gradient(loss, training_vars)
      (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

      if accum_vars is None:
        return grads, loss, model_outputs, metric_outputs
      else:
        new_accum_vars = []
        for i, grad in enumerate(grads):
          new_accum_vars.append(
              accum_vars[i] +
              tf.math.scalar_mul(1.0 / num_accumulation_steps, grad))
        return new_accum_vars, loss, model_outputs, metric_outputs

    def get_input_slice(input_dict, idx):
      split_input = {}
      for key in input_dict:
        split_input[key] = input_dict[key][idx]
      return split_input

    def _replicated_step(inputs):
      """Replicated training step."""
      inputs, labels = inputs
      if explicit_allreduce:
        # TODO(b/155523821): Fix OOM issue so we use experimental_compile with
        # multi-worker mirrored strategy.
        with tf.GradientTape() as tape:
          model_outputs, metric_outputs = model(inputs, training=True)
          loss = loss_fn(labels, model_outputs)

        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     training_vars,
                                                     pre_allreduce_callbacks,
                                                     post_allreduce_callbacks,
                                                     allreduce_bytes_per_pack)
      else:
        if num_accumulation_steps > 1:
          accum_vars = [
              tf.zeros_like(tvar, dtype=tf.float32) for tvar in training_vars
          ]
          for key in inputs:
            inputs[key] = tf.split(inputs[key], num_accumulation_steps)

          split_labels = tf.split(labels, num_accumulation_steps)
          for local_step in range(num_accumulation_steps):
            accum_vars, loss, model_outputs, metric_outputs = _compiled_local_step(
                get_input_slice(inputs, local_step), split_labels[local_step],
                training_vars, accum_vars)

          optimizer.apply_gradients(zip(accum_vars, training_vars))
        else:
          grads, loss, model_outputs, metric_outputs = _compiled_local_step(
              inputs, labels, training_vars, None)
          optimizer.apply_gradients(zip(grads, training_vars))
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(loss)
      for metric in train_metrics:
        metric.update_state(metric_outputs['masked_lm_accuracy'])

    @tf.function
    def train_steps(iterator, steps):
      """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      if not isinstance(steps, tf.Tensor):
        raise ValueError('steps should be an Tensor. Python object may cause '
                         'retracing.')

      for _ in tf.range(steps):
        strategy.run(_replicated_step, args=(next(iterator),))

    def train_single_step(iterator):
      """Performs a distributed training step.

      Args:
        iterator: the distributed iterator of training datasets.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      strategy.run(_replicated_step, args=(next(iterator),))

    def test_step(iterator):
      """Calculates evaluation metrics on distributed devices."""

      def _test_step_fn(inputs):
        """Replicated accuracy calculation."""

        inputs, labels = inputs
        model_outputs, metric_outputs = model(inputs, training=False)
        eval_metric_num.update_state(metric_outputs['masked_lm_num'])
        eval_metric_denom.update_state(metric_outputs['masked_lm_denom'])
      strategy.run(_test_step_fn, args=(next(iterator),))

    if not run_eagerly:
      train_single_step = tf.function(train_single_step)
      test_step = tf.function(test_step)

    def _run_evaluation(current_training_step, test_iterator):
      """Runs validation steps and aggregate metrics."""
      mlperf_epoch_num = int(current_training_step / steps_between_eval)
      mlp_log.mlperf_print(
          'eval_start', None, metadata={'epoch_num': mlperf_epoch_num})
      for _ in range(eval_steps):
        test_step(test_iterator)
      mlp_log.mlperf_print(
          'eval_stop', None, metadata={'epoch_num': mlperf_epoch_num})

      with eval_summary_writer.as_default():
        masked_lm_accuracy = (
            _float_metric_value(eval_metric_num) /
            _float_metric_value(eval_metric_denom))
        logging.info('Step: [%d] Validation %s = %f', current_training_step,
                     'masked_lm_accuracy', masked_lm_accuracy)
        tf.summary.scalar(
            'masked_lm_accuracy',
            masked_lm_accuracy,
            step=current_training_step)
        mlp_log.mlperf_print(
            'eval_accuracy',
            masked_lm_accuracy,
            metadata={'epoch_num': mlperf_epoch_num})
        eval_summary_writer.flush()
      return masked_lm_accuracy

    def _run_callbacks_on_batch_begin(batch):
      """Runs custom callbacks at the start of every step."""
      # While BERT pretraining does not have epochs,
      # to make the logging consistent with other mlperf models,
      # in all the mlp_log, epochs are steps.
      mlp_log.mlperf_print(
          'block_start',
          None,
          metadata={
              'first_epoch_num': int(batch),
              'epoch_count': int(steps_per_loop),
          })
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_begin(batch)

    def _run_callbacks_on_batch_end(batch, logs):
      """Runs custom callbacks at the end of every step."""
      mlp_log.mlperf_print(
          'block_stop', None, metadata={
              'first_epoch_num': int(batch),
          })
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_end(batch, logs)

    # Training loop starts here.
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    sub_model_checkpoint = tf.train.Checkpoint(
        model=sub_model) if sub_model_export_name else None

    # TODO: commenting this out, as we always load from a initial checkpoint
    # latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    # if latest_checkpoint_file:
    #   logging.info(
    #       'Checkpoint file %s found and restoring from '
    #       'checkpoint', latest_checkpoint_file)
    #   checkpoint.restore(latest_checkpoint_file)
    #   logging.info('Loading from checkpoint file completed')

    current_step = optimizer.iterations.numpy()
    checkpoint_name = 'ctl_step_{step}.ckpt'
    checkpoint_save_dir = model_dir if enable_checkpoint_and_summary else None

    if init_checkpoint:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'initial checkpoint for core model.', init_checkpoint)
      checkpoint = tf.train.Checkpoint(model=sub_pretrain_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
      logging.info('Loading from checkpoint file completed')

    if device_warmup:
      synthetic_train_iterator = _get_input_iterator(synthetic_train_input_fn,
                                                     strategy)
      logging.info('Running device warmup for 1 step.')
      train_steps(synthetic_train_iterator, tf.constant(1, dtype=tf.int32))
      # Reset the global step.
      tf.keras.backend.set_value(optimizer.iterations, 0)
      current_step = optimizer.iterations.numpy()

    masked_lm_accuracy = 0
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)

    while current_step < total_training_steps:
      # Training loss/metric are taking average over steps inside micro
      # training loop. We reset the their values before each round.
      train_loss_metric.reset_states()
      for metric in train_metrics + model.metrics:
        metric.reset_states()

      _run_callbacks_on_batch_begin(current_step)
      # Runs several steps in the host while loop.
      steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

      train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
      train_loss = _float_metric_value(train_loss_metric)
      _run_callbacks_on_batch_end(current_step, {'loss': train_loss})
      current_step += steps

      # Updates training logging.
      training_status = 'Train Step: %d/%d  / loss = %s' % (
          current_step, total_training_steps, train_loss)

      with train_summary_writer.as_default():
        tf.summary.scalar(
            train_loss_metric.name, train_loss, step=current_step)
        for metric in train_metrics + model.metrics:
          metric_value = _float_metric_value(metric)
          training_status += '  %s = %f' % (metric.name, metric_value)
          tf.summary.scalar(metric.name, metric_value, step=current_step)
        train_summary_writer.flush()
      logging.info(training_status)

      # Saves model checkpoints and run validation steps at every epoch end.
      if current_step % steps_per_epoch == 0:
        # To avoid repeated model saving, we do not save after the last
        # step of training.
        if current_step < total_training_steps:
          _save_checkpoint(checkpoint, checkpoint_save_dir,
                           checkpoint_name.format(step=current_step))
          if sub_model_export_name:
            _save_checkpoint(
                sub_model_checkpoint, checkpoint_save_dir,
                '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
      if eval_input_fn and (current_step % (steps_between_eval) == 0) and (
          current_step >= steps_before_eval_start):
        logging.info('Running evaluation after step: %s.', current_step)
        masked_lm_accuracy = _run_evaluation(
            current_step, _get_input_iterator(eval_input_fn, strategy))
        if masked_lm_accuracy >= stop_threshold:
          mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
          break

        # Re-initialize evaluation metric.
        eval_metric_num.reset_states()
        eval_metric_denom.reset_states()

    if masked_lm_accuracy < stop_threshold:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'aborted'})

    _save_checkpoint(checkpoint, checkpoint_save_dir,
                     checkpoint_name.format(step=current_step))
    if sub_model_export_name:
      _save_checkpoint(sub_model_checkpoint, checkpoint_save_dir,
                       '%s.ckpt' % sub_model_export_name)

    if enable_checkpoint_and_summary:
      training_summary = {
          'total_training_steps': total_training_steps,
          'train_loss': _float_metric_value(train_loss_metric),
      }
      if train_metrics:
        # TODO(hongkuny): Cleans up summary reporting in text.
        training_summary['last_train_metrics'] = _float_metric_value(
            train_metrics[0])
        #training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])

      write_txt_summary(training_summary, summary_dir)

    return model, masked_lm_accuracy, current_step
コード例 #14
0
def get_ess(samples):
    return tf.function(functools.partial(
        tfp.mcmc.effective_sample_size,
        filter_beyond_positive_pairs=True,
    ),
                       autograph=False)(samples).numpy()
コード例 #15
0
    def test_forecasts_match_reference(self, use_slope):
        seed = test_util.test_seed()
        num_observed_steps = 5
        num_forecast_steps = 4
        model, observed_time_series, is_missing = self._build_test_model(
            num_timesteps=num_observed_steps + num_forecast_steps,
            true_slope_scale=0.5 if use_slope else None,
            batch_shape=[3])

        samples = tf.function(lambda: gibbs_sampler.fit_with_gibbs_sampling(  # pylint: disable=g-long-lambda
            model,
            tfp.sts.MaskedTimeSeries(
                observed_time_series[..., :num_observed_steps, tf.newaxis],
                is_missing[..., :num_observed_steps]),
            num_results=10000,
            num_warmup_steps=100,
            seed=seed))()
        predictive_dist = gibbs_sampler.one_step_predictive(
            model,
            samples,
            num_forecast_steps=num_forecast_steps,
            thin_every=1)
        predictive_mean, predictive_stddev = self.evaluate(
            (predictive_dist.mean(), predictive_dist.stddev()))
        self.assertAllEqual(predictive_mean.shape,
                            [3, num_observed_steps + num_forecast_steps])
        self.assertAllEqual(predictive_stddev.shape,
                            [3, num_observed_steps + num_forecast_steps])

        if use_slope:
            parameter_samples = (samples.observation_noise_scale,
                                 samples.level_scale, samples.slope_scale,
                                 samples.weights)
        else:
            parameter_samples = (samples.observation_noise_scale,
                                 samples.level_scale, samples.weights)

        # Note that although we expect the Gibbs-sampled forecasts to match a
        # reference implementation, we *don't* expect the one-step predictions to
        # match `tfp.sts.one_step_predictive`, because that makes predictions using
        # a filtered posterior (i.e., given only previous observations) whereas the
        # Gibbs-sampled latent `level`s will incorporate some information from
        # future observations.
        reference_forecast_dist = tfp.sts.forecast(
            model,
            observed_time_series=observed_time_series[
                ..., :num_observed_steps],
            parameter_samples=parameter_samples,
            num_steps_forecast=num_forecast_steps)

        reference_forecast_mean = self.evaluate(
            reference_forecast_dist.mean()[..., 0])
        reference_forecast_stddev = self.evaluate(
            reference_forecast_dist.stddev()[..., 0])

        self.assertAllClose(predictive_mean[..., -num_forecast_steps:],
                            reference_forecast_mean,
                            atol=1.0 if use_slope else 0.3)
        self.assertAllClose(predictive_stddev[..., -num_forecast_steps:],
                            reference_forecast_stddev,
                            atol=2.0 if use_slope else 1.0)
コード例 #16
0
ファイル: use_mnist_cnn.py プロジェクト: perfmjs/tensorflow
 def _scale_one_loss(l):  # Separate def avoids lambda capture of loop var.
   f = tf.function(lambda: tf.multiply(multiplier, l()))
   _ = f.get_concrete_function()
   return f
コード例 #17
0
def run_ncf_custom_training(params,
                            strategy,
                            keras_model,
                            optimizer,
                            callbacks,
                            train_input_dataset,
                            eval_input_dataset,
                            num_train_steps,
                            num_eval_steps,
                            generate_input_online=True):
    """Runs custom training loop.

  Args:
    params: Dictionary containing training parameters.
    strategy: Distribution strategy to be used for distributed training.
    keras_model: Model used for training.
    optimizer: Optimizer used for training.
    callbacks: Callbacks to be invoked between batches/epochs.
    train_input_dataset: tf.data.Dataset used for training.
    eval_input_dataset: tf.data.Dataset used for evaluation.
    num_train_steps: Total number of steps to run for training.
    num_eval_steps: Total number of steps to run for evaluation.
    generate_input_online: Whether input data was generated by data producer.
      When data is generated by data producer, then train dataset must be
      re-initialized after every epoch.

  Returns:
    A tuple of train loss and a list of training and evaluation results.
  """
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction="sum", from_logits=True)
    train_input_iterator = iter(
        strategy.experimental_distribute_dataset(train_input_dataset))

    def train_step(train_iterator):
        """Called once per step to train the model."""
        def step_fn(features):
            """Computes loss and applied gradient per replica."""
            with tf.GradientTape() as tape:
                softmax_logits = keras_model(features)
                # The loss can overflow in float16, so we cast to float32.
                softmax_logits = tf.cast(softmax_logits, "float32")
                labels = features[rconst.TRAIN_LABEL_KEY]
                loss = loss_object(
                    labels,
                    softmax_logits,
                    sample_weight=features[rconst.VALID_POINT_MASK])
                loss *= (1.0 / params["batch_size"])
                if FLAGS.dtype == "fp16":
                    loss = optimizer.get_scaled_loss(loss)

            grads = tape.gradient(loss, keras_model.trainable_variables)
            if FLAGS.dtype == "fp16":
                grads = optimizer.get_unscaled_gradients(grads)
            # Converting gradients to dense form helps in perf on GPU for NCF
            grads = neumf_model.sparse_to_dense_grads(
                list(zip(grads, keras_model.trainable_variables)))
            optimizer.apply_gradients(grads)
            return loss

        per_replica_losses = strategy.run(step_fn,
                                          args=(next(train_iterator), ))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                    per_replica_losses,
                                    axis=None)
        return mean_loss

    def eval_step(eval_iterator):
        """Called once per eval step to compute eval metrics."""
        def step_fn(features):
            """Computes eval metrics per replica."""
            softmax_logits = keras_model(features)
            in_top_k, metric_weights = metric_fn(
                softmax_logits, features[rconst.DUPLICATE_MASK], params)
            hr_sum = tf.reduce_sum(in_top_k * metric_weights)
            hr_count = tf.reduce_sum(metric_weights)
            return hr_sum, hr_count

        per_replica_hr_sum, per_replica_hr_count = (strategy.run(
            step_fn, args=(next(eval_iterator), )))
        hr_sum = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                 per_replica_hr_sum,
                                 axis=None)
        hr_count = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                   per_replica_hr_count,
                                   axis=None)
        return hr_sum, hr_count

    if not FLAGS.run_eagerly:
        train_step = tf.function(train_step)
        eval_step = tf.function(eval_step)

    for callback in callbacks:
        callback.on_train_begin()

    # Not writing tensorboard summaries if running in MLPerf.
    if FLAGS.ml_perf:
        eval_summary_writer, train_summary_writer = None, None
    else:
        summary_dir = os.path.join(FLAGS.model_dir, "summaries")
        eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, "eval"))
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, "train"))

    train_loss = 0
    for epoch in range(FLAGS.train_epochs):
        for cb in callbacks:
            cb.on_epoch_begin(epoch)

        # As NCF dataset is sampled with randomness, not repeating
        # data elements in each epoch has significant impact on
        # convergence. As so, offline-generated TF record files
        # contains all epoch worth of data. Thus we do not need
        # to initialize dataset when reading from tf record files.
        if generate_input_online:
            train_input_iterator = iter(
                strategy.experimental_distribute_dataset(train_input_dataset))

        train_loss = 0
        for step in range(num_train_steps):
            current_step = step + epoch * num_train_steps
            for c in callbacks:
                c.on_batch_begin(current_step)

            train_loss += train_step(train_input_iterator)

            # Write train loss once in every 1000 steps.
            if train_summary_writer and step % 1000 == 0:
                with train_summary_writer.as_default():
                    tf.summary.scalar("training_loss",
                                      train_loss / (step + 1),
                                      step=current_step)

            for c in callbacks:
                c.on_batch_end(current_step)

        train_loss /= num_train_steps
        logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1,
                     train_loss)

        eval_input_iterator = iter(
            strategy.experimental_distribute_dataset(eval_input_dataset))
        hr_sum = 0
        hr_count = 0
        for _ in range(num_eval_steps):
            step_hr_sum, step_hr_count = eval_step(eval_input_iterator)
            hr_sum += step_hr_sum
            hr_count += step_hr_count

        logging.info("Done eval epoch %s, hit_rate=%s.", epoch + 1,
                     hr_sum / hr_count)
        if eval_summary_writer:
            with eval_summary_writer.as_default():
                tf.summary.scalar("hit_rate",
                                  hr_sum / hr_count,
                                  step=current_step)

        if (FLAGS.early_stopping
                and float(hr_sum / hr_count) > params["hr_threshold"]):
            break

    for c in callbacks:
        c.on_train_end()

    # Saving the model at the end of training.
    if not FLAGS.ml_perf:
        checkpoint = tf.train.Checkpoint(model=keras_model,
                                         optimizer=optimizer)
        checkpoint_path = os.path.join(FLAGS.model_dir, "ctl_checkpoint")
        checkpoint.save(checkpoint_path)
        logging.info("Saving model as TF checkpoint: %s", checkpoint_path)

    return train_loss, [None, hr_sum / hr_count]
コード例 #18
0
def trace(
    state: State,
    fn: TransitionOperator,
    num_steps: IntTensor,
    trace_fn: Callable[[State, TensorNest], TensorNest],
    parallel_iterations: int = 10,
) -> Tuple[State, TensorNest]:
    """`TransitionOperator` that runs `fn` repeatedly and traces its outputs.

  Args:
    state: A nest of `Tensor`s or None.
    fn: A `TransitionOperator`.
    num_steps: Number of steps to run the function for. Must be greater than 1.
    trace_fn: Callable that the unpacked outputs of `fn` and returns a nest of
      `Tensor`s. These will be stacked and returned.
    parallel_iterations: Number of iterations of the while loop to run in
      parallel.

  Returns:
    state: The final state returned by `fn`.
    traces: Stacked outputs of `trace_fn`.
  """
    state = tf.nest.map_structure(
        lambda t: t if t is None else tf.convert_to_tensor(t), state)

    def wrapper(state):
        state, extra = tf.nest.map_structure(tf.convert_to_tensor,
                                             call_fn(fn, state))
        trace_element = tf.nest.map_structure(tf.convert_to_tensor,
                                              trace_fn(state, extra))
        return state, trace_element

    if any(e is None
           for e in tf.nest.flatten(state)) or tf.executing_eagerly():
        state, first_trace = wrapper(state)
        trace_arrays = tf.nest.map_structure(
            lambda v: tf.TensorArray(  # pylint: disable=g-long-lambda
                v.dtype,
                size=num_steps,
                element_shape=v.shape).write(0, v),
            first_trace)
        start_idx = 1
    else:
        state_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)
        # We need the shapes and dtypes of the outputs of `wrapper` function to
        # create the `TensorArray`s, we can get it by pre-compiling the wrapper
        # function.
        wrapper = tf.function(autograph=False)(wrapper)
        concrete_wrapper = wrapper.get_concrete_function(state_spec)
        _, trace_dtypes = concrete_wrapper.output_dtypes
        _, trace_shapes = concrete_wrapper.output_shapes
        trace_arrays = tf.nest.map_structure(
            lambda dtype, shape: tf.TensorArray(  # pylint: disable=g-long-lambda
                dtype,
                size=num_steps,
                element_shape=shape),
            trace_dtypes,
            trace_shapes)
        wrapper = lambda state: concrete_wrapper(*tf.nest.flatten(state))
        start_idx = 0

    def body(i, state, trace_arrays):
        state, trace_element = wrapper(state)
        trace_arrays = tf.nest.map_structure(lambda a, v: a.write(i, v),
                                             trace_arrays, trace_element)
        return i + 1, state, trace_arrays

    def cond(i, *_):
        return i < num_steps

    _, state, trace_arrays = tf.while_loop(
        cond=cond,
        body=body,
        loop_vars=(start_idx, state, trace_arrays),
        parallel_iterations=parallel_iterations)

    stacked_trace = tf.nest.map_structure(lambda x: x.stack(), trace_arrays)

    static_length = tf.get_static_value(num_steps)

    def _merge_static_length(x):
        x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
        return x

    stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace)

    return state, stacked_trace
コード例 #19
0
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tf2xla_pb2

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

def f(a, b):
  return jax.lax.add(a, b).sum()

f = tf.function(jax2tf.convert(f))
a = b = tf.ones([1, 1])
cf = f.get_concrete_function(a, b)

graph_def = cf.graph.as_graph_def()
with open('graph.pb', 'wb') as fp:
  fp.write(graph_def.SerializeToString())

config = tf2xla_pb2.Config()
batch_size = 1

feeds = [o.name for o in cf.graph.get_operations() if o.name.startswith('jax2tf_arg')]
fetches = [o.name for o in cf.graph.get_operations() if o.name.startswith('jax2tf_out')]

for idx, x in enumerate(cf.inputs):
	x.set_shape([batch_size] + list(x.shape)[1:])
	feed = config.feed.add()
	feed.id.node_name = feeds[idx]
	feed.shape.MergeFrom(x.shape.as_proto())
コード例 #20
0
    def testCompositeTensor(self, bijector_name, data):

        bijector, event_dim = self._draw_bijector(
            bijector_name,
            data,
            batch_shape=[],
            validate_args=True,
            allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
                               set(COMPOSITE_TENSOR_IS_BROKEN)))

        # TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector
        # when AutoCT is enabled for meta-bijectors and LinearOperator.
        if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN:
            composite_bij = experimental.as_composite(bijector)
        else:
            composite_bij = bijector

        if not tf.executing_eagerly():
            composite_bij = tf.nest.map_structure(
                lambda x: (
                    tf.convert_to_tensor(x)  # pylint: disable=g-long-lambda
                    if isinstance(x, DeferredTensor) else x),
                composite_bij,
                expand_composites=True)

        self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor)
        flat = tf.nest.flatten(composite_bij, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(composite_bij,
                                          flat,
                                          expand_composites=True)

        # Compare forward maps before and after compositing.
        n = 3
        xs = self._draw_domain_tensor(bijector,
                                      data,
                                      event_dim,
                                      sample_shape=[n])
        before_ys = bijector.forward(xs)
        after_ys = unflat.forward(xs)
        self.assertAllClose(*self.evaluate((before_ys, after_ys)))

        # Compare inverse maps before and after compositing.
        ys = self._draw_codomain_tensor(bijector,
                                        data,
                                        event_dim,
                                        sample_shape=[n])
        before_xs = bijector.inverse(ys)
        after_xs = unflat.inverse(ys)
        self.assertAllClose(*self.evaluate((before_xs, after_xs)))

        # Input to tf.function
        self.assertAllClose(
            before_ys,
            tf.function(lambda b: b.forward(xs))(composite_bij),
            rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
            atol=COMPOSITE_TENSOR_ATOL[bijector_name])

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in composite_bij.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            tape.watch(wrt_vars)
            # TODO(b/73073515): Fix graph mode gradients with bijector caching.
            ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)
コード例 #21
0
 def test_function(self):
     lop = AutoDiag(2. * tf.ones([3]))
     self.assertAllClose(
         6. * tf.ones([3]),
         tf.function(lambda lop: lop.matvec(3. * tf.ones([3])))(lop))
コード例 #22
0
ファイル: sequential_test.py プロジェクト: paolodedios/keras
 def __init__(self, name=None):
     super().__init__(name=name)
     self.call = tf.function(self.call)
コード例 #23
0
    def testCompositeTensor(self, bijector_name, data):

        bijector, event_dim = self._draw_bijector(
            bijector_name,
            data,
            batch_shape=[],
            validate_args=True,
            allowed_bijectors=(set(bhps.INSTANTIABLE_BIJECTORS) -
                               set(COMPOSITE_TENSOR_IS_BROKEN)))

        if type(bijector) is invert_lib._Invert:  # pylint: disable=unidiomatic-typecheck
            if isinstance(bijector.bijector, tf.__internal__.CompositeTensor):
                raise TypeError(
                    '`_Invert` should wrap only non-`CompositeTensor` '
                    'bijectors.')
            self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')

        if not tf.executing_eagerly():
            bijector = tf.nest.map_structure(
                lambda x: (
                    tf.convert_to_tensor(x)  # pylint: disable=g-long-lambda
                    if isinstance(x, DeferredTensor) else x),
                bijector,
                expand_composites=True)

        self.assertIsInstance(bijector, tf.__internal__.CompositeTensor)
        flat = tf.nest.flatten(bijector, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(bijector,
                                          flat,
                                          expand_composites=True)

        # Compare forward maps before and after compositing.
        n = 3
        xs = self._draw_domain_tensor(bijector,
                                      data,
                                      event_dim,
                                      sample_shape=[n])
        before_ys = bijector.forward(xs)
        after_ys = unflat.forward(xs)
        self.assertAllClose(*self.evaluate((before_ys, after_ys)))

        # Compare inverse maps before and after compositing.
        ys = self._draw_codomain_tensor(bijector,
                                        data,
                                        event_dim,
                                        sample_shape=[n])
        before_xs = bijector.inverse(ys)
        after_xs = unflat.inverse(ys)
        self.assertAllClose(*self.evaluate((before_xs, after_xs)))

        # Input to tf.function
        self.assertAllClose(before_ys,
                            tf.function(lambda b: b.forward(xs))(bijector),
                            rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
                            atol=COMPOSITE_TENSOR_ATOL[bijector_name])

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            tape.watch(wrt_vars)
            # TODO(b/73073515): Fix graph mode gradients with bijector caching.
            ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)
コード例 #24
0
 def __init__(self, name=None):
   super(MySequential, self).__init__(name=name)
   self.call = tf.function(self.call)
コード例 #25
0
 def test_returns_none_illegal(self, target, c):
     c = tf.constant(c)
     with self.assertRaisesRegex(ValueError, '"i" is None'):
         tf.function(target)(c)
コード例 #26
0
ファイル: loss_scale_benchmark.py プロジェクト: ttigong/keras
    def _benchmark(self, gradient_type, num_gpus, mode, loss_scaling):
        """Benchmarks loss scaling.

    We run a simple model with several scalar variables. The loss is the sum of
    all variables. The model is simple because we want to measure only the
    performance of loss scaling, not the performance of the model itself.

    Args:
      gradient_type: "optimizer" or "gradient_tape". How gradients are computed.
        "optimizer" uses Optimizer.minimize. "gradient_tape" uses
        GradientTape.gradient along with LossScaleOptimizer.get_scaled_loss and
        LossScaleOptimizer.get_unscaled_gradients.
      num_gpus: The number of GPUs to use. Must be at least 1.
      mode: "eager" or "tf_function". "tf_function" causes all computations to
        be wrapped in a tf.function, while "eager" runs computations eagerly.
      loss_scaling: "fixed", "dynamic", or None. The type of loss scaling to
        use. None means use no loss scaling, which is useful as a baseline to
        see how much slower loss scaling is in comparison.
    """
        ls_str = loss_scaling or 'no_loss_scaling'
        name = '%s_%d_GPU_%s_%s' % (gradient_type, num_gpus, mode, ls_str)
        with tf.__internal__.eager_context.eager_mode(), _get_strategy(
                num_gpus).scope() as strategy:
            opt = adam.Adam()
            if loss_scaling == 'fixed':
                loss_scale = tf.mixed_precision.experimental.FixedLossScale(2.)
            elif loss_scaling == 'dynamic':
                # Make increment_period so high that it's effectively infinite. This
                # means the loss scale will never change. Any performance overhead
                # from increasing/decreasing the loss scale is typically negligible
                # since it happens infrequently, so we only benchmark the common case
                # of the loss scale not changing.
                increment_period = 1000000
                loss_scale = tf.mixed_precision.experimental.DynamicLossScale(
                    initial_loss_scale=2., increment_period=increment_period)
            else:
                assert loss_scaling is None
                loss_scale = None
            if loss_scale:
                opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)

            num_vars = 200
            num_warmup_iters = 1
            num_iters = 20
            # By using scalar variables, we reduce overhead of the actual GPU work of
            # multiplying variables, dividing gradients, and checking gradients for
            # NaNs. Measuring these overheads isn't very useful as there is little we
            # can do to reduce them (one such way would be to fuse dividing gradients
            # and checking them for NaNs). We still have all other overheads, such as
            # all-reducing the `is_finite` values and having a tf.cond or
            # tf.while_loop based on whether gradients are NaNs. Currently, these
            # other overheads are much more significant than the GPU work.
            var_list = [
                tf.Variable(i, dtype='float32') for i in range(num_vars)
            ]

            def get_loss():
                return tf.add_n(var_list)

            if gradient_type == 'gradient_tape':
                if loss_scale is None:

                    def minimize_fn():
                        with tf.GradientTape() as tape:
                            loss = get_loss()
                        grads = tape.gradient(loss, var_list)
                        return opt.apply_gradients(zip(grads, var_list))
                else:

                    def minimize_fn():
                        with tf.GradientTape() as tape:
                            loss = get_loss()
                            scaled_loss = opt.get_scaled_loss(loss)
                        scaled_grads = tape.gradient(scaled_loss, var_list)
                        grads = opt.get_unscaled_gradients(scaled_grads)
                        return opt.apply_gradients(zip(grads, var_list))
            else:
                assert gradient_type == 'optimizer'

                def minimize_fn():
                    return opt.minimize(get_loss, var_list)

            def run_fn():
                strategy.run(minimize_fn)

            if mode == 'tf_function':
                run_fn = tf.function(run_fn)

            for _ in range(num_warmup_iters):
                run_fn()

            start = time.time()
            for _ in range(num_iters):
                run_fn()
            end = time.time()
            self.report_benchmark(iters=num_iters,
                                  wall_time=(end - start) / num_iters,
                                  name=name)
コード例 #27
0
def train():
  """Trains model and evaluates on relevant downstream tasks."""
  CONFIG.LOGDIR = FLAGS.logdir
  logdir = CONFIG.LOGDIR
  setup_train_dir(logdir)

  # Common code for multigpu and single gpu. Set devices here if you don't
  # want to use all the GPUs on the machine. Default is to use all GPUs.
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    algo = get_algo(CONFIG.TRAINING_ALGO)

    # Setup summary writer.
    summary_writer = tf.summary.create_file_writer(
        os.path.join(logdir, 'train_logs'), flush_millis=10000)

    learning_rate, optimizer, global_step = get_lr_opt_global_step()
    ckpt_manager, _, _ = restore_ckpt(
        logdir=logdir, optimizer=optimizer, **algo.model)

    global_step_value = global_step.numpy()

    # Remember in Eager mode learning rate variable needs to be updated
    # manually. Calling lr_fn each iteration to get current learning rate.
    lr_fn = get_lr_fn(CONFIG.OPTIMIZER)

    # Setup Dataset Iterators from train and val datasets.
    batch_size_per_replica = CONFIG.TRAIN.BATCH_SIZE
    total_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
    train_ds = create_dataset('train', mode='train',
                              batch_size=total_batch_size,
                              return_iterator=False)
    train_iterator = strategy.make_dataset_iterator(train_ds)

    def train_step(data):
      steps = data['chosen_steps']
      seq_lens = data['seq_lens']
      loss = algo.train_one_iter(data, steps, seq_lens, global_step, optimizer)
      return loss

    # This reduction only affects reporting, not the gradients.
    # pylint: disable=g-long-lambda
    dist_train = lambda it: strategy.reduce(
        tf.distribute.ReduceOp.SUM, strategy.experimental_run(train_step, it),
        axis=None)
    # pylint: enable=g-long-lambda
    if FLAGS.defun:
      dist_train = tf.function(dist_train)

    stopwatch = Stopwatch()

    try:
      while global_step_value < CONFIG.TRAIN.MAX_ITERS:
        with summary_writer.as_default():
          with tf.summary.record_if(
              global_step_value % CONFIG.LOGGING.REPORT_INTERVAL == 0):

            loss = dist_train(train_iterator)

            # Update learning rate based in lr_fn.
            learning_rate.assign(lr_fn(learning_rate, global_step))

            tf.summary.scalar('loss', loss, step=global_step)
            tf.summary.scalar('learning_rate', learning_rate, step=global_step)

            # Save checkpoint.
            if global_step_value % CONFIG.CHECKPOINT.SAVE_INTERVAL == 0:
              ckpt_manager.save()
              logging.info('Checkpoint saved at iter %d.', global_step_value)

            # Update global step.
            global_step_value = global_step.numpy()

            time_per_iter = stopwatch.elapsed()

            tf.summary.scalar(
                'timing/time_per_iter', time_per_iter, step=global_step)

            logging.info('Iter[{}/{}], {:.1f}s/iter, Loss: {:.3f}'.format(
                global_step_value, CONFIG.TRAIN.MAX_ITERS, time_per_iter,
                loss.numpy()))
            # Reset stopwatch after iter is complete.
            stopwatch.reset()

    except KeyboardInterrupt:
      logging.info('Caught keyboard interrupt. Saving model before quitting.')

    finally:
      # Save the final checkpoint.
      ckpt_manager.save()
      logging.info('Checkpoint saved at iter %d', global_step_value)
コード例 #28
0
 def test_iterator_two_vars_loop(self):
   with self.assertRaises(RuntimeError):
     tf.function(iterator_two_vars_loop)(self.ds, self.dds)
コード例 #29
0
def skip_if_no_xla(skip_test_fn):
  try:
    tf.function(lambda: tf.constant(0), jit_compile=True)()
  except tf.errors.UnimplementedError as e:
    if 'Could not find compiler' in str(e):
      skip_test_fn('XLA not available')
コード例 #30
0
 def test_iterator_next_with_catching_stop_iteration(self):
   with self.assertRaises(tf.errors.OutOfRangeError):
     tf.function(iterator_next_with_catching_stop_iteration)(self.ds, self.dds,
                                                             tf.constant(True))
コード例 #31
0
ファイル: micro_benchmarks.py プロジェクト: MFChunga/poo
 def benchmark_tf_np_tf_function_mlp_inference_batch_1_cpu(self):
     with tf.device('/CPU:0'):
         model = tf_numpy_mlp.MLP()
         x = tfnp.ones(shape=(1, 10)).astype(np.float32)
         self._benchmark_and_report(self._get_name(),
                                    tf.function(lambda: model.inference(x)))
コード例 #32
0
def _build_sampler_loop_body(model,
                             observed_time_series,
                             is_missing=None,
                             default_pseudo_observations=None,
                             experimental_use_dynamic_cholesky=False,
                             experimental_use_weight_adjustment=False):
    """Builds a Gibbs sampler for the given model and observed data.

  Args:
    model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    observed_time_series: Float `Tensor` time series of shape `[...,
      num_timesteps]`.
    is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A
      `True` value indicates that the observation for that timestep is missing.
    default_pseudo_observations: Optional scalar float `Tensor` Controls the
      number of pseudo-observations for the prior precision matrix over the
      weights.
    experimental_use_dynamic_cholesky: Optional bool - in case of spike and slab
      sampling, will dynamically select the subset of the design matrix with
      active features to perform the Cholesky decomposition. This may provide
      a speedup when the number of true features is small compared to the size
      of the design matrix.
    experimental_use_weight_adjustment: Optional bool - use a nonstandard
      update for the posterior precision of the weight in case of a spike and
      slab sampler.

  Returns:
    sampler_loop_body: Python callable that performs a single cycle of Gibbs
      sampling. Its first argument is a `GibbsSamplerState`, and it returns a
      new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is
      ignored.
  """
    if JAX_MODE and experimental_use_dynamic_cholesky:
        raise ValueError('Dynamic Cholesky decomposition not supported in JAX')
    level_component = model.components[0]
    if not (isinstance(level_component, sts.LocalLevel)
            or isinstance(level_component, sts.LocalLinearTrend)):
        raise ValueError(
            'Expected the first model component to be an instance of '
            '`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; '
            'instead saw {}'.format(level_component))
    model_has_slope = isinstance(level_component, sts.LocalLinearTrend)

    # TODO(kloveless): When we add support for more flexible models, remove
    # this assumption.
    regression_component = (None if len(model.components) != 2 else
                            model.components[1])
    if regression_component:
        if not (isinstance(regression_component, sts.LinearRegression)
                or isinstance(regression_component,
                              SpikeAndSlabSparseLinearRegression)):
            raise ValueError(
                'Expected the second model component to be an instance of '
                '`tfp.sts.LinearRegression` or '
                '`SpikeAndSlabSparseLinearRegression`; '
                'instead saw {}'.format(regression_component))
        model_has_spike_slab_regression = isinstance(
            regression_component, SpikeAndSlabSparseLinearRegression)

    if is_missing is not None:  # Ensure series does not contain NaNs.
        observed_time_series = tf.where(is_missing,
                                        tf.zeros_like(observed_time_series),
                                        observed_time_series)
    num_observed_steps = prefer_static.shape(observed_time_series)[-1]

    design_matrix = _get_design_matrix(model)
    num_missing = 0.
    if design_matrix is not None:
        design_matrix = design_matrix.to_dense()[:num_observed_steps]
        if is_missing is None:
            num_missing = 0.
            is_missing = tf.zeros(num_observed_steps, dtype=bool)
        else:
            # Replace design matrix with zeros at unobserved timesteps. This ensures
            # they will not affect the posterior on weights.
            design_matrix = tf.where(is_missing[..., tf.newaxis],
                                     tf.zeros_like(design_matrix),
                                     design_matrix)
            num_missing = tf.reduce_sum(tf.cast(is_missing,
                                                design_matrix.dtype),
                                        axis=-1)

    # Untransform scale priors -> variance priors by reaching thru Sqrt bijector.
    observation_noise_param = model.parameters[0]
    if 'observation_noise' not in observation_noise_param.name:
        raise ValueError(
            'Model parameters {} do not match the expected sampler '
            'state.'.format(model.parameters))
    observation_noise_variance_prior = observation_noise_param.prior.distribution
    if model_has_slope:
        level_scale_variance_prior, slope_scale_variance_prior = [
            p.prior.distribution for p in level_component.parameters
        ]
    else:
        level_scale_variance_prior = (
            level_component.parameters[0].prior.distribution)

    if regression_component:
        if model_has_spike_slab_regression:
            if experimental_use_dynamic_cholesky:
                sampler = dynamic_spike_and_slab.DynamicSpikeSlabSampler
            else:
                sampler = spike_and_slab.SpikeSlabSampler
            spike_and_slab_sampler = sampler(
                design_matrix,
                weights_prior_precision=regression_component.
                _weights_prior_precision,  # pylint: disable=protected-access
                nonzero_prior_prob=regression_component.
                _sparse_weights_nonzero_prob,  # pylint: disable=protected-access
                observation_noise_variance_prior_concentration=(
                    observation_noise_variance_prior.concentration),
                observation_noise_variance_prior_scale=(
                    observation_noise_variance_prior.scale),
                observation_noise_variance_upper_bound=(
                    # The given bound is for the scale, so it must be squared to get
                    # the upper bound for the variance.
                    tf.math.square(observation_noise_variance_prior.upper_bound
                                   ) if hasattr(
                                       observation_noise_variance_prior,
                                       'upper_bound') else None),
                num_missing=num_missing,
                **({
                    'default_pseudo_observations': default_pseudo_observations
                } if default_pseudo_observations is not None else {}))
            # In case the nonzero probability is exactly one, any proposal with any
            # zero weights will have log prob of -infinity, so we will pin the
            # proposals to one.
            # TODO(colcarroll): Can we short-circuit the feature selection loop in
            # case this is `True`?
            pin_to_nonzero = tf.greater_equal(
                regression_component._sparse_weights_nonzero_prob, 1.)  # pylint: disable=protected-access

        else:
            weights_prior_scale = (
                regression_component.parameters[0].prior.scale)

    # Sub-selects in `forward_filter_sequential` take up a lot of the runtime
    # with a dynamic Cholesky, but compiling here seems to help.
    # TODO(b/234726324): Should this always be compiled?
    if experimental_use_dynamic_cholesky:
        resample_latents = tf.function(jit_compile=True,
                                       autograph=False)(_resample_latents)
        resample_scale = tf.function(jit_compile=True,
                                     autograph=False)(_resample_scale)
    else:
        resample_latents = _resample_latents
        resample_scale = _resample_scale

    def sampler_loop_body(previous_sample, _):
        """Runs one sampler iteration, resampling all model variables."""

        (weights_seed, level_seed, observation_noise_scale_seed,
         level_scale_seed,
         loop_seed) = samplers.split_seed(previous_sample.seed,
                                          n=5,
                                          salt='sampler_loop_body')
        # Preserve backward-compatible seed behavior by splitting slope separately.
        slope_scale_seed, = samplers.split_seed(previous_sample.seed,
                                                n=1,
                                                salt='sampler_loop_body_slope')

        if regression_component:
            # We encourage a reasonable initialization by sampling the weights first,
            # so at the first step they are regressed directly against the observed
            # time series. If we instead sampled the level first it might 'explain
            # away' some observed variation that we would ultimately prefer to explain
            # through the regression weights, because the level can represent
            # arbitrary variation, while the weights are limited to representing
            # variation in the subspace given by the design matrix.
            if model_has_spike_slab_regression:
                if experimental_use_weight_adjustment:
                    previous_observation_noise_variance = tf.square(
                        previous_sample.observation_noise_scale)
                else:
                    previous_observation_noise_variance = 1.
                targets = tf.where(
                    is_missing, tf.zeros_like(observed_time_series),
                    observed_time_series - previous_sample.level)
                (observation_noise_variance, weights
                 ) = spike_and_slab_sampler.sample_noise_variance_and_weights(
                     initial_nonzeros=tf.math.logical_or(
                         tf.not_equal(previous_sample.weights, 0.),
                         pin_to_nonzero),
                     previous_observation_noise_variance=
                     previous_observation_noise_variance,
                     targets=targets,
                     seed=weights_seed)
                observation_noise_scale = tf.sqrt(observation_noise_variance)

            else:
                weights = _resample_weights(
                    design_matrix=design_matrix,
                    target_residuals=observed_time_series -
                    previous_sample.level,
                    observation_noise_scale=previous_sample.
                    observation_noise_scale,
                    weights_prior_scale=weights_prior_scale,
                    seed=weights_seed)
                # Noise scale will be resampled below.
                observation_noise_scale = previous_sample.observation_noise_scale

            regression_residuals = observed_time_series - tf.linalg.matvec(
                design_matrix, weights)
        else:
            # If there is no regression, then the entire timeseries is a residual.
            regression_residuals = observed_time_series
            # Noise scale will be resampled below.
            observation_noise_scale = previous_sample.observation_noise_scale
            weights = previous_sample.weights

        latents = resample_latents(
            observed_residuals=regression_residuals,
            level_scale=previous_sample.level_scale,
            slope_scale=previous_sample.slope_scale
            if model_has_slope else None,
            observation_noise_scale=observation_noise_scale,
            initial_state_prior=level_component.initial_state_prior,
            is_missing=is_missing,
            seed=level_seed)
        level = latents[..., 0]
        level_residuals = level[..., 1:] - level[..., :-1]
        if model_has_slope:
            slope = latents[..., 1]
            level_residuals -= slope[..., :-1]
            slope_residuals = slope[..., 1:] - slope[..., :-1]

        # Estimate level scale from the empirical changes in level.
        level_scale = resample_scale(prior=level_scale_variance_prior,
                                     observed_residuals=level_residuals,
                                     is_missing=None,
                                     seed=level_scale_seed)
        if model_has_slope:
            slope_scale = resample_scale(prior=slope_scale_variance_prior,
                                         observed_residuals=slope_residuals,
                                         is_missing=None,
                                         seed=slope_scale_seed)
        if not (regression_component and model_has_spike_slab_regression):
            # Estimate noise scale from the residuals.
            observation_noise_scale = resample_scale(
                prior=observation_noise_variance_prior,
                observed_residuals=regression_residuals - level,
                is_missing=is_missing,
                seed=observation_noise_scale_seed)

        return GibbsSamplerState(
            observation_noise_scale=observation_noise_scale,
            level_scale=level_scale,
            slope_scale=(slope_scale
                         if model_has_slope else previous_sample.slope_scale),
            weights=weights,
            level=level,
            slope=(slope if model_has_slope else previous_sample.slope),
            seed=loop_seed)

    return sampler_loop_body
コード例 #33
0
 def skip_if_no_xla(self):
     try:
         tf.function(lambda: tf.constant(0), experimental_compile=True)()
     except (tf.errors.UnimplementedError, NotImplementedError) as e:
         if 'Could not find compiler' in str(e):
             self.skipTest('XLA not available')
コード例 #34
0
  def __init__(self,
               parameter_prior,
               parameterized_initial_state_prior_fn,
               parameterized_transition_fn,
               parameterized_observation_fn,
               parameterized_initial_state_proposal_fn=None,
               parameterized_proposal_fn=None,
               parameter_constraining_bijector=None,
               name=None):
    """Builds an iterated filter for parameter estimation in sequential models.

    Iterated filtering is a parameter estimation method in which parameters
    are included in an augmented state space, with dynamics that introduce
    parameter perturbations, and a filtering
    algorithm such as particle filtering is run several times with perturbations
    of decreasing size. This class implements the IF2 algorithm of
    [Ionides et al., 2015][1], for which, under appropriate conditions
    (including a uniform prior) the final parameter distribution approaches a
    point mass at the maximum likelihood estimate. If a non-uniform prior is
    provided, the final parameter distribution will (under appropriate
    conditions) approach a point mass at the maximum a posteriori (MAP) value.

    This class augments the state space of a sequential model to include
    parameter perturbations, and provides utilities to run particle filtering
    on that augmented model. Alternately, the augmented components may be passed
    directly into a filtering algorithm of the user's choice.

    Args:
      parameter_prior: prior `tfd.Distribution` over parameters (may be a joint
        distribution).
      parameterized_initial_state_prior_fn: `callable` with signature
        `initial_state_prior = parameterized_initial_state_prior_fn(parameters)`
        where `parameters` has the form of a sample from `parameter_prior`,
        and `initial_state_prior` is a distribution over the initial state.
      parameterized_transition_fn: `callable` with signature
        `next_state_dist = parameterized_transition_fn(
        step, state, parameters, **kwargs)`.
      parameterized_observation_fn: `callable` with signature
        `observation_dist = parameterized_observation_fn(
        step, state, parameters, **kwargs)`.
      parameterized_initial_state_proposal_fn: optional `callable` with
        signature `initial_state_proposal =
        parameterized_initial_state_proposal_fn(parameters)` where `parameters`
        has the form of a sample from `parameter_prior`, and
        `initial_state_proposal` is a distribution over the initial state.
      parameterized_proposal_fn: optional `callable` with signature
        `next_state_dist = parameterized_transition_fn(
        step, state, parameters, **kwargs)`.
        Default value: `None`.
      parameter_constraining_bijector: optional `tfb.Bijector` instance
        such that `parameter_constraining_bijector.forward(x)` returns valid
        parameters for any real-valued `x` of the same structure and shape
        as `parameters`. If `None`, the default bijector of the provided
        `parameter_prior` will be used.
        Default value: `None`.
      name: `str` name for ops constructed by this object.
        Default value: `iterated_filter`.

    #### Example

    We'll walk through applying iterated filtering to a toy
    Susceptible-Infected-Recovered (SIR) model, a [compartmental model](
    https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model)
    of infectious disease. Note that the model we use here is extremely
    simplified and is intended as a pedagogical example; it should not be
    interpreted to describe disease spread in the real world.

    We begin by specifying a prior distribution over the parameters to be
    inferred, thus defining the structure of the parameter space and the support
    of the parameters (which will imply a default constraining bijector). Here
    we'll use uniform priors over ranges that we expect to contain the
    parameters:

    ```python
    parameter_prior = tfd.JointDistributionNamed({
        'infection_rate': tfd.Uniform(low=0., high=3.),
        'recovery_rate': tfd.Uniform(low=0., high=3.),
    })
    ```

    The model specification itself is identical to that used by
    `tfp.experimental.mcmc.infer_trajectories`, except that each component
    accepts an additional `parameters` keyword argument. We start by specifying
    a parameterized prior on initial states. In this case, our state
    includes the current number of susceptible and infected individuals
    (the third compartment, recovered individuals, is implicitly defined
    to include the remaining population). We'll also include, as auxiliary
    variables, the daily counts of new infections and new recoveries; these
    will help ensure that people shift consistently across compartments.

    ```python
    population_size = 1000
    initial_state_prior_fn = lambda parameters: tfd.JointDistributionNamed({
        'new_infections': tfd.Poisson(parameters['infection_rate']),
        'new_recoveries': tfd.Deterministic(
            tf.broadcast_to(0., tf.shape(parameters['recovery_rate']))),
        'susceptible': (lambda new_infections:
                        tfd.Deterministic(population_size - new_infections)),
        'infected': (lambda new_infections:
                     tfd.Deterministic(new_infections))})
    ```

    **Note**: the state prior must have the same batch shape as the
    passed-in parameters; equivalently, it must sample a full state for each
    parameter particle. If any part of the state prior does not depend
    on the parameters, you must manually ensure that it has the appropriate
    batch shape. For example, in the definition of `new_recoveries` above,
    applying `broadcast_to` with the shape of a parameter ensures that
    the batch shape is maintained.

    Next, we specify a transition model. This takes the state at the
    previous day, along with parameters, and returns a distribution
    over the state for the current day.

    ```python
    def parameterized_infection_dynamics(_, previous_state, parameters):
      new_infections = tfd.Poisson(
          parameters['infection_rate'] * previous_state['infected'] *
          previous_state['susceptible'] / population_size)
      new_recoveries = tfd.Poisson(
          previous_state['infected'] * parameters['recovery_rate'])
      return tfd.JointDistributionNamed({
          'new_infections': new_infections,
          'new_recoveries': new_recoveries,
          'susceptible': lambda new_infections: tfd.Deterministic(
            tf.maximum(0., previous_state['susceptible'] - new_infections)),
          'infected': lambda new_infections, new_recoveries: tfd.Deterministic(
            tf.maximum(0.,
                       (previous_state['infected'] +
                        new_infections - new_recoveries)))})
    ```

    Finally, assume that every day we get to observe noisy counts of new
    infections and recoveries.

    ```python
    def parameterized_infection_observations(_, state, parameters):
      del parameters  # Not used.
      return tfd.JointDistributionNamed({
          'new_infections': tfd.Poisson(state['new_infections'] + 0.1),
          'new_recoveries': tfd.Poisson(state['new_recoveries'] + 0.1)})
    ```

    Combining these components, an `IteratedFilter` augments
    the state space to include parameters that may change over time.

    ```python
    iterated_filter = tfp.experimental.sequential.IteratedFilter(
      parameter_prior=parameter_prior,
      parameterized_initial_state_prior_fn=initial_state_prior_fn,
      parameterized_transition_fn=parameterized_infection_dynamics,
      parameterized_observation_fn=parameterized_infection_observations)
    ```

    We may then run the filter to estimate parameters from a series
    of observations:

    ```python
     # Simulated with `infection_rate=1.2` and `recovery_rate=0.1`.
     observed_values = {
       'new_infections': tf.convert_to_tensor([
          2., 7., 14., 24., 45., 93., 160., 228., 252., 158.,  17.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
       'new_recoveries': tf.convert_to_tensor([
          0., 0., 3., 4., 3., 8., 12., 31., 49., 73., 85., 65., 71.,
          58., 42., 65., 36., 31., 32., 27., 31., 20., 19., 19., 14., 27.])
     }
     parameter_particles = iterated_filter.estimate_parameters(
         observations=observed_values,
         num_iterations=20,
         num_particles=4096,
         initial_perturbation_scale=1.0,
         cooling_schedule=(
             tfp.experimental.sequential.geometric_cooling_schedule(
                 0.001, k=20)),
         seed=test_util.test_seed())
     print('Mean of parameter particles from final iteration: {}'.format(
       tf.nest.map_structure(lambda x: tf.reduce_mean(x[-1], axis=0),
                             parameter_particles)))
     print('Standard deviation of parameter particles from '
           'final iteration: {}'.format(
           tf.nest.map_structure(lambda x: tf.math.reduce_std(x[-1], axis=0),
                                 parameter_particles)))
    ```

    For more control, we could alternately choose to run filtering iterations
    on the augmented model manually, using the filter of our choice.
    For example, manually invoking `infer_trajectories` would allow us
    to inspect the parameter and state values at all timesteps, and their
    corresponding log-probabilities:

    ```python
    trajectories, lps = tfp.experimental.mcmc.infer_trajectories(
      observations=observations,
      initial_state_prior=iterated_filter.joint_initial_state_prior,
      transition_fn=functools.partial(
          iterated_filter.joint_transition_fn,
          perturbation_scale=perturbation_scale),
      observation_fn=iterated_filter.joint_observation_fn,
      proposal_fn=iterated_filter.joint_proposal_fn,
      initial_state_proposal=iterated_filter.joint_initial_state_proposal(
          initial_unconstrained_parameters),
      num_particles=4096)
    ```

    #### References:

    [1] Edward L. Ionides, Dao Nguyen, Yves Atchade, Stilian Stoev, and Aaron A.
    King. Inference for dynamic and latent variable models via iterated,
    perturbed Bayes maps. _Proceedings of the National Academy of Sciences_
    112, no. 3: 719-724, 2015.
    https://www.pnas.org/content/pnas/112/3/719.full.pdf
    """
    name = name or 'IteratedFilter'
    with tf.name_scope(name):
      self._parameter_prior = parameter_prior
      self._parameterized_initial_state_prior_fn = (
          parameterized_initial_state_prior_fn)

      if parameter_constraining_bijector is None:
        parameter_constraining_bijector = (
            parameter_prior.experimental_default_event_space_bijector())
      self._parameter_constraining_bijector = parameter_constraining_bijector

      # Augment the prior to include both parameters and states.
      self._joint_initial_state_prior = joint_prior_on_parameters_and_state(
          parameter_prior,
          parameterized_initial_state_prior_fn,
          parameter_constraining_bijector,
          prior_is_constrained=True)

      # Check that prior samples have a consistent number of particles.
      # TODO(davmre): remove the need for dummy shape dependencies,
      # and this check, by using `JointDistributionNamedAutoBatched` with
      # auto-vectorization enabled in `joint_prior_on_parameters_and_state`.
      num_particles_canary = 13
      prior_static_sample_shapes = tf.function(
          lambda: self._joint_initial_state_prior.sample(num_particles_canary),
          autograph=False).get_concrete_function().output_shapes
      if not all([s[:1].is_compatible_with([num_particles_canary])
                  for s in tf.nest.flatten(prior_static_sample_shapes)]):
        raise ValueError('The specified prior does not generate consistent '
                         'shapes when sampled. Please verify that all parts of '
                         '`initial_state_prior_fn` have batch shape matching '
                         'that of the parameters. This may require creating '
                         '"dummy" dependencies on parameters; for example: '
                         '`tf.broadcast_to(value, tf.shape(parameter))`. (in a '
                         'test sample with {} particles, we expected all) '
                         'values to have shape compatible with [{}, ...]; '
                         'saw shapes {})'.format(num_particles_canary,
                                                 num_particles_canary,
                                                 prior_static_sample_shapes))

      # Augment the transition and observation fns to cover both
      # parameters and states.
      self._joint_transition_fn = augment_transition_fn_with_parameters(
          parameter_prior,
          parameterized_transition_fn,
          parameter_constraining_bijector)
      self._joint_observation_fn = augment_observation_fn_with_parameters(
          parameterized_observation_fn,
          parameter_constraining_bijector)

      # If given a proposal for the initial state, augment it into a joint
      # proposal over parameters and states.
      joint_initial_state_proposal = None
      if parameterized_initial_state_proposal_fn:
        joint_initial_state_proposal = joint_prior_on_parameters_and_state(
            parameter_prior,
            parameterized_initial_state_proposal_fn,
            parameter_constraining_bijector)
      else:
        parameterized_initial_state_proposal_fn = (
            parameterized_initial_state_prior_fn)
      self._joint_initial_state_proposal = joint_initial_state_proposal
      self._parameterized_initial_state_proposal_fn = (
          parameterized_initial_state_proposal_fn)

      # If given a conditional proposal fn (for non-initial states), augment
      # it to be joint over states and parameters.
      self._joint_proposal_fn = None
      if parameterized_proposal_fn:
        self._joint_proposal_fn = augment_transition_fn_with_parameters(
            parameter_prior,
            parameterized_proposal_fn,
            parameter_constraining_bijector)

      self._batch_ndims = tf.nest.map_structure(
          prefer_static.rank_from_shape,
          parameter_prior.batch_shape_tensor())
      self._name = name