def floating_tensor_to_f64(x):
  """Cast x to float64 if a floating-point Tensor, else return it.

  This is meant to be used with `tf.nest.map_structure`, or other
  situations where it may not be obvious whether an object is a Tensor
  or not, or has floating dtype or not.

  Args:
    x: The object to be cast or left be.

  Returns:
    x: x, either cast or left be.

  """
  if tf.is_tensor(x) and dtype_util.is_floating(x.dtype):
    return tf.cast(x, dtype=tf.float64)
  else:
    return x
Esempio n. 2
0
    def _mark_as_return(tensor):
        """Marks `tensor` as the return value for automatic control deps."""
        if not tf.is_tensor(tensor):
            return tensor

        # pylint: disable=protected-access
        return_tensor = acd.mark_as_return(tensor)
        if getattr(tensor, '_keras_mask', None) is not None:
            return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
        else:
            return_tensor._keras_mask = None

        # Handle TensorFlow Probability attached metadata.
        # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
        if getattr(tensor, '_tfp_distribution', None) is not None:
            return_tensor._tfp_distribution = tensor._tfp_distribution

        return return_tensor
def _all_shapes(thing):
  if isinstance(thing, (tfd.Distribution, tfb.Bijector)):
    # pylint: disable=g-complex-comprehension
    answer = [s for _, param in thing.parameters.items()
              for s in _all_shapes(param)]
    if isinstance(thing, tfd.TransformedDistribution):
      answer = [thing.batch_shape + s for s in answer]
    if isinstance(thing, tfd.Distribution):
      answer += [thing.batch_shape + thing.event_shape]
    if isinstance(thing, tfd.MixtureSameFamily):
      num_components = thing.mixture_distribution.logits_parameter().shape[-1]
      answer += [thing.batch_shape + [num_components] + thing.event_shape]
    return answer
  elif tf.is_tensor(thing):
    return [thing.shape]
  else:
    # Assume the thing is some Python constant like a string or a boolean
    return []
def stringify_slices(slices):
    """Returns a list of strings describing the items in `slices`."""
    pretty_slices = []
    slices = slices if isinstance(slices, tuple) else (slices, )
    for slc in slices:
        if slc == Ellipsis:
            pretty_slices.append('...')
        elif isinstance(slc, slice):
            pretty_slices.append('{}:{}:{}'.format(*[
                '' if s is None else s for s in (slc.start, slc.stop, slc.step)
            ]))
        elif isinstance(slc, int) or tf.is_tensor(slc):
            pretty_slices.append(str(slc))
        elif slc is tf.newaxis:
            pretty_slices.append('tf.newaxis')
        else:
            raise ValueError('Unexpected slice type: {}'.format(type(slc)))
    return pretty_slices
Esempio n. 5
0
def _get_static_value(pred):
    """Helper function for getting static values from maybe-tensor objects."""
    if JAX_MODE:
        try:
            return np.asarray(pred)
        except:  # JAX sometimes raises raw Exception in __array__.  # pylint: disable=bare-except
            return None
    if tf.is_tensor(pred):
        pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

        # TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
        # pylint: disable=protected-access
        if pred_value is None:
            pred_value = c_api.TF_TryEvaluateConstant_wrapper(
                pred.graph._c_graph, pred._as_tf_output())
        # pylint: enable=protected-access
        return pred_value
    return pred
def convert_to_ndarray(test_obj, a):
    """Converts the input `a` into an ndarray.

  Args:
    test_obj: An object which has the `evaluate` method. Used to evaluate `a` if
      `a` is a Tensor.
    a: Object to be converted to an ndarray.

  Returns:
    An ndarray containing the values of `a`.
  """
    # TODO(b/177990397): This function should be independent of the test framework
    # If a is tensor-like then convert it to ndarray
    if tf.is_tensor(a):
        a = test_obj.evaluate(a)
    if not isinstance(a, np.ndarray):
        return np.array(a)
    return a
Esempio n. 7
0
  def call(self, inputs, training=None, mask=None):  # pylint: disable=redefined-outer-name
    # If applicable, update the static input shape of the model.
    if not self._has_explicit_input_shape:
      if not tf.is_tensor(inputs) and not isinstance(
          inputs, tf.Tensor):
        # This is a Sequential with multiple inputs. This is technically an
        # invalid use case of Sequential, but we tolerate it for backwards
        # compatibility.
        self._use_legacy_deferred_behavior = True
        self._build_input_shape = tf.nest.map_structure(
            _get_shape_tuple, inputs)
        if tf.__internal__.tf2.enabled():
          logging.warning('Layers in a Sequential model should only have a '
                          f'single input tensor. Received: inputs={inputs}. '
                          'Consider rewriting this model with the Functional '
                          'API.')
      else:
        self._build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype)

    if self._graph_initialized:
      if not self.built:
        self._init_graph_network(self.inputs, self.outputs)
      return super().call(inputs, training=training, mask=mask)

    outputs = inputs  # handle the corner case where self.layers is empty
    for layer in self.layers:
      # During each iteration, `inputs` are the inputs to `layer`, and `outputs`
      # are the outputs of `layer` applied to `inputs`. At the end of each
      # iteration `inputs` is set to `outputs` to prepare for the next layer.
      kwargs = {}
      argspec = self._layer_call_argspecs[layer].args
      if 'mask' in argspec:
        kwargs['mask'] = mask
      if 'training' in argspec:
        kwargs['training'] = training

      outputs = layer(inputs, **kwargs)

      if len(tf.nest.flatten(outputs)) != 1:
        raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
      # `outputs` will be the inputs to the next layer.
      inputs = outputs
      mask = getattr(outputs, '_keras_mask', None)
    return outputs
Esempio n. 8
0
def _extract_init_kwargs(obj,
                         omit_kwargs=(),
                         limit_to=None,
                         prefer_static_value=()):
    """Extract constructor kwargs to reconstruct `obj`."""
    sig = tf_inspect.signature(obj.__init__)
    if any(v.kind in (tf_inspect.Parameter.VAR_KEYWORD,
                      tf_inspect.Parameter.VAR_POSITIONAL)
           for v in sig.parameters.values()):
        raise ValueError(
            '*args and **kwargs are not supported. Found `{}`'.format(sig))

    keys = [p for p in sig.parameters if p != 'self' and p not in omit_kwargs]
    if limit_to is not None:
        keys = [k for k in keys if k in limit_to]

    kwargs = {}
    not_found = object()
    for k in keys:
        srcs = [
            getattr(obj, k, not_found),
            getattr(obj, '_' + k, not_found),
            getattr(obj, 'parameters', {}).get(k, not_found),
        ]
        if any(v is not not_found for v in srcs):
            kwargs[k] = [v for v in srcs if v is not not_found][0]
        else:
            raise ValueError(
                f'Could not determine an appropriate value for field `{k}` in object '
                ' `{obj}`. Looked for \n'
                ' 1. an attr called `{k}`,\n'
                ' 2. an attr called `_{k}`,\n'
                ' 3. an entry in `obj.parameters` with key "{k}".')
        if k in prefer_static_value and kwargs[k] is not None:
            if tf.is_tensor(kwargs[k]):
                static_val = tf.get_static_value(kwargs[k])
                if static_val is not None:
                    kwargs[k] = static_val
        if isinstance(kwargs[k], (np.ndarray, np.generic)):
            # Generally, these are shapes or int, but may be other parameters such as
            # `power` for `tfb.PowerTransform`.
            kwargs[k] = kwargs[k].tolist()
    return kwargs
Esempio n. 9
0
 def reduce(v):
     has_tensors = False
     if tf.is_tensor(v):
         v = tf.TensorSpec.from_tensor(v)
         has_tensors = True
     if isinstance(v, CompositeTensor):
         v = v._type_spec  # pylint: disable=protected-access
         has_tensors = True
     if isinstance(v, (list, tuple)):
         reduced = [reduce(v_) for v_ in v]
         has_tensors = any(ht for (_, ht) in reduced)
         if has_tensors != all(ht for (_, ht) in reduced):
             raise NotImplementedError(
                 _mk_err_msg(
                     clsid, self,
                     'Found `{}` with both Tensor and non-Tensor parts: {}'
                     .format(type(v), v)))
         v = type(v)([spec for (spec, _) in reduced])
     return v, has_tensors
Esempio n. 10
0
            def check_consistency(tf_fn, np_fn, args):
                # If `args` is a single item, put it in a tuple
                if isinstance(args, np.ndarray) or tf.is_tensor(args):
                    args = (args, )
                tensorflow_value = self.evaluate(
                    tf_fn(*_maybe_convert_to_tensors(args)))
                kwargs = jax_kwargs() if MODE_JAX else {}
                numpy_value = np_fn(*args, **kwargs)
                if assert_shape_only:

                    def assert_same_shape(x, y):
                        self.assertAllEqual(x.shape, y.shape)

                    tf.nest.map_structure(assert_same_shape, tensorflow_value,
                                          numpy_value)
                else:
                    self.assertAllCloseAccordingToType(tensorflow_value,
                                                       numpy_value,
                                                       atol=atol,
                                                       rtol=rtol)
Esempio n. 11
0
    def __init__(self, value):
        """Creates a new `_DeferredTensorInput`.

    Args:
      value: either a `Tensor`-like object, or a nullary function returning such
        an object.
    """
        if isinstance(value, helpers.RateObject):
            raise TypeError(
                "a DeferredTensor may only be created from a Tensor-like object, "
                "or a nullary function returning such")
        self._value = value

        # For non-callable non-Tensor types, we make a deep copy to make
        # extra-certain that it is immutable (since we'll hash it).
        if not callable(self._value) and not tf.is_tensor(self._value):
            self._value = copy.deepcopy(self._value)

        # We memoize the hash, since it can be expensive to compute.
        self._hash = None
Esempio n. 12
0
def split_seed(seed, n=2, salt=None, name=None):
    """Splits a seed deterministically into derived seeds."""
    if not (isinstance(n, int)
            or tf.is_tensor(n)):  # avoid confusion with salt.
        raise TypeError(
            '`n` must be a python `int` or an int Tensor, got {}'.format(
                repr(n)))
    with tf.name_scope(name or 'split'):
        seed = sanitize_seed(seed, salt=salt)
        if JAX_MODE:
            from jax import random as jaxrand  # pylint: disable=g-import-not-at-top
            return jaxrand.split(seed, n)
        seeds = tf.random.stateless_uniform([n, 2],
                                            seed=seed,
                                            minval=None,
                                            maxval=None,
                                            dtype=SEED_DTYPE)
        if isinstance(n, six.integer_types):
            seeds = tf.unstack(seeds)
        return seeds
Esempio n. 13
0
def _assert_batch_shape_matches_weights(distribution, weights_shape, diststr):
    """Checks that all parts of a distribution have the expected batch shape."""
    shapes = [weights_shape] + tf.nest.flatten(
        distribution.batch_shape_tensor())
    static_shapes = [
        tf.get_static_value(ps.convert_to_shape_tensor(s)) for s in shapes
    ]
    static_shapes_not_none = [s for s in static_shapes if s is not None]
    static_shapes_match = all([
        np.all(a == b)  # Also need to check for rank mismatch (below).
        for (a,
             b) in zip(static_shapes_not_none[1:], static_shapes_not_none[:-1])
    ])

    # Build a separate list of static ranks, since rank is often static even when
    # shape is not.
    ranks = [ps.rank_from_shape(s) for s in shapes]
    static_ranks = [int(r) for r in ranks if not tf.is_tensor(r)]
    static_ranks_match = all(
        [a == b for (a, b) in zip(static_ranks[1:], static_ranks[:-1])])

    msg = (
        "The {diststr} distribution's batch shape does not match the particle "
        "weights; a correct {diststr} distribution must return an independent "
        "log-density for each particle. You may be "
        "creating a joint distribution in which some parts do not depend on the "
        "previous particles, and/or you are creating an autobatched joint "
        "distribution without setting `batch_ndims`.".format(diststr=diststr))
    if not (static_ranks_match and static_shapes_match):
        raise ValueError(
            msg + ' ' +
            'Weights have shape {}, but the distribution has batch '
            'shape {}.'.format(weights_shape, distribution.batch_shape))

    assertions = []
    if distribution.validate_args and any([s is None for s in static_shapes]):
        assertions = [
            assert_util.assert_equal(a, b, message=msg)
            for a, b in zip(shapes[1:], shapes[:-1])
        ]
    return assertions
Esempio n. 14
0
        def _type_spec(self):
            def get_default_args(fn_or_object):
                fn = type(fn_or_object) if isinstance(fn_or_object,
                                                      object) else fn_or_object
                return {
                    k: v.default
                    for k, v in inspect.signature(fn).parameters.items()
                    if v.default is not inspect.Parameter.empty
                }

            if six.PY3:
                default_kwargs = get_default_args(self)
                missing = object()
                kwargs = {
                    k: v
                    for k, v in self.parameters.items()
                    if default_kwargs.get(k, missing) is not v
                }  # non-default kwargs only
            else:
                kwargs = dict(self.parameters)
            param_specs = {}
            try:
                composite_tensor_params = self._composite_tensor_params  # pylint: disable=protected-access
            except NotImplementedError:
                composite_tensor_params = ()
            for k in composite_tensor_params:
                if k in kwargs and kwargs[k] is not None:
                    v = kwargs.pop(k)
                    if isinstance(v, CompositeTensor):
                        param_specs[k] = v._type_spec  # pylint: disable=protected-access
                    elif tf.is_tensor(v):
                        param_specs[k] = tf.TensorSpec.from_tensor(v)
            for k, v in list(kwargs.items()):
                if isinstance(v, CompositeTensor):
                    param_specs[k] = v._type_spec  # pylint: disable=protected-access
                    kwargs.pop(k)
                elif callable(v):
                    raise NotImplementedError(
                        'Unable to make CompositeTensor including callable argument.'
                        + k)
            return _TFPTypeSpec(clsid, param_specs=param_specs, kwargs=kwargs)
Esempio n. 15
0
 def convert_fn(path, value, dtype, dtype_hint, name=None):
     if not allow_packing and nest.is_nested(value) and any(
             # Treat arrays like Tensors for full parity in JAX backend.
             tf.is_tensor(x) or isinstance(x, np.ndarray)
             for x in nest.flatten(value)):
         raise NotImplementedError(
             ('Cannot convert a structure of tensors to a '
              'single tensor. Saw {} at path {}.').format(value, path))
     if as_shape_tensor:
         return ps.convert_to_shape_tensor(value,
                                           dtype,
                                           dtype_hint,
                                           name=name)
     elif 'KerasTensor' in str(type(value)):
         # This is a hack to detect symbolic Keras tensors to work around
         # b/206660667.  The issue was that symbolic Keras tensors would
         # break the Bijector cache on forward/inverse log det jacobian,
         # because tf.convert_to_tensor is not a no-op thereon.
         return value
     else:
         return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
Esempio n. 16
0
      def check_consistency(tf_fn, np_fn, args):
        # If `args` is a single item, put it in a tuple
        if isinstance(args, np.ndarray) or tf.is_tensor(args):
          args = (args,)
        tensorflow_value = self.evaluate(
            tf_fn(*_maybe_convert_to_tensors(args)))
        kwargs = jax_kwargs() if JAX_MODE else {}
        numpy_value = np_fn(*args, **kwargs)
        if assert_shape_only:

          def assert_same_shape(x, y):
            self.assertAllEqual(x.shape, y.shape)

          tf.nest.map_structure(assert_same_shape, tensorflow_value,
                                numpy_value)
        else:
          for i, (tf_val, np_val) in enumerate(six.moves.zip_longest(
              tf.nest.flatten(tensorflow_value), tf.nest.flatten(numpy_value))):
            self.assertAllCloseAccordingToType(
                tf_val, np_val, atol=atol, rtol=rtol,
                msg='output {}'.format(i))
Esempio n. 17
0
def _get_static_predicate(pred):
  """Helper function for statically evaluating predicates in `cond`."""
  if tf.is_tensor(pred):
    pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

    # TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
    # pylint: disable=protected-access
    if pred_value is None:
      pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
                                                        pred._as_tf_output())
    # pylint: enable=protected-access
    if pred_value in (0, 1, True, False):
      pred_value = bool(pred_value)

  elif pred in (0, 1, True, False):  # Accept 1/0 as valid boolean values
    # This branch also casts np.array(False), tf.EagerTensor(True), etc.
    pred_value = bool(pred)
  else:
    raise TypeError('`pred` must be a Tensor, or a Python bool, or 1 or 0. '
                    'Found instead: {}'.format(pred))
  return pred_value
Esempio n. 18
0
def get_discriminator_focal_loss(learner_agent_output, env_output,
                                 actor_agent_output, actor_action,
                                 reward_clipping, discounting, baseline_cost,
                                 entropy_cost, num_steps):
    """Discriminator focal loss."""
    # Check if learner_agent_output has logits and labels prepared already,
    # otherwise filter and prepare the logits and labels.
    if (isinstance(learner_agent_output.policy_logits, dict)
            and 'labels' in learner_agent_output.policy_logits
            and tf.is_tensor(learner_agent_output.baseline)):
        # Shape: [batch]
        labels = learner_agent_output.policy_logits['labels']
        output_logits = learner_agent_output.baseline
        tf.debugging.assert_equal(tf.shape(labels), tf.shape(output_logits))
        loss_tag = 'loss/focal_loss'
    else:
        # labels and output_logits have shape: [batch].
        labels, output_logits = _get_discriminator_logits(
            learner_agent_output, env_output, actor_agent_output, actor_action,
            reward_clipping, discounting, baseline_cost, entropy_cost,
            num_steps)
        loss_tag = 'loss/focal_loss (w/ softmin_softmax)'

    # Shape = [batch]
    fl, ce = focal_loss_from_logits(output_logits,
                                    labels,
                                    alpha=FLAGS.focal_loss_alpha,
                                    gamma=FLAGS.focal_loss_gamma,
                                    normalizer=FLAGS.focal_loss_normalizer)
    tf.summary.scalar(loss_tag, tf.reduce_mean(fl), step=num_steps)
    tf.summary.scalar('loss/CE (reference only)',
                      tf.reduce_mean(ce),
                      step=num_steps)
    tf.summary.scalar('labels/num_labels_per_batch',
                      tf.size(labels),
                      step=num_steps)
    tf.summary.scalar('labels/mean_labels',
                      tf.reduce_mean(labels),
                      step=num_steps)
    return fl
Esempio n. 19
0
def create_timbre_spectrogram(audio, hparams):
    """Create either a CQT or mel spectrogram"""
    if tf.is_tensor(audio):
        audio = audio.numpy()
    if isinstance(audio, bytes):
        # Get samples from wav data.
        samples = audio_io.wav_data_to_samples(audio, hparams.sample_rate)
    else:
        samples = audio

    if hparams.timbre_spec_type == 'mel':
        spec = np.abs(
            librosa.feature.melspectrogram(
                samples,
                hparams.sample_rate,
                hop_length=hparams.timbre_hop_length,
                fmin=librosa.midi_to_hz(constants.MIN_TIMBRE_PITCH),
                fmax=librosa.midi_to_hz(constants.MAX_TIMBRE_PITCH),
                n_mels=constants.TIMBRE_SPEC_BANDS,
                pad_mode='symmetric',
                htk=hparams.spec_mel_htk,
                power=2)).T

    else:
        spec = np.abs(
            librosa.core.cqt(samples,
                             hparams.sample_rate,
                             hop_length=hparams.timbre_hop_length,
                             fmin=librosa.midi_to_hz(
                                 constants.MIN_TIMBRE_PITCH),
                             n_bins=constants.TIMBRE_SPEC_BANDS,
                             bins_per_octave=constants.BINS_PER_OCTAVE,
                             pad_mode='symmetric')).T

    # convert amplitude to power
    if hparams.timbre_spec_log_amplitude:
        spec = librosa.power_to_db(spec) - librosa.power_to_db(np.array([1e-9
                                                                         ]))[0]
        spec = spec / np.max(spec)
    return spec
Esempio n. 20
0
def _get_static_value(pred):
    """Helper function for getting static values from maybe-tensor objects."""
    if JAX_MODE:
        try:
            return np.asarray(pred)
        except:  # JAX sometimes raises raw Exception in __array__.  # pylint: disable=bare-except
            return None
    if tf.is_tensor(pred):
        pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

        # Explicitly check for ops.Tensor, to avoid an AttributeError
        # when requesting `KerasTensor.graph`.
        if pred_value is None and isinstance(pred, ops.Tensor):
            if hasattr(tensor_util, 'try_evaluate_constant'):
                pred_value = tensor_util.try_evaluate_constant(pred)
            else:
                # TODO(feyu): remove this branch after try_evaluate_constant is in
                # tf-nightly.
                pred_value = c_api.TF_TryEvaluateConstant_wrapper(
                    pred.graph._c_graph, pred._as_tf_output())  # pylint: disable=protected-access
        return pred_value
    return pred
Esempio n. 21
0
def split_seed(seed, n=2, salt=None, name=None):
    """Splits a seed into `n` derived seeds.

  See https://github.com/tensorflow/probability/blob/main/PRNGS.md
  for details.
  Args:
    seed: The seed to split; may be an `int`, an `(int, int) tuple`, or a
      `Tensor`. `int` seeds are converted to `Tensor` seeds using
      `tf.random.uniform` stateful sampling. Tuples are converted to `Tensor`.
    n: The number of splits to return. In TensorFlow, if `n` is an integer, this
      function returns a list of seeds and otherwise returns a `Tensor` of
      seeds.  In JAX, this function always returns an array of seeds.
    salt: Optional `str` salt to mix with the seed.
    name: Optional name to scope related ops.

  Returns:
    seeds: If `n` is a Python `int`, a `tuple` of seed values is returned. If
      `n` is an int `Tensor`, a single `Tensor` of shape `[n, 2]` is returned. A
      single such seed is suitable to pass as the `seed` argument of the
      `tf.random.stateless_*` ops.
  """
    if not (isinstance(n, int) or isinstance(n, np.ndarray)
            or tf.is_tensor(n)):  # avoid confusion with salt.
        raise TypeError(
            '`n` must be a python `int` or an int Tensor, got {}'.format(
                repr(n)))
    with tf.name_scope(name or 'split_seed'):
        seed = sanitize_seed(seed, salt=salt)
        if JAX_MODE:
            from jax import random as jaxrand  # pylint: disable=g-import-not-at-top
            return jaxrand.split(seed, int(n))
        seeds = tf.random.stateless_uniform([n, 2],
                                            seed=seed,
                                            minval=None,
                                            maxval=None,
                                            dtype=SEED_DTYPE)
        if isinstance(n, six.integer_types):
            seeds = tf.unstack(seeds)
        return seeds
def validate_per_replica_inputs(distribution_strategy, x):
    """Validates PerReplica dataset input list.

    Args:
      distribution_strategy: The current DistributionStrategy used to call
        `fit`, `evaluate` and `predict`.
      x: A list of PerReplica objects that represent the input or
        target values.

    Returns:
      List containing the first element of each of the PerReplica objects in
      the input list.

    Raises:
      ValueError: If any of the objects in the `per_replica_list` is not a
        tensor.

    """
    # Convert the inputs and targets into a list of PerReplica objects.
    per_replica_list = tf.nest.flatten(x)
    x_values_list = []
    for x in per_replica_list:
        # At this point x should contain only tensors.
        x_values = distribution_strategy.unwrap(x)
        for value in x_values:
            if not tf.is_tensor(value):
                raise ValueError(
                    "Dataset input to the model should be tensors instead "
                    "they are of type {}".format(type(value)))

        if not tf.executing_eagerly():
            # Validate that the shape and dtype of all the elements in x are the
            # same.
            validate_all_tensor_shapes(x, x_values)
        validate_all_tensor_types(x, x_values)

        x_values_list.append(x_values[0])
    return x_values_list
def _extract_type_spec_recursively(value):
    """Return (collection of) TypeSpec(s) for `value` if it includes `Tensor`s.

  If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
  `value` is a collection containing `Tensor` values, recursively supplant them
  with their respective `TypeSpec`s in a collection of parallel stucture.

  If `value` is nont of the above, return it unchanged.

  Args:
    value: a Python `object` to (possibly) turn into a (collection of)
    `tf.TypeSpec`(s).

  Returns:
    spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
    or `value`, if no `Tensor`s are found.
  """
    if isinstance(value, composite_tensor.CompositeTensor):
        return value._type_spec  # pylint: disable=protected-access
    if isinstance(value, tf.Variable):
        return resource_variable_ops.VariableSpec(value.shape,
                                                  dtype=value.dtype,
                                                  trainable=value.trainable)
    if tf.is_tensor(value):
        return tf.TensorSpec(value.shape, value.dtype)
    if tf.nest.is_nested(value):
        specs = tf.nest.map_structure(_extract_type_spec_recursively, value)
        was_tensor = tf.nest.flatten(
            tf.nest.map_structure(lambda a, b: a is not b, value, specs))
        has_tensors = any(was_tensor)
        has_only_tensors = all(was_tensor)
        if has_tensors:
            if has_tensors != has_only_tensors:
                raise NotImplementedError(
                    'Found `{}` with both Tensor and non-Tensor parts: {}'.
                    format(type(value), value))
            return specs
    return value
Esempio n. 24
0
def _update_loop_variables(step, current_step_results,
                           accumulated_step_results, state_history):
    """Update the loop state to reflect a step of filtering."""

    # Write particles, indices, and likelihoods to their respective arrays.
    accumulated_step_results = tf.nest.map_structure(
        lambda x, y: x.write(step, y), accumulated_step_results,
        current_step_results)

    history_is_empty = (tf.is_tensor(state_history)
                        and state_history.shape[0] == 0)
    if not history_is_empty:
        batch_shape = prefer_static.shape(
            current_step_results.parent_indices)[1:-1]
        batch_rank = prefer_static.rank_from_shape(batch_shape)

        # Permute the particles from previous steps to match the current resampled
        # indices, so that the state history reflects coherent trajectories.
        def update_state_history_to_use_current_indices(x):
            return tf.gather(x[-1:],
                             current_step_results.parent_indices,
                             axis=batch_rank + 1,
                             batch_dims=batch_rank)

        resampled_state_history = tf.nest.map_structure(
            update_state_history_to_use_current_indices, state_history)

        # Update the history by concat'ing the carried-forward elements with the
        # most recent state.
        state_history = tf.nest.map_structure(
            lambda h, s: tf.concat([h, s[tf.newaxis, ...]], axis=0),
            resampled_state_history, current_step_results.particles)

    return ParticleFilterLoopVariables(
        step=step + 1,
        previous_step_results=current_step_results,
        accumulated_step_results=accumulated_step_results,
        state_history=state_history)
Esempio n. 25
0
    def _to_components(self, obj):
        params = _kwargs_from(self._clsid,
                              obj,
                              limit_to=list(self._param_specs))
        for k, v in params.items():

            def reduce(spec, v):
                if isinstance(spec, (list, tuple)):
                    v = type(spec)([reduce(sp, v_) for sp, v_ in zip(spec, v)])
                elif not tf.is_tensor(v):
                    v = spec._to_components(v)  # pylint: disable=protected-access
                return v

            if not tf.is_tensor(v):
                try:
                    params[k] = reduce(self._param_specs[k], v)
                except TypeError as e:
                    raise NotImplementedError(
                        _mk_err_msg(
                            self._clsid, obj,
                            '(Unable to convert dependent entry \'{}\': {})'.
                            format(k, str(e))))
        return params
  def __init__(self, seed, salt):
    """Initializes a `SeedStream`.

    Args:
      seed: Any Python object convertible to string, with the exception of
        Tensor objects, which are reserved for stateless sampling semantics.
        The seed supplies the initial entropy.  If `None`, operations seeded
        with seeds drawn from this `SeedStream` will follow TensorFlow semantics
        for not being seeded.
      salt: Any Python object convertible to string, supplying
        auxiliary entropy.  Must be unique across the Distributions
        and TensorFlow Probability code base.  See class docstring for
        rationale.
    """
    self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed
    if JAX_MODE and isinstance(self._seed, int):
      import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
      self._seed = jaxrand.PRNGKey(self._seed)
    if not JAX_MODE:
      if tf.is_tensor(self._seed):
        raise TypeError('{}: {}'.format(TENSOR_SEED_MSG_PREFIX, self._seed))
    self._salt = salt
    self._counter = 0
Esempio n. 27
0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None):  # pylint: disable=unused-argument
    """Emulates tf.convert_to_tensor."""
    assert not tf.is_tensor(value), value
    if is_tensor(value):
        if dtype is not None:
            dtype = utils.numpy_dtype(dtype)
            # if np.result_type(value, dtype) != dtype:
            #   raise ValueError('Expected dtype {} but got {} with dtype {}.'.format(
            #       dtype, value, value.dtype))
            return value.astype(dtype)
        return value
    if isinstance(value, Dimension):
        value = _dimension_value(value)
    if isinstance(value, TensorShape):
        value = [_dimension_value(d) for d in value.as_list()]
    if tf.nest.is_nested(value):
        value = tf.nest.map_structure(_convert_to_tensor, value)
    if dtype is None and dtype_hint is not None:
        dtype_hint = utils.numpy_dtype(dtype_hint)
        value = np.array(value)
        if np.size(value):
            # Match TF behavior, which won't downcast e.g. float to int.
            if np.issubdtype(value.dtype, np.complexfloating):
                if not np.issubdtype(dtype_hint, np.complexfloating):
                    return value
            if np.issubdtype(value.dtype, np.floating):
                if not (np.issubdtype(dtype_hint, np.floating)
                        or np.issubdtype(dtype_hint, np.complexfloating)):
                    return value
            if np.issubdtype(value.dtype, np.integer):
                if not (np.issubdtype(dtype_hint, np.integer)
                        or np.issubdtype(dtype_hint, np.floating)
                        or np.issubdtype(dtype_hint, np.complexfloating)):
                    return value
        return value.astype(dtype_hint)
    return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
Esempio n. 28
0
def convert_to_dtype(tensor_or_dtype, dtype=None, dtype_hint=None):
  """Get a dtype from a list/tensor/dtype using convert_to_tensor semantics."""
  if tensor_or_dtype is None:
    return dtype or dtype_hint

  # Tensorflow dtypes need to be typechecked
  if tf.is_tensor(tensor_or_dtype):
    dt = base_dtype(tensor_or_dtype.dtype)
  elif isinstance(tensor_or_dtype, tf.DType):
    dt = base_dtype(tensor_or_dtype)
  # Numpy dtypes defer to dtype/dtype_hint
  elif isinstance(tensor_or_dtype, np.ndarray):
    dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
  elif np.issctype(tensor_or_dtype):
    dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
  else:
    # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    # Note that this will add ops in graph-mode; we may want to consider
    # other ways to handle this case.
    dt = tf.convert_to_tensor(tensor_or_dtype, dtype, dtype_hint).dtype

  if not SKIP_DTYPE_CHECKS and dtype and not base_equal(dtype, dt):
    raise TypeError('Found incompatible dtypes, {} and {}.'.format(dtype, dt))
  return dt
Esempio n. 29
0
def is_tensor_or_variable(x):
    return tf.is_tensor(x) or isinstance(x, tf.Variable)
def tensor_to_f64(x):
    if tf.is_tensor(x) and x.dtype.is_floating:
        return tf.cast(x, dtype=tf.float64)
    else:
        return x