Exemplo n.º 1
0
  def _build(self, y_pred, y_true):
    """One-time setup of metric objects."""
    super(MetricsContainer, self)._build(y_pred)

    self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
    self._metrics = self._conform_to_outputs(y_pred, self._metrics)

    self._weighted_metrics = self._maybe_broadcast_to_outputs(
        y_pred, self._weighted_metrics)
    self._weighted_metrics = self._conform_to_outputs(y_pred,
                                                      self._weighted_metrics)

    # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
    y_pred = nest.list_to_tuple(y_pred)
    y_true = nest.list_to_tuple(y_true)
    self._metrics = nest.list_to_tuple(self._metrics)
    self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics)

    # Convert to `Metric` objects, potentially disambiguating based on output
    # properties.
    self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects,
                                             self._metrics, y_true, y_pred)
    self._weighted_metrics = nest.map_structure_up_to(y_pred,
                                                      self._get_metric_objects,
                                                      self._weighted_metrics,
                                                      y_true, y_pred)

    self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False)
    self._weighted_metrics = nest.flatten_up_to(
        y_pred, self._weighted_metrics, check_types=False)

    # Assumes metrics, weighted_metrics have been flattened up to outputs.
    self._set_metric_names()
    self._create_ordered_metrics()
    self._built = True
Exemplo n.º 2
0
    def _build(self, y_pred, y_true):
        """One-time setup of metric objects."""

        if self._output_names is None:
            # Subclass output names like 'output_1' are used for `Metric` names.
            self._output_names = create_pseudo_output_names(y_pred)

        # If a single metric or flat list of metrics, apply to all outputs.
        self._metrics = self._maybe_broadcast(self._metrics, y_pred)
        self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics,
                                                       y_pred)

        # Accept a dict of metrics keyed by output_name when outputs are a flat
        # list.
        self._metrics = map_to_output_names(y_pred, self._output_names,
                                            self._metrics)
        self._weighted_metrics = map_to_output_names(y_pred,
                                                     self._output_names,
                                                     self._weighted_metrics)

        # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
        # pylint: disable=protected-access
        y_pred = nest._list_to_tuple(y_pred)
        y_true = nest._list_to_tuple(y_true)
        self._metrics = nest._list_to_tuple(self._metrics)
        self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics)
        # pylint: enable=protected-access

        # Convert to `Metric` objects, potentially disambiguating based on output
        # properties.
        self._metrics = nest.map_structure_up_to(y_pred,
                                                 self._get_metric_objects,
                                                 self._metrics, y_true, y_pred)
        self._weighted_metrics = nest.map_structure_up_to(
            y_pred, self._get_metric_objects, self._weighted_metrics, y_true,
            y_pred)

        self._metrics = nest.flatten_up_to(y_pred,
                                           self._metrics,
                                           check_types=False)
        self._weighted_metrics = nest.flatten_up_to(y_pred,
                                                    self._weighted_metrics,
                                                    check_types=False)

        # Assumes metrics, weighted_metrics have been flattened up to outputs.
        self._set_metric_names()

        # Cache the flat order needed when returning metrics, for backwards compat.
        self._metrics_in_order = []
        for output_metrics, output_weighted_metrics in zip(
                self._metrics, self._weighted_metrics):
            for m in nest.flatten(output_metrics):
                if m is not None:
                    self._metrics_in_order.append(m)
            for wm in nest.flatten(output_weighted_metrics):
                if wm is not None:
                    self._metrics_in_order.append(wm)

        self._built = True
  def testNestFlattenUpTo(self):
    s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
    s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]
    result1 = nest.flatten_up_to(s1, s2, expand_composites=True)
    expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6]
    self.assertEqual(result1, expected1)

    result2 = nest.flatten_up_to(s1, s2, expand_composites=False)
    expected2 = [
        TestCompositeTensor(1, 2, 3), 100,
        TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    ]
    self.assertEqual(result2, expected2)
def _concrete_function_callable_with(function, inputs, allow_conversion):
  """Returns whether concrete `function` can be called with `inputs`."""
  expected_structure = function.graph.structured_input_signature
  try:
    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
  except (TypeError, ValueError):
    return False
  try:
    # Verify that no input elements were dropped during flattening.
    repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
    # TODO(b/129422719): Namedtuple subclasses re-created through
    # saved_model.load don't compare equal in type to the original in
    # assert_same_structure. Fix that and we can take out check_types=False
    # here.
    nest.assert_same_structure(inputs, repacked, check_types=False)
  except (TypeError, ValueError):
    return False

  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
    if isinstance(expected, tensor_spec.TensorSpec):
      if allow_conversion:
        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
        return False
      if arg.dtype != expected.dtype:
        return False
      if not expected.shape.is_compatible_with(arg.shape):
        return False
    elif isinstance(expected, type_spec.TypeSpec):
      return expected.is_compatible_with(arg)
    elif (_is_tensor(arg) and
          id(arg) != id(expected)) or (not _is_tensor(arg) and arg != expected):
      return False
  return True
def _call_concrete_function(function, inputs):
    """Calls a restored Function with structured inputs.

  This differs from `function.__call__` in that inputs and outputs are
  structured and that it casts inputs to tensors if needed.

  Note: this does not checks that non-tensor inputs match. That should be
  done before via `_concrete_function_callable_with`.

  Args:
    function: ConcreteFunction to call.
    inputs: Structured inputs compatible with
        `function.graph.structured_input_signature`.

  Returns:
    The structured function output.
  """
    expected_structure = function.graph.structured_input_signature
    flatten_inputs = nest.flatten_up_to(expected_structure,
                                        inputs,
                                        expand_composites=True)
    flatten_expected = nest.flatten(expected_structure, expand_composites=True)
    tensor_inputs = []
    for arg, expected in zip(flatten_inputs, flatten_expected):
        if isinstance(expected, tensor_spec.TensorSpec):
            tensor_inputs.append(
                ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
    result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
    if isinstance(result, ops.Operation):
        return None
    return result
  def batch_shape_tensor(self, name='batch_shape_tensor'):
    """Shape of a single sample from a single event index as a 1-D `Tensor`.

    The batch dimensions are indexes into independent, non-identical
    parameterizations of this distribution.

    Args:
      name: name to give to the op

    Returns:
      batch_shape: `Tensor`.
    """
    with self._name_and_control_scope(name):
      # Joint distributions may have a structured `batch shape_tensor` or a
      # single `batch_shape_tensor` that applies to all components. (Simple
      # distributions always have a single `batch_shape_tensor`.) If the
      # distribution's `batch_shape` is an instance of `tf.TensorShape`, we
      # infer that `batch_shape_tensor` is not structured.
      shallow_structure = (None if isinstance(self.batch_shape, tf.TensorShape)
                           else self.dtype)
      if all([tensorshape_util.is_fully_defined(s)
              for s in nest.flatten_up_to(
                  shallow_structure, self.batch_shape, check_types=False)]):
        batch_shape = nest.map_structure_up_to(
            shallow_structure,
            tensorshape_util.as_list,
            self.batch_shape, check_types=False)
      else:
        batch_shape = self._batch_shape_tensor()
      return nest.map_structure_up_to(
          shallow_structure,
          lambda s: tf.identity(  # pylint: disable=g-long-lambda
              tf.convert_to_tensor(s, dtype=tf.int32), name='batch_shape'),
          batch_shape, check_types=False)
Exemplo n.º 7
0
    def call(self, observation, step_type=None, network_state=()):
        del step_type  # unused.

        if self._batch_squash:
            outer_rank = nest_utils.get_outer_rank(observation,
                                                   self.input_tensor_spec)
            batch_squash = utils.BatchSquash(outer_rank)
            observation = tf.nest.map_structure(batch_squash.flatten,
                                                observation)

        if self._preprocessing_layers is None:
            processed = observation
        else:
            processed = []
            for obs, layer in zip(
                    nest.flatten_up_to(self.input_tensor_spec, observation),
                    self._preprocessing_layers):
                processed.append(layer(obs))

        states = processed

        if self._preprocessing_combiner is not None:
            states = self._preprocessing_combiner(states)

        for layer in self._postprocessing_layers:
            states = layer(states)

        if self._batch_squash:
            states = tf.nest.map_structure(batch_squash.unflatten, states)

        return states, network_state
def _concrete_function_callable_with(function, inputs, allow_conversion):
    """Returns whether concrete `function` can be called with `inputs`."""
    expected_structure = function.graph.structured_input_signature
    try:
        flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
    except (TypeError, ValueError):
        return False

    for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
        if isinstance(expected, tensor_spec.TensorSpec):
            if allow_conversion:
                arg = _try_convert_to_tensor_spec(arg,
                                                  dtype_hint=expected.dtype)
            if not _is_tensor(arg) and not isinstance(arg,
                                                      tensor_spec.TensorSpec):
                return False
            if arg.dtype != expected.dtype:
                return False
            if not expected.shape.is_compatible_with(arg.shape):
                return False
        elif isinstance(expected, type_spec.TypeSpec):
            if not expected.is_compatible_with(arg):
                return False
        elif _is_tensor(arg):
            if id(arg) != id(expected):
                return False
        else:
            if arg != expected:
                return False
    return True
Exemplo n.º 9
0
def _concrete_function_callable_with(function, inputs, allow_conversion):
  """Returns whether concrete `function` can be called with `inputs`."""
  expected_structure = function.graph.structured_input_signature
  try:
    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
  except (TypeError, ValueError):
    return False
  try:
    # Verify that no input elements were dropped during flattening.
    repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
    nest.assert_same_structure(inputs, repacked)
  except (TypeError, ValueError):
    return False

  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
    if isinstance(expected, tensor_spec.TensorSpec):
      if allow_conversion:
        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
        return False
      if arg.dtype != expected.dtype:
        return False
      if not expected.shape.is_compatible_with(arg.shape):
        return False
    else:
      if arg != expected:
        return False
  return True
  def call(self, observation, step_type=None, network_state=(), training=False):
    del step_type  # unused.

    if self._batch_squash:
      outer_rank = nest_utils.get_outer_rank(
          observation, self.input_tensor_spec)
      batch_squash = utils.BatchSquash(outer_rank)
      observation = tf.nest.map_structure(batch_squash.flatten, observation)

    if self._flat_preprocessing_layers is None:
      processed = observation
    else:
      processed = []
      for obs, layer in zip(
          nest.flatten_up_to(self._preprocessing_nest, observation),
          self._flat_preprocessing_layers):
        processed.append(layer(obs, training=training))
      if len(processed) == 1 and self._preprocessing_combiner is None:
        # If only one observation is passed and the preprocessing_combiner
        # is unspecified, use the preprocessed version of this observation.
        processed = processed[0]

    states = processed

    if self._preprocessing_combiner is not None:
      states = self._preprocessing_combiner(states)

    for layer in self._postprocessing_layers:
      states = layer(states, training=training)

    if self._batch_squash:
      states = tf.nest.map_structure(batch_squash.unflatten, states)

    return states, network_state
Exemplo n.º 11
0
    def sample_distributions(self,
                             sample_shape=(),
                             seed=None,
                             value=None,
                             name='sample_distributions',
                             **kwargs):
        with self._name_and_control_scope(name):
            value = self._resolve_value(value=value,
                                        allow_partially_specified=True,
                                        **kwargs)
            value_might_have_sample_dims = (
                value is not None and _might_have_excess_ndims(
                    # Double-flatten in case any components have structured events.
                    flat_value=nest.flatten_up_to(self._single_sample_ndims,
                                                  self._model_flatten(value),
                                                  check_types=False),
                    flat_core_ndims=tf.nest.flatten(
                        self._single_sample_ndims)))

            # TODO(b/157953455): Return distributions as CompositeTensors once
            # vectorized_map supports this.
            if self.use_vectorized_map and (
                    _might_have_nonzero_size(sample_shape)
                    or value_might_have_sample_dims):
                raise NotImplementedError(
                    '`sample_distributions` with nontrivial '
                    'sample shape is not yet supported '
                    'for autovectorized JointDistributions.')
            else:
                ds, xs = self._call_flat_sample_distributions(
                    sample_shape=sample_shape, seed=seed, value=value)
            return self._model_unflatten(ds), self._model_unflatten(xs)
def _concrete_function_callable_with(function, inputs, allow_conversion):
  """Returns whether concrete `function` can be called with `inputs`."""
  expected_structure = function.graph.structured_input_signature
  try:
    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
  except (TypeError, ValueError):
    return False
  try:
    # Verify that no input elements were dropped during flattening.
    repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
    # TODO(b/129422719): Namedtuple subclasses re-created through
    # saved_model.load don't compare equal in type to the original in
    # assert_same_structure. Fix that and we can take out check_types=False
    # here.
    nest.assert_same_structure(inputs, repacked, check_types=False)
  except (TypeError, ValueError):
    return False

  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
    if isinstance(expected, tensor_spec.TensorSpec):
      if allow_conversion:
        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
        return False
      if arg.dtype != expected.dtype:
        return False
      if not expected.shape.is_compatible_with(arg.shape):
        return False
    else:
      if arg != expected:
        return False
  return True
def _call_concrete_function(function, inputs):
  """Calls a restored Function with structured inputs.

  This differs from `function.__call__` in that inputs and outputs are
  structured and that it casts inputs to tensors if needed.

  Note: this does not checks that non-tensor inputs match. That should be
  done before via `_concrete_function_callable_with`.

  Args:
    function: ConcreteFunction to call.
    inputs: Structured inputs compatible with
        `function.graph.structured_input_signature`.

  Returns:
    The structured function output.
  """
  expected_structure = function.graph.structured_input_signature
  flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
  tensor_inputs = []
  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
    if isinstance(expected, tensor_spec.TensorSpec):
      tensor_inputs.append(
          ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
  result = function._call_flat(tensor_inputs)  # pylint: disable=protected-access
  if isinstance(result, ops.Operation):
    return None
  return result
    def _call_execute_model(self,
                            sample_shape,
                            seed,
                            value=None,
                            sample_and_trace_fn=None):
        """Wraps the base `_call_execute_model` with vectorized_map."""
        value_might_have_sample_dims = (
            value is not None and _might_have_excess_ndims(
                # Double-flatten in case any components have structured events.
                flat_value=nest.flatten_up_to(self._single_sample_ndims,
                                              self._model_flatten(value),
                                              check_types=False),
                flat_core_ndims=tf.nest.flatten(self._single_sample_ndims)))
        sample_shape_may_be_nontrivial = (
            distribution_util.shape_may_be_nontrivial(sample_shape))

        if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial
                                               or  # pylint: disable=protected-access
                                               value_might_have_sample_dims):
            # No need to auto-vectorize.
            return joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                self,
                sample_shape=sample_shape,
                seed=seed,
                value=value,
                sample_and_trace_fn=sample_and_trace_fn)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_execute_model` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)

        vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic(
            lambda v, seed: (  # pylint: disable=g-long-lambda
                joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                    self,
                    sample_shape=(),
                    seed=seed,
                    value=v,
                    sample_and_trace_fn=sample_and_trace_fn)),
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)
        # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
        # not currently supporting keyword args. This is needed because the
        # `iid_sample` wrapper below expects to pass through a `seed` kwarg.
        vectorized_execute_model = (
            lambda v, seed: vectorized_execute_model_helper(v, seed))  # pylint: disable=unnecessary-lambda

        if sample_shape_may_be_nontrivial:
            vectorized_execute_model = vectorization_util.iid_sample(
                vectorized_execute_model, sample_shape)

        return vectorized_execute_model(value, seed=seed)
Exemplo n.º 15
0
 def structured_event_to_vector(self, x):
   """Converts an event from the wrapped model to this model's event."""
   first_element = tf.nest.flatten(x)[0]
   first_shape = nest.flatten_up_to(self.model.dtype,
                                    self.model.event_shape)[0]
   batch_shape = first_element.shape[:len(first_element.shape) -
                                     len(first_shape)]
   return _flatten_and_concat(x, batch_shape, self.dtype)
Exemplo n.º 16
0
 def _flat_fn(*args):
   outputs = []
   for i, out_axis in enumerate(nest.flatten_up_to(out_dtype, map_out_axes)):
     local_args = nest.map_structure_up_to(
         args, functools.partial(_pbroadcast_input, out_axis), args,
         map_in_axes)
     outputs.append(_flat_fn_index(i, *local_args))
   return tf.nest.pack_sequence_as(out_dtype, outputs)
Exemplo n.º 17
0
  def testNestFlattenUpTo(self, s1, s2, expected, paths,
                          expand_composites=True):
    result = nest.flatten_up_to(s1, s2, expand_composites=expand_composites)
    self.assertEqual(expected, result)

    result_with_paths = nest.flatten_with_tuple_paths_up_to(
        s1, s2, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))
Exemplo n.º 18
0
  def testNestFlattenUpTo(self, s1, s2, expected, paths,
                          expand_composites=True):
    result = nest.flatten_up_to(s1, s2, expand_composites=expand_composites)
    self.assertEqual(expected, result)

    result_with_paths = nest.flatten_with_tuple_paths_up_to(
        s1, s2, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))
    def testNestFlattenUpTo(self):
        s1 = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(5, 6)
        }]
        s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]
        result1 = nest.flatten_up_to(s1, s2, expand_composites=True)
        expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6]
        self.assertEqual(result1, expected1)

        result2 = nest.flatten_up_to(s1, s2, expand_composites=False)
        expected2 = [
            TestCompositeTensor(1, 2, 3), 100,
            TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        ]
        self.assertEqual(result2, expected2)
Exemplo n.º 20
0
    def build(self, y_pred, y_true):
        """One-time setup of metric objects."""
        super(MetricsContainer, self).build(y_pred)

        self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
        self._metrics = self._conform_to_outputs(y_pred, self._metrics)

        self._weighted_metrics = self._maybe_broadcast_to_outputs(
            y_pred, self._weighted_metrics)
        self._weighted_metrics = self._conform_to_outputs(
            y_pred, self._weighted_metrics)

        # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
        y_pred = nest.list_to_tuple(y_pred)
        y_true = nest.list_to_tuple(y_true)
        self._metrics = nest.list_to_tuple(self._metrics)
        self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics)

        # Convert to `Metric` objects, potentially disambiguating based on output
        # properties.
        self._metrics = nest.map_structure_up_to(y_pred,
                                                 self._get_metric_objects,
                                                 self._metrics, y_true, y_pred)
        self._weighted_metrics = nest.map_structure_up_to(
            y_pred, self._get_metric_objects, self._weighted_metrics, y_true,
            y_pred)

        self._metrics = nest.flatten_up_to(y_pred,
                                           self._metrics,
                                           check_types=False)
        self._weighted_metrics = nest.flatten_up_to(y_pred,
                                                    self._weighted_metrics,
                                                    check_types=False)

        # Assumes metrics, weighted_metrics have been flattened up to outputs.
        #
        # If we are loading a model that has been already serialized, we do not
        # want to re-apply any pre-processing metric renaming steps.
        if not self._from_serialized:
            self._set_metric_names()
        self._create_ordered_metrics()
        self._built = True
Exemplo n.º 21
0
def metropolis_hastings_step(current_state: State,
                             proposed_state: State,
                             energy_change: FloatTensor,
                             log_uniform: FloatTensor = None,
                             seed=None) -> Tuple[State, tf.Tensor, tf.Tensor]:
    """Metropolis-Hastings step.

  This probabilistically chooses between `current_state` and `proposed_state`
  based on the `energy_change` so as to preserve detailed balance.

  Energy change is the negative of `log_accept_ratio`.

  Args:
    current_state: Current state.
    proposed_state: Proposed state.
    energy_change: E(proposed_state) - E(previous_state).
    log_uniform: Optional logarithm of a uniformly distributed random sample in
      [0, 1]. It is used to accept/reject the current and proposed state.
    seed: For reproducibility.

  Returns:
    new_state: The chosen state.
    is_accepted: Whether the proposed state was accepted.
    log_uniform: The random number that was used to select between the two
      states.
  """
    flat_current = tf.nest.flatten(current_state)
    flat_proposed = nest.flatten_up_to(current_state, proposed_state)
    # Impute the None's in the current state.
    flat_current = [
        p if c is None else c for p, c in zip(flat_proposed, flat_current)
    ]
    current_state = tf.nest.pack_sequence_as(current_state, flat_current)

    current_state = tf.nest.map_structure(tf.convert_to_tensor, current_state)
    proposed_state = tf.nest.map_structure(tf.convert_to_tensor,
                                           proposed_state)
    energy_change = tf.convert_to_tensor(value=energy_change)

    log_accept_ratio = -energy_change

    if log_uniform is None:
        log_uniform = tf.math.log(
            tf.random.uniform(shape=tf.shape(input=log_accept_ratio),
                              dtype=log_accept_ratio.dtype.base_dtype,
                              seed=seed))
    is_accepted = log_uniform < log_accept_ratio

    next_state = mcmc_util.choose(is_accepted,
                                  proposed_state,
                                  current_state,
                                  name='choose_next_state')
    return next_state, is_accepted, log_uniform
Exemplo n.º 22
0
def _convert_inputs_to_signature(inputs, input_signature,
                                 flat_input_signature):
    """Converts inputs to pass into a function with an explicit signature."""
    def format_error_message(inputs, input_signature):
        return ("  inputs: (\n" + "    " +
                ",\n    ".join(str(i) for i in inputs) + ")\n" +
                "  input_signature: (\n" + "    " +
                ",\n    ".join(str(i) for i in input_signature) + ")")

    try:
        flatten_inputs = nest.flatten_up_to(
            input_signature,
            inputs[:len(input_signature)],
            expand_composites=True,
            check_types=False)  # lists are convert to tuples for `tf.data`.
    except ValueError:
        raise ValueError("Structure of Python function inputs does not match "
                         "input_signature:\n"
                         f"{format_error_message(inputs, input_signature)}.")

    need_packing = False
    for index, (value,
                spec) in enumerate(zip(flatten_inputs, flat_input_signature)):
        if (isinstance(spec, tensor_spec.TensorSpec)
                and not _pywrap_utils.IsTensor(value)):
            try:
                flatten_inputs[index] = ops.convert_to_tensor(
                    value, dtype_hint=spec.dtype)
                need_packing = True
            except ValueError:
                raise ValueError(
                    "When input_signature is provided, all inputs to "
                    "the Python function must be convertible to "
                    "tensors:\n"
                    f"{format_error_message(inputs, input_signature)}.")

    if any(not spec.is_compatible_with(other)
           for spec, other in zip(flat_input_signature, flatten_inputs)):
        raise ValueError("Python inputs incompatible with input_signature:\n"
                         f"{format_error_message(inputs, input_signature)}.")

    if need_packing:
        inputs = nest.pack_sequence_as(structure=input_signature,
                                       flat_sequence=flatten_inputs,
                                       expand_composites=True)

    flat_inputs = nest.flatten(inputs, expand_composites=True)

    return (inputs, flat_inputs, [
        t for t in flat_inputs
        if isinstance(t, (ops.Tensor,
                          resource_variable_ops.BaseResourceVariable))
    ])
Exemplo n.º 23
0
    def _build(self, y_pred, y_true):
        """One-time setup of metric objects."""

        if self._output_names is None:
            # Subclass output names like 'output_1' are used for `Metric` names.
            self._output_names = create_output_names(y_pred)

        # Accept a dict of metrics keyed by output_name when outputs are a flat
        # list.
        self._metrics = map_to_output_names(y_pred, self._output_names,
                                            self._metrics)
        self._weighted_metrics = map_to_output_names(y_pred,
                                                     self._output_names,
                                                     self._weighted_metrics)

        # If a single metric is supplied, apply to all outputs.
        self._metrics = self._maybe_broadcast(self._metrics, y_pred)
        self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics,
                                                       y_pred)

        # Convert to `Metric` objects, potentially disambiguating based on output
        # properties.
        self._metrics = nest.map_structure_up_to(y_pred,
                                                 self._get_metric_objects,
                                                 self._metrics, y_true, y_pred)
        self._weighted_metrics = nest.map_structure_up_to(
            y_pred, self._get_metric_objects, self._weighted_metrics, y_true,
            y_pred)

        self._metrics = nest.flatten_up_to(y_pred,
                                           self._metrics,
                                           check_types=False)
        self._weighted_metrics = nest.flatten_up_to(y_pred,
                                                    self._weighted_metrics,
                                                    check_types=False)

        # Assumes metrics, weighted_metrics have been flattened up to outputs.
        self._set_metric_names()

        self._built = True
Exemplo n.º 24
0
    def _sample_n(self, sample_shape, seed, value=None):
        value_might_have_sample_dims = (
            value is not None and _might_have_excess_ndims(
                # Double-flatten in case any components have structured events.
                flat_value=nest.flatten_up_to(self._single_sample_ndims,
                                              self._model_flatten(value),
                                              check_types=False),
                flat_core_ndims=tf.nest.flatten(self._single_sample_ndims)))

        if not self.use_vectorized_map or not (
                _might_have_nonzero_size(sample_shape)
                or value_might_have_sample_dims):
            # No need to auto-vectorize.
            xs = self._call_flat_sample_distributions(
                sample_shape=sample_shape, seed=seed, value=value)[1]
            return self._model_unflatten(xs)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_flat_sample_distributions` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)
        batch_flat_sample = vectorization_util.make_rank_polymorphic(
            lambda v, seed: self._call_flat_sample_distributions(  # pylint: disable=g-long-lambda
                sample_shape=(),
                seed=seed,
                value=v)[1],
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)

        # Draw samples.
        vectorized_flat_sample = vectorization_util.iid_sample(
            # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
            # not currently supporting keyword args.
            lambda v, seed: batch_flat_sample(v, seed),
            sample_shape)  # pylint: disable=unnecessary-lambda
        xs = vectorized_flat_sample(value, seed=seed)
        return self._model_unflatten(xs)
Exemplo n.º 25
0
 def generator_py_func(iterator_id):
     try:
         values = next(generator_state.get_iterator(iterator_id))
     except StopIteration:
         generator_state.iterator_completed(iterator_id)
         raise StopIteration("Iteration finished.")
     ret_arrays = [script_ops.FuncRegistry._convert(ret) for ret in
         nest.flatten_up_to(output_types, values)]
     for (ret_array, expected_dtype, expected_shape) in zip(ret_arrays, flattened_types,
             flattened_shapes):
         if ret_array.dtype != expected_dtype.as_numpy_dtype:
             raise TypeError(
                     "`generator` yielded an element of type %s where an element "
                     "of type %s was expected." % (
                     ret_array.dtype, expected_dtype.as_numpy_dtype))
         if not expected_shape.is_compatible_with(ret_array.shape):
             raise ValueError(
                     "`generator` yielded an element of shape %s where an element "
                     "of shape %s was expected." % (ret_array.shape, expected_shape))
     return ret_arrays
Exemplo n.º 26
0
    def _to_obs_space_dtype(self, observation):
        """Make sure observation matches the specified space.

    Observation spaces in gym didn't have a dtype for a long time. Now that they
    do there is a large number of environments that do not follow the dtype in
    the space definition. Since we use the space definition to create the
    tensorflow graph we need to make sure observations match the expected
    dtypes.

    Args:
      observation: Observation to match the dtype on.

    Returns:
      The observation with a dtype matching the observation spec.
    """
        # Make sure we handle cases where observations are provided as a list.
        flat_obs = nest.flatten_up_to(self._observation_spec, observation)

        matched_observations = []
        for spec, obs in zip(self._flat_obs_spec, flat_obs):
            matched_observations.append(np.asarray(obs, dtype=spec.dtype))
        return tf.nest.pack_sequence_as(self._observation_spec,
                                        matched_observations)
Exemplo n.º 27
0
def unpack_structs_like(template, packed):
  """Converts a structure of tuples like `template` to a tuple of structures."""
  return tuple(nest.pack_sequence_as(template, flat) for flat in
               zip(*nest.flatten_up_to(template, packed, check_types=False)))
Exemplo n.º 28
0
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')

        bijector, event_dim = self._draw_bijector(bijector_name, data)

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        # TODO(b/148459057): Except, don't verify Softfloor on Guitar because
        # of numerical problem.
        def exception(bijector):
            if not tfp_hps.running_under_guitar():
                return False
            if isinstance(bijector, tfb.Softfloor):
                return True
            if is_invert(bijector):
                return exception(bijector.bijector)
            return False

        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0
                and not exception(bijector)):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=True)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        ys = self._draw_codomain_tensor(bijector, data, event_dim)
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=False)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)

        # Verify that `_is_permutation` implies constant zero Jacobian.
        if bijector._is_permutation:
            self.assertTrue(bijector._is_constant_jacobian)
            self.assertAllEqual(ldj, 0.)

        # Verify correctness of batch shape.
        xs_batch_shapes = tf.nest.map_structure(
            lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs,
            bijector.inverse_event_ndims(event_ndims))
        empirical_batch_shape = functools.reduce(
            ps.broadcast_shape,
            nest.flatten_up_to(bijector.forward_min_event_ndims,
                               xs_batch_shapes))
        batch_shape = bijector.experimental_batch_shape(
            y_event_ndims=event_ndims)
        if tensorshape_util.is_fully_defined(batch_shape):
            self.assertAllEqual(empirical_batch_shape, batch_shape)
        self.assertAllEqual(
            empirical_batch_shape,
            bijector.experimental_batch_shape_tensor(
                y_event_ndims=event_ndims))

        # Check that the outputs of forward_dtype and inverse_dtype match the dtypes
        # of the outputs of forward and inverse.
        self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
        self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
Exemplo n.º 29
0
  def generate_enqueue_ops(self, per_host_sharded_inputs):
    """Generates the host-side Ops to enqueue the partitioned inputs.

    per_host_sharded_inputs is a list, one for each replica, of lists of
    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
    replica i.
    sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].

    For example, if sharded_inputs[i][j] is a 2-D Tensor:
    [[A, B, C, D],
     [E ,F, G, H]]
    self._input_partition_dims[j] is [2, 4].

    sharded_inputs[i][j] will be partitioned and flattened into:
    [A, B, C, D, E, F, G, H] and fed into the logical core ids:
    [0, 1, 2, 3, 4, 5, 6, 7] respectively.

    Args:
      per_host_sharded_inputs: a list of lists of Tensors. The length of the
        outer list determines the number of shards. Each inner list indicates
        the types and shapes of the tuples in the corresponding shard.

    Returns:
      A list of host-side Ops, one for each shard, that when executed together
      will enqueue a full-size element of infeed.

    Raises:
      ValueError: if the queue configuration has previously been frozen and the
        shapes of the elements of sharded_inputs are not compatible with the
        frozen configuration; or if the shapes of the elements of sharded_inputs
        don't form a consistent unsharded tuple; or if the elements of a tuple
        have different device constraints; or if the partition dims are invalid.
      TypeError: if the queue configuration has previously been frozen and the
        types of the elements of sharded_inputs are not compatible with the
        frozen configuration; or if the types of the elements of sharded_inputs
        don't form a consistent unsharded tuple.
    """
    self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
    number_of_replicas_per_host = len(per_host_sharded_inputs)
    number_of_tuple_elements = len(per_host_sharded_inputs[0])

    assert len(self._input_partition_dims) == number_of_tuple_elements
    per_host_enqueue_ops = []

    for replica_index in range(number_of_replicas_per_host):
      flattened_inputs = per_host_sharded_inputs[replica_index]
      inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
                                                 self._input_partition_dims)
      inputs_parted_iters = [
          iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in
          zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat)
      ]

      for core_index in xrange(self._device_assignment.num_cores_per_replica):
        # Places different partitions to different logic cores.
        logical_core = self._get_logical_core(core_index)
        replica_id = self._device_assignment.lookup_replicas(
            self._host_id, logical_core)[replica_index]
        ordinal = self._device_assignment.tpu_ordinal(
            replica=replica_id, logical_core=logical_core)
        infeed_inputs = []
        for it in inputs_parted_iters:
          input_for_device = next(it, None)
          if input_for_device is not None:
            infeed_inputs.append(input_for_device)

        if infeed_inputs:
          per_host_enqueue_ops.append(
              tpu_ops.infeed_enqueue_tuple(
                  inputs=infeed_inputs,
                  shapes=[x.shape for x in infeed_inputs],
                  name="enqueue/replica_{0}/input_{1}".format(
                      replica_index, core_index),
                  device_ordinal=ordinal))
    return per_host_enqueue_ops
Exemplo n.º 30
0
  def testFlattenUpTo(self):
    # Shallow tree ends at scalar.
    input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
    shallow_tree = [[True, True], [False, True]]
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
    self.assertEqual(flattened_shallow_tree, [True, True, False, True])

    # Shallow tree ends at string.
    input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
    shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    input_tree_flattened = nest.flatten(input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
    self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

    # Make sure dicts are correctly flattened, yielding values, not keys.
    input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
    shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [1, {"c": 2}, 3, (4, 5)])

    # Namedtuples.
    ab_tuple = NestTest.ABTuple
    input_tree = ab_tuple(a=[0, 1], b=2)
    shallow_tree = ab_tuple(a=0, b=1)
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [[0, 1], 2])

    # Nested dicts, OrderedDicts and namedtuples.
    input_tree = collections.OrderedDict(
        [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
         ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
    shallow_tree = input_tree
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
    shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [ab_tuple(a=[0, {"b": 1}], b=2),
                      3,
                      collections.OrderedDict([("f", 4)])])
    shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [ab_tuple(a=[0, {"b": 1}], b=2),
                      {"d": 3, "e": collections.OrderedDict([("f", 4)])}])

    ## Shallow non-list edge-case.
    # Using iterable elements.
    input_tree = ["input_tree"]
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = ["input_tree_0", "input_tree_1"]
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = [0]
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = [0, 1]
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Both non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = 0
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Input non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = ["shallow_tree"]
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'str'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)

    input_tree = "input_tree"
    shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = [9]
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'int'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)

    input_tree = 0
    shallow_tree = [9, 8]
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)
Exemplo n.º 31
0
    def testFlattenUpTo(self):
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'str'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'int'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)
Exemplo n.º 32
0
def _wrap_for_composites(func, inp, Tout):
    """Wraps user inputs to support composite tensors for `py_function`.

  1. Flattens `inp` to a list of Tensors (by flattening any composite tensors).
  2. Creates a wrapper fuction for `func` that expects flat inputs and:
     - Packs the inputs into the input structure expected by `func`.
     - Calls `func` with the packed inputs.
     - Checks that `func`'s output matches `Tout`.
     - Flattens func`'s output to a list of Tensors (flattening any composite
       tensors).

  Args:
    func: The function to wrap (`func` argument to `py_function`).
    inp: The input arguments for func (`inp` argument to `py_function`).
    Tout: The expected output types for func (`Tout` argument to `py_function).

  Returns:
    A tuple `(func, inp, Tout, out_structure)`, where `func` is the wrapped
    function, `inp` is the flattened inputs, `Tout` is the list of expected
    dtypes for the flattened outputs, and `out_structure` is the expected
    output structure (which can be used to pack the output tensors).
  """
    in_structure = [
        v if isinstance(v, composite_tensor.CompositeTensor) else 1
        for v in inp
    ]
    inp = nest.flatten_up_to(in_structure, inp, expand_composites=True)
    out_structure = Tout
    Tout = [
        v.dtype if isinstance(v, tensor_spec.TensorSpec) else v
        for v in nest.flatten(Tout, expand_composites=True)
    ]

    def wrapped_func(*flat_inp):
        structured_inp = nest.pack_sequence_as(in_structure,
                                               flat_inp,
                                               expand_composites=True)
        out = func(*structured_inp)
        if not out_structure:
            return []  # Ignore return value if none is requested/expected.
        if not isinstance(out, (list, tuple)):
            out = [out]  # func may return a single value instead of a list.
        flat_out = []
        for elt, expected_type in zip(out, out_structure):
            if (isinstance(expected_type, type_spec.TypeSpec)
                    and not isinstance(expected_type, tensor_spec.TensorSpec)):
                if not expected_type.is_compatible_with(elt):
                    # pylint: disable=protected-access
                    raise ValueError(
                        f"py_function: func={func} returned {out!r}, "
                        f"which did not match Tout={out_structure!r}.\nIn particular, "
                        f"{elt!r} is not compatible with {expected_type!r}.")
                flat_out.extend(nest.flatten(elt, expand_composites=True))
            else:
                # Pro-actively check if the return value is a composite tensor when
                # we expect a Tensor.  We would catch this later (when we call
                # convert_to_tensor), but checking it here lets us give a better
                # error message.
                if isinstance(elt, composite_tensor.CompositeTensor):
                    raise ValueError(
                        f"py_function: func={func} returned {out!r}, "
                        f"which did not match Tout={out_structure!r}.\nIn particular, "
                        f"{elt!r} is not a Tensor.")
                flat_out.append(elt)
        return flat_out

    return wrapped_func, inp, Tout, out_structure
Exemplo n.º 33
0
def build_factored_surrogate_posterior(
        event_shape=None,
        constraining_bijectors=None,
        initial_unconstrained_loc=_sample_uniform_initial_loc,
        initial_unconstrained_scale=1e-2,
        trainable_distribution_fn=_build_trainable_normal_dist,
        seed=None,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    constraining_bijectors: Optional `tfb.Bijector` instance, or nested
      structure of such instances, defining support(s) of the posterior
      variables. The structure must match that of `event_shape` and may
      contain `None` values. A posterior variable will
      be modeled as `tfd.TransformedDistribution(underlying_dist,
      constraining_bijector)` if a corresponding constraining bijector is
      specified, otherwise it is modeled as supported on the
      unconstrained real line.
    initial_unconstrained_loc: Optional Python `callable` with signature
      `tensor = initial_unconstrained_loc(shape, seed)` used to sample
      real-valued initializations for the unconstrained representation of each
      variable. May alternately be a nested structure of
      `Tensor`s, giving specific initial locations for each variable; these
      must have structure matching `event_shape` and shapes determined by the
      inverse image of `event_shape` under `constraining_bijectors`, which
      may optionally be prefixed with a common batch shape.
      Default value: `functools.partial(tf.random.uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    initial_unconstrained_scale: Optional scalar float `Tensor` initial
      scale for the unconstrained distributions, or a nested structure of
      `Tensor` initial scales for each variable.
      Default value: `1e-2`.
    trainable_distribution_fn: Optional Python `callable` with signature
      `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale,
      event_ndims, validate_args)`. This is called for each model variable to
      build the corresponding factor in the surrogate posterior. It is expected
      that the distribution returned is supported on unconstrained real values.
      Default value: `functools.partial(
        tfp.experimental.vi.build_trainable_location_scale_distribution,
        distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution.
    seed: Python integer to seed the random number generator. This is used
      only when `initial_loc` is not specified.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    constraining_bijectors=[tfb.Softplus(),   # Rate is positive.
                            tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify one when we build the surrogate posterior. This function requires the
  initial location to be specified in *unconstrained* space; we do this by
  inverting the constraining bijectors (note this section also demonstrates the
  creation of a dict-structured model).

  ```python
  initial_loc = {'concentration': 0.4, 'rate': 0.2}
  constraining_bijectors={'concentration': tfb.Softplus(),   # Rate is positive.
                          'rate': tfb.Softplus()}   # Concentration is positive.
  initial_unconstrained_loc = tf.nest.map_fn(
    lambda b, x: b.inverse(x) if b is not None else x,
    constraining_bijectors, initial_loc)
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    constraining_bijectors=constraining_bijectors,
    initial_unconstrained_loc=initial_unconstrained_state,
    initial_unconstrained_scale=1e-4)
  ```

  """

    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        seed = tfp_util.SeedStream(seed,
                                   salt='build_factored_surrogate_posterior')

        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)
        flat_event_shapes = tf.nest.flatten(event_shape)

        # For simplicity, we'll work with flattened lists of state parts and
        # repack the structure at the end.
        if constraining_bijectors is not None:
            flat_bijectors = tf.nest.flatten(constraining_bijectors)
        else:
            flat_bijectors = [None for _ in flat_event_shapes]
        flat_unconstrained_event_shapes = [
            b.inverse_event_shape_tensor(s) if b is not None else s
            for s, b in zip(flat_event_shapes, flat_bijectors)
        ]

        # Construct initial locations for the internal unconstrained dists.
        if callable(
                initial_unconstrained_loc):  # Sample random initialization.
            flat_unconstrained_locs = [
                initial_unconstrained_loc(shape=s, seed=seed())
                for s in flat_unconstrained_event_shapes
            ]
        else:  # Use provided initialization.
            flat_unconstrained_locs = nest.flatten_up_to(
                shallow_structure,
                initial_unconstrained_loc,
                check_types=False)

        if nest.is_nested(initial_unconstrained_scale):
            flat_unconstrained_scales = nest.flatten_up_to(
                shallow_structure,
                initial_unconstrained_scale,
                check_types=False)
        else:
            flat_unconstrained_scales = [
                initial_unconstrained_scale for _ in flat_unconstrained_locs
            ]

        # Extract the rank of each event, so that we build distributions with the
        # correct event shapes.
        flat_unconstrained_event_ndims = [
            prefer_static.rank_from_shape(s)
            for s in flat_unconstrained_event_shapes
        ]

        # Build the component surrogate posteriors.
        flat_component_dists = []
        for initial_loc, initial_scale, event_ndims, bijector in zip(
                flat_unconstrained_locs, flat_unconstrained_scales,
                flat_unconstrained_event_ndims, flat_bijectors):
            unconstrained_dist = trainable_distribution_fn(
                initial_loc=initial_loc,
                initial_scale=initial_scale,
                event_ndims=event_ndims,
                validate_args=validate_args)
            flat_component_dists.append(
                bijector(unconstrained_dist
                         ) if bijector is not None else unconstrained_dist)
        component_distributions = tf.nest.pack_sequence_as(
            event_shape, flat_component_dists)

        # Return a `Distribution` object whose events have the specified structure.
        return (joint_distribution_util.
                independent_joint_distribution_from_structure(
                    component_distributions, validate_args=validate_args))
 def _model_flatten(self, xs):
     if self._sample_dtype is None:
         return tuple(xs)
     return nest.flatten_up_to(self._sample_dtype, xs)
Exemplo n.º 35
0
  def testFlattenUpTo(self):
    input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
    shallow_tree = [[True, True], [False, True]]
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
    self.assertEqual(flattened_shallow_tree, [True, True, False, True])

    input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
    shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    input_tree_flattened = nest.flatten(input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
    self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

    ## Shallow non-list edge-case.
    # Using iterable elements.
    input_tree = ["input_tree"]
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = ["input_tree_0", "input_tree_1"]
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = [0]
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = [0, 1]
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Both non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = 0
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Input non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = ["shallow_tree"]
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'str'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)

    input_tree = "input_tree"
    shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = [9]
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'int'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)

    input_tree = 0
    shallow_tree = [9, 8]
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, shallow_tree)
Exemplo n.º 36
0
 def _model_flatten(self, xs):
   if self._sample_dtype is None:
     return tuple((xs[k] for k in self._flat_resolve_names())
                  if isinstance(xs, collections.Mapping) else xs)
   return nest.flatten_up_to(self._sample_dtype, xs)
Exemplo n.º 37
0
def _setup_mcmc(model, n_chains, *, init_position=None, seed=None, **pins):
    """Construct bijector and transforms needed for windowed MCMC.

  This pins the initial model, constructs a bijector that unconstrains and
  flattens each dimension and adds a leading batch shape of `n_chains`,
  initializes a point in the unconstrained space, and constructs a transformed
  log probability using the bijector.

  Note that we must manually construct this target log probability instead of
  using a transformed transition kernel because the TTK assumes the shape
  in is the same as the shape out.

  Args:
    model: `tfd.JointDistribution`
      The model to sample from.
    n_chains: list of ints
      Number of chains (independent examples) to run.
    init_position: Optional
      Structure of tensors at which to initialize sampling. Should have the
      same shape and structure as
      `model.experimental_pin(**pins).sample_unpinned(n_chains)`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    **pins:
      Values passed to `model.experimental_pin`.


  Returns:
    target_log_prob_fn: Callable on the transformed space.
    initial_transformed_position: `tf.Tensor`, sampled from a uniform (-2, 2).
    bijector: `tfb.Bijector` instance, which unconstrains and flattens.
    step_broadcast_fn: Callable to broadcast step size over latent structure.
    batch_shape: Batch shape of the model.
    shard_axis_names: Shard axis names for the model
  """
    pinned_model = model.experimental_pin(**pins) if pins else model
    bijector, step_bijector = _get_flat_unconstraining_bijector(pinned_model)

    if init_position is None:
        raw_init_dist = initialization.init_near_unconstrained_zero(
            pinned_model)
        init_position = initialization.retry_init(
            raw_init_dist.sample,
            target_fn=pinned_model.unnormalized_log_prob,
            sample_shape=n_chains,
            seed=seed)

    initial_transformed_position = tf.nest.map_structure(
        tf.identity, bijector.forward(init_position))

    batch_shape = pinned_model.batch_shape
    if tf.nest.is_nested(batch_shape):
        batch_shape = functools.reduce(tf.broadcast_static_shape,
                                       tf.nest.flatten(batch_shape))

    if not tensorshape_util.is_fully_defined(batch_shape):
        batch_shape = pinned_model.batch_shape_tensor()
        if tf.nest.is_nested(batch_shape):
            batch_shape = functools.reduce(tf.broadcast_dynamic_shape,
                                           tf.nest.flatten(batch_shape))

    # This tf.function is not redundant with the ones on _fast_window
    # and _slow_window because the various kernels (like HMC) may invoke
    # `target_log_prob_fn` multiple times within one window.
    @tf.function(autograph=False)
    def target_log_prob_fn(*args):
        lp = pinned_model.unnormalized_log_prob(bijector.inverse(args))
        ldj = bijector.inverse_log_det_jacobian(
            args, event_ndims=[1 for _ in initial_transformed_position])
        return lp + ldj

    def step_broadcast(step_size):
        # Only apply the bijector to nested step sizes or non-scalar batches.
        if tf.nest.is_nested(step_size):
            return step_bijector(
                nest_util.broadcast_structure(
                    pinned_model.event_shape_tensor(), step_size))
        else:
            return step_size

    shard_axis_names = pinned_model.experimental_shard_axis_names
    if any(tf.nest.flatten(shard_axis_names)):
        shard_axis_names = nest.flatten_up_to(
            initial_transformed_position,
            list(pinned_model._model_flatten(shard_axis_names)))  # pylint: disable=protected-access

    else:
        # No active shard axis names
        shard_axis_names = None

    return (target_log_prob_fn, initial_transformed_position, bijector,
            step_broadcast,
            ps.convert_to_shape_tensor(batch_shape,
                                       name='batch_shape'), shard_axis_names)