def floating_tensor_to_f64(x): """Cast x to float64 if a floating-point Tensor, else return it. This is meant to be used with `tf.nest.map_structure`, or other situations where it may not be obvious whether an object is a Tensor or not, or has floating dtype or not. Args: x: The object to be cast or left be. Returns: x: x, either cast or left be. """ if tf.is_tensor(x) and dtype_util.is_floating(x.dtype): return tf.cast(x, dtype=tf.float64) else: return x
def _mark_as_return(tensor): """Marks `tensor` as the return value for automatic control deps.""" if not tf.is_tensor(tensor): return tensor # pylint: disable=protected-access return_tensor = acd.mark_as_return(tensor) if getattr(tensor, '_keras_mask', None) is not None: return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask) else: return_tensor._keras_mask = None # Handle TensorFlow Probability attached metadata. # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`. if getattr(tensor, '_tfp_distribution', None) is not None: return_tensor._tfp_distribution = tensor._tfp_distribution return return_tensor
def _all_shapes(thing): if isinstance(thing, (tfd.Distribution, tfb.Bijector)): # pylint: disable=g-complex-comprehension answer = [s for _, param in thing.parameters.items() for s in _all_shapes(param)] if isinstance(thing, tfd.TransformedDistribution): answer = [thing.batch_shape + s for s in answer] if isinstance(thing, tfd.Distribution): answer += [thing.batch_shape + thing.event_shape] if isinstance(thing, tfd.MixtureSameFamily): num_components = thing.mixture_distribution.logits_parameter().shape[-1] answer += [thing.batch_shape + [num_components] + thing.event_shape] return answer elif tf.is_tensor(thing): return [thing.shape] else: # Assume the thing is some Python constant like a string or a boolean return []
def stringify_slices(slices): """Returns a list of strings describing the items in `slices`.""" pretty_slices = [] slices = slices if isinstance(slices, tuple) else (slices, ) for slc in slices: if slc == Ellipsis: pretty_slices.append('...') elif isinstance(slc, slice): pretty_slices.append('{}:{}:{}'.format(*[ '' if s is None else s for s in (slc.start, slc.stop, slc.step) ])) elif isinstance(slc, int) or tf.is_tensor(slc): pretty_slices.append(str(slc)) elif slc is tf.newaxis: pretty_slices.append('tf.newaxis') else: raise ValueError('Unexpected slice type: {}'.format(type(slc))) return pretty_slices
def _get_static_value(pred): """Helper function for getting static values from maybe-tensor objects.""" if JAX_MODE: try: return np.asarray(pred) except: # JAX sometimes raises raw Exception in __array__. # pylint: disable=bare-except return None if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access return pred_value return pred
def convert_to_ndarray(test_obj, a): """Converts the input `a` into an ndarray. Args: test_obj: An object which has the `evaluate` method. Used to evaluate `a` if `a` is a Tensor. a: Object to be converted to an ndarray. Returns: An ndarray containing the values of `a`. """ # TODO(b/177990397): This function should be independent of the test framework # If a is tensor-like then convert it to ndarray if tf.is_tensor(a): a = test_obj.evaluate(a) if not isinstance(a, np.ndarray): return np.array(a) return a
def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name # If applicable, update the static input shape of the model. if not self._has_explicit_input_shape: if not tf.is_tensor(inputs) and not isinstance( inputs, tf.Tensor): # This is a Sequential with multiple inputs. This is technically an # invalid use case of Sequential, but we tolerate it for backwards # compatibility. self._use_legacy_deferred_behavior = True self._build_input_shape = tf.nest.map_structure( _get_shape_tuple, inputs) if tf.__internal__.tf2.enabled(): logging.warning('Layers in a Sequential model should only have a ' f'single input tensor. Received: inputs={inputs}. ' 'Consider rewriting this model with the Functional ' 'API.') else: self._build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype) if self._graph_initialized: if not self.built: self._init_graph_network(self.inputs, self.outputs) return super().call(inputs, training=training, mask=mask) outputs = inputs # handle the corner case where self.layers is empty for layer in self.layers: # During each iteration, `inputs` are the inputs to `layer`, and `outputs` # are the outputs of `layer` applied to `inputs`. At the end of each # iteration `inputs` is set to `outputs` to prepare for the next layer. kwargs = {} argspec = self._layer_call_argspecs[layer].args if 'mask' in argspec: kwargs['mask'] = mask if 'training' in argspec: kwargs['training'] = training outputs = layer(inputs, **kwargs) if len(tf.nest.flatten(outputs)) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) # `outputs` will be the inputs to the next layer. inputs = outputs mask = getattr(outputs, '_keras_mask', None) return outputs
def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None, prefer_static_value=()): """Extract constructor kwargs to reconstruct `obj`.""" sig = tf_inspect.signature(obj.__init__) if any(v.kind in (tf_inspect.Parameter.VAR_KEYWORD, tf_inspect.Parameter.VAR_POSITIONAL) for v in sig.parameters.values()): raise ValueError( '*args and **kwargs are not supported. Found `{}`'.format(sig)) keys = [p for p in sig.parameters if p != 'self' and p not in omit_kwargs] if limit_to is not None: keys = [k for k in keys if k in limit_to] kwargs = {} not_found = object() for k in keys: srcs = [ getattr(obj, k, not_found), getattr(obj, '_' + k, not_found), getattr(obj, 'parameters', {}).get(k, not_found), ] if any(v is not not_found for v in srcs): kwargs[k] = [v for v in srcs if v is not not_found][0] else: raise ValueError( f'Could not determine an appropriate value for field `{k}` in object ' ' `{obj}`. Looked for \n' ' 1. an attr called `{k}`,\n' ' 2. an attr called `_{k}`,\n' ' 3. an entry in `obj.parameters` with key "{k}".') if k in prefer_static_value and kwargs[k] is not None: if tf.is_tensor(kwargs[k]): static_val = tf.get_static_value(kwargs[k]) if static_val is not None: kwargs[k] = static_val if isinstance(kwargs[k], (np.ndarray, np.generic)): # Generally, these are shapes or int, but may be other parameters such as # `power` for `tfb.PowerTransform`. kwargs[k] = kwargs[k].tolist() return kwargs
def reduce(v): has_tensors = False if tf.is_tensor(v): v = tf.TensorSpec.from_tensor(v) has_tensors = True if isinstance(v, CompositeTensor): v = v._type_spec # pylint: disable=protected-access has_tensors = True if isinstance(v, (list, tuple)): reduced = [reduce(v_) for v_ in v] has_tensors = any(ht for (_, ht) in reduced) if has_tensors != all(ht for (_, ht) in reduced): raise NotImplementedError( _mk_err_msg( clsid, self, 'Found `{}` with both Tensor and non-Tensor parts: {}' .format(type(v), v))) v = type(v)([spec for (spec, _) in reduced]) return v, has_tensors
def check_consistency(tf_fn, np_fn, args): # If `args` is a single item, put it in a tuple if isinstance(args, np.ndarray) or tf.is_tensor(args): args = (args, ) tensorflow_value = self.evaluate( tf_fn(*_maybe_convert_to_tensors(args))) kwargs = jax_kwargs() if MODE_JAX else {} numpy_value = np_fn(*args, **kwargs) if assert_shape_only: def assert_same_shape(x, y): self.assertAllEqual(x.shape, y.shape) tf.nest.map_structure(assert_same_shape, tensorflow_value, numpy_value) else: self.assertAllCloseAccordingToType(tensorflow_value, numpy_value, atol=atol, rtol=rtol)
def __init__(self, value): """Creates a new `_DeferredTensorInput`. Args: value: either a `Tensor`-like object, or a nullary function returning such an object. """ if isinstance(value, helpers.RateObject): raise TypeError( "a DeferredTensor may only be created from a Tensor-like object, " "or a nullary function returning such") self._value = value # For non-callable non-Tensor types, we make a deep copy to make # extra-certain that it is immutable (since we'll hash it). if not callable(self._value) and not tf.is_tensor(self._value): self._value = copy.deepcopy(self._value) # We memoize the hash, since it can be expensive to compute. self._hash = None
def split_seed(seed, n=2, salt=None, name=None): """Splits a seed deterministically into derived seeds.""" if not (isinstance(n, int) or tf.is_tensor(n)): # avoid confusion with salt. raise TypeError( '`n` must be a python `int` or an int Tensor, got {}'.format( repr(n))) with tf.name_scope(name or 'split'): seed = sanitize_seed(seed, salt=salt) if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top return jaxrand.split(seed, n) seeds = tf.random.stateless_uniform([n, 2], seed=seed, minval=None, maxval=None, dtype=SEED_DTYPE) if isinstance(n, six.integer_types): seeds = tf.unstack(seeds) return seeds
def _assert_batch_shape_matches_weights(distribution, weights_shape, diststr): """Checks that all parts of a distribution have the expected batch shape.""" shapes = [weights_shape] + tf.nest.flatten( distribution.batch_shape_tensor()) static_shapes = [ tf.get_static_value(ps.convert_to_shape_tensor(s)) for s in shapes ] static_shapes_not_none = [s for s in static_shapes if s is not None] static_shapes_match = all([ np.all(a == b) # Also need to check for rank mismatch (below). for (a, b) in zip(static_shapes_not_none[1:], static_shapes_not_none[:-1]) ]) # Build a separate list of static ranks, since rank is often static even when # shape is not. ranks = [ps.rank_from_shape(s) for s in shapes] static_ranks = [int(r) for r in ranks if not tf.is_tensor(r)] static_ranks_match = all( [a == b for (a, b) in zip(static_ranks[1:], static_ranks[:-1])]) msg = ( "The {diststr} distribution's batch shape does not match the particle " "weights; a correct {diststr} distribution must return an independent " "log-density for each particle. You may be " "creating a joint distribution in which some parts do not depend on the " "previous particles, and/or you are creating an autobatched joint " "distribution without setting `batch_ndims`.".format(diststr=diststr)) if not (static_ranks_match and static_shapes_match): raise ValueError( msg + ' ' + 'Weights have shape {}, but the distribution has batch ' 'shape {}.'.format(weights_shape, distribution.batch_shape)) assertions = [] if distribution.validate_args and any([s is None for s in static_shapes]): assertions = [ assert_util.assert_equal(a, b, message=msg) for a, b in zip(shapes[1:], shapes[:-1]) ] return assertions
def _type_spec(self): def get_default_args(fn_or_object): fn = type(fn_or_object) if isinstance(fn_or_object, object) else fn_or_object return { k: v.default for k, v in inspect.signature(fn).parameters.items() if v.default is not inspect.Parameter.empty } if six.PY3: default_kwargs = get_default_args(self) missing = object() kwargs = { k: v for k, v in self.parameters.items() if default_kwargs.get(k, missing) is not v } # non-default kwargs only else: kwargs = dict(self.parameters) param_specs = {} try: composite_tensor_params = self._composite_tensor_params # pylint: disable=protected-access except NotImplementedError: composite_tensor_params = () for k in composite_tensor_params: if k in kwargs and kwargs[k] is not None: v = kwargs.pop(k) if isinstance(v, CompositeTensor): param_specs[k] = v._type_spec # pylint: disable=protected-access elif tf.is_tensor(v): param_specs[k] = tf.TensorSpec.from_tensor(v) for k, v in list(kwargs.items()): if isinstance(v, CompositeTensor): param_specs[k] = v._type_spec # pylint: disable=protected-access kwargs.pop(k) elif callable(v): raise NotImplementedError( 'Unable to make CompositeTensor including callable argument.' + k) return _TFPTypeSpec(clsid, param_specs=param_specs, kwargs=kwargs)
def convert_fn(path, value, dtype, dtype_hint, name=None): if not allow_packing and nest.is_nested(value) and any( # Treat arrays like Tensors for full parity in JAX backend. tf.is_tensor(x) or isinstance(x, np.ndarray) for x in nest.flatten(value)): raise NotImplementedError( ('Cannot convert a structure of tensors to a ' 'single tensor. Saw {} at path {}.').format(value, path)) if as_shape_tensor: return ps.convert_to_shape_tensor(value, dtype, dtype_hint, name=name) elif 'KerasTensor' in str(type(value)): # This is a hack to detect symbolic Keras tensors to work around # b/206660667. The issue was that symbolic Keras tensors would # break the Bijector cache on forward/inverse log det jacobian, # because tf.convert_to_tensor is not a no-op thereon. return value else: return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
def check_consistency(tf_fn, np_fn, args): # If `args` is a single item, put it in a tuple if isinstance(args, np.ndarray) or tf.is_tensor(args): args = (args,) tensorflow_value = self.evaluate( tf_fn(*_maybe_convert_to_tensors(args))) kwargs = jax_kwargs() if JAX_MODE else {} numpy_value = np_fn(*args, **kwargs) if assert_shape_only: def assert_same_shape(x, y): self.assertAllEqual(x.shape, y.shape) tf.nest.map_structure(assert_same_shape, tensorflow_value, numpy_value) else: for i, (tf_val, np_val) in enumerate(six.moves.zip_longest( tf.nest.flatten(tensorflow_value), tf.nest.flatten(numpy_value))): self.assertAllCloseAccordingToType( tf_val, np_val, atol=atol, rtol=rtol, msg='output {}'.format(i))
def _get_static_predicate(pred): """Helper function for statically evaluating predicates in `cond`.""" if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access if pred_value in (0, 1, True, False): pred_value = bool(pred_value) elif pred in (0, 1, True, False): # Accept 1/0 as valid boolean values # This branch also casts np.array(False), tf.EagerTensor(True), etc. pred_value = bool(pred) else: raise TypeError('`pred` must be a Tensor, or a Python bool, or 1 or 0. ' 'Found instead: {}'.format(pred)) return pred_value
def get_discriminator_focal_loss(learner_agent_output, env_output, actor_agent_output, actor_action, reward_clipping, discounting, baseline_cost, entropy_cost, num_steps): """Discriminator focal loss.""" # Check if learner_agent_output has logits and labels prepared already, # otherwise filter and prepare the logits and labels. if (isinstance(learner_agent_output.policy_logits, dict) and 'labels' in learner_agent_output.policy_logits and tf.is_tensor(learner_agent_output.baseline)): # Shape: [batch] labels = learner_agent_output.policy_logits['labels'] output_logits = learner_agent_output.baseline tf.debugging.assert_equal(tf.shape(labels), tf.shape(output_logits)) loss_tag = 'loss/focal_loss' else: # labels and output_logits have shape: [batch]. labels, output_logits = _get_discriminator_logits( learner_agent_output, env_output, actor_agent_output, actor_action, reward_clipping, discounting, baseline_cost, entropy_cost, num_steps) loss_tag = 'loss/focal_loss (w/ softmin_softmax)' # Shape = [batch] fl, ce = focal_loss_from_logits(output_logits, labels, alpha=FLAGS.focal_loss_alpha, gamma=FLAGS.focal_loss_gamma, normalizer=FLAGS.focal_loss_normalizer) tf.summary.scalar(loss_tag, tf.reduce_mean(fl), step=num_steps) tf.summary.scalar('loss/CE (reference only)', tf.reduce_mean(ce), step=num_steps) tf.summary.scalar('labels/num_labels_per_batch', tf.size(labels), step=num_steps) tf.summary.scalar('labels/mean_labels', tf.reduce_mean(labels), step=num_steps) return fl
def create_timbre_spectrogram(audio, hparams): """Create either a CQT or mel spectrogram""" if tf.is_tensor(audio): audio = audio.numpy() if isinstance(audio, bytes): # Get samples from wav data. samples = audio_io.wav_data_to_samples(audio, hparams.sample_rate) else: samples = audio if hparams.timbre_spec_type == 'mel': spec = np.abs( librosa.feature.melspectrogram( samples, hparams.sample_rate, hop_length=hparams.timbre_hop_length, fmin=librosa.midi_to_hz(constants.MIN_TIMBRE_PITCH), fmax=librosa.midi_to_hz(constants.MAX_TIMBRE_PITCH), n_mels=constants.TIMBRE_SPEC_BANDS, pad_mode='symmetric', htk=hparams.spec_mel_htk, power=2)).T else: spec = np.abs( librosa.core.cqt(samples, hparams.sample_rate, hop_length=hparams.timbre_hop_length, fmin=librosa.midi_to_hz( constants.MIN_TIMBRE_PITCH), n_bins=constants.TIMBRE_SPEC_BANDS, bins_per_octave=constants.BINS_PER_OCTAVE, pad_mode='symmetric')).T # convert amplitude to power if hparams.timbre_spec_log_amplitude: spec = librosa.power_to_db(spec) - librosa.power_to_db(np.array([1e-9 ]))[0] spec = spec / np.max(spec) return spec
def _get_static_value(pred): """Helper function for getting static values from maybe-tensor objects.""" if JAX_MODE: try: return np.asarray(pred) except: # JAX sometimes raises raw Exception in __array__. # pylint: disable=bare-except return None if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # Explicitly check for ops.Tensor, to avoid an AttributeError # when requesting `KerasTensor.graph`. if pred_value is None and isinstance(pred, ops.Tensor): if hasattr(tensor_util, 'try_evaluate_constant'): pred_value = tensor_util.try_evaluate_constant(pred) else: # TODO(feyu): remove this branch after try_evaluate_constant is in # tf-nightly. pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output()) # pylint: disable=protected-access return pred_value return pred
def split_seed(seed, n=2, salt=None, name=None): """Splits a seed into `n` derived seeds. See https://github.com/tensorflow/probability/blob/main/PRNGS.md for details. Args: seed: The seed to split; may be an `int`, an `(int, int) tuple`, or a `Tensor`. `int` seeds are converted to `Tensor` seeds using `tf.random.uniform` stateful sampling. Tuples are converted to `Tensor`. n: The number of splits to return. In TensorFlow, if `n` is an integer, this function returns a list of seeds and otherwise returns a `Tensor` of seeds. In JAX, this function always returns an array of seeds. salt: Optional `str` salt to mix with the seed. name: Optional name to scope related ops. Returns: seeds: If `n` is a Python `int`, a `tuple` of seed values is returned. If `n` is an int `Tensor`, a single `Tensor` of shape `[n, 2]` is returned. A single such seed is suitable to pass as the `seed` argument of the `tf.random.stateless_*` ops. """ if not (isinstance(n, int) or isinstance(n, np.ndarray) or tf.is_tensor(n)): # avoid confusion with salt. raise TypeError( '`n` must be a python `int` or an int Tensor, got {}'.format( repr(n))) with tf.name_scope(name or 'split_seed'): seed = sanitize_seed(seed, salt=salt) if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top return jaxrand.split(seed, int(n)) seeds = tf.random.stateless_uniform([n, 2], seed=seed, minval=None, maxval=None, dtype=SEED_DTYPE) if isinstance(n, six.integer_types): seeds = tf.unstack(seeds) return seeds
def validate_per_replica_inputs(distribution_strategy, x): """Validates PerReplica dataset input list. Args: distribution_strategy: The current DistributionStrategy used to call `fit`, `evaluate` and `predict`. x: A list of PerReplica objects that represent the input or target values. Returns: List containing the first element of each of the PerReplica objects in the input list. Raises: ValueError: If any of the objects in the `per_replica_list` is not a tensor. """ # Convert the inputs and targets into a list of PerReplica objects. per_replica_list = tf.nest.flatten(x) x_values_list = [] for x in per_replica_list: # At this point x should contain only tensors. x_values = distribution_strategy.unwrap(x) for value in x_values: if not tf.is_tensor(value): raise ValueError( "Dataset input to the model should be tensors instead " "they are of type {}".format(type(value))) if not tf.executing_eagerly(): # Validate that the shape and dtype of all the elements in x are the # same. validate_all_tensor_shapes(x, x_values) validate_all_tensor_types(x, x_values) x_values_list.append(x_values[0]) return x_values_list
def _extract_type_spec_recursively(value): """Return (collection of) TypeSpec(s) for `value` if it includes `Tensor`s. If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If `value` is a collection containing `Tensor` values, recursively supplant them with their respective `TypeSpec`s in a collection of parallel stucture. If `value` is nont of the above, return it unchanged. Args: value: a Python `object` to (possibly) turn into a (collection of) `tf.TypeSpec`(s). Returns: spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` or `value`, if no `Tensor`s are found. """ if isinstance(value, composite_tensor.CompositeTensor): return value._type_spec # pylint: disable=protected-access if isinstance(value, tf.Variable): return resource_variable_ops.VariableSpec(value.shape, dtype=value.dtype, trainable=value.trainable) if tf.is_tensor(value): return tf.TensorSpec(value.shape, value.dtype) if tf.nest.is_nested(value): specs = tf.nest.map_structure(_extract_type_spec_recursively, value) was_tensor = tf.nest.flatten( tf.nest.map_structure(lambda a, b: a is not b, value, specs)) has_tensors = any(was_tensor) has_only_tensors = all(was_tensor) if has_tensors: if has_tensors != has_only_tensors: raise NotImplementedError( 'Found `{}` with both Tensor and non-Tensor parts: {}'. format(type(value), value)) return specs return value
def _update_loop_variables(step, current_step_results, accumulated_step_results, state_history): """Update the loop state to reflect a step of filtering.""" # Write particles, indices, and likelihoods to their respective arrays. accumulated_step_results = tf.nest.map_structure( lambda x, y: x.write(step, y), accumulated_step_results, current_step_results) history_is_empty = (tf.is_tensor(state_history) and state_history.shape[0] == 0) if not history_is_empty: batch_shape = prefer_static.shape( current_step_results.parent_indices)[1:-1] batch_rank = prefer_static.rank_from_shape(batch_shape) # Permute the particles from previous steps to match the current resampled # indices, so that the state history reflects coherent trajectories. def update_state_history_to_use_current_indices(x): return tf.gather(x[-1:], current_step_results.parent_indices, axis=batch_rank + 1, batch_dims=batch_rank) resampled_state_history = tf.nest.map_structure( update_state_history_to_use_current_indices, state_history) # Update the history by concat'ing the carried-forward elements with the # most recent state. state_history = tf.nest.map_structure( lambda h, s: tf.concat([h, s[tf.newaxis, ...]], axis=0), resampled_state_history, current_step_results.particles) return ParticleFilterLoopVariables( step=step + 1, previous_step_results=current_step_results, accumulated_step_results=accumulated_step_results, state_history=state_history)
def _to_components(self, obj): params = _kwargs_from(self._clsid, obj, limit_to=list(self._param_specs)) for k, v in params.items(): def reduce(spec, v): if isinstance(spec, (list, tuple)): v = type(spec)([reduce(sp, v_) for sp, v_ in zip(spec, v)]) elif not tf.is_tensor(v): v = spec._to_components(v) # pylint: disable=protected-access return v if not tf.is_tensor(v): try: params[k] = reduce(self._param_specs[k], v) except TypeError as e: raise NotImplementedError( _mk_err_msg( self._clsid, obj, '(Unable to convert dependent entry \'{}\': {})'. format(k, str(e)))) return params
def __init__(self, seed, salt): """Initializes a `SeedStream`. Args: seed: Any Python object convertible to string, with the exception of Tensor objects, which are reserved for stateless sampling semantics. The seed supplies the initial entropy. If `None`, operations seeded with seeds drawn from this `SeedStream` will follow TensorFlow semantics for not being seeded. salt: Any Python object convertible to string, supplying auxiliary entropy. Must be unique across the Distributions and TensorFlow Probability code base. See class docstring for rationale. """ self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed if JAX_MODE and isinstance(self._seed, int): import jax.random as jaxrand # pylint: disable=g-import-not-at-top self._seed = jaxrand.PRNGKey(self._seed) if not JAX_MODE: if tf.is_tensor(self._seed): raise TypeError('{}: {}'.format(TENSOR_SEED_MSG_PREFIX, self._seed)) self._salt = salt self._counter = 0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None): # pylint: disable=unused-argument """Emulates tf.convert_to_tensor.""" assert not tf.is_tensor(value), value if is_tensor(value): if dtype is not None: dtype = utils.numpy_dtype(dtype) # if np.result_type(value, dtype) != dtype: # raise ValueError('Expected dtype {} but got {} with dtype {}.'.format( # dtype, value, value.dtype)) return value.astype(dtype) return value if isinstance(value, Dimension): value = _dimension_value(value) if isinstance(value, TensorShape): value = [_dimension_value(d) for d in value.as_list()] if tf.nest.is_nested(value): value = tf.nest.map_structure(_convert_to_tensor, value) if dtype is None and dtype_hint is not None: dtype_hint = utils.numpy_dtype(dtype_hint) value = np.array(value) if np.size(value): # Match TF behavior, which won't downcast e.g. float to int. if np.issubdtype(value.dtype, np.complexfloating): if not np.issubdtype(dtype_hint, np.complexfloating): return value if np.issubdtype(value.dtype, np.floating): if not (np.issubdtype(dtype_hint, np.floating) or np.issubdtype(dtype_hint, np.complexfloating)): return value if np.issubdtype(value.dtype, np.integer): if not (np.issubdtype(dtype_hint, np.integer) or np.issubdtype(dtype_hint, np.floating) or np.issubdtype(dtype_hint, np.complexfloating)): return value return value.astype(dtype_hint) return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
def convert_to_dtype(tensor_or_dtype, dtype=None, dtype_hint=None): """Get a dtype from a list/tensor/dtype using convert_to_tensor semantics.""" if tensor_or_dtype is None: return dtype or dtype_hint # Tensorflow dtypes need to be typechecked if tf.is_tensor(tensor_or_dtype): dt = base_dtype(tensor_or_dtype.dtype) elif isinstance(tensor_or_dtype, tf.DType): dt = base_dtype(tensor_or_dtype) # Numpy dtypes defer to dtype/dtype_hint elif isinstance(tensor_or_dtype, np.ndarray): dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype) elif np.issctype(tensor_or_dtype): dt = base_dtype(dtype or dtype_hint or tensor_or_dtype) else: # If this is a Python object, call `convert_to_tensor` and grab the dtype. # Note that this will add ops in graph-mode; we may want to consider # other ways to handle this case. dt = tf.convert_to_tensor(tensor_or_dtype, dtype, dtype_hint).dtype if not SKIP_DTYPE_CHECKS and dtype and not base_equal(dtype, dt): raise TypeError('Found incompatible dtypes, {} and {}.'.format(dtype, dt)) return dt
def is_tensor_or_variable(x): return tf.is_tensor(x) or isinstance(x, tf.Variable)
def tensor_to_f64(x): if tf.is_tensor(x) and x.dtype.is_floating: return tf.cast(x, dtype=tf.float64) else: return x