Пример #1
0
def _check_same_dtype_and_shape(tensor, tensor_info, name):
  """Validate that tensor has the same properties as the TensorInfo proto.

  Args:
    tensor: a `Tensor` object.
    tensor_info: a `TensorInfo` proto.
    name: Name of the input (to identify Tensor if an error is raised).

  Raises:
    ValueError: If the tensor shape or dtype don't match the TensorInfo
  """
  dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype))
  shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape)

  if dtype_error or shape_error:
    msg = 'Tensor shape and/or dtype validation failed for input %s:' % name
    if dtype_error:
      msg += ('\n\tExpected dtype: %s, Got: %s'
              % (dtypes.DType(tensor_info.dtype), tensor.dtype))
    if shape_error:
      msg += ('\n\tExpected shape: %s, Got: %s'
              % (tensor_shape.TensorShape(tensor_info.tensor_shape),
                 tensor.shape))

    raise ValueError(msg)
  def __init__(self, method_name='runTest'):
    super(XLATestCase, self).__init__(method_name)
    self.device = FLAGS.test_device
    self.has_custom_call = (self.device == 'XLA_CPU')
    self.all_tf_types = [
        dtypes.DType(types_pb2.DataType.Value(name))
        for name in FLAGS.types.split(',')
    ]
    self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
    self.int_types = [
        dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
    ]
    self.float_types = [
        dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
    ]
    self.numeric_types = self.int_types + self.float_types

    # Parse the manifest file, if any, into a regex identifying tests to
    # disable
    self.disabled_regex = None
    if FLAGS.disabled_manifest is not None:
      comments_re = re.compile('#.*$')
      manifest_file = open(FLAGS.disabled_manifest, 'r')
      lines = manifest_file.read().splitlines()
      lines = [comments_re.sub('', l).strip() for l in lines]
      self.disabled_regex = re.compile('|'.join(lines))
      manifest_file.close()
Пример #3
0
    def __call__(self, *args):
        """Executes the passed function in eager mode."""
        tensor_inputs = [
            x for x in nest.flatten(args) if isinstance(x, ops.Tensor)
        ]
        if tape.should_record(tensor_inputs) or tape.should_record(
                self._extra_inputs):
            if not self._has_backprop:
                self._compute_backprop()
            return self._backprop_call(tensor_inputs)

        if context.in_graph_mode():
            g = ops.get_default_graph()
            if self._fdef.name not in g._functions:  # pylint: disable=protected-access
                g._add_function(self._fdef)  # pylint: disable=protected-access
            signature = self._fdef.definition.signature
            args = list(tensor_inputs) + self._extra_inputs
            op = g.create_op(
                signature.name, [ops.convert_to_tensor(x) for x in args],
                [dtypes.DType(x.type) for x in signature.output_arg],
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            result = op.outputs
            for i, s in enumerate(self._output_shapes):
                result[i].set_shape(s)
        else:
            result = execute.execute(str(self._func_name),
                                     num_outputs=self._num_outputs,
                                     inputs=tensor_inputs + self._extra_inputs)

        return self._build_call_outputs(self._returns, result)
Пример #4
0
def dense_shape_and_type(matrix):
    """Get dense shape and dtype of the tf.Tensor containing the matrix.

  Args:
    matrix: A `tf.Tensor` of type `tf.variant` storing a sparse matrix.

  Returns:
    An instance of `ShapeAndType` with properties `shape` (a `tf.TensorShape`)
    and `dtype` (a `tf.DType`).

  Raises:
    TypeError: if `matrix` is not a tensor or its dtype is not variant.
    ValueError: if `matrix` lacks static handle data containing the dense
      shape and dtype.
  """
    if not isinstance(matrix, ops.Tensor):
        raise TypeError("matrix should be a tensor, but saw: %s" % (matrix, ))
    if matrix.dtype != dtypes.variant:
        raise TypeError("expected matrix to be type tf.variant, but saw: %s" %
                        (matrix.dtype, ))
    handle_data = _get_handle_data(matrix)
    if not handle_data or not handle_data.is_set:
        raise ValueError("matrix has missing handle data: %s" % (matrix, ))
    if len(handle_data.shape_and_type) != 1:
        raise ValueError("len(matrix.handle_data.shape_and_type) != 1: '%s'" %
                         (handle_data.shape_and_type, ))
    return DenseShapeAndType(
        tensor_shape.TensorShape(handle_data.shape_and_type[0].shape),
        dtypes.DType(handle_data.shape_and_type[0].dtype))
Пример #5
0
 def _instantiate(self, input_types):
     """Instantiate this function given input argument types."""
     # Stringify the type list.
     key = _type_list_to_str(input_types)
     defined = self._overload.get(key)
     if not defined:
         # If not defined yet, define the function given the input types.
         name = self._func_name
         if name is not None:
             name = "_".join([name, key])
         defined = _DefinedFunction(self._func, self._argnames, input_types,
                                    name, None, self._python_grad_func,
                                    **self._extra_kwargs)
         if self._grad_func:
             # If _grad_func is given, it is another
             # _OverloadedFunction. We need to instantiate it with the
             # right input types.
             output_types = [
                 dtypes.DType(_.type)
                 for _ in defined.definition.signature.output_arg
             ]
             # pylint: disable=protected-access
             defined._grad_func = self._grad_func._instantiate(input_types +
                                                               output_types)
             # pylint: enable=protected-access
         self._overload[key] = defined
     return defined
Пример #6
0
 def testReduce(self):
     for enum in dtypes._TYPE_TO_STRING:
         dtype = dtypes.DType(enum)
         ctor, args = dtype.__reduce__()
         self.assertEquals(ctor, dtypes.as_dtype)
         self.assertEquals(args, (dtype.name, ))
         reconstructed = ctor(*args)
         self.assertEquals(reconstructed, dtype)
Пример #7
0
def call_function(func_def, *inputs, **kwargs):
    """Calls the function described by `func_def`.

  This adds a `call` op to the default graph that calls the function described
  by `func_def` with the tensors listed in `inputs` as arguments.  It returns
  the outputs of the call, which are one or more tensors.

  `func_def` is a
  [`FunctionDef`](
  https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/function.proto)
  protcol buffer describing a
  TensorFlow function.  See [`define_function()`](#define_function) for an
  easy way to create one from a Python function.

  You can pass an optional keyword parameters `name=string` to name the
  added operation.

  `func_def` is automatically added to the function library of the graph if
  needed.

  Args:
    func_def: A `FunctionDef` protocol buffer.
    *inputs: A list of tensors
    **kwargs: Optional keyword arguments.  Can only contain 'name'.

  Returns:
    A list of tensors representing the outputs of the call to `func_def`.

  Raises:
    ValueError: if the arguments are invalid.
  """
    name = kwargs.pop("name", None)
    if kwargs:
        raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
    func_name = func_def.signature.name
    with ops.op_scope(inputs, name, func_name) as name:
        if len(inputs) != len(func_def.signature.input_arg):
            raise ValueError("Expected number of arguments: %d" %
                             len(func_def.signature.input_arg))
        output_types = [
            dtypes.DType(x.type) for x in func_def.signature.output_arg
        ]
        g = ops.get_default_graph()
        g._add_function(func_def)  # pylint: disable=protected-access
        # TODO(touts): Pass compute_shapes as "try if function exists"
        op = g.create_op(func_name,
                         list(inputs),
                         output_types,
                         name=name,
                         compute_shapes=False)
        if op.outputs:
            if len(op.outputs) == 1:
                return op.outputs[0]
            else:
                return tuple(op.outputs)
        else:
            return op
Пример #8
0
def _call(sig, *inputs, **kwargs):
  """Adds a node calling a function.

  This adds a `call` op to the default graph that calls the function
  of signature `sig`, passing the tensors in `inputs` as arguments.
  It returns the outputs of the call, which are one or more tensors.

  `sig` is OpDefArg.a `_DefinedFunction` object.

  You can pass an optional keyword parameter `name=string` to name the
  added operation.

  You can pass an optional keyword parameter `noinline=True|False` to
  instruct the runtime not to inline the function body into the call
  site.

  Args:
    sig: OpDefArg. The signature of the function.
    *inputs: arguments to the function.
    **kwargs: Optional keyword arguments.  Can only contain 'name' or
        'noinline'.

  Returns:
     A 2-element tuple. First element: a Tensor if the function returns a single
     value; a list of Tensors if the function returns multiple value; the
     Operation if the function returns no values. Second element: the Operation.

  Raises:
    ValueError: if the arguments are invalid.
  """
  if len(inputs) != len(sig.input_arg):
    raise ValueError("Expected number of arguments: %d, received: %d" % (len(
        sig.input_arg), len(inputs)))
  name = kwargs.pop("name", None)
  g = ops.get_default_graph()
  func_name = sig.name
  if name is None:
    name = func_name
  attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
  output_types = [dtypes.DType(x.type) for x in sig.output_arg]
  op = g.create_op(
      func_name,
      list(inputs),
      output_types,
      name=name,
      attrs=attrs,
      op_def=sig,
      compute_shapes=False)
  if op.outputs:
    if len(op.outputs) == 1:
      ret = op.outputs[0]
    else:
      ret = tuple(op.outputs)
  else:
    ret = op
  return ret, op
Пример #9
0
 def testRepr(self):
     for enum, name in dtypes._TYPE_TO_STRING.items():
         if enum > 100:
             continue
         dtype = dtypes.DType(enum)
         self.assertEquals(repr(dtype), "tf." + name)
         import tensorflow as tf
         dtype2 = eval(repr(dtype))
         self.assertEquals(type(dtype2), dtypes.DType)
         self.assertEquals(dtype, dtype2)
Пример #10
0
    def _add_eightbit_prologue_nodes(self, original_node):
        namespace_prefix = original_node + "_eightbit"
        reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes(
            namespace_prefix,
            helper.node_name_from_input(
                self.node_name_mapping[original_node].node.input[0]))
        input_names = []
        min_max_names = []
        for each_input_name in self.node_name_mapping[
                original_node].node.input:
            if each_input_name[0] == '^':
                continue
            input_node_name = helper.node_name_from_input(each_input_name)
            if self.intel_cpu_eightbitize and input_node_name in self.output_node_maps:
                # dtype = dtypes.DType(
                #     self.output_node_maps[input_node_name].attr["T"].type
                # ) if self.output_node_maps[
                #     input_node_name].op == "Dequantize" else dtypes.quint8

                if self.output_node_maps[input_node_name].op == "Dequantize":
                    dtype = dtypes.DType(
                        self.output_node_maps[input_node_name].attr["T"].type)
                elif self._find_relu_node(
                        self.node_name_mapping[original_node].node):
                    dtype = dtypes.quint8
                elif self.node_name_mapping[original_node].node.op == "MatMul":
                    # mkl ops _MklQuantizedMatMulWithBiasAndRelu|AndRequantize
                    # requires the T1 data type as quint8
                    dtype = dtypes.quint8
                else:
                    dtype = dtypes.qint8
            else:
                dtype = dtypes.quint8 if self._find_relu_node(
                    self.node_name_mapping[original_node].node
                ) else dtypes.qint8

            quantize_input_name, min_input_name, max_input_name = (
                self._eightbitize_input_to_node(namespace_prefix,
                                                each_input_name,
                                                reshape_dims_name,
                                                reduction_dims_name,
                                                dtype=dtype))
            input_names.append(quantize_input_name)
            min_max_names.append(min_input_name)
            min_max_names.append(max_input_name)
        all_input_names = []
        all_input_names.extend(input_names)
        all_input_names.extend(min_max_names)

        for original_input_name in self.node_name_mapping[
                original_node].node.input:
            if original_input_name[0] == '^':
                all_input_names.append(original_input_name)
        return all_input_names
def _get_types_from_tensor_info_dict(tensor_info_dict):
  """Returns a map of keys to DType objects.

  Args:
    tensor_info_dict: map with TensorInfo proto as values.

  Returns:
    Map with corresponding DType objects as values.
  """
  return {
      key: dtypes.DType(tensor_info.dtype)
      for key, tensor_info in tensor_info_dict.items()
  }
Пример #12
0
    def _backprop_call(self, args):
        """Calls the wrapped function and records the result on a tape."""
        all_args = args + self._extra_inputs
        signature = self._forward_fdef.definition.signature
        if context.in_graph_mode():
            g = ops.get_default_graph()
            g._add_function(self._forward_fdef)  # pylint: disable=protected-access
            unwrapped_args = [ag_core.getval(x) for x in all_args]
            op = g.create_op(
                signature.name,
                [ops.convert_to_tensor(x) for x in unwrapped_args],
                [dtypes.DType(x.type) for x in signature.output_arg],
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            outputs = op.outputs
            outputs = [outputs] if isinstance(outputs,
                                              (tensor.Tensor, ops.Tensor,
                                               type(None))) else list(outputs)
            for i, s in enumerate(self._output_shapes):
                outputs[i].set_shape(s)
        else:
            outputs = execute.execute(str(signature.name),
                                      num_outputs=len(signature.output_arg),
                                      inputs=all_args)
        real_outputs = outputs[:len(self._returns)]
        side_outputs = outputs[len(self._returns):]
        watched_extra_inputs = []
        for t in self._extra_inputs:
            tid = ops.tensor_id(t)
            for t in tape._tape_stack.stack:  # pylint: disable=protected-access
                w = t.value.tensors.get(tid, None)
                if w is not None:
                    watched_extra_inputs.append(w)
                    break
            else:  # Note: for-else here done on purpose
                watched_extra_inputs.append(t)

        def backward_function_wrapper(*outputs):
            outputs = outputs[len(real_outputs):]
            return self._backward_function(*outputs)

        real_outputs = tape.record_operation(real_outputs,
                                             (args + watched_extra_inputs),
                                             side_outputs,
                                             backward_function_wrapper)

        return self._build_call_outputs(self._returns, real_outputs)
    def _readAndCheckGraphExecutionTracesFile(self, context_ids):
        """Read & verify the content of the .graph_execution_trace debug-event file.

    Args:
      context_ids: Op-creation context IDs from _readAndCheckGraphsFile().

    Returns:
      op_names: Names of the ops that are executed, as a `list` of `str`s.
      device_names: Names of the devices that the ops belong to, respectively.
        A `list` of `str`s of the same length as `op_name`s.
      output_slots: Output slots, as a `list` of `int`s, with the same length as
        `op_names`. In other words, for an executed op with N output tensors,
        there will be N entries in this `list` and in `op_names`, at
        corresponding indices.
      tensor_values: Tensor values or their concise summaries, depending on
        TensorDebugMode.
    """
        with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
            graph_execution_traces_iter = reader.graph_execution_traces_iterator(
            )
            op_names = []
            device_names = []
            output_slots = []
            tensor_values = []
            for debug_event in graph_execution_traces_iter:
                self.assertGreaterEqual(debug_event.wall_time, 0)
                graph_execution_trace = debug_event.graph_execution_trace
                op_names.append(graph_execution_trace.op_name)
                self.assertTrue(graph_execution_trace.device_name)
                device_names.append(graph_execution_trace.device_name)
                # All the ops in the graph have only one output.
                self.assertTrue(graph_execution_trace.tfdbg_context_id)
                self.assertIn(graph_execution_trace.tfdbg_context_id,
                              context_ids)
                output_slots.append(graph_execution_trace.output_slot)
                dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype)
                if (dtype.is_numpy_compatible
                        and dtype._type_enum != types_pb2.DT_STRING):  # pylint:disable=protected-access
                    # TODO(cais): Figure out how to properly convert string tensor proto
                    # to numpy representation.
                    tensor_values.append(
                        tensor_util.MakeNdarray(
                            graph_execution_trace.tensor_proto))
                else:
                    tensor_values.append(None)
            return op_names, device_names, output_slots, tensor_values
Пример #14
0
    def __call__(self, *args):
        """Executes the passed function in eager mode."""
        for v in self._variables:
            if v._trainable:  # pylint: disable=protected-access
                tape.watch_variable(v)

        tensor_inputs = [
            x for x in nest.flatten(args) if isinstance(x, ops.Tensor)
        ]
        if tape.should_record(tensor_inputs) or tape.should_record(
                self._extra_inputs):
            if not self._has_backprop:
                self._compute_backprop()
            return self._backprop_call(tensor_inputs)

        ctx = context.context()
        if ctx.in_graph_mode():
            g = ops.get_default_graph()
            if self._function_def.name not in g._functions:  # pylint: disable=protected-access
                g._add_function(self._function_def)  # pylint: disable=protected-access
            for f in self._graph._functions.values():  # pylint: disable=protected-access
                if f.name not in g._functions:  # pylint: disable=protected-access
                    g._add_function(f)  # pylint: disable=protected-access
            signature = self._function_def.definition.signature
            args = list(tensor_inputs) + self._extra_inputs
            op = g.create_op(
                signature.name,
                [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
                tuple(
                    dtypes_module.DType(x.type) for x in signature.output_arg),
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            result = op.outputs
            if not result:
                return op
            for i, s in enumerate(self._output_shapes):
                result[i].set_shape(s)
        else:
            result = execute.execute(str(self._func_name),
                                     num_outputs=self._num_outputs,
                                     inputs=tensor_inputs + self._extra_inputs,
                                     attrs=None,
                                     ctx=ctx)

        return self._build_call_outputs(result)
Пример #15
0
def _call_function(sig, *inputs, **kwargs):
    """Adds a node calling a function.

  Args:
    sig: The signature of the function.
    *inputs: arguments to the function.
    **kwargs: Optional keyword arguments.  Can only contain 'name' or
        'noinline'.

  Returns:
     A Tensor if the function returns a single value; a list of Tensors
     if the functio returns multiple value; the Operation if the function
     returns no values.

  Raises:
    ValueError: if the arguments are invalid.

  """
    name = kwargs.pop("name", None)
    noinline = kwargs.pop("noinline", None)
    if noinline is None:
        attrs = None
    else:
        attrs = {}
        attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
    if kwargs:
        raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
    g = ops.get_default_graph()
    func_name = sig.name
    output_types = [dtypes.DType(x.type) for x in sig.output_arg]
    with ops.name_scope(name, func_name, inputs) as name:
        op = g.create_op(func_name,
                         list(inputs),
                         output_types,
                         name=name,
                         attrs=attrs,
                         compute_shapes=False)
    setattr(op, "_sig", sig)
    if op.outputs:
        if len(op.outputs) == 1:
            return op.outputs[0]
        else:
            return tuple(op.outputs)
    else:
        return op
Пример #16
0
    def _backprop_call(self, args):
        """Calls the wrapped function and records the result on a tape."""
        all_args = args + self._extra_inputs
        signature = self._forward_fdef.signature
        ctx = context.context()
        if ctx.in_graph_mode():
            g = ops.get_default_graph()
            g._add_function(self._forward_fdef)  # pylint: disable=protected-access

            def make_tensor(x):
                if isinstance(x, ops.Tensor):
                    return x
                return ops.internal_convert_to_tensor(x, ctx=ctx)

            op = g.create_op(signature.name,
                             [make_tensor(x) for x in all_args],
                             tuple(
                                 dtypes_module.DType(x.type)
                                 for x in signature.output_arg),
                             op_def=signature,
                             name="FunctionCall",
                             compute_shapes=False)
            outputs = op.outputs
            outputs = [outputs] if isinstance(outputs,
                                              (ops.Tensor,
                                               type(None))) else list(outputs)
            for i, s in enumerate(self._output_shapes):
                outputs[i].set_shape(s)
        else:
            outputs = execute.execute(str(signature.name),
                                      num_outputs=len(signature.output_arg),
                                      inputs=all_args,
                                      attrs=None,
                                      ctx=ctx)
        real_outputs = outputs[:len(self._returns)]
        side_outputs = outputs[len(self._returns):]

        def backward_function(*args):
            return self._backward_function(*(list(args) + side_outputs))

        tape.record_operation(signature.name, real_outputs,
                              (args + self._extra_inputs), backward_function)

        return self._build_call_outputs(real_outputs)
Пример #17
0
  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    for v in self._variables:
      if v.trainable:
        tape.watch_variable(v)

    tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
    if tape.should_record(tensor_inputs) or tape.should_record(
        self._extra_inputs):
      if self._backward_function is None:
        self._construct_backprop_function()
      return self._backprop_call(tensor_inputs)

    ctx = context.context()
    if ctx.executing_eagerly():
      result = execute.execute(
          str(self._func_name),
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs,
          attrs=None,
          ctx=ctx)
    else:
      g = ops.get_default_graph()
      self.add_to_graph(g)
      signature = self._function_def.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name,
          [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
          tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      if not result:
        return op
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)

    return self._build_call_outputs(result)
Пример #18
0
 def testInvalid(self):
     with self.assertRaises(TypeError):
         dtypes.DType(types_pb2.DT_INVALID)
     with self.assertRaises(TypeError):
         dtypes.as_dtype(types_pb2.DT_INVALID)
Пример #19
0
 def testAllTypesConstructible(self):
     for datatype_enum in types_pb2.DataType.values():
         if datatype_enum == types_pb2.DT_INVALID:
             continue
         self.assertEqual(datatype_enum,
                          dtypes.DType(datatype_enum).as_datatype_enum)
Пример #20
0
def call_function(func, *inputs, **kwargs):
    """Calls the function described by `func`.

  This adds a `call` op to the default graph that calls the function described
  by `func` with the tensors listed in `inputs` as arguments.  It returns
  the outputs of the call, which are one or more tensors.

  `func` is a `_DefinedFunction` object. See
  [`define_function()`](#define_function) for an easy way to create
  one from a Python function.

  You can pass an optional keyword parameter `name=string` to name the
  added operation.

  You can pass an optional keyword parameter `noinline=True|False` to instruct
  the runtime not to inline the function body into the call site.

  `func` is automatically added to the function library of the graph if
  needed.

  Args:
    func: A `_DefinedFunction` object.
    *inputs: A list of tensors
    **kwargs: Optional keyword arguments.  Can only contain 'name' or
        'noinline'.

  Returns:
    A list of tensors representing the outputs of the call to `func`.

  Raises:
    ValueError: if the arguments are invalid.

  """
    func_def = func.definition
    name = kwargs.pop("name", None)
    noinline = kwargs.pop("noinline", None)
    if noinline is None:
        attrs = None
    else:
        attrs = {}
        attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
    if kwargs:
        raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
    func_name = func_def.signature.name
    with ops.name_scope(name, func_name, inputs) as name:
        if len(inputs) != len(func_def.signature.input_arg):
            raise ValueError("Expected number of arguments: %d, received: %d" %
                             (len(func_def.signature.input_arg), len(inputs)))
        output_types = [
            dtypes.DType(x.type) for x in func_def.signature.output_arg
        ]
        # TODO(touts): Pass compute_shapes as "try if function exists"
        g = ops.get_default_graph()
        op = g.create_op(func_name,
                         list(inputs),
                         output_types,
                         name=name,
                         attrs=attrs,
                         compute_shapes=False)
        if op.outputs:
            if len(op.outputs) == 1:
                return op.outputs[0]
            else:
                return tuple(op.outputs)
        else:
            return op
Пример #21
0
 def outputs_type_map(self):
   """Returns a map from tensor alias to tensor type."""
   return {alias: dtypes.DType(info.dtype)
           for alias, info in self._client.signature.outputs.iteritems()}
def get_variable_to_dtype_map(self):
  return {
      name: dtypes.DType(type_enum)
      for name, type_enum in self._GetVariableToDataTypeMap().items()  # pylint: disable=protected-access
  }
Пример #23
0
 def do_decode(self, value, decode_fn):
   del decode_fn
   return dtypes.DType(value.tensor_dtype_value)
Пример #24
0
 def testAsDtypeReturnsInternedVersion(self):
     dt = dtypes.DType(types_pb2.DT_VARIANT)
     self.assertIs(dtypes.as_dtype(dt), dtypes.variant)