Exemplo n.º 1
0
 def test_list(self):
   a = [np.ones(10), np.ones(20)]
   model_inputs = training_utils_v1.ModelInputs(a)
   self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
   vals = model_inputs.get_symbolic_inputs()
   self.assertTrue(tensor_util.is_tf_type(vals[0]))
   self.assertTrue(tensor_util.is_tf_type(vals[1]))
Exemplo n.º 2
0
def verify_single_cond_var(name, body_var, orelse_var):
    """Verifies whether body_var and orelse_var are consistent."""
    if body_var is None:
        raise ValueError(
            "'{}' is None at the end of the main branch.".format(name))
    if orelse_var is None:
        raise ValueError(
            "'{}' is None at the end of the else branch.".format(name))

    if isinstance(body_var, (bool, int, float, str, np.ndarray)):
        body_var = ops.convert_to_tensor_v2(body_var)

    if isinstance(orelse_var, (bool, int, float, str, np.ndarray)):
        orelse_var = ops.convert_to_tensor_v2(orelse_var)

    if (not tensor_util.is_tf_type(body_var)
            or not tensor_util.is_tf_type(orelse_var)):
        return

    # TODO(mdan): Properly account for CompositeTensors.
    if (not hasattr(body_var, 'dtype') or not hasattr(orelse_var, 'dtype')):
        return

    if body_var.dtype != orelse_var.dtype:
        raise TypeError(
            "'{}' has dtype {} in the main branch, but dtype {} in the else"
            ' branch'.format(name, body_var.dtype.name, orelse_var.dtype.name))
Exemplo n.º 3
0
 def test_dict(self):
     a = {'b': np.ones(10), 'a': np.ones(20)}
     model_inputs = training_utils_v1.ModelInputs(a)
     self.assertEqual(['a', 'b'], model_inputs.get_input_names())
     vals = model_inputs.get_symbolic_inputs()
     self.assertTrue(tensor_util.is_tf_type(vals['a']))
     self.assertTrue(tensor_util.is_tf_type(vals['b']))
Exemplo n.º 4
0
 def testIndexedSlices(self):
   x = indexed_slices.IndexedSlices(
       constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30]))
   x_value = indexed_slices.IndexedSlicesValue(
       np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100]))
   self.assertTrue(tensor_util.is_tf_type(x))
   self.assertFalse(tensor_util.is_tf_type(x_value))
Exemplo n.º 5
0
 def test_single_thing(self):
     a = np.ones(10)
     model_inputs = training_utils_v1.ModelInputs(a)
     self.assertEqual(['input_1'], model_inputs.get_input_names())
     vals = model_inputs.get_symbolic_inputs()
     self.assertTrue(tensor_util.is_tf_type(vals))
     vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
     self.assertEqual(1, len(vals))
     self.assertTrue(tensor_util.is_tf_type(vals[0]))
     self.assertEqual(backend.floatx(), vals[0].dtype)
Exemplo n.º 6
0
def _verify_single_loop_var(
    name, check_shape, init, entry, exit_, shape_invariant):
  """Verifies whether the initial, entry and exit values are consistent."""
  assert entry is not None, "no TF op should set '{}' to None?".format(name)
  if exit_ is None:
    raise ValueError("'{}' is None at the end of the iteration.".format(name))

  if isinstance(init, (bool, int, float, str, np.ndarray)):
    init = ops.convert_to_tensor_v2(init)
  if isinstance(entry, (bool, int, float, str, np.ndarray)):
    entry = ops.convert_to_tensor_v2(entry)
  if isinstance(exit_, (bool, int, float, str, np.ndarray)):
    exit_ = ops.convert_to_tensor_v2(exit_)

  if (not tensor_util.is_tf_type(entry) or
      not tensor_util.is_tf_type(exit_)):
    return

  # TODO(mdan): Properly account for CompositeTensors.
  if (not hasattr(entry, 'dtype') or
      not hasattr(exit_, 'dtype')):
    return
  if (not hasattr(entry, 'shape') or
      not hasattr(exit_, 'shape')):
    return

  if entry.dtype != exit_.dtype:
    raise TypeError(
        "'{}' has dtype {} before the loop, but dtype {} after one"
        ' iteration'.format(
            name,
            entry.dtype.name,
            exit_.dtype.name,
        ))
  if check_shape:
    exit_shape = exit_.shape
    if shape_invariant is None:
      entry_shape = entry.shape
      if not _is_subshape(exit_shape, entry_shape):
        raise ValueError(
            "'{}' has shape {} before the loop, but shape {} after one"
            ' iteration. Use tf.autograph.experimental.set_loop_options to set'
            ' shape invariants.'.format(name, entry_shape, exit_shape))
    else:
      init_shape = init.shape
      if not _is_subshape(init_shape, shape_invariant):
        raise ValueError(
            "'{}' has shape {} before the loop, which does not conform with"
            ' the shape invariant {}.'.format(name, init_shape,
                                              shape_invariant))
      if not _is_subshape(exit_shape, shape_invariant):
        raise ValueError(
            "'{}' has shape {} after one iteration, which does not conform with"
            ' the shape invariant {}.'.format(
                name, exit_shape, shape_invariant))
Exemplo n.º 7
0
def is_tensor_list(t):
    # TODO(mdan): This is just a heuristic.
    # With TF lacking support for templated types, this is unfortunately the
    # closest we can get right now. A dedicated op ought to be possible to
    # construct.
    return (tensor_util.is_tf_type(t) and t.dtype == dtypes.variant
            and not t.shape.ndims)
Exemplo n.º 8
0
def assert_stmt(expression1, expression2):
  """Functional form of an assert statement.

  This follows the semantics of the Python assert statement, however the
  concrete implementations may deviate from it. See the respective
  implementation for details.

  In general, the assert statement should not be used for control flow.
  Furthermore, it is encouraged that the assertion expressions should not have
  side effects.

  Args:
    expression1: Any
    expression2: Callable[[], Any], returns the expression to include in the
        error message when expression1 evaluates to False. When expression1 is
        True, the result of expression2 will not be evaluated, however,
        expression2 itself may be evaluated in some implementations.

  Returns:
    Any, implementation-dependent.

  Raises:
    ValueError: if any arguments are illegal.
  """
  if not callable(expression2):
    raise ValueError('{} must be a callable'.format(expression2))
  args, _, keywords, _ = tf_inspect.getargspec(expression2)
  if args or keywords:
    raise ValueError('{} may not have any arguments'.format(expression2))

  if tensor_util.is_tf_type(expression1):
    return _tf_assert_stmt(expression1, expression2)
  else:
    return _py_assert_stmt(expression1, expression2)
Exemplo n.º 9
0
def string_format(
        template: str,
        inputs: typing.Union[ragged_tensor.Ragged,
                             typing.List[ragged_tensor.RaggedOrDense]],
        placeholder="{}",
        summarize=3,
        name=None):
    """Version of tf.strings.format that handles RaggedTensors."""
    if tensor_util.is_tf_type(inputs) or ragged_tensor.is_ragged(inputs):
        inputs = [inputs]

    split_template = template.split(placeholder)
    if len(inputs) != len(split_template) - 1:
        raise ValueError(
            "num placeholders in template and num inputs must match"
            ": {} vs {}".format(len(split_template) - 1, len(inputs)))

    with ops.name_scope(name, "StringFormat", [inputs]):
        output_pieces = [constant_op.constant(split_template[0])]
        for i, input in enumerate(inputs):
            if ragged_tensor.is_ragged(input):
                output_pieces.append(ragged_tensor_to_string(input, summarize))
            else:
                output_pieces.append(
                    string_ops.string_format("{}", [input],
                                             summarize=summarize))
            output_pieces.append(constant_op.constant(split_template[i + 1]))
        if len(output_pieces) == 1:
            return output_pieces[0]
        else:
            return string_ops.reduce_join(output_pieces)
Exemplo n.º 10
0
def list_pop(list_, i, opts):
  """The list pop function.

  Note: it is unspecified where list_ will be mutated or not. If list_ is
  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
  list, it will be. In general, if the list is mutated then the return value
  should point to the original entity.

  Args:
    list_: An entity that supports pop semantics.
    i: Optional index to pop from. May be None.
    opts: A ListPopOpts.

  Returns:
    Tuple (x, out_list_):
      out_list_: same as list_, after the removal was performed.
      x: the removed element value.

  Raises:
    ValueError: if list_ is not of a known list-like type or the operation is
    not supported for that type.
  """
  assert isinstance(opts, ListPopOpts)

  if isinstance(list_, tensor_array_ops.TensorArray):
    raise ValueError('TensorArray does not support item removal')
  elif tensor_util.is_tf_type(list_):
    if list_.dtype == dtypes.variant:
      return _tf_tensor_list_pop(list_, i, opts)
    else:
      raise ValueError(
          'tensor lists are expected to be Tensors with dtype=tf.variant,'
          ' instead found %s' % list_)
  else:
    return _py_list_pop(list_, i)
Exemplo n.º 11
0
def list_append(list_, x):
  """The list append function.

  Note: it is unspecified where list_ will be mutated or not. If list_ is
  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
  list, it will be. In general, if the list is mutated then the return value
  should point to the original entity.

  Args:
    list_: An entity that supports append semantics.
    x: The element to append.

  Returns:
    Same as list_, after the append was performed.

  Raises:
    ValueError: if list_ is not of a known list-like type.
  """
  if isinstance(list_, tensor_array_ops.TensorArray):
    return _tf_tensorarray_append(list_, x)
  elif tensor_util.is_tf_type(list_):
    if list_.dtype == dtypes.variant:
      return _tf_tensor_list_append(list_, x)
    else:
      raise ValueError(
          'tensor lists are expected to be Tensors with dtype=tf.variant,'
          ' instead found %s' % list_)
  else:
    return _py_list_append(list_, x)
Exemplo n.º 12
0
def _zeros(shape, dtype):
  """Helper to return (possibly cached) zero tensors in eager mode."""
  # Note: variants will use _zeros_like
  if dtype == dtypes.string or dtype == dtypes.resource:
    return None

  ctx = context.context()
  if not ctx.executing_eagerly():
    return array_ops.zeros(shape, dtype)

  device = ctx.device_name

  if tensor_util.is_tf_type(shape):
    shape_key = shape.ref()
  else:
    shape_key = shape
  cache_key = shape_key, dtype, device
  cached = ctx.zeros_cache().get(cache_key)
  if cached is None:
    if dtypes.as_dtype(dtype).is_bool:
      value = False
    else:
      value = 0
    cached = _fast_fill(value, shape, dtype)
    ctx.zeros_cache().put(cache_key, cached)
  return cached
Exemplo n.º 13
0
def list_stack(list_, opts):
  """The list stack function.

  This does not have a direct correspondent in Python. The closest idiom to
  this is tf.append or np.stack. It's different from those in the sense that it
  accepts a Tensor list, rather than a list of tensors. It can also accept
  TensorArray. When the target is anything else, the dispatcher will rely on
  ctx.original_call for fallback.

  Args:
    list_: An entity that supports append semantics.
    opts: A ListStackOpts object.

  Returns:
    The output of the stack operation, typically a Tensor.
  """
  assert isinstance(opts, ListStackOpts)

  if isinstance(list_, tensor_array_ops.TensorArray):
    return _tf_tensorarray_stack(list_)
  elif tensor_util.is_tf_type(list_):
    if list_.dtype == dtypes.variant:
      return _tf_tensor_list_stack(list_, opts)
    else:
      # No-op for primitive Tensor arguments.
      return list_
  else:
    return _py_list_stack(list_, opts)
Exemplo n.º 14
0
def set_item(target, i, x):
    """The slice write operator (i.e. __setitem__).

  Note: it is unspecified whether target will be mutated or not. In general,
  if target is mutable (like Python lists), it will be mutated.

  Args:
    target: An entity that supports setitem semantics.
    i: Index to modify.
    x: The new element value.

  Returns:
    Same as target, after the update was performed.

  Raises:
    ValueError: if target is not of a supported type.
  """
    if isinstance(target, tensor_array_ops.TensorArray):
        return _tf_tensorarray_set_item(target, i, x)
    elif tensor_util.is_tf_type(target):
        if target.dtype == dtypes.variant:
            return _tf_tensor_list_set_item(target, i, x)
        else:
            return _tf_tensor_set_item(target, i, x)
    else:
        return _py_set_item(target, i, x)
Exemplo n.º 15
0
def slice_arrays(arrays, indices, contiguous=True):
    """Slices batches out of provided arrays (workaround for eager tensors).

  Unfortunately eager tensors don't have the same slicing behavior as
  Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
  hence we cannot use `generic_utils.slice_arrays` directly
  and we have to implement this workaround based on `concat`. This has a
  performance cost.

  Args:
    arrays: Single array or list of arrays.
    indices: List of indices in the array that should be included in the output
      batch.
    contiguous: Boolean flag indicating whether the indices are contiguous.

  Returns:
    Slice of data (either single array or list of arrays).
  """
    converted_to_list = False
    if not isinstance(arrays, list):
        converted_to_list = True
        arrays = [arrays]
    if any(tensor_util.is_tf_type(x) for x in arrays):
        if not contiguous:
            entries = [[x[i:i + 1] for i in indices] for x in arrays]
            slices = [array_ops.concat(x, axis=0) for x in entries]
        else:
            slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
    else:
        slices = generic_utils.slice_arrays(arrays, indices)

    if converted_to_list:
        slices = slices[0]
    return slices
Exemplo n.º 16
0
 def print_wrapper(*vals):
   vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals)
   # TensorFlow doesn't seem to generate Unicode when passing strings to
   # py_func. This causes the print to add a "b'" wrapper to the output,
   # which is probably never what you want.
   vals = tuple(v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
   six.print_(*vals, **override_kwargs)
Exemplo n.º 17
0
    def _fetch_preprocessing_callback(fetch):
      """Extract out lists of ops, tensors, and tensor type info.

      Turns TensorInfos into Tensors in the original `fetches` structure.
      Also extracts ops from `fetches`.

      Args:
        fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
          string identifying a Tensor or Operation.

      Returns:
        `fetch` converted to a Tensor.
      """
      if isinstance(fetch, ops.Operation):
        operation_fetches.append(fetch)
        return fetch
      elif isinstance(fetch, meta_graph_pb2.TensorInfo):
        tensor_infos.append(fetch)
        decoded = _get_element_from_tensor_info(fetch, self._func_graph)
        if (tensor_util.is_tf_type(decoded) or
            isinstance(decoded, composite_tensor.CompositeTensor)):
          tensor_fetches.append(decoded)
        else:
          operation_fetches.append(decoded)
        return decoded
      elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
        tensor_fetches.append(fetch)
        return fetch
      else:
        graph_element = self.graph.as_graph_element(fetch)
        return _fetch_preprocessing_callback(graph_element)
Exemplo n.º 18
0
def get_item(target, i, opts):
    """The slice read operator (i.e. __getitem__).

  Note: it is unspecified whether target will be mutated or not. In general,
  if target is mutable (like Python lists), it will be mutated.

  Args:
    target: An entity that supports getitem semantics.
    i: Index to read from.
    opts: A GetItemOpts object.

  Returns:
    The read element.

  Raises:
    ValueError: if target is not of a supported type.
  """
    assert isinstance(opts, GetItemOpts)

    if isinstance(target, tensor_array_ops.TensorArray):
        return _tf_tensorarray_get_item(target, i)
    elif tensor_util.is_tf_type(target):
        if target.dtype == dtypes.variant:
            return _tf_tensor_list_get_item(target, i, opts)
        elif target.dtype == dtypes.string and target.shape.ndims == 0:
            return _tf_tensor_string_get_item(target, i)
        else:
            return _tf_tensor_get_item(target, i)
    else:
        return _py_get_item(target, i)
Exemplo n.º 19
0
    def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
        call_args = [] if call_args is None else call_args
        call_kwargs = {} if call_kwargs is None else call_kwargs
        outputs = [] if outputs is None else outputs

        self.layer = layer
        self.is_input = not call_args and not call_kwargs

        # These arguments are user-provided. Copy the structures here so that
        # future user modifications do not affect the node's metadata.
        # We copy using map_structure rather than python's shallow or deep copy,
        # because the args can be data structures (so shallow copy is
        # insufficient), but individual values might not support copy.copy
        # or be too expensive to deep copy.
        call_args = nest.map_structure(lambda t: t, call_args)
        call_kwargs = nest.map_structure(lambda t: t, call_kwargs)
        self.outputs = nest.map_structure(lambda t: t, outputs)
        self.call_args = call_args
        self.call_kwargs = call_kwargs

        # Cached for performance.
        self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs))
        # Used to avoid expensive `nest` operations in the most common case.
        self._single_positional_tensor_passed = (not self.call_kwargs and len(
            self.call_args) == 1 and tensor_util.is_tf_type(self.call_args[0]))

        if not ops.executing_eagerly_outside_functions():
            # Create TensorFlowOpLayers if needed (in TF1)
            for obj in self._flat_arguments:
                if (isinstance(obj, ops.Tensor)
                        and base_layer_utils.needs_keras_history(
                            obj, ignore_call_context=True)):
                    base_layer_utils.create_keras_history(obj)

        self._keras_inputs = []
        self._keras_inputs_ids_and_indices = []
        for i, ele in enumerate(self._flat_arguments):
            if is_keras_tensor(ele):
                self._keras_inputs.append(ele)
                kt_id = str(id(ele))
                kt_index = i
                self._keras_inputs_ids_and_indices.append((kt_id, kt_index))

        # Wire up Node to Layers.
        self.layer._inbound_nodes.append(self)
        for kt in self.keras_inputs:
            inbound_layer = kt._keras_history.layer
            if inbound_layer is not None:  # `None` for `Input` tensors.
                inbound_layer._outbound_nodes.append(self)

        # Set metadata on outputs.
        node_index = len(self.layer._inbound_nodes) - 1
        for i, tensor in enumerate(nest.flatten(outputs)):
            tensor._keras_history = KerasHistory(layer=layer,
                                                 node_index=node_index,
                                                 tensor_index=i)

        # Cached for performance.
        self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
        self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
Exemplo n.º 20
0
def get_tensor_from_node(node):
    """Resolves a saved model graph node into a tensor to be captured.

  Args:
    node: a tensor, variable, or resource to be resolved into a capturable
      tensor

  Returns:
    A list of tensors.
  Raises:
    ValueError: if the node cannot be converted into a tensor.
  """
    with ops.init_scope():
        # TODO(b/210144904): Use __tf_tensor__ instead of `is_[...]` checks
        if getattr(node, "is_distributed_variable", False):
            return node
        elif getattr(node, "is_distributed_table", False):
            return node
        elif getattr(node, "is_sharded_variable", False):
            return node
        elif resource_variable_ops.is_resource_variable(node):
            return node.handle
        elif isinstance(node, tracking.Asset):
            return node.asset_path
        elif tensor_util.is_tf_type(node):
            return node
        elif isinstance(node, tracking.CapturableResource):
            # Note: this executes restored functions in the CapturableResource.
            return node.resource_handle
        raise ValueError(f"Cannot convert node {node} to tensor.")
Exemplo n.º 21
0
def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
    r"""Formats a string template using a list of tensors.

  Formats a string template using a list of tensors, abbreviating tensors by
  only printing the first and last `summarize` elements of each dimension
  (recursively). If formatting only one tensor into a template, the tensor does
  not have to be wrapped in a list.

  Example:
    Formatting a single-tensor template:

    >>> tensor = tf.range(5)
    >>> tf.strings.format("tensor: {}, suffix", tensor)
    <tf.Tensor: shape=(), dtype=string, numpy=b'tensor: [0 1 2 3 4], suffix'>

    Formatting a multi-tensor template:

    >>> tensor_a = tf.range(2)
    >>> tensor_b = tf.range(1, 4, 2)
    >>> tf.strings.format("a: {}, b: {}, suffix", (tensor_a, tensor_b))
    <tf.Tensor: shape=(), dtype=string, numpy=b'a: [0 1], b: [1 3], suffix'>


  Args:
    template: A string template to format tensor values into.
    inputs: A list of `Tensor` objects, or a single Tensor.
      The list of tensors to format into the template string. If a solitary
      tensor is passed in, the input tensor will automatically be wrapped as a
      list.
    placeholder: An optional `string`. Defaults to `{}`.
      At each placeholder occurring in the template, a subsequent tensor
      will be inserted.
    summarize: An optional `int`. Defaults to `3`.
      When formatting the tensors, show the first and last `summarize`
      entries of each tensor dimension (recursively). If set to -1, all
      elements of the tensor will be shown.
    name: A name for the operation (optional).

  Returns:
    A scalar `Tensor` of type `string`.

  Raises:
    ValueError: if the number of placeholders does not match the number of
      inputs.
  """
    # If there is only one tensor to format, we will automatically wrap it in a
    # list to simplify the user experience
    if tensor_util.is_tf_type(inputs):
        inputs = [inputs]
    if template.count(placeholder) != len(inputs):
        raise ValueError(
            "%s placeholder(s) in template does not match %s tensor(s)"
            " provided as input" % (template.count(placeholder), len(inputs)))

    return gen_string_ops.string_format(inputs,
                                        template=template,
                                        placeholder=placeholder,
                                        summarize=summarize,
                                        name=name)
 def test_match_staging_level(self):
     some_tensor = constant_op.constant(0)
     tensor_one = special_functions.match_staging_level(1, some_tensor)
     python_one = special_functions.match_staging_level(1, 1)
     with self.cached_session() as sess:
         self.assertTrue(tensor_util.is_tf_type(tensor_one))
         self.assertAllEqual(self.evaluate(tensor_one), 1)
         self.assertEqual(python_one, 1)
 def _find_any_tensor(batch_features):
     tensors = [
         x for x in nest.flatten(batch_features)
         if tensor_util.is_tf_type(x)
     ]
     if not tensors:
         raise ValueError('Cannot find any Tensor in features dict.')
     return tensors[0]
Exemplo n.º 24
0
 def _preprocess_inputs(self, inputs):
     if isinstance(inputs, (tuple, list)):
         # If any of them is tensor or ndarray, then treat as list
         if any(
                 tensor_util.is_tf_type(inp) or isinstance(inp, np.ndarray)
                 for inp in inputs):
             return [self._preprocess_single_input(inp) for inp in inputs]
     return self._preprocess_single_input(inputs)
Exemplo n.º 25
0
    def _wrap_and_check_metrics(self, metrics):
        """Handle the saving of metrics.

    Metrics is either a tuple of (value, update_op), or a dict of such tuples.
    Here, we separate out the tuples and create a dict with names to tensors.

    Args:
      metrics: Dict of metric results keyed by name.
        The values of the dict can be one of the following:
        (1) instance of `Metric` class.
        (2) (metric_value, update_op) tuples, or a single tuple.
        metric_value must be a Tensor, and update_op must be a Tensor or Op.

    Returns:
      dict of output_names to tensors

    Raises:
      ValueError: if the dict key is not a string, or the metric values or ops
        are not tensors.
    """
        if not isinstance(metrics, dict):
            metrics = {self.METRICS_NAME: metrics}

        outputs = {}
        for key, value in metrics.items():
            if isinstance(value, tuple):
                metric_val, metric_op = value
            else:  # value is a keras.Metrics object
                metric_val = value.result()
                assert len(value.updates) == 1  # We expect only one update op.
                metric_op = value.updates[0]
            key = self._check_output_key(key, self.METRICS_NAME)
            key = self._prefix_key(key, self.METRICS_NAME)

            val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX
            op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX
            if not isinstance(metric_val, ops.Tensor):
                raise ValueError(
                    '{} output value must be a Tensor; got {}.'.format(
                        key, metric_val))
            if not (tensor_util.is_tf_type(metric_op)
                    or isinstance(metric_op, ops.Operation)):
                raise ValueError(
                    '{} update_op must be a Tensor or Operation; got {}.'.
                    format(key, metric_op))

            # We must wrap any ops (or variables) in a Tensor before export, as the
            # SignatureDef proto expects tensors only. See b/109740581
            metric_op_tensor = metric_op
            if not isinstance(metric_op, ops.Tensor):
                with ops.control_dependencies([metric_op]):
                    metric_op_tensor = constant_op.constant(
                        [], name='metric_op_wrapper')

            outputs[val_name] = metric_val
            outputs[op_name] = metric_op_tensor

        return outputs
Exemplo n.º 26
0
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
  """Overload of new_list that stages a Tensor list creation."""
  if tensor_util.is_tf_type(elements):
    if element_shape is not None:
      raise ValueError(
          'element shape may not be specified when creating list from tensor')
    element_shape = array_ops.shape(elements)[1:]
    l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
    return l

  elements = tuple(ops.convert_to_tensor(el) for el in elements)

  all_dtypes = set(el.dtype for el in elements)
  if len(all_dtypes) == 1:
    inferred_dtype = tuple(all_dtypes)[0]
    if element_dtype is not None and element_dtype != inferred_dtype:
      raise ValueError(
          'incompatible dtype; specified: {}, inferred from {}: {}'.format(
              element_dtype, elements, inferred_dtype))
  elif all_dtypes:
    # Heterogeneous lists are ok.
    if element_dtype is not None:
      raise ValueError(
          'specified dtype {} is inconsistent with that of elements {}'.format(
              element_dtype, elements))
    inferred_dtype = dtypes.variant
  else:
    inferred_dtype = dtypes.variant

  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
  if len(all_shapes) == 1:
    inferred_shape = array_ops.shape(elements[0])
    if element_shape is not None and element_shape != inferred_shape:
      raise ValueError(
          'incompatible shape; specified: {}, inferred from {}: {}'.format(
              element_shape, elements, inferred_shape))
  elif all_shapes:
    # Heterogeneous lists are ok.
    if element_shape is not None:
      raise ValueError(
          'specified shape {} is inconsistent with that of elements {}'.format(
              element_shape, elements))
    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention
  else:
    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention

  if element_dtype is None:
    element_dtype = inferred_dtype
  if element_shape is None:
    element_shape = inferred_shape

  element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
  l = list_ops.empty_tensor_list(
      element_shape=element_shape, element_dtype=element_dtype)
  for el in elements:
    l = list_ops.tensor_list_push_back(l, el)
  return l
Exemplo n.º 27
0
def len_(s):
  if tensors.is_tensor_array(s):
    return _tf_tensor_array_len(s)
  elif tensors.is_tensor_list(s):
    return _tf_tensor_list_len(s)
  elif tensor_util.is_tf_type(s):
    return _tf_tensor_len(s)
  if isinstance(s, dataset_ops.DatasetV2):
    return _tf_dataset_len(s)
  return _py_len(s)
Exemplo n.º 28
0
def _placeholder_value(like, original=None):
    if isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)):
        return original
    if isinstance(like, (int, float, bool)):
        return type(like)(0)
    if tensor_util.is_tf_type(like):
        return array_ops.zeros(like.shape, like.dtype)
    elif isinstance(like, (list, tuple, dict)):
        return nest.map_structure(_placeholder_value, like)
    return original
Exemplo n.º 29
0
    def __init__(self,
                 dtype,
                 reparameterization_type,
                 validate_args,
                 allow_nan_stats,
                 parameters=None,
                 graph_parents=None,
                 name=None):
        """Constructs the `Distribution`.

    **This is a private method for subclass use.**

    Args:
      dtype: The type of the event samples. `None` implies no type-enforcement.
      reparameterization_type: Instance of `ReparameterizationType`.
        If `distributions.FULLY_REPARAMETERIZED`, this
        `Distribution` can be reparameterized in terms of some standard
        distribution with a function whose Jacobian is constant for the support
        of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
        then no such reparameterization is available.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      parameters: Python `dict` of parameters used to instantiate this
        `Distribution`.
      graph_parents: Python `list` of graph prerequisites of this
        `Distribution`.
      name: Python `str` name prefixed to Ops created by this class. Default:
        subclass name.

    Raises:
      ValueError: if any member of graph_parents is `None` or not a `Tensor`.
    """
        graph_parents = [] if graph_parents is None else graph_parents
        for i, t in enumerate(graph_parents):
            if t is None or not tensor_util.is_tf_type(t):
                raise ValueError("Graph parent item %d is not a Tensor; %s." %
                                 (i, t))
        if not name or name[-1] != "/":  # `name` is not a name scope
            non_unique_name = name or type(self).__name__
            with ops.name_scope(non_unique_name) as name:
                pass
        self._dtype = dtype
        self._reparameterization_type = reparameterization_type
        self._allow_nan_stats = allow_nan_stats
        self._validate_args = validate_args
        self._parameters = parameters or {}
        self._graph_parents = graph_parents
        self._name = name
Exemplo n.º 30
0
def IsTrainable(tensor_or_dtype):
    """Determines whether a tensor or dtype supports infinitesimal changes."""
    if tensor_util.is_tf_type(tensor_or_dtype):
        dtype = _DTypeFromTensor(tensor_or_dtype)
    else:
        dtype = tensor_or_dtype
    dtype = dtypes.as_dtype(dtype)
    return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
                                dtypes.complex64, dtypes.complex128,
                                dtypes.resource, dtypes.variant,
                                dtypes.bfloat16)