示例#1
0
  def testMapStructureUpTo(self):
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    op_tuple = collections.namedtuple("op_tuple", "add, mul")
    inp_val = ab_tuple(a=2, b=3)
    inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
    self.assertEqual(out.a, 6)
    self.assertEqual(out.b, 15)

    data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
    name_list = ["evens", ["odds", "primes"]]
    out = nest.map_structure_up_to(
        name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
        name_list, data_list)
    self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
示例#2
0
def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
  ix = [0]
  def enumerated_fn(*inner_args, **inner_kwargs):
    r = map_fn(ix[0], *inner_args, **inner_kwargs)
    ix[0] += 1
    return r
  return nest.map_structure_up_to(shallow_structure,
                                  enumerated_fn, *args, **kwargs)
示例#3
0
    def event_shape(self):
        """Shape of a single sample from as a `TensorShape`.

    May be partially defined or unknown.

    Returns:
      event_shape: `TensorShape`, possibly unknown.
    """
        return nest.map_structure_up_to(self.dtype, tf.TensorShape,
                                        self._event_shape)
def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
  ix = [0]

  def enumerated_fn(*inner_args, **inner_kwargs):
    r = map_fn(ix[0], *inner_args, **inner_kwargs)
    ix[0] += 1
    return r

  return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args,
                                  **kwargs)
示例#5
0
    def one_step(self,
                 new_chain_state,
                 current_reducer_state,
                 previous_kernel_results=None,
                 axis=None):
        """Update the `current_reducer_state` with a new chain state.

    Chunking semantics are similar to those of batching and are specified by the
    `axis` parameter. If chunking is enabled (axis is not `None`), all elements
    along the specified `axis` will be treated as separate samples. If a
    single scalar value is provided for a non-scalar sample structure, that
    value will be used for all elements in the structure. If not, an identical
    structure must be provided.

    Args:
      new_chain_state: A (possibly nested) structure of incoming chain state(s)
        with shape and dtype compatible with those used to initialize the
        `current_reducer_state`.
      current_reducer_state: `CovarianceReducerState`s representing the current
        state of the running covariance.
      previous_kernel_results: A (possibly nested) structure of `Tensor`s
        representing internal calculations made in a related
        `TransitionKernel`. For streaming covariance, this argument has no
        influence on computation; hence, it is `None` by default. However, it's
        still accepted to fit the `Reducer` base class.
      axis: If chunking is desired, this is a (possibly nested) structure of
        integers that specifies the axis with chunked samples. For individual
        samples, set this to `None`. By default, samples are not chunked
        (`axis` is None).

    Returns:
      new_reducer_state: `CovarianceReducerState` with updated running
        statistics. Its `cov_state` field has an identical structure to the
        `current_reducer_state`.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'covariance_reducer',
                                    'one_step')):
            cov_streams = _prepare_args(current_reducer_state.init_structure,
                                        self.event_ndims)
            new_chain_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                    new_chain_state)
            if not nest.is_nested(axis):
                axis = nest_util.broadcast_structure(new_chain_state, axis)
            running_cov_state = nest.map_structure_up_to(
                current_reducer_state.init_structure,
                lambda strm, *args: strm.update(*args),
                cov_streams,
                current_reducer_state.cov_state,
                new_chain_state,
                axis,
                check_types=False,
            )
            return CovarianceReducerState(current_reducer_state.init_structure,
                                          running_cov_state)
示例#6
0
def _flatten_nested_jacobian(jacobian, state_shape):
  """Flattens a nested Jacobian into a matrix.

  The flattening and concatenation follows the interpretation of the structure
  as being a leading 'axis', meaning that if the input has 'shape':
  [input_structure, A, B], and the output has 'shape':
  [output_structure, C, D], the input Jacobian should have the 'shape':
  [input_structure, output_structure, A, B, C, D]. As with the regular axes, the
  encoding is input major.

  Args:
    jacobian: A nested Jacobian.
    state_shape: A nested collection of state shapes.

  Returns:
    jacobian_mat: The Jacobian matrix.

  #### Examples

  Non-structured state:

  ```python
  input = tf.zeros([1, 2])
  output = tf.zeros([3])
  jacobian = tf.zeros([1, 2, 3])
  ```

  Structured state:

  ```python
  input = {'x': tf.zeros([1, 2])}
  output = {'y': tf.zeros([3])}
  jacobian = {'x': {'y': tf.zeros([1, 2, 3])}}
  ```

  A more complicated structure:

  ```python
  input = [tf.zeros([1, 2]), tf.zeros([])]
  output = {'y': tf.zeros([3])}
  jacobian = [{'y': tf.zeros([1, 2, 3])}, {'y': tf.zeros([3]}]
  ```

  """

  def _flatten_row(jacobian_row, state_shape_part):
    state_size = ps.reduce_prod(state_shape_part)
    jacobian_row_mats = tf.nest.map_structure(
        lambda j: tf.reshape(j, ps.stack([state_size, -1], axis=0)),
        jacobian_row)
    return tf.concat(tf.nest.flatten(jacobian_row_mats), axis=-1)

  flat_rows = nest.map_structure_up_to(state_shape, _flatten_row, jacobian,
                                       state_shape)
  return tf.concat(tf.nest.flatten(flat_rows), axis=0)
示例#7
0
  def testMapStructureUpTo(self):
    # Named tuples.
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    op_tuple = collections.namedtuple("op_tuple", "add, mul")
    inp_val = ab_tuple(a=2, b=3)
    inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
    self.assertEqual(out.a, 6)
    self.assertEqual(out.b, 15)

    # Lists.
    data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
    name_list = ["evens", ["odds", "primes"]]
    out = nest.map_structure_up_to(
        name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
        name_list, data_list)
    self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])

    # Dicts.
    inp_val = dict(a=2, b=3)
    inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val,
        lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
    self.assertEqual(out["a"], 6)
    self.assertEqual(out["b"], 15)

    # Non-equal dicts.
    inp_val = dict(a=2, b=3)
    inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
    with self.assertRaisesRegexp(ValueError, "same keys"):
      nest.map_structure_up_to(
          inp_val,
          lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

    # Dict+custom mapping.
    inp_val = dict(a=2, b=3)
    inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val,
        lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
    self.assertEqual(out["a"], 6)
    self.assertEqual(out["b"], 15)

    # Non-equal dict/mapping.
    inp_val = dict(a=2, b=3)
    inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
    with self.assertRaisesRegexp(ValueError, "same keys"):
      nest.map_structure_up_to(
          inp_val,
          lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
示例#8
0
    def build(self, y_pred, y_true):
        """One-time setup of metric objects."""
        super(MetricsContainer, self).build(y_pred)

        self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
        self._metrics = self._conform_to_outputs(y_pred, self._metrics)

        self._weighted_metrics = self._maybe_broadcast_to_outputs(
            y_pred, self._weighted_metrics)
        self._weighted_metrics = self._conform_to_outputs(
            y_pred, self._weighted_metrics)

        # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
        y_pred = nest.list_to_tuple(y_pred)
        y_true = nest.list_to_tuple(y_true)
        self._metrics = nest.list_to_tuple(self._metrics)
        self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics)

        # Convert to `Metric` objects, potentially disambiguating based on output
        # properties.
        self._metrics = nest.map_structure_up_to(y_pred,
                                                 self._get_metric_objects,
                                                 self._metrics, y_true, y_pred)
        self._weighted_metrics = nest.map_structure_up_to(
            y_pred, self._get_metric_objects, self._weighted_metrics, y_true,
            y_pred)

        self._metrics = nest.flatten_up_to(y_pred,
                                           self._metrics,
                                           check_types=False)
        self._weighted_metrics = nest.flatten_up_to(y_pred,
                                                    self._weighted_metrics,
                                                    check_types=False)

        # Assumes metrics, weighted_metrics have been flattened up to outputs.
        #
        # If we are loading a model that has been already serialized, we do not
        # want to re-apply any pre-processing metric renaming steps.
        if not self._from_serialized:
            self._set_metric_names()
        self._create_ordered_metrics()
        self._built = True
示例#9
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        seed = SeedStream(seed, salt='JointDistributionCoroutine')
        gen = self._model_coroutine()
        index = 0
        d = next(gen)
        if self._require_root and not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                if (value is not None and len(value) > index
                        and value[index] is not None):
                    seed(
                    )  # Ensure reproducibility even when xs are (partially) set.

                    def convert_tree_to_tensor(x, dtype_hint):
                        return tf.convert_to_tensor(x, dtype_hint=dtype_hint)

                    # This signature does not allow kwarg names. Applies
                    # `convert_to_tensor` on the next value.
                    next_value = nest.map_structure_up_to(
                        ds[-1].dtype,  # shallow_tree
                        convert_tree_to_tensor,  # func
                        value[index],  # x
                        ds[-1].dtype)  # dtype_hint
                else:
                    next_value = actual_distribution.sample(
                        sample_shape=sample_shape
                        if isinstance(d, self.Root) else (),
                        seed=seed())

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(
                            tf.nest.map_structure(tf.identity, next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out
示例#10
0
    def testMapStructureUpTo(self):
        # Named tuples.
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        op_tuple = collections.namedtuple("op_tuple", "add, mul")
        inp_val = ab_tuple(a=2, b=3)
        inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
        out = nest.map_structure_up_to(
            inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val,
            inp_ops)
        self.assertEqual(out.a, 6)
        self.assertEqual(out.b, 15)

        # Lists.
        data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
        name_list = ["evens", ["odds", "primes"]]
        out = nest.map_structure_up_to(
            name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
            name_list, data_list)
        self.assertEqual(out,
                         ["first_4_evens", ["first_5_odds", "first_3_primes"]])
示例#11
0
 def log_prob_ratio_parts_fn(x_y):
   x = tf.nest.map_structure(lambda part: part[0], x_y)
   y = tf.nest.map_structure(lambda part: part[1], x_y)
   p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0]
   q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0]
   # Ensure sharded distributions defer reductions.
   kwds = lambda a: {'reduce_over_shards': False} if a else {}
   return nest.map_structure_up_to(
       p_dists,
       lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)),
       p_dists, x, q_dists, y, p_axis_names)
 def _loop_body(step, current_estimate, accumulated_quantities):
   """Body of the `while_loop` for running forward filtering."""
   current_observations = tf.nest.map_structure(
       lambda x, step=step: tf.gather(x, step), observations)
   updated_estimate = nest.map_structure_up_to(
       current_observations,
       run_ekf_step,
       current_estimate, current_observations)
   new_accumulated_quantities = _write_accumulated_quantities(
       step, accumulated_quantities, updated_estimate)
   return step + 1, updated_estimate, new_accumulated_quantities
示例#13
0
    def normalize(self,
                  tensor,
                  clip_value=5.0,
                  center_mean=True,
                  variance_epsilon=1e-3):
        """Applies normalization to tensor.

    Args:
      tensor: Tensor to normalize.
      clip_value: Clips normalized observations between +/- this value if
        clip_value > 0, otherwise does not apply clipping.
      center_mean: If true, subtracts off mean from normalized tensor.
      variance_epsilon: Epsilon to avoid division by zero in normalization.

    Returns:
      normalized_tensor: Tensor after applying normalization.
    """
        tf.nest.assert_same_structure(tensor, self._tensor_spec)
        tensor = tf.nest.map_structure(lambda t: tf.cast(t, tf.float32),
                                       tensor)

        with tf.name_scope(self._scope + '/normalize'):
            mean_estimate, var_estimate = self._get_mean_var_estimates()
            mean = (mean_estimate if center_mean else tf.nest.map_structure(
                tf.zeros_like, mean_estimate))

            def _normalize_single_tensor(single_tensor, single_mean,
                                         single_var):
                return tf.nn.batch_normalization(
                    single_tensor,
                    single_mean,
                    single_var,
                    offset=None,
                    scale=None,
                    variance_epsilon=variance_epsilon,
                    name='normalized_tensor')

            normalized_tensor = nest.map_structure_up_to(
                self._tensor_spec, _normalize_single_tensor, tensor, mean,
                var_estimate)

            if clip_value > 0:

                def _clip(t):
                    return tf.clip_by_value(t,
                                            -clip_value,
                                            clip_value,
                                            name='clipped_normalized_tensor')

                normalized_tensor = tf.nest.map_structure(
                    _clip, normalized_tensor)

        return normalized_tensor
    def test_constrained_affine_from_distributions(self,
                                                   dist_classes,
                                                   event_shape,
                                                   operators,
                                                   initial_loc,
                                                   implicit_batch_shape,
                                                   bijector,
                                                   dtype,
                                                   is_static,
                                                   is_stateless=JAX_MODE):
        if not tf.executing_eagerly() and not is_static:
            self.skipTest(
                'tfb.Reshape requires statically known shapes in graph'
                ' mode.')

        init_seed, grads_seed, shapes_seed, dtype_seed = samplers.split_seed(
            test_util.test_seed(sampler_type='stateless'), n=4)

        # pylint: disable=g-long-lambda
        initial_loc = tf.nest.map_structure(
            lambda s: self.maybe_static(np.array(s, dtype=dtype),
                                        is_static=is_static), initial_loc)
        distributions = nest.map_structure_up_to(
            dist_classes, lambda d, loc, s: tfd.Independent(
                d(loc=loc, scale=1.),
                reinterpreted_batch_ndims=ps.rank_from_shape(s)), dist_classes,
            initial_loc, event_shape)
        # pylint: enable=g-long-lambda

        surrogate_posterior = self._initialize_surrogate(
            'build_affine_surrogate_posterior_from_base_distribution',
            is_stateless=is_stateless,
            seed=init_seed,
            base_distribution=distributions,
            operators=operators,
            bijector=bijector,
            validate_args=True)

        event_shape = nest.map_structure(lambda d: d.event_shape_tensor(),
                                         distributions)
        if bijector is not None:
            event_shape = nest.map_structure(
                lambda b, s: s
                if b is None else b.forward_event_shape_tensor(s), bijector,
                event_shape)

        self._test_shapes(surrogate_posterior,
                          batch_shape=implicit_batch_shape,
                          event_shape=event_shape,
                          seed=shapes_seed)
        self._test_dtype(surrogate_posterior, dtype, dtype_seed)
        if not is_stateless:
            self._test_gradients(surrogate_posterior, seed=grads_seed)
示例#15
0
    def _build(self, y_pred, y_true):
        """One-time setup of metric objects."""

        if self._output_names is None:
            # Subclass output names like 'output_1' are used for `Metric` names.
            self._output_names = create_output_names(y_pred)

        # Accept a dict of metrics keyed by output_name when outputs are a flat
        # list.
        self._metrics = map_to_output_names(y_pred, self._output_names,
                                            self._metrics)
        self._weighted_metrics = map_to_output_names(y_pred,
                                                     self._output_names,
                                                     self._weighted_metrics)

        # If a single metric is supplied, apply to all outputs.
        self._metrics = self._maybe_broadcast(self._metrics, y_pred)
        self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics,
                                                       y_pred)

        # Convert to `Metric` objects, potentially disambiguating based on output
        # properties.
        self._metrics = nest.map_structure_up_to(y_pred,
                                                 self._get_metric_objects,
                                                 self._metrics, y_true, y_pred)
        self._weighted_metrics = nest.map_structure_up_to(
            y_pred, self._get_metric_objects, self._weighted_metrics, y_true,
            y_pred)

        self._metrics = nest.flatten_up_to(y_pred,
                                           self._metrics,
                                           check_types=False)
        self._weighted_metrics = nest.flatten_up_to(y_pred,
                                                    self._weighted_metrics,
                                                    check_types=False)

        # Assumes metrics, weighted_metrics have been flattened up to outputs.
        self._set_metric_names()

        self._built = True
    def _build(self, y_pred, y_true):
        """One-time setup of metric objects."""
        super(MetricsContainer, self)._build(y_pred)

        self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
        self._metrics = self._conform_to_outputs(y_pred, self._metrics)

        self._weighted_metrics = self._maybe_broadcast_to_outputs(
            y_pred, self._weighted_metrics)
        self._weighted_metrics = self._conform_to_outputs(
            y_pred, self._weighted_metrics)

        # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
        # pylint: disable=protected-access
        y_pred = nest._list_to_tuple(y_pred)
        y_true = nest._list_to_tuple(y_true)
        self._metrics = nest._list_to_tuple(self._metrics)
        self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics)
        # pylint: enable=protected-access

        # Convert to `Metric` objects, potentially disambiguating based on output
        # properties.
        self._metrics = nest.map_structure_up_to(y_pred,
                                                 self._get_metric_objects,
                                                 self._metrics, y_true, y_pred)
        self._weighted_metrics = nest.map_structure_up_to(
            y_pred, self._get_metric_objects, self._weighted_metrics, y_true,
            y_pred)

        self._metrics = nest.flatten_up_to(y_pred,
                                           self._metrics,
                                           check_types=False)
        self._weighted_metrics = nest.flatten_up_to(y_pred,
                                                    self._weighted_metrics,
                                                    check_types=False)

        # Assumes metrics, weighted_metrics have been flattened up to outputs.
        self._set_metric_names()
        self._create_ordered_metrics()
        self._built = True
示例#17
0
def nested_distributions_from_specs(specs, parameters):
    """Builds a nest of distributions from a nest of specs.

  Args:
    specs: A nest of distribution specs.
    parameters: A nest of distribution kwargs.

  Returns:
    Nest of distribution instances with the same structure as the given specs.
  """
    return nest.map_structure_up_to(
        specs, lambda spec, parameters: spec.build_distribution(**parameters),
        specs, parameters)
示例#18
0
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensors.

  Args:
    dequeues: A list of dequeued tensors on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same dequeues with appropriate xla_sharding attribute.
  """
  nest.assert_shallow_structure(dequeues, dims)
  return nest.map_structure_up_to(
      dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
def reduce_tensors(structures, shallow=False):
    if len(structures) == 1:
        reduced_structure = structures[0]
    else:
        if shallow:
            if isinstance(structures[0], dict):
                shallow_tree = type(structures[0])([(k, None) for k in structures[0]])
            else:
                shallow_tree = type(structures[0])([None for _ in structures[0]])
            reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures)
        else:
            reduced_structure = nest.map_structure(_reduce_entries, *structures)
    return reduced_structure
示例#20
0
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensors.

  Args:
    dequeues: A list of dequeued tensors on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same dequeues with appropriate xla_sharding attribute.
  """
  nest.assert_shallow_structure(dequeues, dims)
  return nest.map_structure_up_to(
      dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
示例#21
0
 def _from_compatible_tensor_list(
         self, tensor_list: List["ops.Tensor"]
 ) -> composite_tensor.CompositeTensor:
     """Reconstructs a value from a compatible flat list of `ops.Tensor`."""
     flat_specs = nest.map_structure(
         functools.partial(get_batchable_flat_tensor_specs,
                           context_spec=self), self._component_specs)
     nested_tensor_list = nest.pack_sequence_as(flat_specs, tensor_list)
     components = nest.map_structure_up_to(self._component_specs,
                                           batchable_from_tensor_list,
                                           self._component_specs,
                                           nested_tensor_list)
     return self._from_components(components)
示例#22
0
文件: core.py 项目: jxzhangjhu/ucate
 def mc_sample(self,
               x,
               batch_size=None,
               steps=None,
               max_queue_size=10,
               workers=1,
               use_multiprocessing=False):
     outputs = None
     with self.distribute_strategy.scope():
         data_handler = data_adapter.DataHandler(
             x=x,
             batch_size=batch_size,
             steps_per_epoch=steps,
             initial_epoch=0,
             epochs=1,
             max_queue_size=max_queue_size,
             workers=workers,
             use_multiprocessing=use_multiprocessing,
             model=self)
         predict_function = self.make_mc_sample_function()
         for _, iterator in data_handler.enumerate_epochs():
             with data_handler.catch_stop_iteration():
                 for step in data_handler.steps():
                     tmp_batch_outputs = predict_function(iterator)
                     if not data_handler.inferred_steps:
                         context.async_wait()
                     batch_outputs = tmp_batch_outputs
                     if outputs is None:
                         outputs = nest.map_structure(
                             lambda batch_output: [batch_output],
                             batch_outputs)
                     else:
                         nest.map_structure_up_to(
                             batch_outputs,
                             lambda output, batch_output: output.append(
                                 batch_output), outputs, batch_outputs)
     all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
     return tf_utils.to_numpy_or_python_type(all_outputs)
  def testNestMapStructureUpTo(self):
    s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
    s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]

    def func(x):
      return x + 10 if isinstance(x, int) else x

    result = nest.map_structure_up_to(s1, func, s2, expand_composites=True)
    expected = [[TestCompositeTensor(11, 12, 13)], 110, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16)
    }]
    self.assertEqual(result, expected)
示例#24
0
        def internal_trace_fn(curr_state, kr):
            if reducer:

                def fin(reducer, red_state):
                    return reducer.finalize(red_state)

                # Extra level of list will be unwrapped by *reduction_args, below.
                reduction_args = [
                    nest.map_structure_up_to(reducer, fin, reducer,
                                             kr.reduction_results)
                ]
            else:
                reduction_args = []
            return trace_fn(curr_state, kr.inner_results, *reduction_args)
  def event_shape_tensor(self, name='event_shape_tensor'):
    """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.

    Args:
      name: name to give to the op

    Returns:
      event_shape: `Tensor`.
    """
    with self._name_and_control_scope(name):
      if all([tensorshape_util.is_fully_defined(s)
              for s in nest.flatten(self.event_shape)]):
        event_shape = nest.map_structure_up_to(
            self.dtype,
            tensorshape_util.as_list,
            self.event_shape, check_types=False)
      else:
        event_shape = self._event_shape_tensor()
      return nest.map_structure_up_to(
          self.dtype,
          lambda s: tf.identity(  # pylint: disable=g-long-lambda
              tf.convert_to_tensor(s, dtype=tf.int32), name='event_shape'),
          event_shape, check_types=False)
示例#26
0
        def value_grad(v, value_axis_names, term_grads):
            """Computes reductions of output gradients.

      A `log_prob_parts` function takes in a list of values and outputs
      a log density for each input to the function. The vector-Jacobian
      product (VJP) of a `log_prob_parts` function thus needs to compute the
      gradient of each output term w.r.t. each input value. This function
      overrides the default VJP of an output term `j` w.r.t to an input
      value `i` to include an all-reduce-sum when:
      1) The gradient of `j` w.r.t. `i` is connected.
      2) `j` is a sharded term and `i` is an unsharded value.

      If these conditions do not hold, the gradient remains the same and
      either corresponds to:
      1) The gradient of a sharded term w.r.t to a sharded value
      2) The gradient of an unsharded term w.r.t. to an unsharded value.
      3) The gradient of an unsharded term w.r.t. to an sharded value.
      In any of these cases, no all-reduce-sum is necessary.
      Args:
        v: The output term of a `log_prob_part` function.
        value_axis_names: A list of axis names indicating whether or not the
          output term is sharded or not, `None` if no sharding.
        term_grads: The gradient of the output term w.r.t. to each of the input
          values to the `log_prob_part` function.

      Returns:
        The vector Jacobian product of `v` w.r.t. the input parts of the
        `log_prob_parts` function.
      """
            term_grads = term_grads.grads

            def psum_grads(term_grad, term_axis_names):
                if term_grad is not None:
                    if not value_axis_names and term_axis_names:
                        # TODO(https://github.com/google/jax/issues/6022): This cast
                        # shouldn't be here.
                        term_grad = tf.cast(
                            psum(term_grad, axis_name=term_axis_names),
                            term_grad.dtype)
                return term_grad

            total_grad = nest.map_structure_up_to(term_grads, psum_grads,
                                                  term_grads, map_axes)
            if all([grad is None for grad in tf.nest.flatten(total_grad)]):
                return None
            return tf.add_n([
                v for v in tf.nest.flatten(total_grad)
                if tfp_custom_gradient.is_valid_gradient(v)
            ])
示例#27
0
    def initialize(self):
        """Initializes an empty `RunningPotentialScaleReductionState`.

    Returns:
      state: `RunningPotentialScaleReductionState` representing a stream
        of no inputs.
    """
        broadcasted_dtype = nest_util.broadcast_structure(
            self.independent_chain_ndims, self.dtype)
        chain_var = nest.map_structure_up_to(self.independent_chain_ndims,
                                             RunningVariance.from_shape,
                                             self.shape,
                                             broadcasted_dtype,
                                             check_types=False)
        return RunningPotentialScaleReductionState(chain_var)
示例#28
0
 def testExampleDoc1(self):
   seed = test_util.test_seed_stream()
   model = TestModel()
   unconstrained_values = tf.nest.map_structure(
       lambda d, s: tf.random.normal(s, dtype=d, seed=seed()),
       model.dtype,
       model.event_shape,
   )
   constrained_values = nest.map_structure_up_to(
       model.default_event_space_bijector,
       lambda b, v: b(v),
       model.default_event_space_bijector,
       unconstrained_values,
   )
   self.assertGreaterEqual(self.evaluate(constrained_values), 0.)
    def _action(self, time_step, policy_state):
        del time_step  # Unused.
        if policy_state is None:
            policy_state = [0, 0]

        action_index, num_repeats = policy_state  #  pylint: disable=unpacking-non-sequence

        def _check_episode_length():
            if action_index >= len(self._action_script):
                raise ValueError(
                    "Episode is longer than the provided scripted policy. Consider "
                    "setting a TimeLimit wrapper that stops episodes within the length"
                    " of your scripted policy.")

        _check_episode_length()
        n, current_action = self._action_script[action_index]

        # If the policy has been executed n times get the next scripted action.
        # Allow users to disable entries in the scripted policy by setting n <= 0.
        while num_repeats >= n:
            action_index += 1
            num_repeats = 0
            _check_episode_length()
            n, current_action = self._action_script[action_index]

        num_repeats += 1

        # To make it easier for the user we allow the actions in the script to be
        # lists instead of numpy arrays. Checking the arrays_nest requires us to
        # have the leaves be objects and not lists so we lift them into numpy
        # arrays.
        def actions_as_array(action_spec, action):
            return np.asarray(action, dtype=action_spec.dtype)

        current_action = nest.map_structure_up_to(self._action_spec,
                                                  actions_as_array,
                                                  self._action_spec,
                                                  current_action)

        if not array_spec.check_arrays_nest(current_action, self._action_spec):
            raise ValueError(
                "Action at index {} does not match the environment's action_spec. "
                "Got: {}. Expected {}.".format(action_index, current_action,
                                               self._action_spec))

        logging.info("Policy_state: %r", policy_state)
        return policy_step.PolicyStep(current_action,
                                      [action_index, num_repeats])
示例#30
0
def batchable_from_tensor_list(spec, tensor_list):
    """Returns a value with type `spec` decoded from `tensor_list`."""
    if isinstance(spec, tensor_spec.TensorSpec):
        assert len(tensor_list) == 1
        return tensor_list[0]
    elif hasattr(spec, "__batch_encoder__"):
        encoded_specs = spec.__batch_encoder__.encoding_specs(spec)
        flat_specs = nest.map_structure(get_batchable_flat_tensor_specs,
                                        encoded_specs)
        encoded_flats = nest.pack_sequence_as(flat_specs, tensor_list)
        encoded_value = nest.map_structure_up_to(encoded_specs,
                                                 batchable_from_tensor_list,
                                                 encoded_specs, encoded_flats)
        return spec.__batch_encoder__.decode(spec, encoded_value)
    else:
        return spec._from_compatible_tensor_list(tensor_list)  # pylint: disable=protected-access
示例#31
0
def _nested_convert_to_tensor(struct, dtype=None, name=None):
  """Eagerly converts struct to Tensor, recursing upon failure."""
  if dtype is not None or not tf.nest.is_nested(struct):
    return tf.convert_to_tensor(struct, dtype=dtype)

  if _maybe_convertible_to_tensor(struct):
    try:
      # Try converting the structure wholesale.
      return tf.convert_to_tensor(value=struct, name=name)
    except (ValueError, TypeError):
      # Unfortunately Eager/Graph mode don't agree on the error type.
      pass
  # Try converting all of its children.
  shallow_struct = _get_shallow_structure(struct)
  return nest.map_structure_up_to(
      shallow_struct, lambda s: _nested_convert_to_tensor(s, name=name), struct)
    def testNestMapStructureUpTo(self):
        s1 = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(5, 6)
        }]
        s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]

        def func(x):
            return x + 10 if isinstance(x, int) else x

        result = nest.map_structure_up_to(s1, func, s2, expand_composites=True)
        expected = [[TestCompositeTensor(11, 12, 13)], 110, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16)
        }]
        self.assertEqual(result, expected)
示例#33
0
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            independent_joint_distribution_from_structure,
            structure_of_distributions)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(
            structure_of_distributions, validate_args=validate_args)
    return joint_distribution_sequential.JointDistributionSequential(
        structure_of_distributions, validate_args=validate_args)
示例#34
0
  def call(self, observations, step_type, network_state):
    del step_type

    states = tf.cast(tf.nest.flatten(observations)[0], tf.float32)
    for layer in self.layers:
      states = layer(states)

    single_action_spec = tf.nest.flatten(self._output_tensor_spec)[0]
    actions, stdevs = tf.split(states, 2, axis=1)
    actions = tf.reshape(actions, [-1] + single_action_spec.shape.as_list())
    stdevs = tf.reshape(stdevs, [-1] + single_action_spec.shape.as_list())
    actions = tf.nest.pack_sequence_as(self._output_tensor_spec, [actions])
    stdevs = tf.nest.pack_sequence_as(self._output_tensor_spec, [stdevs])

    distribution = nest.map_structure_up_to(
        self._output_tensor_spec, tfp.distributions.Normal, actions, stdevs)
    return distribution, network_state
    def test_constrained_affine_from_distributions(self, dist_classes,
                                                   event_shape, operators,
                                                   initial_loc,
                                                   implicit_batch_shape,
                                                   bijector, dtype, is_static):
        if not tf.executing_eagerly() and not is_static:
            self.skipTest(
                'tfb.Reshape requires statically known shapes in graph'
                ' mode.')
        # pylint: disable=g-long-lambda
        initial_loc = tf.nest.map_structure(
            lambda s: self.maybe_static(np.array(s, dtype=dtype),
                                        is_static=is_static), initial_loc)
        distributions = nest.map_structure_up_to(
            dist_classes, lambda d, loc, s: tfd.Independent(
                d(loc=loc, scale=1.),
                reinterpreted_batch_ndims=ps.rank_from_shape(s)), dist_classes,
            initial_loc, event_shape)
        # pylint: enable=g-long-lambda
        surrogate_posterior = (
            tfp.experimental.vi.
            build_affine_surrogate_posterior_from_base_distribution(
                distributions,
                operators=operators,
                bijector=bijector,
                validate_args=True))

        event_shape = nest.map_structure(lambda d: d.event_shape_tensor(),
                                         distributions)
        if bijector is not None:
            event_shape = nest.map_structure(
                lambda b, s: s
                if b is None else b.forward_event_shape_tensor(s), bijector,
                event_shape)

        self.evaluate(
            [v.initializer for v in surrogate_posterior.trainable_variables])

        seed = test_util.test_seed_stream()
        self._test_shapes(surrogate_posterior,
                          batch_shape=implicit_batch_shape,
                          event_shape=event_shape,
                          seed=seed())
        self._test_gradients(surrogate_posterior, seed=seed())
        self._test_dtype(surrogate_posterior, dtype, seed())
示例#36
0
def py_func(func,
            args=(),
            kwargs=None,
            output_types=None,
            output_shapes=None,
            stateful=True,
            name=None):
  """Wraps a python function and uses it as a TensorFlow op.

  This function is a wrapper around `tf.compat.v1.py_func` and improve it with
  kwargs
  and output_shapes. Further it changed some argument names.

  Given a python function `func`, which takes numpy arrays as its
  inputs and returns numpy arrays as its outputs, wrap this function as an
  operation in a TensorFlow graph. The following snippet constructs a simple
  TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
  in the graph:

  ```python
  def my_func(x):
    # x will be a numpy array with the contents of the placeholder below
    return np.sinh(x)
  inp = tf.compat.v1.placeholder(tf.float32)
  y = tf.compat.v1.py_func(my_func, [inp], tf.float32)
  ```


  **N.B.** The `tf.compat.v1.py_func()` operation has the following known
  limitations:

  * The body of the function (i.e. `func`) will not be serialized in a
    `GraphDef`. Therefore, you should not use this function if you need to
    serialize your model and restore it in a different environment.

  * The operation must run in the same address space as the Python program
    that calls `tf.compat.v1.py_func()`. If you are using distributed
    TensorFlow, you
    must run a `tf.distribute.Server` in the same process as the program that
    calls
    `tf.compat.v1.py_func()` and you must pin the created operation to a device
    in that
    server (e.g. using `with tf.device():`).

  Args:
    func: A Python function, which accepts a list of NumPy `ndarray` objects
      having element types that match the corresponding `tf.Tensor` objects in
      `inp`, and returns a list of `ndarray` objects (or a single `ndarray`)
      having element types that match the corresponding values in `Tout`.
    args: A list of `Tensor` objects.
    kwargs: A dict with `Tensor` objects as values.
    output_types: A nested structure of tensorflow data types or a single
      tensorflow data type if there is only one, indicating what `func` returns.
    output_shapes: Same as output_types, except the types are replaces with
      shapes (optional).
    stateful: (Boolean.) If True, the function should be considered stateful. If
      a function is stateless, when given the same input it will return the same
      output and have no observable side effects. Optimizations such as common
      subexpression elimination are only performed on stateless operations.
    name: A name for the operation (optional).

  Returns:
    Tensorflow op that wraps the input python function.
  """

  if kwargs is None:
    kwargs = {}

  if not isinstance(args, (list, tuple)):
    raise TypeError('args must be list and not {}. args: {}'.format(
        type(args), args))

  if not isinstance(kwargs, dict):
    raise TypeError('kwargs must be dict and not {}. args: {}'.format(
        type(kwargs), kwargs))

  # For dynamic type inference use callable output_types and output_shapes
  if callable(output_types):
    # If callable assume same signature and call with tensors and get the types
    output_types = output_types(*args, **kwargs)
  if callable(output_shapes):
    # If callable assume same signature and call with tensors and get the shapes
    output_shapes = output_shapes(*args, **kwargs)

  flat_output_types = nest.flatten(output_types)
  args = (args, kwargs)
  flat_args = nest.flatten(args)

  def python_function_wrapper(*py_args):
    py_args, py_kwargs = nest.pack_sequence_as(args, py_args)

    ret = func(*py_args, **py_kwargs)
    # TODO(alextp): Catch Exceptions and improve msg, because tensorflow
    # ist not able to preserve the traceback, i.e. the Exceptions does not
    # contain any information where the Exception was raised.
    nest.assert_shallow_structure(output_types, ret)
    return nest.flatten(ret)

  flat_values = _py_func(
      python_function_wrapper,
      flat_args,
      flat_output_types,
      stateful=stateful,
      name=name)

  if output_shapes is not None:
    # I am not sure if this is nessesary
    output_shapes = nest.map_structure_up_to(output_types,
                                             tensor_shape.as_shape,
                                             output_shapes)

    flattened_shapes = nest.flatten(output_shapes)
    for ret_t, shape in zip(flat_values, flattened_shapes):
      ret_t.set_shape(shape)

  return nest.pack_sequence_as(output_types, flat_values)
 def testNestMapStructureUpTo(self, s1, s2, expected):
   func = lambda x: x + 10 if isinstance(x, int) else x
   result = nest.map_structure_up_to(s1, func, s2, expand_composites=True)
   self.assertEqual(result, expected)