def testMapStructureUpTo(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) self.assertEqual(out.a, 6) self.assertEqual(out.b, 15) data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ["evens", ["odds", "primes"]] out = nest.map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): ix = [0] def enumerated_fn(*inner_args, **inner_kwargs): r = map_fn(ix[0], *inner_args, **inner_kwargs) ix[0] += 1 return r return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, **kwargs)
def event_shape(self): """Shape of a single sample from as a `TensorShape`. May be partially defined or unknown. Returns: event_shape: `TensorShape`, possibly unknown. """ return nest.map_structure_up_to(self.dtype, tf.TensorShape, self._event_shape)
def one_step(self, new_chain_state, current_reducer_state, previous_kernel_results=None, axis=None): """Update the `current_reducer_state` with a new chain state. Chunking semantics are similar to those of batching and are specified by the `axis` parameter. If chunking is enabled (axis is not `None`), all elements along the specified `axis` will be treated as separate samples. If a single scalar value is provided for a non-scalar sample structure, that value will be used for all elements in the structure. If not, an identical structure must be provided. Args: new_chain_state: A (possibly nested) structure of incoming chain state(s) with shape and dtype compatible with those used to initialize the `current_reducer_state`. current_reducer_state: `CovarianceReducerState`s representing the current state of the running covariance. previous_kernel_results: A (possibly nested) structure of `Tensor`s representing internal calculations made in a related `TransitionKernel`. For streaming covariance, this argument has no influence on computation; hence, it is `None` by default. However, it's still accepted to fit the `Reducer` base class. axis: If chunking is desired, this is a (possibly nested) structure of integers that specifies the axis with chunked samples. For individual samples, set this to `None`. By default, samples are not chunked (`axis` is None). Returns: new_reducer_state: `CovarianceReducerState` with updated running statistics. Its `cov_state` field has an identical structure to the `current_reducer_state`. """ with tf.name_scope( mcmc_util.make_name(self.name, 'covariance_reducer', 'one_step')): cov_streams = _prepare_args(current_reducer_state.init_structure, self.event_ndims) new_chain_state = tf.nest.map_structure(tf.convert_to_tensor, new_chain_state) if not nest.is_nested(axis): axis = nest_util.broadcast_structure(new_chain_state, axis) running_cov_state = nest.map_structure_up_to( current_reducer_state.init_structure, lambda strm, *args: strm.update(*args), cov_streams, current_reducer_state.cov_state, new_chain_state, axis, check_types=False, ) return CovarianceReducerState(current_reducer_state.init_structure, running_cov_state)
def _flatten_nested_jacobian(jacobian, state_shape): """Flattens a nested Jacobian into a matrix. The flattening and concatenation follows the interpretation of the structure as being a leading 'axis', meaning that if the input has 'shape': [input_structure, A, B], and the output has 'shape': [output_structure, C, D], the input Jacobian should have the 'shape': [input_structure, output_structure, A, B, C, D]. As with the regular axes, the encoding is input major. Args: jacobian: A nested Jacobian. state_shape: A nested collection of state shapes. Returns: jacobian_mat: The Jacobian matrix. #### Examples Non-structured state: ```python input = tf.zeros([1, 2]) output = tf.zeros([3]) jacobian = tf.zeros([1, 2, 3]) ``` Structured state: ```python input = {'x': tf.zeros([1, 2])} output = {'y': tf.zeros([3])} jacobian = {'x': {'y': tf.zeros([1, 2, 3])}} ``` A more complicated structure: ```python input = [tf.zeros([1, 2]), tf.zeros([])] output = {'y': tf.zeros([3])} jacobian = [{'y': tf.zeros([1, 2, 3])}, {'y': tf.zeros([3]}] ``` """ def _flatten_row(jacobian_row, state_shape_part): state_size = ps.reduce_prod(state_shape_part) jacobian_row_mats = tf.nest.map_structure( lambda j: tf.reshape(j, ps.stack([state_size, -1], axis=0)), jacobian_row) return tf.concat(tf.nest.flatten(jacobian_row_mats), axis=-1) flat_rows = nest.map_structure_up_to(state_shape, _flatten_row, jacobian, state_shape) return tf.concat(tf.nest.flatten(flat_rows), axis=0)
def testMapStructureUpTo(self): # Named tuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) self.assertEqual(out.a, 6) self.assertEqual(out.b, 15) # Lists. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ["evens", ["odds", "primes"]] out = nest.map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) # Dicts. inp_val = dict(a=2, b=3) inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) out = nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) self.assertEqual(out["a"], 6) self.assertEqual(out["b"], 15) # Non-equal dicts. inp_val = dict(a=2, b=3) inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) with self.assertRaisesRegexp(ValueError, "same keys"): nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) # Dict+custom mapping. inp_val = dict(a=2, b=3) inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) out = nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) self.assertEqual(out["a"], 6) self.assertEqual(out["b"], 15) # Non-equal dict/mapping. inp_val = dict(a=2, b=3) inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) with self.assertRaisesRegexp(ValueError, "same keys"): nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
def build(self, y_pred, y_true): """One-time setup of metric objects.""" super(MetricsContainer, self).build(y_pred) self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) self._metrics = self._conform_to_outputs(y_pred, self._metrics) self._weighted_metrics = self._maybe_broadcast_to_outputs( y_pred, self._weighted_metrics) self._weighted_metrics = self._conform_to_outputs( y_pred, self._weighted_metrics) # Standardize on tuple since `tf.data` turns lists into `Tensor`s. y_pred = nest.list_to_tuple(y_pred) y_true = nest.list_to_tuple(y_true) self._metrics = nest.list_to_tuple(self._metrics) self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics) # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to( y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to(y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. # # If we are loading a model that has been already serialized, we do not # want to re-apply any pre-processing metric renaming steps. if not self._from_serialized: self._set_metric_names() self._create_ordered_metrics() self._built = True
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): """Executes `model`, creating both samples and distributions.""" ds = [] values_out = [] seed = SeedStream(seed, salt='JointDistributionCoroutine') gen = self._model_coroutine() index = 0 d = next(gen) if self._require_root and not isinstance(d, self.Root): raise ValueError('First distribution yielded by coroutine must ' 'be wrapped in `Root`.') try: while True: actual_distribution = d.distribution if isinstance( d, self.Root) else d ds.append(actual_distribution) if (value is not None and len(value) > index and value[index] is not None): seed( ) # Ensure reproducibility even when xs are (partially) set. def convert_tree_to_tensor(x, dtype_hint): return tf.convert_to_tensor(x, dtype_hint=dtype_hint) # This signature does not allow kwarg names. Applies # `convert_to_tensor` on the next value. next_value = nest.map_structure_up_to( ds[-1].dtype, # shallow_tree convert_tree_to_tensor, # func value[index], # x ds[-1].dtype) # dtype_hint else: next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance(d, self.Root) else (), seed=seed()) if self._validate_args: with tf.control_dependencies( self._assert_compatible_shape( index, sample_shape, next_value)): values_out.append( tf.nest.map_structure(tf.identity, next_value)) else: values_out.append(next_value) index += 1 d = gen.send(next_value) except StopIteration: pass return ds, values_out
def testMapStructureUpTo(self): # Named tuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) self.assertEqual(out.a, 6) self.assertEqual(out.b, 15) # Lists. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ["evens", ["odds", "primes"]] out = nest.map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names)
def _loop_body(step, current_estimate, accumulated_quantities): """Body of the `while_loop` for running forward filtering.""" current_observations = tf.nest.map_structure( lambda x, step=step: tf.gather(x, step), observations) updated_estimate = nest.map_structure_up_to( current_observations, run_ekf_step, current_estimate, current_observations) new_accumulated_quantities = _write_accumulated_quantities( step, accumulated_quantities, updated_estimate) return step + 1, updated_estimate, new_accumulated_quantities
def normalize(self, tensor, clip_value=5.0, center_mean=True, variance_epsilon=1e-3): """Applies normalization to tensor. Args: tensor: Tensor to normalize. clip_value: Clips normalized observations between +/- this value if clip_value > 0, otherwise does not apply clipping. center_mean: If true, subtracts off mean from normalized tensor. variance_epsilon: Epsilon to avoid division by zero in normalization. Returns: normalized_tensor: Tensor after applying normalization. """ tf.nest.assert_same_structure(tensor, self._tensor_spec) tensor = tf.nest.map_structure(lambda t: tf.cast(t, tf.float32), tensor) with tf.name_scope(self._scope + '/normalize'): mean_estimate, var_estimate = self._get_mean_var_estimates() mean = (mean_estimate if center_mean else tf.nest.map_structure( tf.zeros_like, mean_estimate)) def _normalize_single_tensor(single_tensor, single_mean, single_var): return tf.nn.batch_normalization( single_tensor, single_mean, single_var, offset=None, scale=None, variance_epsilon=variance_epsilon, name='normalized_tensor') normalized_tensor = nest.map_structure_up_to( self._tensor_spec, _normalize_single_tensor, tensor, mean, var_estimate) if clip_value > 0: def _clip(t): return tf.clip_by_value(t, -clip_value, clip_value, name='clipped_normalized_tensor') normalized_tensor = tf.nest.map_structure( _clip, normalized_tensor) return normalized_tensor
def test_constrained_affine_from_distributions(self, dist_classes, event_shape, operators, initial_loc, implicit_batch_shape, bijector, dtype, is_static, is_stateless=JAX_MODE): if not tf.executing_eagerly() and not is_static: self.skipTest( 'tfb.Reshape requires statically known shapes in graph' ' mode.') init_seed, grads_seed, shapes_seed, dtype_seed = samplers.split_seed( test_util.test_seed(sampler_type='stateless'), n=4) # pylint: disable=g-long-lambda initial_loc = tf.nest.map_structure( lambda s: self.maybe_static(np.array(s, dtype=dtype), is_static=is_static), initial_loc) distributions = nest.map_structure_up_to( dist_classes, lambda d, loc, s: tfd.Independent( d(loc=loc, scale=1.), reinterpreted_batch_ndims=ps.rank_from_shape(s)), dist_classes, initial_loc, event_shape) # pylint: enable=g-long-lambda surrogate_posterior = self._initialize_surrogate( 'build_affine_surrogate_posterior_from_base_distribution', is_stateless=is_stateless, seed=init_seed, base_distribution=distributions, operators=operators, bijector=bijector, validate_args=True) event_shape = nest.map_structure(lambda d: d.event_shape_tensor(), distributions) if bijector is not None: event_shape = nest.map_structure( lambda b, s: s if b is None else b.forward_event_shape_tensor(s), bijector, event_shape) self._test_shapes(surrogate_posterior, batch_shape=implicit_batch_shape, event_shape=event_shape, seed=shapes_seed) self._test_dtype(surrogate_posterior, dtype, dtype_seed) if not is_stateless: self._test_gradients(surrogate_posterior, seed=grads_seed)
def _build(self, y_pred, y_true): """One-time setup of metric objects.""" if self._output_names is None: # Subclass output names like 'output_1' are used for `Metric` names. self._output_names = create_output_names(y_pred) # Accept a dict of metrics keyed by output_name when outputs are a flat # list. self._metrics = map_to_output_names(y_pred, self._output_names, self._metrics) self._weighted_metrics = map_to_output_names(y_pred, self._output_names, self._weighted_metrics) # If a single metric is supplied, apply to all outputs. self._metrics = self._maybe_broadcast(self._metrics, y_pred) self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics, y_pred) # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to( y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to(y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. self._set_metric_names() self._built = True
def _build(self, y_pred, y_true): """One-time setup of metric objects.""" super(MetricsContainer, self)._build(y_pred) self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) self._metrics = self._conform_to_outputs(y_pred, self._metrics) self._weighted_metrics = self._maybe_broadcast_to_outputs( y_pred, self._weighted_metrics) self._weighted_metrics = self._conform_to_outputs( y_pred, self._weighted_metrics) # Standardize on tuple since `tf.data` turns lists into `Tensor`s. # pylint: disable=protected-access y_pred = nest._list_to_tuple(y_pred) y_true = nest._list_to_tuple(y_true) self._metrics = nest._list_to_tuple(self._metrics) self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics) # pylint: enable=protected-access # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to( y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to(y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. self._set_metric_names() self._create_ordered_metrics() self._built = True
def nested_distributions_from_specs(specs, parameters): """Builds a nest of distributions from a nest of specs. Args: specs: A nest of distribution specs. parameters: A nest of distribution kwargs. Returns: Nest of distribution instances with the same structure as the given specs. """ return nest.map_structure_up_to( specs, lambda spec, parameters: spec.build_distribution(**parameters), specs, parameters)
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): """Tags appropriate XLA sharding attribute to the dequeued tensors. Args: dequeues: A list of dequeued tensors on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same dequeues with appropriate xla_sharding attribute. """ nest.assert_shallow_structure(dequeues, dims) return nest.map_structure_up_to( dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
def reduce_tensors(structures, shallow=False): if len(structures) == 1: reduced_structure = structures[0] else: if shallow: if isinstance(structures[0], dict): shallow_tree = type(structures[0])([(k, None) for k in structures[0]]) else: shallow_tree = type(structures[0])([None for _ in structures[0]]) reduced_structure = nest.map_structure_up_to(shallow_tree, _reduce_entries, *structures) else: reduced_structure = nest.map_structure(_reduce_entries, *structures) return reduced_structure
def _from_compatible_tensor_list( self, tensor_list: List["ops.Tensor"] ) -> composite_tensor.CompositeTensor: """Reconstructs a value from a compatible flat list of `ops.Tensor`.""" flat_specs = nest.map_structure( functools.partial(get_batchable_flat_tensor_specs, context_spec=self), self._component_specs) nested_tensor_list = nest.pack_sequence_as(flat_specs, tensor_list) components = nest.map_structure_up_to(self._component_specs, batchable_from_tensor_list, self._component_specs, nested_tensor_list) return self._from_components(components)
def mc_sample(self, x, batch_size=None, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False): outputs = None with self.distribute_strategy.scope(): data_handler = data_adapter.DataHandler( x=x, batch_size=batch_size, steps_per_epoch=steps, initial_epoch=0, epochs=1, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) predict_function = self.make_mc_sample_function() for _, iterator in data_handler.enumerate_epochs(): with data_handler.catch_stop_iteration(): for step in data_handler.steps(): tmp_batch_outputs = predict_function(iterator) if not data_handler.inferred_steps: context.async_wait() batch_outputs = tmp_batch_outputs if outputs is None: outputs = nest.map_structure( lambda batch_output: [batch_output], batch_outputs) else: nest.map_structure_up_to( batch_outputs, lambda output, batch_output: output.append( batch_output), outputs, batch_outputs) all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) return tf_utils.to_numpy_or_python_type(all_outputs)
def testNestMapStructureUpTo(self): s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] s2 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] def func(x): return x + 10 if isinstance(x, int) else x result = nest.map_structure_up_to(s1, func, s2, expand_composites=True) expected = [[TestCompositeTensor(11, 12, 13)], 110, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16) }] self.assertEqual(result, expected)
def internal_trace_fn(curr_state, kr): if reducer: def fin(reducer, red_state): return reducer.finalize(red_state) # Extra level of list will be unwrapped by *reduction_args, below. reduction_args = [ nest.map_structure_up_to(reducer, fin, reducer, kr.reduction_results) ] else: reduction_args = [] return trace_fn(curr_state, kr.inner_results, *reduction_args)
def event_shape_tensor(self, name='event_shape_tensor'): """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. Args: name: name to give to the op Returns: event_shape: `Tensor`. """ with self._name_and_control_scope(name): if all([tensorshape_util.is_fully_defined(s) for s in nest.flatten(self.event_shape)]): event_shape = nest.map_structure_up_to( self.dtype, tensorshape_util.as_list, self.event_shape, check_types=False) else: event_shape = self._event_shape_tensor() return nest.map_structure_up_to( self.dtype, lambda s: tf.identity( # pylint: disable=g-long-lambda tf.convert_to_tensor(s, dtype=tf.int32), name='event_shape'), event_shape, check_types=False)
def value_grad(v, value_axis_names, term_grads): """Computes reductions of output gradients. A `log_prob_parts` function takes in a list of values and outputs a log density for each input to the function. The vector-Jacobian product (VJP) of a `log_prob_parts` function thus needs to compute the gradient of each output term w.r.t. each input value. This function overrides the default VJP of an output term `j` w.r.t to an input value `i` to include an all-reduce-sum when: 1) The gradient of `j` w.r.t. `i` is connected. 2) `j` is a sharded term and `i` is an unsharded value. If these conditions do not hold, the gradient remains the same and either corresponds to: 1) The gradient of a sharded term w.r.t to a sharded value 2) The gradient of an unsharded term w.r.t. to an unsharded value. 3) The gradient of an unsharded term w.r.t. to an sharded value. In any of these cases, no all-reduce-sum is necessary. Args: v: The output term of a `log_prob_part` function. value_axis_names: A list of axis names indicating whether or not the output term is sharded or not, `None` if no sharding. term_grads: The gradient of the output term w.r.t. to each of the input values to the `log_prob_part` function. Returns: The vector Jacobian product of `v` w.r.t. the input parts of the `log_prob_parts` function. """ term_grads = term_grads.grads def psum_grads(term_grad, term_axis_names): if term_grad is not None: if not value_axis_names and term_axis_names: # TODO(https://github.com/google/jax/issues/6022): This cast # shouldn't be here. term_grad = tf.cast( psum(term_grad, axis_name=term_axis_names), term_grad.dtype) return term_grad total_grad = nest.map_structure_up_to(term_grads, psum_grads, term_grads, map_axes) if all([grad is None for grad in tf.nest.flatten(total_grad)]): return None return tf.add_n([ v for v in tf.nest.flatten(total_grad) if tfp_custom_gradient.is_valid_gradient(v) ])
def initialize(self): """Initializes an empty `RunningPotentialScaleReductionState`. Returns: state: `RunningPotentialScaleReductionState` representing a stream of no inputs. """ broadcasted_dtype = nest_util.broadcast_structure( self.independent_chain_ndims, self.dtype) chain_var = nest.map_structure_up_to(self.independent_chain_ndims, RunningVariance.from_shape, self.shape, broadcasted_dtype, check_types=False) return RunningPotentialScaleReductionState(chain_var)
def testExampleDoc1(self): seed = test_util.test_seed_stream() model = TestModel() unconstrained_values = tf.nest.map_structure( lambda d, s: tf.random.normal(s, dtype=d, seed=seed()), model.dtype, model.event_shape, ) constrained_values = nest.map_structure_up_to( model.default_event_space_bijector, lambda b, v: b(v), model.default_event_space_bijector, unconstrained_values, ) self.assertGreaterEqual(self.evaluate(constrained_values), 0.)
def _action(self, time_step, policy_state): del time_step # Unused. if policy_state is None: policy_state = [0, 0] action_index, num_repeats = policy_state # pylint: disable=unpacking-non-sequence def _check_episode_length(): if action_index >= len(self._action_script): raise ValueError( "Episode is longer than the provided scripted policy. Consider " "setting a TimeLimit wrapper that stops episodes within the length" " of your scripted policy.") _check_episode_length() n, current_action = self._action_script[action_index] # If the policy has been executed n times get the next scripted action. # Allow users to disable entries in the scripted policy by setting n <= 0. while num_repeats >= n: action_index += 1 num_repeats = 0 _check_episode_length() n, current_action = self._action_script[action_index] num_repeats += 1 # To make it easier for the user we allow the actions in the script to be # lists instead of numpy arrays. Checking the arrays_nest requires us to # have the leaves be objects and not lists so we lift them into numpy # arrays. def actions_as_array(action_spec, action): return np.asarray(action, dtype=action_spec.dtype) current_action = nest.map_structure_up_to(self._action_spec, actions_as_array, self._action_spec, current_action) if not array_spec.check_arrays_nest(current_action, self._action_spec): raise ValueError( "Action at index {} does not match the environment's action_spec. " "Got: {}. Expected {}.".format(action_index, current_action, self._action_spec)) logging.info("Policy_state: %r", policy_state) return policy_step.PolicyStep(current_action, [action_index, num_repeats])
def batchable_from_tensor_list(spec, tensor_list): """Returns a value with type `spec` decoded from `tensor_list`.""" if isinstance(spec, tensor_spec.TensorSpec): assert len(tensor_list) == 1 return tensor_list[0] elif hasattr(spec, "__batch_encoder__"): encoded_specs = spec.__batch_encoder__.encoding_specs(spec) flat_specs = nest.map_structure(get_batchable_flat_tensor_specs, encoded_specs) encoded_flats = nest.pack_sequence_as(flat_specs, tensor_list) encoded_value = nest.map_structure_up_to(encoded_specs, batchable_from_tensor_list, encoded_specs, encoded_flats) return spec.__batch_encoder__.decode(spec, encoded_value) else: return spec._from_compatible_tensor_list(tensor_list) # pylint: disable=protected-access
def _nested_convert_to_tensor(struct, dtype=None, name=None): """Eagerly converts struct to Tensor, recursing upon failure.""" if dtype is not None or not tf.nest.is_nested(struct): return tf.convert_to_tensor(struct, dtype=dtype) if _maybe_convertible_to_tensor(struct): try: # Try converting the structure wholesale. return tf.convert_to_tensor(value=struct, name=name) except (ValueError, TypeError): # Unfortunately Eager/Graph mode don't agree on the error type. pass # Try converting all of its children. shallow_struct = _get_shallow_structure(struct) return nest.map_structure_up_to( shallow_struct, lambda s: _nested_convert_to_tensor(s, name=name), struct)
def testNestMapStructureUpTo(self): s1 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(5, 6) }] s2 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] def func(x): return x + 10 if isinstance(x, int) else x result = nest.map_structure_up_to(s1, func, s2, expand_composites=True) expected = [[TestCompositeTensor(11, 12, 13)], 110, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16) }] self.assertEqual(result, expected)
def independent_joint_distribution_from_structure(structure_of_distributions, validate_args=False): """Turns a (potentially nested) structure of dists into a single dist. Args: structure_of_distributions: instance of `tfd.Distribution`, or nested structure (tuple, list, dict, etc.) in which all leaves are `tfd.Distribution` instances. validate_args: Python `bool`. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. Returns: distribution: instance of `tfd.Distribution` such that `distribution.sample()` is equivalent to `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`. If `structure_of_distributions` was indeed a structure (as opposed to a single `Distribution` instance), this will be a `JointDistribution` with the corresponding structure. Raises: TypeError: if any leaves of the input structure are not `tfd.Distribution` instances. """ # If input is already a Distribution, just return it. if dist_util.is_distribution_instance(structure_of_distributions): return structure_of_distributions # If this structure contains other structures (ie, has elements at depth > 1), # recursively turn them into JDs. element_depths = nest.map_structure_with_tuple_paths( lambda path, x: len(path), structure_of_distributions) if max(tf.nest.flatten(element_depths)) > 1: next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, independent_joint_distribution_from_structure, structure_of_distributions) # Otherwise, build a JD from the current structure. if (hasattr(structure_of_distributions, '_asdict') or isinstance(structure_of_distributions, collections.Mapping)): return joint_distribution_named.JointDistributionNamed( structure_of_distributions, validate_args=validate_args) return joint_distribution_sequential.JointDistributionSequential( structure_of_distributions, validate_args=validate_args)
def call(self, observations, step_type, network_state): del step_type states = tf.cast(tf.nest.flatten(observations)[0], tf.float32) for layer in self.layers: states = layer(states) single_action_spec = tf.nest.flatten(self._output_tensor_spec)[0] actions, stdevs = tf.split(states, 2, axis=1) actions = tf.reshape(actions, [-1] + single_action_spec.shape.as_list()) stdevs = tf.reshape(stdevs, [-1] + single_action_spec.shape.as_list()) actions = tf.nest.pack_sequence_as(self._output_tensor_spec, [actions]) stdevs = tf.nest.pack_sequence_as(self._output_tensor_spec, [stdevs]) distribution = nest.map_structure_up_to( self._output_tensor_spec, tfp.distributions.Normal, actions, stdevs) return distribution, network_state
def test_constrained_affine_from_distributions(self, dist_classes, event_shape, operators, initial_loc, implicit_batch_shape, bijector, dtype, is_static): if not tf.executing_eagerly() and not is_static: self.skipTest( 'tfb.Reshape requires statically known shapes in graph' ' mode.') # pylint: disable=g-long-lambda initial_loc = tf.nest.map_structure( lambda s: self.maybe_static(np.array(s, dtype=dtype), is_static=is_static), initial_loc) distributions = nest.map_structure_up_to( dist_classes, lambda d, loc, s: tfd.Independent( d(loc=loc, scale=1.), reinterpreted_batch_ndims=ps.rank_from_shape(s)), dist_classes, initial_loc, event_shape) # pylint: enable=g-long-lambda surrogate_posterior = ( tfp.experimental.vi. build_affine_surrogate_posterior_from_base_distribution( distributions, operators=operators, bijector=bijector, validate_args=True)) event_shape = nest.map_structure(lambda d: d.event_shape_tensor(), distributions) if bijector is not None: event_shape = nest.map_structure( lambda b, s: s if b is None else b.forward_event_shape_tensor(s), bijector, event_shape) self.evaluate( [v.initializer for v in surrogate_posterior.trainable_variables]) seed = test_util.test_seed_stream() self._test_shapes(surrogate_posterior, batch_shape=implicit_batch_shape, event_shape=event_shape, seed=seed()) self._test_gradients(surrogate_posterior, seed=seed()) self._test_dtype(surrogate_posterior, dtype, seed())
def py_func(func, args=(), kwargs=None, output_types=None, output_shapes=None, stateful=True, name=None): """Wraps a python function and uses it as a TensorFlow op. This function is a wrapper around `tf.compat.v1.py_func` and improve it with kwargs and output_shapes. Further it changed some argument names. Given a python function `func`, which takes numpy arrays as its inputs and returns numpy arrays as its outputs, wrap this function as an operation in a TensorFlow graph. The following snippet constructs a simple TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation in the graph: ```python def my_func(x): # x will be a numpy array with the contents of the placeholder below return np.sinh(x) inp = tf.compat.v1.placeholder(tf.float32) y = tf.compat.v1.py_func(my_func, [inp], tf.float32) ``` **N.B.** The `tf.compat.v1.py_func()` operation has the following known limitations: * The body of the function (i.e. `func`) will not be serialized in a `GraphDef`. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment. * The operation must run in the same address space as the Python program that calls `tf.compat.v1.py_func()`. If you are using distributed TensorFlow, you must run a `tf.distribute.Server` in the same process as the program that calls `tf.compat.v1.py_func()` and you must pin the created operation to a device in that server (e.g. using `with tf.device():`). Args: func: A Python function, which accepts a list of NumPy `ndarray` objects having element types that match the corresponding `tf.Tensor` objects in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) having element types that match the corresponding values in `Tout`. args: A list of `Tensor` objects. kwargs: A dict with `Tensor` objects as values. output_types: A nested structure of tensorflow data types or a single tensorflow data type if there is only one, indicating what `func` returns. output_shapes: Same as output_types, except the types are replaces with shapes (optional). stateful: (Boolean.) If True, the function should be considered stateful. If a function is stateless, when given the same input it will return the same output and have no observable side effects. Optimizations such as common subexpression elimination are only performed on stateless operations. name: A name for the operation (optional). Returns: Tensorflow op that wraps the input python function. """ if kwargs is None: kwargs = {} if not isinstance(args, (list, tuple)): raise TypeError('args must be list and not {}. args: {}'.format( type(args), args)) if not isinstance(kwargs, dict): raise TypeError('kwargs must be dict and not {}. args: {}'.format( type(kwargs), kwargs)) # For dynamic type inference use callable output_types and output_shapes if callable(output_types): # If callable assume same signature and call with tensors and get the types output_types = output_types(*args, **kwargs) if callable(output_shapes): # If callable assume same signature and call with tensors and get the shapes output_shapes = output_shapes(*args, **kwargs) flat_output_types = nest.flatten(output_types) args = (args, kwargs) flat_args = nest.flatten(args) def python_function_wrapper(*py_args): py_args, py_kwargs = nest.pack_sequence_as(args, py_args) ret = func(*py_args, **py_kwargs) # TODO(alextp): Catch Exceptions and improve msg, because tensorflow # ist not able to preserve the traceback, i.e. the Exceptions does not # contain any information where the Exception was raised. nest.assert_shallow_structure(output_types, ret) return nest.flatten(ret) flat_values = _py_func( python_function_wrapper, flat_args, flat_output_types, stateful=stateful, name=name) if output_shapes is not None: # I am not sure if this is nessesary output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) flattened_shapes = nest.flatten(output_shapes) for ret_t, shape in zip(flat_values, flattened_shapes): ret_t.set_shape(shape) return nest.pack_sequence_as(output_types, flat_values)
def testNestMapStructureUpTo(self, s1, s2, expected): func = lambda x: x + 10 if isinstance(x, int) else x result = nest.map_structure_up_to(s1, func, s2, expand_composites=True) self.assertEqual(result, expected)