def _build(self, y_pred, y_true): """One-time setup of metric objects.""" super(MetricsContainer, self)._build(y_pred) self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) self._metrics = self._conform_to_outputs(y_pred, self._metrics) self._weighted_metrics = self._maybe_broadcast_to_outputs( y_pred, self._weighted_metrics) self._weighted_metrics = self._conform_to_outputs(y_pred, self._weighted_metrics) # Standardize on tuple since `tf.data` turns lists into `Tensor`s. y_pred = nest.list_to_tuple(y_pred) y_true = nest.list_to_tuple(y_true) self._metrics = nest.list_to_tuple(self._metrics) self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics) # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to( y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. self._set_metric_names() self._create_ordered_metrics() self._built = True
def _build(self, y_pred, y_true): """One-time setup of metric objects.""" if self._output_names is None: # Subclass output names like 'output_1' are used for `Metric` names. self._output_names = create_pseudo_output_names(y_pred) # If a single metric or flat list of metrics, apply to all outputs. self._metrics = self._maybe_broadcast(self._metrics, y_pred) self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics, y_pred) # Accept a dict of metrics keyed by output_name when outputs are a flat # list. self._metrics = map_to_output_names(y_pred, self._output_names, self._metrics) self._weighted_metrics = map_to_output_names(y_pred, self._output_names, self._weighted_metrics) # Standardize on tuple since `tf.data` turns lists into `Tensor`s. # pylint: disable=protected-access y_pred = nest._list_to_tuple(y_pred) y_true = nest._list_to_tuple(y_true) self._metrics = nest._list_to_tuple(self._metrics) self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics) # pylint: enable=protected-access # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to( y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to(y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. self._set_metric_names() # Cache the flat order needed when returning metrics, for backwards compat. self._metrics_in_order = [] for output_metrics, output_weighted_metrics in zip( self._metrics, self._weighted_metrics): for m in nest.flatten(output_metrics): if m is not None: self._metrics_in_order.append(m) for wm in nest.flatten(output_weighted_metrics): if wm is not None: self._metrics_in_order.append(wm) self._built = True
def testNestFlattenUpTo(self): s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] s2 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] result1 = nest.flatten_up_to(s1, s2, expand_composites=True) expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6] self.assertEqual(result1, expected1) result2 = nest.flatten_up_to(s1, s2, expand_composites=False) expected2 = [ TestCompositeTensor(1, 2, 3), 100, TestCompositeTensor(TestCompositeTensor(4, 5), 6) ] self.assertEqual(result2, expected2)
def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False try: # Verify that no input elements were dropped during flattening. repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) # TODO(b/129422719): Namedtuple subclasses re-created through # saved_model.load don't compare equal in type to the original in # assert_same_structure. Fix that and we can take out check_types=False # here. nest.assert_same_structure(inputs, repacked, check_types=False) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False elif isinstance(expected, type_spec.TypeSpec): return expected.is_compatible_with(arg) elif (_is_tensor(arg) and id(arg) != id(expected)) or (not _is_tensor(arg) and arg != expected): return False return True
def _call_concrete_function(function, inputs): """Calls a restored Function with structured inputs. This differs from `function.__call__` in that inputs and outputs are structured and that it casts inputs to tensors if needed. Note: this does not checks that non-tensor inputs match. That should be done before via `_concrete_function_callable_with`. Args: function: ConcreteFunction to call. inputs: Structured inputs compatible with `function.graph.structured_input_signature`. Returns: The structured function output. """ expected_structure = function.graph.structured_input_signature flatten_inputs = nest.flatten_up_to(expected_structure, inputs, expand_composites=True) flatten_expected = nest.flatten(expected_structure, expand_composites=True) tensor_inputs = [] for arg, expected in zip(flatten_inputs, flatten_expected): if isinstance(expected, tensor_spec.TensorSpec): tensor_inputs.append( ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) result = function._call_flat(tensor_inputs, function._captured_inputs) # pylint: disable=protected-access if isinstance(result, ops.Operation): return None return result
def batch_shape_tensor(self, name='batch_shape_tensor'): """Shape of a single sample from a single event index as a 1-D `Tensor`. The batch dimensions are indexes into independent, non-identical parameterizations of this distribution. Args: name: name to give to the op Returns: batch_shape: `Tensor`. """ with self._name_and_control_scope(name): # Joint distributions may have a structured `batch shape_tensor` or a # single `batch_shape_tensor` that applies to all components. (Simple # distributions always have a single `batch_shape_tensor`.) If the # distribution's `batch_shape` is an instance of `tf.TensorShape`, we # infer that `batch_shape_tensor` is not structured. shallow_structure = (None if isinstance(self.batch_shape, tf.TensorShape) else self.dtype) if all([tensorshape_util.is_fully_defined(s) for s in nest.flatten_up_to( shallow_structure, self.batch_shape, check_types=False)]): batch_shape = nest.map_structure_up_to( shallow_structure, tensorshape_util.as_list, self.batch_shape, check_types=False) else: batch_shape = self._batch_shape_tensor() return nest.map_structure_up_to( shallow_structure, lambda s: tf.identity( # pylint: disable=g-long-lambda tf.convert_to_tensor(s, dtype=tf.int32), name='batch_shape'), batch_shape, check_types=False)
def call(self, observation, step_type=None, network_state=()): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation = tf.nest.map_structure(batch_squash.flatten, observation) if self._preprocessing_layers is None: processed = observation else: processed = [] for obs, layer in zip( nest.flatten_up_to(self.input_tensor_spec, observation), self._preprocessing_layers): processed.append(layer(obs)) states = processed if self._preprocessing_combiner is not None: states = self._preprocessing_combiner(states) for layer in self._postprocessing_layers: states = layer(states) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state
def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False elif isinstance(expected, type_spec.TypeSpec): if not expected.is_compatible_with(arg): return False elif _is_tensor(arg): if id(arg) != id(expected): return False else: if arg != expected: return False return True
def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False try: # Verify that no input elements were dropped during flattening. repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) nest.assert_same_structure(inputs, repacked) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False else: if arg != expected: return False return True
def call(self, observation, step_type=None, network_state=(), training=False): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank( observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation = tf.nest.map_structure(batch_squash.flatten, observation) if self._flat_preprocessing_layers is None: processed = observation else: processed = [] for obs, layer in zip( nest.flatten_up_to(self._preprocessing_nest, observation), self._flat_preprocessing_layers): processed.append(layer(obs, training=training)) if len(processed) == 1 and self._preprocessing_combiner is None: # If only one observation is passed and the preprocessing_combiner # is unspecified, use the preprocessed version of this observation. processed = processed[0] states = processed if self._preprocessing_combiner is not None: states = self._preprocessing_combiner(states) for layer in self._postprocessing_layers: states = layer(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state
def sample_distributions(self, sample_shape=(), seed=None, value=None, name='sample_distributions', **kwargs): with self._name_and_control_scope(name): value = self._resolve_value(value=value, allow_partially_specified=True, **kwargs) value_might_have_sample_dims = ( value is not None and _might_have_excess_ndims( # Double-flatten in case any components have structured events. flat_value=nest.flatten_up_to(self._single_sample_ndims, self._model_flatten(value), check_types=False), flat_core_ndims=tf.nest.flatten( self._single_sample_ndims))) # TODO(b/157953455): Return distributions as CompositeTensors once # vectorized_map supports this. if self.use_vectorized_map and ( _might_have_nonzero_size(sample_shape) or value_might_have_sample_dims): raise NotImplementedError( '`sample_distributions` with nontrivial ' 'sample shape is not yet supported ' 'for autovectorized JointDistributions.') else: ds, xs = self._call_flat_sample_distributions( sample_shape=sample_shape, seed=seed, value=value) return self._model_unflatten(ds), self._model_unflatten(xs)
def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False try: # Verify that no input elements were dropped during flattening. repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) # TODO(b/129422719): Namedtuple subclasses re-created through # saved_model.load don't compare equal in type to the original in # assert_same_structure. Fix that and we can take out check_types=False # here. nest.assert_same_structure(inputs, repacked, check_types=False) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False else: if arg != expected: return False return True
def _call_concrete_function(function, inputs): """Calls a restored Function with structured inputs. This differs from `function.__call__` in that inputs and outputs are structured and that it casts inputs to tensors if needed. Note: this does not checks that non-tensor inputs match. That should be done before via `_concrete_function_callable_with`. Args: function: ConcreteFunction to call. inputs: Structured inputs compatible with `function.graph.structured_input_signature`. Returns: The structured function output. """ expected_structure = function.graph.structured_input_signature flatten_inputs = nest.flatten_up_to(expected_structure, inputs) tensor_inputs = [] for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): tensor_inputs.append( ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) result = function._call_flat(tensor_inputs) # pylint: disable=protected-access if isinstance(result, ops.Operation): return None return result
def _call_execute_model(self, sample_shape, seed, value=None, sample_and_trace_fn=None): """Wraps the base `_call_execute_model` with vectorized_map.""" value_might_have_sample_dims = ( value is not None and _might_have_excess_ndims( # Double-flatten in case any components have structured events. flat_value=nest.flatten_up_to(self._single_sample_ndims, self._model_flatten(value), check_types=False), flat_core_ndims=tf.nest.flatten(self._single_sample_ndims))) sample_shape_may_be_nontrivial = ( distribution_util.shape_may_be_nontrivial(sample_shape)) if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial or # pylint: disable=protected-access value_might_have_sample_dims): # No need to auto-vectorize. return joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access self, sample_shape=sample_shape, seed=seed, value=value, sample_and_trace_fn=sample_and_trace_fn) # Set up for autovectorized sampling. To support the `value` arg, we need to # first understand which dims are from the model itself, then wrap # `_call_execute_model` to batch over all remaining dims. value_core_ndims = None if value is not None: value_core_ndims = tf.nest.map_structure( lambda v, nd: None if v is None else nd, value, self._model_unflatten(self._single_sample_ndims), check_types=False) vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic( lambda v, seed: ( # pylint: disable=g-long-lambda joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access self, sample_shape=(), seed=seed, value=v, sample_and_trace_fn=sample_and_trace_fn)), core_ndims=[value_core_ndims, None], validate_args=self.validate_args) # Redefine the polymorphic fn to hack around `make_rank_polymorphic` # not currently supporting keyword args. This is needed because the # `iid_sample` wrapper below expects to pass through a `seed` kwarg. vectorized_execute_model = ( lambda v, seed: vectorized_execute_model_helper(v, seed)) # pylint: disable=unnecessary-lambda if sample_shape_may_be_nontrivial: vectorized_execute_model = vectorization_util.iid_sample( vectorized_execute_model, sample_shape) return vectorized_execute_model(value, seed=seed)
def structured_event_to_vector(self, x): """Converts an event from the wrapped model to this model's event.""" first_element = tf.nest.flatten(x)[0] first_shape = nest.flatten_up_to(self.model.dtype, self.model.event_shape)[0] batch_shape = first_element.shape[:len(first_element.shape) - len(first_shape)] return _flatten_and_concat(x, batch_shape, self.dtype)
def _flat_fn(*args): outputs = [] for i, out_axis in enumerate(nest.flatten_up_to(out_dtype, map_out_axes)): local_args = nest.map_structure_up_to( args, functools.partial(_pbroadcast_input, out_axis), args, map_in_axes) outputs.append(_flat_fn_index(i, *local_args)) return tf.nest.pack_sequence_as(out_dtype, outputs)
def testNestFlattenUpTo(self, s1, s2, expected, paths, expand_composites=True): result = nest.flatten_up_to(s1, s2, expand_composites=expand_composites) self.assertEqual(expected, result) result_with_paths = nest.flatten_with_tuple_paths_up_to( s1, s2, expand_composites=expand_composites) self.assertEqual(result_with_paths, list(zip(paths, expected)))
def testNestFlattenUpTo(self): s1 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(5, 6) }] s2 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] result1 = nest.flatten_up_to(s1, s2, expand_composites=True) expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6] self.assertEqual(result1, expected1) result2 = nest.flatten_up_to(s1, s2, expand_composites=False) expected2 = [ TestCompositeTensor(1, 2, 3), 100, TestCompositeTensor(TestCompositeTensor(4, 5), 6) ] self.assertEqual(result2, expected2)
def build(self, y_pred, y_true): """One-time setup of metric objects.""" super(MetricsContainer, self).build(y_pred) self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) self._metrics = self._conform_to_outputs(y_pred, self._metrics) self._weighted_metrics = self._maybe_broadcast_to_outputs( y_pred, self._weighted_metrics) self._weighted_metrics = self._conform_to_outputs( y_pred, self._weighted_metrics) # Standardize on tuple since `tf.data` turns lists into `Tensor`s. y_pred = nest.list_to_tuple(y_pred) y_true = nest.list_to_tuple(y_true) self._metrics = nest.list_to_tuple(self._metrics) self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics) # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to( y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to(y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. # # If we are loading a model that has been already serialized, we do not # want to re-apply any pre-processing metric renaming steps. if not self._from_serialized: self._set_metric_names() self._create_ordered_metrics() self._built = True
def metropolis_hastings_step(current_state: State, proposed_state: State, energy_change: FloatTensor, log_uniform: FloatTensor = None, seed=None) -> Tuple[State, tf.Tensor, tf.Tensor]: """Metropolis-Hastings step. This probabilistically chooses between `current_state` and `proposed_state` based on the `energy_change` so as to preserve detailed balance. Energy change is the negative of `log_accept_ratio`. Args: current_state: Current state. proposed_state: Proposed state. energy_change: E(proposed_state) - E(previous_state). log_uniform: Optional logarithm of a uniformly distributed random sample in [0, 1]. It is used to accept/reject the current and proposed state. seed: For reproducibility. Returns: new_state: The chosen state. is_accepted: Whether the proposed state was accepted. log_uniform: The random number that was used to select between the two states. """ flat_current = tf.nest.flatten(current_state) flat_proposed = nest.flatten_up_to(current_state, proposed_state) # Impute the None's in the current state. flat_current = [ p if c is None else c for p, c in zip(flat_proposed, flat_current) ] current_state = tf.nest.pack_sequence_as(current_state, flat_current) current_state = tf.nest.map_structure(tf.convert_to_tensor, current_state) proposed_state = tf.nest.map_structure(tf.convert_to_tensor, proposed_state) energy_change = tf.convert_to_tensor(value=energy_change) log_accept_ratio = -energy_change if log_uniform is None: log_uniform = tf.math.log( tf.random.uniform(shape=tf.shape(input=log_accept_ratio), dtype=log_accept_ratio.dtype.base_dtype, seed=seed)) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') return next_state, is_accepted, log_uniform
def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): """Converts inputs to pass into a function with an explicit signature.""" def format_error_message(inputs, input_signature): return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) + ")\n" + " input_signature: (\n" + " " + ",\n ".join(str(i) for i in input_signature) + ")") try: flatten_inputs = nest.flatten_up_to( input_signature, inputs[:len(input_signature)], expand_composites=True, check_types=False) # lists are convert to tuples for `tf.data`. except ValueError: raise ValueError("Structure of Python function inputs does not match " "input_signature:\n" f"{format_error_message(inputs, input_signature)}.") need_packing = False for index, (value, spec) in enumerate(zip(flatten_inputs, flat_input_signature)): if (isinstance(spec, tensor_spec.TensorSpec) and not _pywrap_utils.IsTensor(value)): try: flatten_inputs[index] = ops.convert_to_tensor( value, dtype_hint=spec.dtype) need_packing = True except ValueError: raise ValueError( "When input_signature is provided, all inputs to " "the Python function must be convertible to " "tensors:\n" f"{format_error_message(inputs, input_signature)}.") if any(not spec.is_compatible_with(other) for spec, other in zip(flat_input_signature, flatten_inputs)): raise ValueError("Python inputs incompatible with input_signature:\n" f"{format_error_message(inputs, input_signature)}.") if need_packing: inputs = nest.pack_sequence_as(structure=input_signature, flat_sequence=flatten_inputs, expand_composites=True) flat_inputs = nest.flatten(inputs, expand_composites=True) return (inputs, flat_inputs, [ t for t in flat_inputs if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) ])
def _build(self, y_pred, y_true): """One-time setup of metric objects.""" if self._output_names is None: # Subclass output names like 'output_1' are used for `Metric` names. self._output_names = create_output_names(y_pred) # Accept a dict of metrics keyed by output_name when outputs are a flat # list. self._metrics = map_to_output_names(y_pred, self._output_names, self._metrics) self._weighted_metrics = map_to_output_names(y_pred, self._output_names, self._weighted_metrics) # If a single metric is supplied, apply to all outputs. self._metrics = self._maybe_broadcast(self._metrics, y_pred) self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics, y_pred) # Convert to `Metric` objects, potentially disambiguating based on output # properties. self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects, self._metrics, y_true, y_pred) self._weighted_metrics = nest.map_structure_up_to( y_pred, self._get_metric_objects, self._weighted_metrics, y_true, y_pred) self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False) self._weighted_metrics = nest.flatten_up_to(y_pred, self._weighted_metrics, check_types=False) # Assumes metrics, weighted_metrics have been flattened up to outputs. self._set_metric_names() self._built = True
def _sample_n(self, sample_shape, seed, value=None): value_might_have_sample_dims = ( value is not None and _might_have_excess_ndims( # Double-flatten in case any components have structured events. flat_value=nest.flatten_up_to(self._single_sample_ndims, self._model_flatten(value), check_types=False), flat_core_ndims=tf.nest.flatten(self._single_sample_ndims))) if not self.use_vectorized_map or not ( _might_have_nonzero_size(sample_shape) or value_might_have_sample_dims): # No need to auto-vectorize. xs = self._call_flat_sample_distributions( sample_shape=sample_shape, seed=seed, value=value)[1] return self._model_unflatten(xs) # Set up for autovectorized sampling. To support the `value` arg, we need to # first understand which dims are from the model itself, then wrap # `_call_flat_sample_distributions` to batch over all remaining dims. value_core_ndims = None if value is not None: value_core_ndims = tf.nest.map_structure( lambda v, nd: None if v is None else nd, value, self._model_unflatten(self._single_sample_ndims), check_types=False) batch_flat_sample = vectorization_util.make_rank_polymorphic( lambda v, seed: self._call_flat_sample_distributions( # pylint: disable=g-long-lambda sample_shape=(), seed=seed, value=v)[1], core_ndims=[value_core_ndims, None], validate_args=self.validate_args) # Draw samples. vectorized_flat_sample = vectorization_util.iid_sample( # Redefine the polymorphic fn to hack around `make_rank_polymorphic` # not currently supporting keyword args. lambda v, seed: batch_flat_sample(v, seed), sample_shape) # pylint: disable=unnecessary-lambda xs = vectorized_flat_sample(value, seed=seed) return self._model_unflatten(xs)
def generator_py_func(iterator_id): try: values = next(generator_state.get_iterator(iterator_id)) except StopIteration: generator_state.iterator_completed(iterator_id) raise StopIteration("Iteration finished.") ret_arrays = [script_ops.FuncRegistry._convert(ret) for ret in nest.flatten_up_to(output_types, values)] for (ret_array, expected_dtype, expected_shape) in zip(ret_arrays, flattened_types, flattened_shapes): if ret_array.dtype != expected_dtype.as_numpy_dtype: raise TypeError( "`generator` yielded an element of type %s where an element " "of type %s was expected." % ( ret_array.dtype, expected_dtype.as_numpy_dtype)) if not expected_shape.is_compatible_with(ret_array.shape): raise ValueError( "`generator` yielded an element of shape %s where an element " "of shape %s was expected." % (ret_array.shape, expected_shape)) return ret_arrays
def _to_obs_space_dtype(self, observation): """Make sure observation matches the specified space. Observation spaces in gym didn't have a dtype for a long time. Now that they do there is a large number of environments that do not follow the dtype in the space definition. Since we use the space definition to create the tensorflow graph we need to make sure observations match the expected dtypes. Args: observation: Observation to match the dtype on. Returns: The observation with a dtype matching the observation spec. """ # Make sure we handle cases where observations are provided as a list. flat_obs = nest.flatten_up_to(self._observation_spec, observation) matched_observations = [] for spec, obs in zip(self._flat_obs_spec, flat_obs): matched_observations.append(np.asarray(obs, dtype=spec.dtype)) return tf.nest.pack_sequence_as(self._observation_spec, matched_observations)
def unpack_structs_like(template, packed): """Converts a structure of tuples like `template` to a tuple of structures.""" return tuple(nest.pack_sequence_as(template, flat) for flat in zip(*nest.flatten_up_to(template, packed, check_types=False)))
def testBijector(self, bijector_name, data): tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991') bijector, event_dim = self._draw_bijector(bijector_name, data) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. xs = self._draw_domain_tensor(bijector, data, event_dim) wrt_vars = [xs] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # For scalar bijectors, verify correctness of the _is_increasing method. # TODO(b/148459057): Except, don't verify Softfloor on Guitar because # of numerical problem. def exception(bijector): if not tfp_hps.running_under_guitar(): return False if isinstance(bijector, tfb.Softfloor): return True if is_invert(bijector): return exception(bijector.bijector) return False if (bijector.forward_min_event_ndims == 0 and bijector.inverse_min_event_ndims == 0 and not exception(bijector)): dydx = grads[0] hp.note('dydx: {}'.format(dydx)) isfinite = tf.math.is_finite(dydx) incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal( dydx, 0) # pylint: disable=protected-access self.assertAllEqual( isfinite & incr_or_slope_eq0, isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0)) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=xs.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=True) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. ys = self._draw_codomain_tensor(bijector, data, event_dim) wrt_vars = [ys] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=ys.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=False) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.inverse_log_det_jacobian( ys + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads) # Verify that `_is_permutation` implies constant zero Jacobian. if bijector._is_permutation: self.assertTrue(bijector._is_constant_jacobian) self.assertAllEqual(ldj, 0.) # Verify correctness of batch shape. xs_batch_shapes = tf.nest.map_structure( lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs, bijector.inverse_event_ndims(event_ndims)) empirical_batch_shape = functools.reduce( ps.broadcast_shape, nest.flatten_up_to(bijector.forward_min_event_ndims, xs_batch_shapes)) batch_shape = bijector.experimental_batch_shape( y_event_ndims=event_ndims) if tensorshape_util.is_fully_defined(batch_shape): self.assertAllEqual(empirical_batch_shape, batch_shape) self.assertAllEqual( empirical_batch_shape, bijector.experimental_batch_shape_tensor( y_event_ndims=event_ndims)) # Check that the outputs of forward_dtype and inverse_dtype match the dtypes # of the outputs of forward and inverse. self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype)) self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
def generate_enqueue_ops(self, per_host_sharded_inputs): """Generates the host-side Ops to enqueue the partitioned inputs. per_host_sharded_inputs is a list, one for each replica, of lists of Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed replica i. sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. For example, if sharded_inputs[i][j] is a 2-D Tensor: [[A, B, C, D], [E ,F, G, H]] self._input_partition_dims[j] is [2, 4]. sharded_inputs[i][j] will be partitioned and flattened into: [A, B, C, D, E, F, G, H] and fed into the logical core ids: [0, 1, 2, 3, 4, 5, 6, 7] respectively. Args: per_host_sharded_inputs: a list of lists of Tensors. The length of the outer list determines the number of shards. Each inner list indicates the types and shapes of the tuples in the corresponding shard. Returns: A list of host-side Ops, one for each shard, that when executed together will enqueue a full-size element of infeed. Raises: ValueError: if the queue configuration has previously been frozen and the shapes of the elements of sharded_inputs are not compatible with the frozen configuration; or if the shapes of the elements of sharded_inputs don't form a consistent unsharded tuple; or if the elements of a tuple have different device constraints; or if the partition dims are invalid. TypeError: if the queue configuration has previously been frozen and the types of the elements of sharded_inputs are not compatible with the frozen configuration; or if the types of the elements of sharded_inputs don't form a consistent unsharded tuple. """ self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) number_of_replicas_per_host = len(per_host_sharded_inputs) number_of_tuple_elements = len(per_host_sharded_inputs[0]) assert len(self._input_partition_dims) == number_of_tuple_elements per_host_enqueue_ops = [] for replica_index in range(number_of_replicas_per_host): flattened_inputs = per_host_sharded_inputs[replica_index] inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, self._input_partition_dims) inputs_parted_iters = [ iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) ] for core_index in xrange(self._device_assignment.num_cores_per_replica): # Places different partitions to different logic cores. logical_core = self._get_logical_core(core_index) replica_id = self._device_assignment.lookup_replicas( self._host_id, logical_core)[replica_index] ordinal = self._device_assignment.tpu_ordinal( replica=replica_id, logical_core=logical_core) infeed_inputs = [] for it in inputs_parted_iters: input_for_device = next(it, None) if input_for_device is not None: infeed_inputs.append(input_for_device) if infeed_inputs: per_host_enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=infeed_inputs, shapes=[x.shape for x in infeed_inputs], name="enqueue/replica_{0}/input_{1}".format( replica_index, core_index), device_ordinal=ordinal)) return per_host_enqueue_ops
def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) input_tree_flattened = nest.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, {"c": 2}, 3, (4, 5)]) # Namedtuples. ab_tuple = NestTest.ABTuple input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict( [("a", ab_tuple(a=[0, {"b": 1}], b=2)), ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) shallow_tree = input_tree input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ab_tuple(a=[0, {"b": 1}], b=2), 3, collections.OrderedDict([("f", 4)])]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ab_tuple(a=[0, {"b": 1}], b=2), {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] expected_message = ("If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'str'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] expected_message = ("If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'int'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree)
def testFlattenUpTo(self): input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) input_tree_flattened = nest.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] expected_message = ( "If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'str'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] expected_message = ( "If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'int'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree)
def _wrap_for_composites(func, inp, Tout): """Wraps user inputs to support composite tensors for `py_function`. 1. Flattens `inp` to a list of Tensors (by flattening any composite tensors). 2. Creates a wrapper fuction for `func` that expects flat inputs and: - Packs the inputs into the input structure expected by `func`. - Calls `func` with the packed inputs. - Checks that `func`'s output matches `Tout`. - Flattens func`'s output to a list of Tensors (flattening any composite tensors). Args: func: The function to wrap (`func` argument to `py_function`). inp: The input arguments for func (`inp` argument to `py_function`). Tout: The expected output types for func (`Tout` argument to `py_function). Returns: A tuple `(func, inp, Tout, out_structure)`, where `func` is the wrapped function, `inp` is the flattened inputs, `Tout` is the list of expected dtypes for the flattened outputs, and `out_structure` is the expected output structure (which can be used to pack the output tensors). """ in_structure = [ v if isinstance(v, composite_tensor.CompositeTensor) else 1 for v in inp ] inp = nest.flatten_up_to(in_structure, inp, expand_composites=True) out_structure = Tout Tout = [ v.dtype if isinstance(v, tensor_spec.TensorSpec) else v for v in nest.flatten(Tout, expand_composites=True) ] def wrapped_func(*flat_inp): structured_inp = nest.pack_sequence_as(in_structure, flat_inp, expand_composites=True) out = func(*structured_inp) if not out_structure: return [] # Ignore return value if none is requested/expected. if not isinstance(out, (list, tuple)): out = [out] # func may return a single value instead of a list. flat_out = [] for elt, expected_type in zip(out, out_structure): if (isinstance(expected_type, type_spec.TypeSpec) and not isinstance(expected_type, tensor_spec.TensorSpec)): if not expected_type.is_compatible_with(elt): # pylint: disable=protected-access raise ValueError( f"py_function: func={func} returned {out!r}, " f"which did not match Tout={out_structure!r}.\nIn particular, " f"{elt!r} is not compatible with {expected_type!r}.") flat_out.extend(nest.flatten(elt, expand_composites=True)) else: # Pro-actively check if the return value is a composite tensor when # we expect a Tensor. We would catch this later (when we call # convert_to_tensor), but checking it here lets us give a better # error message. if isinstance(elt, composite_tensor.CompositeTensor): raise ValueError( f"py_function: func={func} returned {out!r}, " f"which did not match Tout={out_structure!r}.\nIn particular, " f"{elt!r} is not a Tensor.") flat_out.append(elt) return flat_out return wrapped_func, inp, Tout, out_structure
def build_factored_surrogate_posterior( event_shape=None, constraining_bijectors=None, initial_unconstrained_loc=_sample_uniform_initial_loc, initial_unconstrained_scale=1e-2, trainable_distribution_fn=_build_trainable_normal_dist, seed=None, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. constraining_bijectors: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, constraining_bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. initial_unconstrained_loc: Optional Python `callable` with signature `tensor = initial_unconstrained_loc(shape, seed)` used to sample real-valued initializations for the unconstrained representation of each variable. May alternately be a nested structure of `Tensor`s, giving specific initial locations for each variable; these must have structure matching `event_shape` and shapes determined by the inverse image of `event_shape` under `constraining_bijectors`, which may optionally be prefixed with a common batch shape. Default value: `functools.partial(tf.random.uniform, minval=-2., maxval=2., dtype=tf.float32)`. initial_unconstrained_scale: Optional scalar float `Tensor` initial scale for the unconstrained distributions, or a nested structure of `Tensor` initial scales for each variable. Default value: `1e-2`. trainable_distribution_fn: Optional Python `callable` with signature `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale, event_ndims, validate_args)`. This is called for each model variable to build the corresponding factor in the surrogate posterior. It is expected that the distribution returned is supported on unconstrained real values. Default value: `functools.partial( tfp.experimental.vi.build_trainable_location_scale_distribution, distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution. seed: Python integer to seed the random number generator. This is used only when `initial_loc` is not specified. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Returns: surrogate_posterior: A `tfd.Distribution` instance whose samples have shape and structure matching that of `event_shape` or `initial_loc`. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. constraining_bijectors=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify one when we build the surrogate posterior. This function requires the initial location to be specified in *unconstrained* space; we do this by inverting the constraining bijectors (note this section also demonstrates the creation of a dict-structured model). ```python initial_loc = {'concentration': 0.4, 'rate': 0.2} constraining_bijectors={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_unconstrained_loc = tf.nest.map_fn( lambda b, x: b.inverse(x) if b is not None else x, constraining_bijectors, initial_loc) surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), constraining_bijectors=constraining_bijectors, initial_unconstrained_loc=initial_unconstrained_state, initial_unconstrained_scale=1e-4) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior') # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) flat_event_shapes = tf.nest.flatten(event_shape) # For simplicity, we'll work with flattened lists of state parts and # repack the structure at the end. if constraining_bijectors is not None: flat_bijectors = tf.nest.flatten(constraining_bijectors) else: flat_bijectors = [None for _ in flat_event_shapes] flat_unconstrained_event_shapes = [ b.inverse_event_shape_tensor(s) if b is not None else s for s, b in zip(flat_event_shapes, flat_bijectors) ] # Construct initial locations for the internal unconstrained dists. if callable( initial_unconstrained_loc): # Sample random initialization. flat_unconstrained_locs = [ initial_unconstrained_loc(shape=s, seed=seed()) for s in flat_unconstrained_event_shapes ] else: # Use provided initialization. flat_unconstrained_locs = nest.flatten_up_to( shallow_structure, initial_unconstrained_loc, check_types=False) if nest.is_nested(initial_unconstrained_scale): flat_unconstrained_scales = nest.flatten_up_to( shallow_structure, initial_unconstrained_scale, check_types=False) else: flat_unconstrained_scales = [ initial_unconstrained_scale for _ in flat_unconstrained_locs ] # Extract the rank of each event, so that we build distributions with the # correct event shapes. flat_unconstrained_event_ndims = [ prefer_static.rank_from_shape(s) for s in flat_unconstrained_event_shapes ] # Build the component surrogate posteriors. flat_component_dists = [] for initial_loc, initial_scale, event_ndims, bijector in zip( flat_unconstrained_locs, flat_unconstrained_scales, flat_unconstrained_event_ndims, flat_bijectors): unconstrained_dist = trainable_distribution_fn( initial_loc=initial_loc, initial_scale=initial_scale, event_ndims=event_ndims, validate_args=validate_args) flat_component_dists.append( bijector(unconstrained_dist ) if bijector is not None else unconstrained_dist) component_distributions = tf.nest.pack_sequence_as( event_shape, flat_component_dists) # Return a `Distribution` object whose events have the specified structure. return (joint_distribution_util. independent_joint_distribution_from_structure( component_distributions, validate_args=validate_args))
def _model_flatten(self, xs): if self._sample_dtype is None: return tuple(xs) return nest.flatten_up_to(self._sample_dtype, xs)
def testFlattenUpTo(self): input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) input_tree_flattened = nest.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] expected_message = ("If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'str'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] expected_message = ("If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'int'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree)
def _model_flatten(self, xs): if self._sample_dtype is None: return tuple((xs[k] for k in self._flat_resolve_names()) if isinstance(xs, collections.Mapping) else xs) return nest.flatten_up_to(self._sample_dtype, xs)
def _setup_mcmc(model, n_chains, *, init_position=None, seed=None, **pins): """Construct bijector and transforms needed for windowed MCMC. This pins the initial model, constructs a bijector that unconstrains and flattens each dimension and adds a leading batch shape of `n_chains`, initializes a point in the unconstrained space, and constructs a transformed log probability using the bijector. Note that we must manually construct this target log probability instead of using a transformed transition kernel because the TTK assumes the shape in is the same as the shape out. Args: model: `tfd.JointDistribution` The model to sample from. n_chains: list of ints Number of chains (independent examples) to run. init_position: Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as `model.experimental_pin(**pins).sample_unpinned(n_chains)`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. **pins: Values passed to `model.experimental_pin`. Returns: target_log_prob_fn: Callable on the transformed space. initial_transformed_position: `tf.Tensor`, sampled from a uniform (-2, 2). bijector: `tfb.Bijector` instance, which unconstrains and flattens. step_broadcast_fn: Callable to broadcast step size over latent structure. batch_shape: Batch shape of the model. shard_axis_names: Shard axis names for the model """ pinned_model = model.experimental_pin(**pins) if pins else model bijector, step_bijector = _get_flat_unconstraining_bijector(pinned_model) if init_position is None: raw_init_dist = initialization.init_near_unconstrained_zero( pinned_model) init_position = initialization.retry_init( raw_init_dist.sample, target_fn=pinned_model.unnormalized_log_prob, sample_shape=n_chains, seed=seed) initial_transformed_position = tf.nest.map_structure( tf.identity, bijector.forward(init_position)) batch_shape = pinned_model.batch_shape if tf.nest.is_nested(batch_shape): batch_shape = functools.reduce(tf.broadcast_static_shape, tf.nest.flatten(batch_shape)) if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = pinned_model.batch_shape_tensor() if tf.nest.is_nested(batch_shape): batch_shape = functools.reduce(tf.broadcast_dynamic_shape, tf.nest.flatten(batch_shape)) # This tf.function is not redundant with the ones on _fast_window # and _slow_window because the various kernels (like HMC) may invoke # `target_log_prob_fn` multiple times within one window. @tf.function(autograph=False) def target_log_prob_fn(*args): lp = pinned_model.unnormalized_log_prob(bijector.inverse(args)) ldj = bijector.inverse_log_det_jacobian( args, event_ndims=[1 for _ in initial_transformed_position]) return lp + ldj def step_broadcast(step_size): # Only apply the bijector to nested step sizes or non-scalar batches. if tf.nest.is_nested(step_size): return step_bijector( nest_util.broadcast_structure( pinned_model.event_shape_tensor(), step_size)) else: return step_size shard_axis_names = pinned_model.experimental_shard_axis_names if any(tf.nest.flatten(shard_axis_names)): shard_axis_names = nest.flatten_up_to( initial_transformed_position, list(pinned_model._model_flatten(shard_axis_names))) # pylint: disable=protected-access else: # No active shard axis names shard_axis_names = None return (target_log_prob_fn, initial_transformed_position, bijector, step_broadcast, ps.convert_to_shape_tensor(batch_shape, name='batch_shape'), shard_axis_names)