def read_summaries(event_dir, event_file_pattern="events.out.tfevents.*"): """Reads summaries from TensorFlow event files. Args: event_dir: Directory containing event files. event_file_pattern: The pattern to look for event files. Returns: A list of tuple (step, dict of summaries), sorted by step. """ if not tf.io.gfile.exists(event_dir): return [] summaries = collections.defaultdict(dict) for event_file in tf.io.gfile.glob( os.path.join(event_dir, event_file_pattern)): for event in tf.compat.v1.train.summary_iterator(event_file): if not event.HasField("summary"): continue for value in event.summary.value: tensor_proto = value.tensor tensor = tf.io.parse_tensor(tensor_proto.SerializeToString(), tf.as_dtype(tensor_proto.dtype)) summaries[event.step][value.tag] = tf.get_static_value(tensor) return list(sorted(summaries.items(), key=lambda x: x[0]))
def testShapes(self): for batch_shape in ([], [1], [2, 3, 4]): dist = make_categorical(batch_shape, 10) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor())) self.assertAllEqual([], dist.event_shape) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) self.assertEqual(10, self.evaluate(dist.num_categories)) # num_categories is available as a constant because the shape is # known at graph build time. self.assertEqual(10, tf.get_static_value(dist.num_categories)) for batch_shape in ([], [1], [2, 3, 4]): dist = make_categorical( batch_shape, tf.constant( 10, dtype=tf.int32)) self.assertAllEqual( len(batch_shape), tensorshape_util.rank(dist.batch_shape)) self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor())) self.assertAllEqual([], dist.event_shape) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) self.assertEqual(10, self.evaluate(dist.num_categories))
def potential_scale_reduction(chains_states, independent_chain_ndims=1, name=None): """Gelman and Rubin (1992)'s potential scale reduction for chain convergence. Given `N > 1` states from each of `C > 1` independent chains, the potential scale reduction factor, commonly referred to as R-hat, measures convergence of the chains (to the same target) by testing for equality of means. Specifically, R-hat measures the degree to which variance (of the means) between chains exceeds what one would expect if the chains were identically distributed. See [Gelman and Rubin (1992)][1]; [Brooks and Gelman (1998)][2]. Some guidelines: * The initial state of the chains should be drawn from a distribution overdispersed with respect to the target. * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1. Before that, R-hat > 1 (except in pathological cases, e.g. if the chain paths were identical). * The above holds for any number of chains `C > 1`. Increasing `C` does improves effectiveness of the diagnostic. * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of course this is problem dependent. See [Brooks and Gelman (1998)][2]. * R-hat only measures non-convergence of the mean. If higher moments, or other statistics are desired, a different diagnostic should be used. See [Brooks and Gelman (1998)][2]. Args: chains_states: `Tensor` or Python `list` of `Tensor`s representing the state(s) of a Markov Chain at each result step. The `ith` state is assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent chains to be tested for convergence to the same target. The remaining dimensions, `A`, can have any shape (even empty). independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the number of giving the number of dimensions, from `dim = 1` to `dim = D`, holding independent chain results to be tested for convergence. name: `String` name to prepend to created tf. Default: `potential_scale_reduction`. Returns: `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for the state(s). Same `dtype` as `state`, and shape equal to `state.shape[1 + independent_chain_ndims:]`. Raises: ValueError: If `independent_chain_ndims < 1`. #### Examples Diagnosing convergence by monitoring 10 chains that each attempt to sample from a 2-variate normal. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) # Get 10 (2x) overdispersed initial states. initial_state = target.sample(10) * 2. ==> (10, 2) # Get 1000 samples from the 10 independent chains. chains_states, _ = tfp.mcmc.sample_chain( num_burnin_steps=200, num_results=1000, current_state=initial_state, kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target.log_prob, step_size=0.05, num_leapfrog_steps=20)) chains_states.shape ==> (1000, 10, 2) rhat = tfp.mcmc.diagnostic.potential_scale_reduction( chains_states, independent_chain_ndims=1) # The second dimension needed a longer burn-in. rhat.eval() ==> [1.05, 1.3] ``` To see why R-hat is reasonable, let `X` be a random variable drawn uniformly from the combined states (combined over all chains). Then, in the limit `N, C --> infinity`, with `E`, `Var` denoting expectation and variance, ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].``` Using the law of total variance, the numerator is the variance of the combined states, and the denominator is the total variance minus the variance of the the individual chain means. If the chains are all drawing from the same distribution, they will have the same mean, and thus the ratio should be one. #### References [1]: Stephen P. Brooks and Andrew Gelman. General Methods for Monitoring Convergence of Iterative Simulations. _Journal of Computational and Graphical Statistics_, 7(4), 1998. [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. """ chains_states_was_list = _is_list_like(chains_states) if not chains_states_was_list: chains_states = [chains_states] # tf.get_static_value returns None iff a constant value (as a numpy # array) is not efficiently computable. Therefore, we try constant_value then # check for None. icn_const_ = tf.get_static_value( tf.convert_to_tensor(value=independent_chain_ndims)) if icn_const_ is not None: independent_chain_ndims = icn_const_ if icn_const_ < 1: raise ValueError( 'Argument `independent_chain_ndims` must be `>= 1`, found: {}'. format(independent_chain_ndims)) with tf.compat.v1.name_scope(name, 'potential_scale_reduction'): rhat_list = [ _potential_scale_reduction_single_state(s, independent_chain_ndims) for s in chains_states ] if chains_states_was_list: return rhat_list return rhat_list[0]
def __init__(self, perm=None, rightmost_transposed_ndims=None, validate_args=False, name='transpose'): """Instantiates the `Transpose` bijector. Args: perm: Positive `int32` vector-shaped `Tensor` representing permutation of rightmost dims (for forward transformation). Note that the `0`th index represents the first of the rightmost dims and the largest value must be `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`. rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor` representing the number of rightmost dimensions to permute. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.size(perm)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are specified. NotImplementedError: if `rightmost_transposed_ndims` is not known prior to graph execution. """ with tf.name_scope(name, values=[perm, rightmost_transposed_ndims]): if (rightmost_transposed_ndims is None) == (perm is None): raise ValueError('Must specify exactly one of ' '`rightmost_transposed_ndims` and `perm`.') if rightmost_transposed_ndims is not None: rightmost_transposed_ndims = tf.convert_to_tensor( value=rightmost_transposed_ndims, dtype=np.int32, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) with tf.control_dependencies( _maybe_validate_rightmost_transposed_ndims( rightmost_transposed_ndims, validate_args)): rightmost_transposed_ndims = tf.identity( rightmost_transposed_ndims) perm = tf.range(start=rightmost_transposed_ndims - 1, limit=-1, delta=-1, name='perm') else: # perm is not None: perm = tf.convert_to_tensor(value=perm, dtype=np.int32, name='perm') rightmost_transposed_ndims = tf.size( input=perm, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) with tf.control_dependencies( _maybe_validate_perm(perm, validate_args)): perm = tf.identity(perm) # TODO(b/110828604): If bijector base class ever supports dynamic # `min_event_ndims`, then this class already works dynamically and the # following five lines can be removed. if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') else: rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_) self._perm = perm self._rightmost_transposed_ndims = rightmost_transposed_ndims super(Transpose, self).__init__( forward_min_event_ndims=rightmost_transposed_ndims_, graph_parents=[perm, rightmost_transposed_ndims], is_constant_jacobian=True, validate_args=validate_args, name=name)
def _sample_n(self, n, seed=None): if self._use_static_graph: with tf.control_dependencies(self._assertions): # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = seed_stream.SeedStream(seed, salt="Mixture") for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) x = tf.stack(samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] npdt = x.dtype.as_numpy_dtype mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, self._static_event_shape.ndims) # [n, B, k, [1]*e] return tf.reduce_sum( input_tensor=x * mask, axis=-1 - self._static_event_shape.ndims) # [n, B, E] with tf.control_dependencies(self._assertions): n = tf.convert_to_tensor(value=n, name="n") static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = tf.shape(input=cat_samples) samples_size = tf.size(input=cat_samples) static_batch_shape = self.batch_shape if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(input_tensor=batch_shape) static_event_shape = self.event_shape if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape( tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = seed_stream.SeedStream(seed, salt="Mixture") for c in range(self.num_components): n_class = tf.size(input=partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape( lhs_flat_ret, tf.concat( [samples_shape, self.event_shape_tensor()], 0)) ret.set_shape( tf.TensorShape(static_samples_shape).concatenate( self.event_shape)) return ret
def __init__(self, component_ssms, observation_noise_scale=None, initial_state_prior=None, initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """Build a state space model representing the sum of component models. Args: component_ssms: Python `list` containing one or more `tfd.LinearGaussianStateSpaceModel` instances. The components will in general implement different time-series models, with possibly different `latent_size`, but they must have the same `dtype`, event shape (`num_timesteps` and `observation_size`), and their batch shapes must broadcast to a compatible batch shape. observation_noise_scale: Optional scalar `float` `Tensor` indicating the standard deviation of the observation noise. May contain additional batch dimensions, which must broadcast with the batch shape of elements in `component_ssms`. If `observation_noise_scale` is specified for the `AdditiveStateSpaceModel`, the observation noise scales of component models are ignored. If `None`, the observation noise scale is derived by summing the noise variances of the component models, i.e., `observation_noise_scale = sqrt(sum( [ssm.observation_noise_scale**2 for ssm in component_ssms]))`. initial_state_prior: Optional instance of `tfd.MultivariateNormal` representing a prior distribution on the latent state at time `initial_step`. If `None`, defaults to the independent priors from component models, i.e., `[component.initial_state_prior for component in component_ssms]`. Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". Raises: ValueError: if components have different `num_timesteps`. """ with tf.compat.v1.name_scope( name, 'AdditiveStateSpaceModel', values=[observation_noise_scale, initial_step]) as name: assertions = [] # Check that all components have the same dtype tf.debugging.assert_same_float_dtype(component_ssms) # Construct an initial state prior as a block-diagonal combination # of the component state priors. if initial_state_prior is None: initial_state_prior = sts_util.factored_joint_mvn( [ssm.initial_state_prior for ssm in component_ssms]) dtype = initial_state_prior.dtype static_num_timesteps = [ tf.get_static_value(ssm.num_timesteps) for ssm in component_ssms if tf.get_static_value(ssm.num_timesteps) is not None ] # If any components have a static value for `num_timesteps`, use that # value for the additive model. (and check that all other static values # match it). if static_num_timesteps: num_timesteps = static_num_timesteps[0] if not all([component_timesteps == num_timesteps for component_timesteps in static_num_timesteps]): raise ValueError('Additive model components must all have the same ' 'number of timesteps ' '(saw: {})'.format(static_num_timesteps)) else: num_timesteps = component_ssms[0].num_timesteps if validate_args and len(static_num_timesteps) != len(component_ssms): assertions += [ tf.compat.v1.assert_equal( num_timesteps, ssm.num_timesteps, message='Additive model components must all have ' 'the same number of timesteps.') for ssm in component_ssms ] # Define the transition and observation models for the additive SSM. # See the "mathematical details" section of the class docstring for # further information. Note that we define these as callables to # handle the fully general case in which some components have time- # varying dynamics. def transition_matrix_fn(t): return tfl.LinearOperatorBlockDiag( [ssm.get_transition_matrix_for_timestep(t) for ssm in component_ssms]) def transition_noise_fn(t): return sts_util.factored_joint_mvn( [ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms]) # Build the observation matrix, concatenating (broadcast) observation # matrices from components. We also take this as an opportunity to enforce # any dynamic assertions we may have generated above. broadcast_batch_shape = tf.convert_to_tensor( value=sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32) broadcast_obs_matrix = tf.ones( tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype) if assertions: with tf.control_dependencies(assertions): broadcast_obs_matrix = tf.identity(broadcast_obs_matrix) def observation_matrix_fn(t): return tfl.LinearOperatorFullMatrix( tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() * broadcast_obs_matrix for ssm in component_ssms], axis=-1)) if observation_noise_scale is not None: observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) def observation_noise_fn(t): return tfd.MultivariateNormalDiag( loc=sum([ssm.get_observation_noise_for_timestep(t).mean() for ssm in component_ssms]), scale_diag=observation_noise_scale[..., tf.newaxis]) else: def observation_noise_fn(t): return sts_util.sum_mvns( [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms]) super(AdditiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=transition_matrix_fn, transition_noise=transition_noise_fn, observation_matrix=observation_matrix_fn, observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def _model_fn(features, labels, mode, params): """Model defination for the Mask-RCNN model based on ResNet. Args: features: the input image tensor and auxiliary information, such as `image_info` and `source_ids`. The image tensor has a shape of [batch_size, height, width, 3]. The height and width are fixed and equal. labels: the input labels in a dictionary. The labels include score targets and box targets which are dense label maps. The labels are generated from get_input_fn function in data/dataloader.py mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT. params: the dictionary defines hyperparameters of model. The default settings are in default_hparams function in this file. Returns: tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction. """ # Set up training loss and learning rate. global_step = tf.compat.v1.train.get_or_create_global_step() if mode == tf.estimator.ModeKeys.PREDICT: if params['include_groundtruth_in_features'] and 'labels' in features: # In include groundtruth for eval. labels = features['labels'] else: labels = None if 'features' in features: features = features['features'] # Otherwise, it is in export mode, the features is past in directly. model_outputs = build_model_graph(features, labels, mode == tf.estimator.ModeKeys.TRAIN, params) model_outputs.update({ 'source_id': features['source_ids'], 'image_info': features['image_info'], }) if mode == tf.estimator.ModeKeys.PREDICT and 'orig_images' in features: model_outputs['orig_images'] = features['orig_images'] # First check if it is in PREDICT mode or EVAL mode to fill out predictions. # Predictions are used during the eval step to generate metrics. if mode in [tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.EVAL]: predictions = {} try: model_outputs['orig_images'] = features['orig_images'] except KeyError: pass if labels and params['include_groundtruth_in_features']: # Labels can only be embedded in predictions. The prediction cannot output # dictionary as a value. predictions.update(labels) model_outputs.pop('fpn_features', None) predictions.update(model_outputs) if mode == tf.estimator.ModeKeys.PREDICT: # If we are doing PREDICT, we can return here. return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # score_loss and box_loss are for logging. only total_loss is optimized. total_rpn_loss, rpn_score_loss, rpn_box_loss = losses.rpn_loss( score_outputs=model_outputs['rpn_score_outputs'], box_outputs=model_outputs['rpn_box_outputs'], labels=labels, params=params) total_fast_rcnn_loss, fast_rcnn_class_loss, fast_rcnn_box_loss = losses.fast_rcnn_loss( class_outputs=model_outputs['class_outputs'], box_outputs=model_outputs['box_outputs'], class_targets=model_outputs['class_targets'], box_targets=model_outputs['box_targets'], params=params) # Only training has the mask loss. # Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/model_builder.py if mode == tf.estimator.ModeKeys.TRAIN and params['include_mask']: mask_loss = losses.mask_rcnn_loss( mask_outputs=model_outputs['mask_outputs'], mask_targets=model_outputs['mask_targets'], select_class_targets=model_outputs['selected_class_targets'], params=params) else: mask_loss = 0. trainable_variables = list( itertools.chain.from_iterable( [model.trainable_variables for model in MODELS.values()])) l2_regularization_loss = params['l2_weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in trainable_variables if not any([ pattern in v.name for pattern in ["batch_normalization", "bias", "beta"] ]) ]) total_loss = total_rpn_loss + total_fast_rcnn_loss + mask_loss + l2_regularization_loss if mode == tf.estimator.ModeKeys.EVAL: # Predictions can only contain a dict of tensors, not a dict of dict of # tensors. These outputs are not used for eval purposes. del predictions['rpn_score_outputs'] del predictions['rpn_box_outputs'] return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=total_loss) if mode == tf.estimator.ModeKeys.TRAIN: learning_rate = learning_rates.step_learning_rate_with_linear_warmup( global_step=global_step, init_learning_rate=params['init_learning_rate'], warmup_learning_rate=params['warmup_learning_rate'], warmup_steps=params['warmup_steps'], learning_rate_levels=params['learning_rate_levels'], learning_rate_steps=params['learning_rate_steps']) optimizer = create_optimizer(learning_rate, params) grads_and_vars = optimizer.compute_gradients( total_loss, trainable_variables, colocate_gradients_with_ops=True) gradients, variables = zip(*grads_and_vars) grads_and_vars = [] # Special treatment for biases (beta is named as bias in reference model) # Reference: https://github.com/ddkang/Detectron/blob/80f3295308/lib/modeling/optimizer.py#L109 for grad, var in zip(gradients, variables): if grad is not None and any( [pattern in var.name for pattern in ["bias", "beta"]]): grad = 2.0 * grad grads_and_vars.append((grad, var)) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) else: train_op = None learning_rate = None replica_id = tf.distribute.get_replica_context().replica_id_in_sync_group if not isinstance(replica_id, tf.Tensor) or tf.get_static_value(replica_id) == 0: register_metric(name="L2 loss", tensor=l2_regularization_loss, aggregator=StandardMeter()) register_metric(name="Mask loss", tensor=mask_loss, aggregator=StandardMeter()) register_metric(name="Total loss", tensor=total_loss, aggregator=StandardMeter()) register_metric(name="RPN box loss", tensor=rpn_box_loss, aggregator=StandardMeter()) register_metric(name="RPN score loss", tensor=rpn_score_loss, aggregator=StandardMeter()) register_metric(name="RPN total loss", tensor=total_rpn_loss, aggregator=StandardMeter()) register_metric(name="FastRCNN class loss", tensor=fast_rcnn_class_loss, aggregator=StandardMeter()) register_metric(name="FastRCNN box loss", tensor=fast_rcnn_box_loss, aggregator=StandardMeter()) register_metric(name="FastRCNN total loss", tensor=total_fast_rcnn_loss, aggregator=StandardMeter()) register_metric(name="Learning rate", tensor=learning_rate, aggregator=StandardMeter()) pass return tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, )
def testTrainWithSparseTensorAndDenseFeaturesLayer(self, agent_class): obs_spec = { 'dense': tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=[3], minimum=-10.0, maximum=10.0), 'sparse_terms': tf.SparseTensorSpec(dtype=tf.string, shape=[4]), 'sparse_frequencies': tf.SparseTensorSpec(dtype=tf.float32, shape=[4]), } cat_column = ( tf.compat.v2.feature_column.categorical_column_with_hash_bucket( 'sparse_terms', hash_bucket_size=5)) weighted_cat_column = ( tf.compat.v2.feature_column.weighted_categorical_column( cat_column, weight_feature_key='sparse_frequencies')) feature_columns = [ tf.compat.v2.feature_column.numeric_column('dense', [3]), tf.compat.v2.feature_column.embedding_column( weighted_cat_column, 3), ] dense_features_layer = tf.compat.v2.keras.layers.DenseFeatures( feature_columns) time_step_spec = ts.time_step_spec(obs_spec) q_net = q_network.QNetwork(time_step_spec.observation, self._action_spec, preprocessing_combiner=dense_features_layer) agent = agent_class(time_step_spec, self._action_spec, q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer()) observations = tensor_spec.sample_spec_nest(obs_spec, outer_dims=[5, 2]) # sparse_terms and sparse_frequencies must be defined on matching indices. observations['sparse_terms'] = tf.SparseTensor( indices=observations['sparse_frequencies'].indices, values=tf.as_string( tf.math.round(observations['sparse_frequencies'].values)), dense_shape=observations['sparse_frequencies'].dense_shape) if not tf.executing_eagerly(): # Mimic unknown inner dims on the SparseTensor def _unknown_inner_shape(t): if not isinstance(t, tf.SparseTensor): return t return tf.SparseTensor( indices=t.indices, values=t.values, dense_shape=tf.compat.v1.placeholder_with_default( t.dense_shape, shape=t.dense_shape.shape)) observations = tf.nest.map_structure(_unknown_inner_shape, observations) self.assertIsNone( tf.get_static_value(observations['sparse_terms'].dense_shape)) time_step = ts.restart(observations, batch_size=[5, 2]) action_step = tensor_spec.sample_spec_nest(self._action_spec, outer_dims=[5, 2]) p_step = policy_step.PolicyStep(action=action_step, state=(), info=()) traj = trajectory.from_transition(time_step, p_step, time_step) loss_info = agent.train(traj) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_info = self.evaluate(loss_info) self.assertGreater(loss_info.loss, 0)
def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [tf.get_static_value(x) for x in args] if any(vec is None for vec in args_): return tf.concat(args, axis=0) return [val for vec in args_ for val in vec]
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name, values=[x]): x = tf.convert_to_tensor(value=x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean( input_tensor=x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.math.ceil( tf.math.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError('Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(value=max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype.real_dtype) max_lags = tf.cast(max_lags, dtype.real_dtype) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def _maybe_get_static_args(args): flat_args = tf.compat.v2.nest.flatten(args) flat_args_ = [tf.get_static_value(a) for a in flat_args] all_static = all(arg is None or arg_ is not None for arg, arg_ in zip(flat_args, flat_args_)) return tf.compat.v2.nest.pack_sequence_as(args, flat_args_), all_static
def sample_shape(self): """Sample shape of random variable as a `TensorShape`.""" if tf.is_tensor(self._sample_shape): return tf.TensorShape(tf.get_static_value(self._sample_shape)) return tf.TensorShape(self._sample_shape)
def features_map_fn(features, local_radius, relative_pos_max_distance, use_hard_g2l_mask, padding_id, eos_id, null_id, cls_id, sep_id, sequence_length, global_sequence_length): """Make features.""" batch_size = tf.get_static_value(features['token_ids'].shape[0]) # sequence_lengths = features['token_ids'].row_lengths() question_lengths = tf.argmax( tf.equal( features['token_ids'].to_tensor( shape=(batch_size, global_sequence_length)), sep_id), -1) + 1 mapped_features = dict( token_ids=tf.cast( features['token_ids'].to_tensor(shape=(batch_size, sequence_length)), tf.int32), global_token_ids=tf.cast( features['global_token_ids'].to_tensor( shape=(batch_size, global_sequence_length)), tf.int32), segment_ids=tf.cast( features['segment_ids'].to_tensor(shape=(batch_size, sequence_length)), tf.int32), ) relative_pos_generator = RelativePositionGenerator( max_distance=relative_pos_max_distance) # Only do long-to-long attention for non-null tokens. # Let the null token attend to itself. l2l_att_mask = tf.ones((batch_size, sequence_length, 2 * local_radius + 1), tf.int32) l2l_att_mask *= 1 - tf.cast( tf.logical_or(tf.equal(mapped_features['token_ids'], padding_id), tf.equal(mapped_features['token_ids'], null_id)), tf.int32)[:, :, tf.newaxis] l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids( seq_len=sequence_length, local_radius=local_radius, batch_size=batch_size) # l2g_att_mask = tf.ones( (batch_size, sequence_length, global_sequence_length), tf.int32) l2g_att_mask *= tf.cast( tf.not_equal(mapped_features['token_ids'], padding_id), tf.int32)[:, :, tf.newaxis] l2g_att_mask *= tf.cast( tf.not_equal(mapped_features['global_token_ids'], padding_id), tf.int32)[:, tf.newaxis, :] l2g_relative_att_ids = tf.fill( (batch_size, sequence_length, global_sequence_length), relative_pos_generator.relative_vocab_size + 1) # g2g_att_mask = tf.ones( (batch_size, global_sequence_length, global_sequence_length), tf.int32) g2g_att_mask *= tf.cast( tf.not_equal(mapped_features['global_token_ids'], padding_id), tf.int32)[:, :, tf.newaxis] g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids( seq_len=global_sequence_length, batch_size=batch_size) global_sentence_mask = tf.equal(mapped_features['global_token_ids'], eos_id) global_question_mask = tf.logical_not( tf.logical_or( tf.logical_or( tf.equal(mapped_features['global_token_ids'], cls_id), tf.equal(mapped_features['global_token_ids'], eos_id)), tf.equal(mapped_features['global_token_ids'], padding_id))) g2g_question_mask = tf.logical_and(global_question_mask[:, tf.newaxis, :], global_question_mask[:, :, tf.newaxis]) g2g_sentence_mask = tf.logical_and(global_sentence_mask[:, tf.newaxis, :], global_sentence_mask[:, :, tf.newaxis]) g2g_local_mask = tf.cast( tf.logical_or(g2g_question_mask, g2g_sentence_mask), tf.int32) g2g_relative_att_ids *= g2g_local_mask g2g_relative_att_ids += (1 - g2g_local_mask) * ( relative_pos_generator.relative_vocab_size + 2) # g2l_att_mask = tf.transpose(l2g_att_mask, [0, 2, 1]) if use_hard_g2l_mask: global_range = tf.range( global_sequence_length, dtype=mapped_features['global_token_ids'].dtype) g2l_att_mask *= tf.cast( tf.logical_or( tf.equal(mapped_features['global_token_ids'], cls_id)[:, :, tf.newaxis], tf.equal(global_range[tf.newaxis, :, tf.newaxis], mapped_features['segment_ids'][:, tf.newaxis, :])), tf.int32) g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids, [0, 2, 1]) mapped_features.update( dict( l2l_att_mask=l2l_att_mask, l2l_relative_att_ids=l2l_relative_att_ids, l2g_att_mask=l2g_att_mask, l2g_relative_att_ids=l2g_relative_att_ids, g2g_att_mask=g2g_att_mask, g2g_relative_att_ids=g2g_relative_att_ids, g2l_att_mask=g2l_att_mask, g2l_relative_att_ids=g2l_relative_att_ids, question_lengths=question_lengths, )) return mapped_features
def _used_weight(weights_list): for weight in weights_list: if weight is not None: return tf.get_static_value(tf.convert_to_tensor(value=weight))
def batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension', name=None): """Multi-linear interpolation on a regular (constant spacing) grid. Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`. For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point ``` [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1), x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)] ``` In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`. Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`). axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or "constant_extension" ==> Extend as constant function. Default value: `"constant_extension"` name: A name to prepend to created ops. Default value: `"batch_interp_regular_nd_grid"`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].` Raises: ValueError: If `rank(x) < 2` is determined statically. ValueError: If `axis` is not a scalar is determined statically. ValueError: If `axis + nd > rank(y_ref)` is determined statically. #### Examples Interpolate a function of one variable. ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., 20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[1.], y_ref=y_ref) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a scalar function of two variables. ```python x_ref_min = [0., 2 * np.pi] x_ref_max = [0., 2 * np.pi] # Build y_ref. x0s, x1s = tf.meshgrid( tf.linspace(x_ref_min[0], x_ref_max[0], num=100), tf.linspace(x_ref_min[1], x_ref_max[1], num=100), indexing='ij') def func(x0, x1): return tf.sin(x0) * tf.cos(x1) y_ref = func(x0s, x1s) x = np.pi * tf.random_uniform(shape=(10, 2)) tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2) ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1]) ``` """ with tf.compat.v1.name_scope( name, default_name='interp_regular_nd_grid', values=[x, x_ref_min, x_ref_max, y_ref, fill_value]): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], preferred_dtype=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor(value=fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(value=x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(value=y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor(value=x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(value=x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically(x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically(x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # "nd" in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') x_ref_max.shape[-1:].assert_is_compatible_with(x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = tf.convert_to_tensor(value=axis, dtype=tf.int32, name='axis') axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref)) axis.shape.assert_has_rank(0) axis_ = tf.get_static_value(axis) y_ref_rank_ = tf.get_static_value(tf.rank(y_ref)) if axis_ is not None and y_ref_rank_ is not None: if axis_ + nd > y_ref_rank_: raise ValueError( 'Since dims `[axis, axis + nd)` index the interpolation table, we ' 'must have `axis + nd <= rank(y_ref)`. Found: ' '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing ' 'dimensions of `x_ref_min` to be {}.'.format( axis_, y_ref_rank_, nd)) x_batch_shape = tf.shape(input=x)[:-2] x_ref_min_batch_shape = tf.shape(input=x_ref_min)[:-1] x_ref_max_batch_shape = tf.shape(input=x_ref_max)[:-1] y_ref_batch_shape = tf.shape(input=y_ref)[:axis] # Do a brute-force broadcast of batch dims (add zeros). batch_shape = y_ref_batch_shape for tensor in [ x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape ]: batch_shape = tf.broadcast_dynamic_shape(batch_shape, tensor) def _batch_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = tf.ones(shape=[n_singletons], dtype=tf.int32) return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0), dtype=dtype) x += _batch_of_zeros_with_rightmost_singletons(n_singletons=2) x_ref_min += _batch_of_zeros_with_rightmost_singletons(n_singletons=1) x_ref_max += _batch_of_zeros_with_rightmost_singletons(n_singletons=1) y_ref += _batch_of_zeros_with_rightmost_singletons( n_singletons=tf.rank(y_ref) - axis) return _batch_interp_with_gather_nd( x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=tf.get_static_value(tf.rank(x)) - 2)
def to_n_step_transition( trajectory: Trajectory, gamma: types.Float ) -> Transition: """Create an n-step transition from a trajectory with `T=N + 1` frames. **NOTE** Tensors of `trajectory` are sliced along their *second* (`time`) dimension, to pull out the appropriate fields for the n-step transitions. The output transition's `next_time_step.{reward, discount}` will contain N-step discounted reward and discount values calculated as: ``` next_time_step.reward = r_t + g^{1} * d_t * r_{t+1} + g^{2} * d_t * d_{t+1} * r_{t+2} + g^{3} * d_t * d_{t+1} * d_{t+2} * r_{t+3} + ... g^{N-1} * d_t * ... * d_{t+N-2} * r_{t+N-1} next_time_step.discount = g^{N-1} * d_t * d_{t+1} * ... * d_{t+N-1} ``` In python notation: ```python discount = gamma**(N-1) * reduce_prod(trajectory.discount[:, :-1]) reward = discounted_return( rewards=trajectory.reward[:, :-1], discounts=gamma * trajectory.discount[:, :-1]) ``` When `trajectory.discount[:, :-1]` is an all-ones tensor, this is equivalent to: ```python next_time_step.discount = ( gamma**(N-1) * tf.ones_like(trajectory.discount[:, 0])) next_time_step.reward = ( sum_{n=0}^{N-1} gamma**n * trajectory.reward[:, n]) ``` Args: trajectory: An instance of `Trajectory`. The tensors in Trajectory must have shape `[B, T, ...]`. `discount` is assumed to be a scalar float, hence the shape of `trajectory.discount` must be `[B, T]`. gamma: A floating point scalar; the discount factor. Returns: An N-step `Transition` where `N = T - 1`. The reward and discount in `time_step.{reward, discount}` are NaN. The n-step discounted reward and final discount are stored in `next_time_step.{reward, discount}`. All tensors in the `Transition` have shape `[B, ...]` (no time dimension). Raises: ValueError: if `discount.shape.rank != 2`. ValueError: if `discount.shape[1] < 2`. """ _validate_rank(trajectory.discount, min_rank=2, max_rank=2) # Use static values when available, so that we can use XLA when the time # dimension is fixed. time_dim = (tf.compat.dimension_value(trajectory.discount.shape[1]) or tf.shape(trajectory.discount)[1]) static_time_dim = tf.get_static_value(time_dim) if static_time_dim in (0, 1): raise ValueError( 'Trajectory frame count must be at least 2, but saw {}. Shape of ' 'trajectory.discount: {}'.format(static_time_dim, trajectory.discount.shape)) n = time_dim - 1 # Use composite calculations to ensure we properly handle SparseTensor etc in # the observations. # pylint: disable=g-long-lambda # Pull out x[:,0] for x in trajectory first_frame = tf.nest.map_structure( lambda t: composite.squeeze( composite.slice_to(t, axis=1, end=1), axis=1), trajectory) # Pull out x[:,-1] for x in trajectory final_frame = tf.nest.map_structure( lambda t: composite.squeeze( composite.slice_from(t, axis=1, start=-1), axis=1), trajectory) # pylint: enable=g-long-lambda # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. reward = trajectory.reward[:, :-1] discount = trajectory.discount[:, :-1] policy_steps = policy_step.PolicyStep( action=first_frame.action, state=(), info=first_frame.policy_info) discounted_reward = value_ops.discounted_return( rewards=reward, discounts=gamma * discount, time_major=False, provide_all_returns=False) # NOTE: `final_discount` will have one less discount than `discount`. # This is so that when the learner/update uses an additional # discount (e.g. gamma) we don't apply it twice. final_discount = gamma**(n-1) * tf.math.reduce_prod(discount, axis=1) time_steps = ts.TimeStep( first_frame.step_type, # unknown reward=tf.nest.map_structure( lambda r: np.nan * tf.ones_like(r), first_frame.reward), # unknown discount=np.nan * tf.ones_like(first_frame.discount), observation=first_frame.observation) next_time_steps = ts.TimeStep( step_type=final_frame.step_type, reward=discounted_reward, discount=final_discount, observation=final_frame.observation) return Transition(time_steps, policy_steps, next_time_steps)
def _maybe_tensor_shape_from_tensor(shape): if isinstance(shape, tf.Tensor): return tensor_shape.as_shape(tf.get_static_value(shape)) else: return shape
def _dynamic_or_static_shape(tensor): shape = tf.shape(input=tensor) static_shape = tf.get_static_value(shape) return static_shape if static_shape is not None else shape
def _single_deterministic_pass_dataset(self, sample_batch_size=None, num_steps=None, num_parallel_calls=None): """Creates a dataset that returns entries from the buffer in fixed order. Args: sample_batch_size: (Optional.) An optional batch_size to specify the number of items to return. See as_dataset() documentation. num_steps: (Optional.) Optional way to specify that sub-episodes are desired. See as_dataset() documentation. num_parallel_calls: (Optional.) Number elements to process in parallel. See as_dataset() documentation. Returns: A dataset of type tf.data.Dataset, elements of which are 2-tuples of: - An item or sequence of items or batch thereof - Auxiliary info for the items (i.e. ids, probs). Raises: ValueError: If `dataset_drop_remainder` is set, and `sample_batch_size > self.batch_size`. In this case all data will be dropped. """ static_size = tf.get_static_value(sample_batch_size) static_num_steps = tf.get_static_value(num_steps) static_self_batch_size = tf.get_static_value(self._batch_size) static_self_max_length = tf.get_static_value(self._max_length) if (self._dataset_drop_remainder and static_size is not None and static_self_batch_size is not None and static_size > static_self_batch_size): raise ValueError( 'sample_batch_size ({}) > self.batch_size ({}) and ' 'dataset_drop_remainder is True. In ' 'this case, ALL data will be dropped by the deterministic dataset.' .format(static_size, static_self_batch_size)) if (self._dataset_drop_remainder and static_num_steps is not None and static_self_max_length is not None and static_num_steps > static_self_max_length): raise ValueError( 'num_steps_size ({}) > self.max_length ({}) and ' 'dataset_drop_remainder is True. In ' 'this case, ALL data will be dropped by the deterministic dataset.' .format(static_num_steps, static_self_max_length)) def get_row_ids(_): """Passed to Dataset.range(self._batch_size).flat_map(.), gets row ids.""" with tf.device(self._device), tf.name_scope(self._scope): with tf.name_scope('single_deterministic_pass_dataset'): # Here we pass num_steps=None because _valid_range_ids uses # num_steps to determine a hard stop when sampling num_steps starting # from the returned indices. But in our case, we want all the indices # and we'll use TF dataset's window() mechanism to get # num_steps-length blocks. The window mechanism handles this stuff # for us. min_frame_offset, max_frame_offset = _valid_range_ids( self._get_last_id(), self._max_length, num_steps=None) tf.compat.v1.assert_less( min_frame_offset, max_frame_offset, message= 'TFUniformReplayBuffer is empty. Make sure to add items ' 'before asking the buffer for data.') min_max_frame_range = tf.range(min_frame_offset, max_frame_offset) drop_remainder = self._dataset_drop_remainder window_shift = self._dataset_window_shift def group_windows(ds_): return ds_.batch(num_steps, drop_remainder=drop_remainder) if sample_batch_size is None: def row_ids(b): # Create a vector of shape [num_frames] and slice it along each # frame. ids = tf.data.Dataset.from_tensor_slices( b * self._max_length + min_max_frame_range) if num_steps is not None: ids = (ids.window( num_steps, shift=window_shift).flat_map(group_windows) ) return ids return tf.data.Dataset.range( self._batch_size).flat_map(row_ids) else: def batched_row_ids(batch): # Create a matrix of indices shaped [num_frames, batch_size] # and slice it along each frame row to get groups of batches # for frame 0, frame 1, ... return tf.data.Dataset.from_tensor_slices( (min_max_frame_range[:, tf.newaxis] + batch * self._max_length)) indices_ds = (tf.data.Dataset.range( self._batch_size).batch( sample_batch_size, drop_remainder=drop_remainder).flat_map( batched_row_ids)) if num_steps is not None: # We have sequences of num_frames rows shaped [sample_batch_size]. # Window and group these to rows of shape # [num_steps, sample_batch_size], then # transpose them to get index tensors of shape # [sample_batch_size, num_steps]. indices_ds = (indices_ds.window( num_steps, shift=window_shift).flat_map( group_windows).map(tf.transpose)) return indices_ds # Get our indices as a dataset; each time we reinitialize the iterator we # update our min/max id bounds from the state of the replay buffer. ds = tf.data.Dataset.range(1).flat_map(get_row_ids) def get_data(id_): with tf.device(self._device), tf.name_scope(self._scope): with tf.name_scope('single_deterministic_pass_dataset'): data = self._data_table.read(id_ % self._capacity) buffer_info = BufferInfo(ids=id_, probabilities=()) return (data, buffer_info) # Deterministic even though num_parallel_calls > 1. Operations are # run in parallel but then the results are returned in original stream # order. ds = ds.map(get_data, num_parallel_calls=num_parallel_calls) return ds
def stateless_dropout(x: tf.Tensor, rate: float, seed: tf.Tensor, noise_shape: Optional[Union[Sequence[int], tf.TensorShape]] = None, name: Optional[Text] = None) -> tf.Tensor: """Computes dropout: randomly sets elements to zero to prevent overfitting. See https://www.tensorflow.org/api_docs/python/tf/nn/dropout. This version differs in that the seed is required if the rate is nonzero. Args: x: A floating point tensor. rate: A scalar `Tensor` with the same type as x. The probability that each element is dropped. For example, setting rate=0.1 would drop 10% of input elements. seed: A shape [2] integer Tensor of seeds to the random number generator. Must have dtype `tf.int32` when compiling to XLA. noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. name: A name for this operation (optional). Returns: A `Tensor` of the same shape of `x`. Raises: ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point tensor. `rate=1` is disallowed, because the output would be all zeros, which is likely not what was intended. """ with tf.name_scope(name or 'stateless_dropout') as name: x = tf.convert_to_tensor(x, name='x') if not x.dtype.is_floating: raise ValueError( 'x has to be a floating point tensor since it\'s going ' ' to be scaled. Got a %s tensor instead.' % x.dtype) if isinstance(rate, numbers.Real): if not (rate >= 0 and rate < 1): raise ValueError( 'rate must be a scalar tensor or a float in the ' 'range [0, 1), got %g' % rate) if rate > 0.5: logging.log_first_n( logging.WARN, 'Large dropout rate: %g (>0.5). In TensorFlow ' '.x, dropout() uses dropout rate instead of keep_prob. ' 'Please ensure that this is intended.', 5, rate) # Early return if nothing needs to be dropped. if tf.get_static_value(rate) == 0: return x rate = tf.convert_to_tensor(rate, dtype=x.dtype, name='rate') rate.shape.assert_has_rank(0) noise_shape = _get_noise_shape(x, noise_shape) # Sample a uniform distribution on [0.0, 1.0) and select values larger than # rate. # # NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0) # and subtract 1.0. random_tensor = tf.random.stateless_uniform(noise_shape, seed=seed, dtype=x.dtype) keep_prob = 1 - rate scale = 1 / keep_prob # NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that # float to be selected, hence we use a >= comparison. keep_mask = random_tensor >= rate ret = x * scale * tf.cast(keep_mask, x.dtype) if not tf.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
# Test images #testLoad( dsTrain, info ) #testLoad( dsValid, info ) # Load data into a dictionary trainDict = { 'image': [], 'label': [], 'mask': [] } validDict = { 'image': [], 'label': [] } for example in dsTrain: image_name = example["file_name"] image = example["image"] label = example["label"] # Get name of file name = tf.get_static_value( image_name ).decode( 'utf-8' ) name = os.path.splitext( name )[0] xml_name = name + ".xml" if( xml_name in missingXMLs ): continue # Add bounding boxes xml_path = os.path.join( xmls_dir_path, xml_name ) sizeInfo = [] for _, elem in ElementTree.iterparse( xml_path ): unmarshal = unmarshallers.get(elem.tag) if unmarshal: data = unmarshal(elem) elem.clear()
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.compat.v1.name_scope(name) as name: with tf.compat.v1.name_scope("init", values=[df, scale_operator]): if not scale_operator.dtype.is_floating: raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(value=df, dtype=scale_operator.dtype, name="df") tf.debugging.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( value=tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = tf.compat.v1.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=([self._df, self._dimension] + self._scale_operator.graph_parents), name=name)
def sample_fn(spec): """Return a composite tensor sample given `spec`. Args: spec: A TensorSpec, SparseTensorSpec, etc. Returns: A tensor or SparseTensor. Raises: NotImplementedError: If `outer_dims` is not statically known and a SparseTensor is requested. """ if isinstance(spec, tf.SparseTensorSpec): outer_shape = tf.get_static_value(outer_dims) if outer_dims is not None and outer_shape is None: raise NotImplementedError( "outer_dims must be statically known, got: {}".format( outer_dims)) shape = tf.TensorShape(outer_shape or []).concatenate(spec.shape) if shape.num_elements() == 0 or tf.compat.dimension_value( shape[0]) == 0: return tf.SparseTensor(indices=tf.zeros([0, shape.rank], dtype=tf.int64), values=tf.zeros([0], dtype=spec.dtype), dense_shape=shape) indices_spec = BoundedTensorSpec( dtype=tf.int64, shape=[7, shape.rank], minimum=[0] * shape.rank, maximum=[x - 1 for x in shape.as_list()]) values_dtype = tf.int32 if spec.dtype == tf.string else spec.dtype values_spec = BoundedTensorSpec(dtype=values_dtype, shape=[7], minimum=0, maximum=shape.as_list()[-1] - 1) values_sample = sample_bounded_spec(values_spec, seed=seed_stream()) if spec.dtype == tf.string: values_sample = tf.as_string(values_sample) return tf.sparse.reorder( tf.SparseTensor(indices=sample_bounded_spec( indices_spec, seed=seed_stream()), values=values_sample, dense_shape=shape)) elif isinstance(spec, (TensorSpec, BoundedTensorSpec)): if spec.dtype == tf.string: sample_spec = BoundedTensorSpec(spec.shape, tf.int32, minimum=0, maximum=10) return tf.as_string( sample_bounded_spec(sample_spec, outer_dims=outer_dims, seed=seed_stream())) else: return sample_bounded_spec(BoundedTensorSpec.from_spec(spec), outer_dims=outer_dims, seed=seed_stream()) else: raise TypeError("Spec type not supported: '{}'".format(spec))
def train_step(reals, prev_rec, noise_amp, scale, step, g_opt, d_opt): real = reals[scale] z_rand = tf.random.normal(real.shape) if scale == 0: z_rec = tf.random.normal(real.shape) else: z_rec = tf.zeros_like(real) for i in range(6): if i == 0 and tf.get_static_value(step) == 0: if scale == 0: """ Coarsest scale is purely generative """ prev_rand = tf.zeros_like(real) prev_rec = tf.zeros_like(real) noise_amp = 1.0 else: """ Finer scale takes noise and image generated from previous scale as input """ prev_rand = self.generate_from_coarsest( scale, reals, 'rand') prev_rec = self.generate_from_coarsest( scale, reals, 'rec') """ Compute the standard deviation of noise """ RMSE = tf.sqrt( tf.reduce_mean(tf.square(real - prev_rec))) noise_amp = self.noise_amp_init * RMSE else: prev_rand = self.generate_from_coarsest( scale, reals, 'rand') Z_rand = z_rand if scale == 0 else noise_amp * z_rand Z_rec = noise_amp * z_rec if i < 3: with tf.GradientTape() as tape: """ Only record the training variables """ fake_rand = self.generators[scale](prev_rand, Z_rand) dis_loss = self.dicriminator_wgan_loss( self.discriminators[scale], real, fake_rand, 1) dis_gradients = tape.gradient( dis_loss, self.discriminators[scale].trainable_variables) d_opt.apply_gradients( zip(dis_gradients, self.discriminators[scale].trainable_variables)) else: with tf.GradientTape() as tape: """ Only record the training variables """ fake_rand = self.generators[scale](prev_rand, Z_rand) fake_rec = self.generators[scale](prev_rec, Z_rec) gen_loss = self.generator_wgan_loss( self.discriminators[scale], fake_rand) rec_loss = self.reconstruction_loss(real, fake_rec) gen_loss = gen_loss + 10 * rec_loss gen_gradients = tape.gradient( gen_loss, self.generators[scale].trainable_variables) g_opt.apply_gradients( zip(gen_gradients, self.generators[scale].trainable_variables)) metrics = (dis_loss, gen_loss, rec_loss) return z_rec, prev_rec, noise_amp, metrics
def __init__(self, cat, components, validate_args=False, allow_nan_stats=True, use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. A `Mixture` is defined by a `Categorical` (`cat`, representing the mixture probabilities) and a list of `Distribution` objects all having matching dtype, batch shape, event shape, and continuity properties (the components). The `num_classes` of `cat` must be possible to infer at graph construction time and match `len(components)`. Args: cat: A `Categorical` distribution instance, representing the probabilities of `distributions`. components: A list or tuple of `Distribution` instances. Each instance must have the same type, be defined on the same domain, and have matching `event_shape` and `batch_shape`. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between cat and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. use_static_graph: Calls to `sample` will not rely on dynamic tensor indexing, allowing for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture. (Possibly useful when running on TPUs). Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: TypeError: If cat is not a `Categorical`, or `components` is not a list or tuple, or the elements of `components` are not instances of `Distribution`, or do not have matching `dtype`. ValueError: If `components` is an empty list or tuple, or its elements do not have a statically known event rank. If `cat.num_classes` cannot be inferred at graph creation time, or the constant value of `cat.num_classes` is not equal to `len(components)`, or all `components` and `cat` do not have matching static batch shapes, or all components do not have matching static event shapes. """ parameters = dict(locals()) # TODO(b/117098119): Remove tf.distribution references once they're gone. if not isinstance(cat, categorical.Categorical) and not isinstance( cat, tf.compat.v1.distributions.Categorical): raise TypeError( "cat must be a Categorical distribution, but saw: %s" % cat) if not components: raise ValueError("components must be a non-empty list or tuple") if not isinstance(components, (list, tuple)): raise TypeError("components must be a list or tuple, but saw: %s" % components) # TODO(b/117098119): Remove tf.distribution references once they're gone. if not all( isinstance(c, distribution.Distribution) or isinstance(cat, tf.compat.v1.distributions.Distribution) for c in components): raise TypeError( "all entries in components must be Distribution instances" " but saw: %s" % components) dtype = components[0].dtype if not all(d.dtype == dtype for d in components): raise TypeError("All components must have the same dtype, but saw " "dtypes: %s" % [(d.name, d.dtype) for d in components]) static_event_shape = components[0].event_shape static_batch_shape = cat.batch_shape for di, d in enumerate(components): if not static_batch_shape.is_compatible_with(d.batch_shape): raise ValueError( "components[{}] batch shape must be compatible with cat " "shape and other component batch shapes".format(di)) static_event_shape = static_event_shape.merge_with(d.event_shape) static_batch_shape = static_batch_shape.merge_with(d.batch_shape) if static_event_shape.ndims is None: raise ValueError( "Expected to know rank(event_shape) from components, but " "none of the components provide a static number of ndims") # Ensure that all batch and event ndims are consistent. with tf.name_scope(name, values=[cat.logits]) as name: num_components = cat.event_size static_num_components = tf.get_static_value(num_components) if static_num_components is None: raise ValueError( "Could not infer number of classes from cat and unable " "to compare this value to the number of components passed in." ) # Possibly convert from numpy 0-D array. static_num_components = int(static_num_components) if static_num_components != len(components): raise ValueError( "cat.num_classes != len(components): %d vs. %d" % (static_num_components, len(components))) cat_batch_shape = cat.batch_shape_tensor() cat_batch_rank = tf.size(input=cat_batch_shape) if validate_args: batch_shapes = [d.batch_shape_tensor() for d in components] batch_ranks = [tf.size(input=bs) for bs in batch_shapes] check_message = ("components[%d] batch shape must match cat " "batch shape") self._assertions = [ tf.compat.v1.assert_equal(cat_batch_rank, batch_ranks[di], message=check_message % di) for di in range(len(components)) ] self._assertions += [ tf.compat.v1.assert_equal(cat_batch_shape, batch_shapes[di], message=check_message % di) for di in range(len(components)) ] else: self._assertions = [] self._cat = cat self._components = list(components) self._num_components = static_num_components self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape self._use_static_graph = use_static_graph if use_static_graph and static_num_components is None: raise ValueError( "Number of categories must be known statically when " "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access for c in self._components: graph_parents += c._graph_parents # pylint: disable=protected-access super(Mixture, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=graph_parents, name=name)
def __init__(self, distribution, bijector, batch_shape=None, event_shape=None, validate_args=False, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = dict(locals()) name = name or (("" if bijector is None else bijector.name) + distribution.name) with tf.compat.v1.name_scope(name, values=[event_shape, batch_shape]) as name: # For convenience we define some handy constants. self._zero = tf.constant(0, dtype=tf.int32, name="zero") self._empty = tf.constant([], dtype=tf.int32, name="empty") # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = _logical_not( _logical_equal(_ndims_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tf.get_static_value(self._override_batch_shape) is None or tf.get_static_value(self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = _logical_not( _logical_equal(_ndims_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tf.get_static_value(self._override_event_shape) is None or tf.get_static_value(self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = _logical_and( self._is_event_override, _logical_not(self._is_batch_override), _logical_not(distribution.is_scalar_batch())) override_event_ndims = _ndims_from_shape( self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = tf.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=( distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def _event_shape(self): sample_shape = tf.TensorShape(tf.get_static_value(self.sample_shape)) if (sample_shape.ndims is None or self.distribution.event_shape.ndims is None): return tf.TensorShape(None) return sample_shape.concatenate(self.distribution.event_shape)
def _static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tf.get_static_value(tf.convert_to_tensor(value=x))
def tensor_and_const_value(v): tensor_value = tf.convert_to_tensor(v) const_value = tf.get_static_value(tensor_value) return (tensor_value, const_value)
def __init__( self, filenames, features, reader_schema, batch_size, drop_remainder, num_parallel_calls, input_stream_buffer_size, avro_data_buffer_size, ): self._filenames = tf.ops.convert_to_tensor(filenames, tf.string, name="filenames") self._features = _AvroDataset._build_keys_for_sparse_features(features) self._reader_schema = reader_schema self._batch_size = tf.ops.convert_to_tensor(batch_size, dtype=tf.int64, name="batch_size") self._drop_remainder = tf.ops.convert_to_tensor(drop_remainder, dtype=tf.bool, name="drop_remainder") self._num_parallel_calls = num_parallel_calls self._input_stream_buffer_size = input_stream_buffer_size self._avro_data_buffer_size = avro_data_buffer_size # Copied from _ParseExampleDataset from data/experimental/ops/parsing_ops.py ( sparse_keys, sparse_types, sparse_dense_shapes, dense_keys, dense_types, dense_defaults, dense_shapes, ) = _AvroDataset._features_to_raw_params( self._features, [tf.io.VarLenFeature, tf.io.SparseFeature, tf.io.FixedLenFeature], ) ( _, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, _, ) = _AvroDataset._process_raw_parameters( None, dense_defaults, sparse_keys, sparse_types, dense_keys, dense_types, dense_shapes, ) self._sparse_keys = sparse_keys self._sparse_types = sparse_types self._dense_keys = dense_keys self._dense_defaults = dense_defaults_vec self._dense_types = dense_types output_shapes = dict( zip(self._dense_keys + self._sparse_keys, dense_shapes + sparse_dense_shapes)) output_types = dict( zip( self._dense_keys + self._sparse_keys, self._dense_types + self._sparse_types, )) output_classes = dict( zip( self._dense_keys + self._sparse_keys, [tf.ops.Tensor for _ in range(len(self._dense_defaults))] + [ tf.sparse.SparseTensor for _ in range(len(self._sparse_keys)) ], )) self._element_spec = _AvroDataset._convert_legacy_structure( output_types, output_shapes, output_classes) constant_drop_remainder = tf.get_static_value(self._drop_remainder) # pylint: disable=protected-access if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). # pylint: disable=g-long-lambda self._element_spec = tf.nest.map_structure( lambda component_spec: component_spec._batch( tf.get_static_value(self._batch_size)), self._element_spec, ) else: self._element_spec = tf.nest.map_structure( lambda component_spec: component_spec._batch(None), self._element_spec) # With batch dimension self._dense_shapes = [ spec.shape for spec in tf.nest.flatten(self._element_spec) if isinstance(spec, tf.TensorSpec) ] variant_tensor = core_ops.io_avro_dataset( filenames=self._filenames, # pylint: disable=protected-access batch_size=self._batch_size, drop_remainder=self._drop_remainder, dense_defaults=self._dense_defaults, input_stream_buffer_size=self._input_stream_buffer_size, avro_data_buffer_size=self._avro_data_buffer_size, reader_schema=self._reader_schema, sparse_keys=self._sparse_keys, dense_keys=self._dense_keys, sparse_types=self._sparse_types, dense_shapes=self._dense_shapes, **self._flat_structure) super().__init__(variant_tensor)