Пример #1
0
def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
  """Fill in default values for grad_ys.

  Args:
    grad_ys: List of gradients, can contain None.
    ys: List of tensors.
    colocate_gradients_with_ops: If True, try colocating gradients with
      the corresponding op.

  Returns:
    A list of gradients to use, without None.

  Raises:
    ValueError: If one of the grad_ys is invalid.
  """
  if len(grad_ys) != len(ys):
    raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
  for i in xrange(len(grad_ys)):
    grad_y = grad_ys[i]
    y = ys[i]
    if grad_y is None:
      with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)):
        grad_ys[i] = array_ops.fill(array_ops.shape(y),
                                    constant_op.constant(1, dtype=y.dtype))
    else:
      if grad_y.dtype != y.dtype:
        raise ValueError("Y and ys_grad must be of the same type, "
                         "not y: %s, ys_grad: %s " %
                         (types.as_dtype(y.dtype).name,
                          types.as_dtype(grad_y.dtype).name))
  return grad_ys
Пример #2
0
def _ComputeGradient(x,
                     x_shape,
                     dx,
                     y,
                     y_shape,
                     dy,
                     x_init_value=None,
                     delta=1e-3):
    """Computes the theoretical and numerical jacobian."""
    t = types.as_dtype(x.dtype)
    allowed_types = [types.float32, types.float64]
    assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
    t2 = types.as_dtype(y.dtype)
    assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name

    if x_init_value is not None:
        i_shape = list(x_init_value.shape)
        assert (list(x_shape) == i_shape
                ), "x_shape = %s, init_data shape = %s" % (x_shape, i_shape)
        x_data = x_init_value
    else:
        if t == types.float32:
            dtype = np.float32
        else:
            dtype = np.float64
        x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype)

    jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx)
    jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta)
    return jacob_t, jacob_n
Пример #3
0
def _SatisfiesTypeConstraint(dtype, attr_def):
    if attr_def.HasField("allowed_values"):
        allowed_list = attr_def.allowed_values.list.type
        if dtype not in allowed_list:
            raise TypeError(
                "DataType %s for attr '%s' not in list of allowed values: %s" %
                (types_lib.as_dtype(dtype).name, attr_def.name, ", ".join(
                    types_lib.as_dtype(x).name for x in allowed_list)))
Пример #4
0
def _SatisfiesTypeConstraint(dtype, attr_def):
  if attr_def.HasField("allowed_values"):
    allowed_list = attr_def.allowed_values.list.type
    if dtype not in allowed_list:
      raise TypeError(
          "DataType %s for attr '%s' not in list of allowed values: %s" %
          (types_lib.as_dtype(dtype).name, attr_def.name,
           ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
Пример #5
0
 def testIsFloating(self):
   self.assertEqual(types.as_dtype("int8").is_floating, False)
   self.assertEqual(types.as_dtype("int16").is_floating, False)
   self.assertEqual(types.as_dtype("int32").is_floating, False)
   self.assertEqual(types.as_dtype("int64").is_floating, False)
   self.assertEqual(types.as_dtype("uint8").is_floating, False)
   self.assertEqual(types.as_dtype("complex64").is_floating, False)
   self.assertEqual(types.as_dtype("float32").is_floating, True)
   self.assertEqual(types.as_dtype("float64").is_floating, True)
   self.assertEqual(types.as_dtype("string").is_floating, False)
   self.assertEqual(types.as_dtype("bool").is_floating, False)
Пример #6
0
 def testIsInteger(self):
   self.assertEqual(types.as_dtype("int8").is_integer, True)
   self.assertEqual(types.as_dtype("int16").is_integer, True)
   self.assertEqual(types.as_dtype("int32").is_integer, True)
   self.assertEqual(types.as_dtype("int64").is_integer, True)
   self.assertEqual(types.as_dtype("uint8").is_integer, True)
   self.assertEqual(types.as_dtype("complex64").is_integer, False)
   self.assertEqual(types.as_dtype("float").is_integer, False)
   self.assertEqual(types.as_dtype("double").is_integer, False)
   self.assertEqual(types.as_dtype("string").is_integer, False)
   self.assertEqual(types.as_dtype("bool").is_integer, False)
Пример #7
0
 def testIsFloating(self):
     self.assertEqual(types.as_dtype("int8").is_floating, False)
     self.assertEqual(types.as_dtype("int16").is_floating, False)
     self.assertEqual(types.as_dtype("int32").is_floating, False)
     self.assertEqual(types.as_dtype("int64").is_floating, False)
     self.assertEqual(types.as_dtype("uint8").is_floating, False)
     self.assertEqual(types.as_dtype("complex64").is_floating, False)
     self.assertEqual(types.as_dtype("float32").is_floating, True)
     self.assertEqual(types.as_dtype("float64").is_floating, True)
     self.assertEqual(types.as_dtype("string").is_floating, False)
     self.assertEqual(types.as_dtype("bool").is_floating, False)
Пример #8
0
 def testIsInteger(self):
   self.assertEqual(types.as_dtype("int8").is_integer, True)
   self.assertEqual(types.as_dtype("int16").is_integer, True)
   self.assertEqual(types.as_dtype("int32").is_integer, True)
   self.assertEqual(types.as_dtype("int64").is_integer, True)
   self.assertEqual(types.as_dtype("uint8").is_integer, True)
   self.assertEqual(types.as_dtype("complex64").is_integer, False)
   self.assertEqual(types.as_dtype("float").is_integer, False)
   self.assertEqual(types.as_dtype("double").is_integer, False)
   self.assertEqual(types.as_dtype("string").is_integer, False)
   self.assertEqual(types.as_dtype("bool").is_integer, False)
Пример #9
0
 def testAllTypesConvertibleToNumpyDtype(self):
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     dtype = types.as_dtype(datatype_enum)
     numpy_dtype = dtype.as_numpy_dtype
     _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
     if dtype.base_dtype != types.bfloat16:
       # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
       self.assertEqual(
           types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
Пример #10
0
 def testAllTypesConvertibleToNumpyDtype(self):
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     dtype = types.as_dtype(datatype_enum)
     numpy_dtype = dtype.as_numpy_dtype
     _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
     if dtype.base_dtype != types.bfloat16:
       # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
       self.assertEqual(
           types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
Пример #11
0
def ones(shape, dtype=types.float32, name=None):
    """Creates a tensor with all elements set to 1.

  This operation returns a tensor of type `dtype` with shape `shape` and all
  elements set to 1.

  For example:

  ```python
  tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
  ```

  Args:
    shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
    dtype: The type of an element in the resulting `Tensor`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` with all elements set to 1.
  """
    with ops.op_scope([shape], name, "ones") as name:
        if isinstance(shape, list):
            output = constant(1, shape=shape, dtype=dtype, name=name)
        else:
            shape = ops.convert_to_tensor(shape, name="shape")
            output = fill(shape, constant(1, dtype=dtype), name=name)
    assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
    return output
Пример #12
0
def _restore_slice(file_pattern,
                   tensor_name,
                   shape_and_slice,
                   tensor_type,
                   name="restore_slice",
                   preferred_shard=-1):
    """Restore a tensor slice from a set of files with a given pattern.

  Example usage:
    RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)

  Args:
    file_pattern: the file pattern used to match a set of checkpoint files.
    tensor_name: the name of the tensor to restore.
    shape_and_slice: the shape-and-slice spec of the slice.
    tensor_type: the type of the tensor to restore.
    name: string.  Optional name for the op.
    preferred_shard: Int. Optional shard to open first in the checkpoint file.

  Returns:
    A tensor of type "tensor_type".
  """
    base_type = types.as_dtype(tensor_type).base_dtype
    return gen_io_ops._restore_slice(file_pattern,
                                     tensor_name,
                                     shape_and_slice,
                                     base_type,
                                     preferred_shard,
                                     name=name)
Пример #13
0
    def __init__(self, key_dtype, value_dtype, default_value, table_ref):
        """Construct a table object from a table reference.

    Args:
      key_dtype:  The table key type.
      value_dtype:  The table value type.
      default_value: The value to use if a key is missing in the table.
      table_ref: The table reference, i.e. the output of the lookup table ops.
    """
        self._key_dtype = types.as_dtype(key_dtype)
        self._value_dtype = types.as_dtype(value_dtype)
        self._shapes = [tensor_shape.TensorShape([1])]
        self._table_ref = table_ref
        self._name = self._table_ref.op.name.split("/")[-1]
        self._default_value = ops.convert_to_tensor(default_value, dtype=self._value_dtype)
        self._default_value.get_shape().merge_with(tensor_shape.scalar())
Пример #14
0
def ones(shape, dtype=types.float32, name=None):
    """Creates a tensor with all elements set to 1.

  This operation returns a tensor of type `dtype` with shape `shape` and all
  elements set to 1.

  For example:

  ```python
  tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
  ```

  Args:
    shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
    dtype: The type of an element in the resulting `Tensor`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` with all elements set to 1.
  """
    with ops.op_scope([shape], name, "ones") as name:
        if isinstance(shape, list):
            output = constant(1, shape=shape, dtype=dtype, name=name)
        else:
            shape = ops.convert_to_tensor(shape, name="shape")
            output = fill(shape, constant(1, dtype=dtype), name=name)
    assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
    return output
Пример #15
0
  def __init__(self, key_dtype, value_dtype, default_value, table_ref):
    """Construct a table object from a table reference.

    Args:
      key_dtype:  The table key type.
      value_dtype:  The table value type.
      default_value: The value to use if a key is missing in the table.
      table_ref: The table reference, i.e. the output of the lookup table ops.
    """
    self._key_dtype = types.as_dtype(key_dtype)
    self._value_dtype = types.as_dtype(value_dtype)
    self._shapes = [tensor_shape.TensorShape([1])]
    self._table_ref = table_ref
    self._name = self._table_ref.op.name.split("/")[-1]
    self._default_value = ops.convert_to_tensor(default_value,
                                                dtype=self._value_dtype)
    self._default_value.get_shape().merge_with(tensor_shape.scalar())
Пример #16
0
def _MakeType(v, attr_def):
    try:
        v = types_lib.as_dtype(v)
    except TypeError:
        raise TypeError("Expected DataType for argument '%s' not %s." %
                        (attr_def.name, repr(v)))
    i = v.as_datatype_enum
    _SatisfiesTypeConstraint(i, attr_def)
    return i
Пример #17
0
def _MakeType(v, attr_def):
  try:
    v = types_lib.as_dtype(v)
  except TypeError:
    raise TypeError("Expected DataType for argument '%s' not %s." %
                    (attr_def.name, repr(v)))
  i = v.as_datatype_enum
  _SatisfiesTypeConstraint(i, attr_def)
  return i
Пример #18
0
 def testDTypesHaveUniqueNames(self):
   dtypes = []
   names = set()
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     dtype = types.as_dtype(datatype_enum)
     dtypes.append(dtype)
     names.add(dtype.name)
   self.assertEqual(len(dtypes), len(names))
Пример #19
0
 def testDTypesHaveUniqueNames(self):
   dtypes = []
   names = set()
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     dtype = types.as_dtype(datatype_enum)
     dtypes.append(dtype)
     names.add(dtype.name)
   self.assertEqual(len(dtypes), len(names))
Пример #20
0
def _VerifyGeneratedGradients(grads, op):
  """Verify that gradients are valid in number and type.

  Args:
    grads: List of generated gradients.
    op: Operation for which the gradients where generated.

  Raises:
    ValueError: if the gradients are invalid.
  """
  if len(grads) != len(op.inputs):
    raise ValueError("Num gradients %d generated for op %s do not match num "
                     "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
  for i in xrange(len(grads)):
    grad = grads[i]
    inp = op.inputs[i]
    if grad is not None:
      if not grad.dtype.is_compatible_with(inp.dtype):
        raise ValueError(
            "Gradient type %s generated for op %s does "
            "not match input type %s" %
            (types.as_dtype(grad.dtype).name, op.node_def,
             types.as_dtype(inp.dtype).name))
Пример #21
0
def _ComputeGradient(x, x_shape, dx, y, y_shape, dy,
                     x_init_value=None, delta=1e-3):
  """Computes the theoretical and numerical jacobian."""
  t = types.as_dtype(x.dtype)
  allowed_types = [types.float32, types.float64]
  assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
  t2 = types.as_dtype(y.dtype)
  assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name

  if x_init_value is not None:
    i_shape = list(x_init_value.shape)
    assert(list(x_shape) == i_shape), "x_shape = %s, init_data shape = %s" % (
        x_shape, i_shape)
    x_data = x_init_value
  else:
    if t == types.float32:
      dtype = np.float32
    else:
      dtype = np.float64
    x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype)

  jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx)
  jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta)
  return jacob_t, jacob_n
Пример #22
0
  def testMinMax(self):
    # make sure min/max evaluates for all data types that have min/max
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      dtype = types.as_dtype(datatype_enum)
      numpy_dtype = dtype.as_numpy_dtype

      # ignore types for which there are no minimum/maximum (or we cannot
      # compute it, such as for the q* types)
      if (dtype.is_quantized or
          dtype.base_dtype == types.bool or
          dtype.base_dtype == types.string or
          dtype.base_dtype == types.complex64):
        continue

      print("%s: %s - %s" % (dtype, dtype.min, dtype.max))

      # check some values that are known
      if numpy_dtype == np.bool_:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 1)
      if numpy_dtype == np.int8:
        self.assertEquals(dtype.min, -128)
        self.assertEquals(dtype.max, 127)
      if numpy_dtype == np.int16:
        self.assertEquals(dtype.min, -32768)
        self.assertEquals(dtype.max, 32767)
      if numpy_dtype == np.int32:
        self.assertEquals(dtype.min, -2147483648)
        self.assertEquals(dtype.max, 2147483647)
      if numpy_dtype == np.int64:
        self.assertEquals(dtype.min, -9223372036854775808)
        self.assertEquals(dtype.max, 9223372036854775807)
      if numpy_dtype == np.uint8:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 255)
      if numpy_dtype == np.uint16:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 4294967295)
      if numpy_dtype == np.uint32:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 18446744073709551615)
      if numpy_dtype in (np.float16, np.float32, np.float64):
        self.assertEquals(dtype.min, np.finfo(numpy_dtype).min)
        self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
Пример #23
0
  def testMinMax(self):
    # make sure min/max evaluates for all data types that have min/max
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      dtype = types.as_dtype(datatype_enum)
      numpy_dtype = dtype.as_numpy_dtype

      # ignore types for which there are no minimum/maximum (or we cannot
      # compute it, such as for the q* types)
      if (dtype.is_quantized or
          dtype.base_dtype == types.bool or
          dtype.base_dtype == types.string or
          dtype.base_dtype == types.complex64):
        continue

      print "%s: %s - %s" % (dtype, dtype.min, dtype.max)

      # check some values that are known
      if numpy_dtype == np.bool_:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 1)
      if numpy_dtype == np.int8:
        self.assertEquals(dtype.min, -128)
        self.assertEquals(dtype.max, 127)
      if numpy_dtype == np.int16:
        self.assertEquals(dtype.min, -32768)
        self.assertEquals(dtype.max, 32767)
      if numpy_dtype == np.int32:
        self.assertEquals(dtype.min, -2147483648)
        self.assertEquals(dtype.max, 2147483647)
      if numpy_dtype == np.int64:
        self.assertEquals(dtype.min, -9223372036854775808)
        self.assertEquals(dtype.max, 9223372036854775807)
      if numpy_dtype == np.uint8:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 255)
      if numpy_dtype == np.uint16:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 4294967295)
      if numpy_dtype == np.uint32:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 18446744073709551615)
      if numpy_dtype in (np.float16, np.float32, np.float64):
        self.assertEquals(dtype.min, np.finfo(numpy_dtype).min)
        self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
Пример #24
0
  def __init__(self, op, value_index, dtype):
    """Creates a new `Tensor`.

    Args:
      op: An `Operation`. `Operation` that computes this tensor.
      value_index: An `int`. Index of the operation's endpoint that produces
        this tensor.
      dtype: A `types.DType`. Type of data stored in this tensor.

    Raises:
      TypeError: If the op is not an `Operation`.
    """
    if not isinstance(op, Operation):
      raise TypeError("op needs to be an Operation: %s" % op)
    self._op = op
    self._value_index = value_index
    self._dtype = types.as_dtype(dtype)
    self._shape = tensor_shape.unknown_shape()
    # List of operations that use this Tensor as input.  We maintain this list
    # to easily navigate a computation graph.
    self._consumers = []
Пример #25
0
def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
                   name="restore_slice", preferred_shard=-1):
  """Restore a tensor slice from a set of files with a given pattern.

  Example usage:
    RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)

  Args:
    file_pattern: the file pattern used to match a set of checkpoint files.
    tensor_name: the name of the tensor to restore.
    shape_and_slice: the shape-and-slice spec of the slice.
    tensor_type: the type of the tensor to restore.
    name: string.  Optional name for the op.
    preferred_shard: Int. Optional shard to open first in the checkpoint file.

  Returns:
    A tensor of type "tensor_type".
  """
  base_type = types.as_dtype(tensor_type).base_dtype
  return gen_io_ops._restore_slice(
      file_pattern, tensor_name, shape_and_slice, base_type,
      preferred_shard, name=name)
Пример #26
0
def MakeNdarray(tensor):
    """Create a numpy ndarray from a tensor.

  Create a numpy ndarray with the same shape and data as the tensor.

  Args:
    tensor: A TensorProto.

  Returns:
    A numpy array with the tensor contents.

  Raises:
    TypeError: if tensor has unsupported type.

  """
    shape = [d.size for d in tensor.tensor_shape.dim]
    num_elements = np.prod(shape)
    tensor_dtype = types.as_dtype(tensor.dtype)
    dtype = tensor_dtype.as_numpy_dtype

    if tensor.tensor_content:
        return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.float32:
        if len(tensor.float_val) == 1:
            return np.repeat(np.array(tensor.float_val[0], dtype=dtype),
                             num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.float64:
        if len(tensor.double_val) == 1:
            return np.repeat(np.array(tensor.double_val[0], dtype=dtype),
                             num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
    elif tensor_dtype in [
            types.int32, types.uint8, types.int16, types.int8, types.qint32,
            types.quint8, types.qint8, types.bfloat16
    ]:
        if len(tensor.int_val) == 1:
            return np.repeat(np.array(tensor.int_val[0], dtype=dtype),
                             num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.int64:
        if len(tensor.int64_val) == 1:
            return np.repeat(np.array(tensor.int64_val[0], dtype=dtype),
                             num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.string:
        if len(tensor.string_val) == 1:
            return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype),
                             num_elements).reshape(shape)
        else:
            return np.array([str(x) for x in tensor.string_val],
                            dtype=dtype).reshape(shape)
    elif tensor_dtype == types.complex64:
        it = iter(tensor.scomplex_val)
        if len(tensor.scomplex_val) == 2:
            return np.repeat(
                np.array(complex(tensor.scomplex_val[0],
                                 tensor.scomplex_val[1]),
                         dtype=dtype), num_elements).reshape(shape)
        else:
            return np.array([complex(x[0], x[1]) for x in zip(it, it)],
                            dtype=dtype).reshape(shape)
    elif tensor_dtype == types.bool:
        if len(tensor.bool_val) == 1:
            return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),
                             num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
    else:
        raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
Пример #27
0
def make_tensor_proto(values, dtype=None, shape=None):
    """Create a TensorProto.

  Args:
    values:    Values to put in the TensorProto.
    dtype:     Optional tensor_pb2 DataType value.
    shape:     List of integers representing the dimensions of tensor.

  Returns:
    A TensorProto. Depending on the type, it may contain data in the
    "tensor_content" attribute, which is not directly useful to Python programs.
    To access the values you should convert the proto back to a numpy ndarray
    with tensor_util.MakeNdarray(proto).

  Raises:
    TypeError:  if unsupported types are provided.
    ValueError: if arguments have inappropriate values.

  make_tensor_proto accepts "values" of a python scalar, a python list, a
  numpy ndarray, or a numpy scalar.

  If "values" is a python scalar or a python list, make_tensor_proto
  first convert it to numpy ndarray. If dtype is None, the
  conversion tries its best to infer the right numpy data
  type. Otherwise, the resulting numpy array has a compatible data
  type with the given dtype.

  In either case above, the numpy ndarray (either the caller provided
  or the auto converted) must have the compatible type with dtype.

  make_tensor_proto then converts the numpy array to a tensor proto.

  If "shape" is None, the resulting tensor proto represents the numpy
  array precisely.

  Otherwise, "shape" specifies the tensor's shape and the numpy array
  can not have more elements than what "shape" specifies.

  """
    if dtype:
        dtype = types.as_dtype(dtype)

    # We first convert value to a numpy array or scalar.
    if isinstance(values, (np.ndarray, np.generic)):
        if dtype:
            nparray = values.astype(dtype.as_numpy_dtype)
        else:
            nparray = values
    else:
        if values is None:
            raise ValueError("None values not supported.")
        # if dtype is provided, forces numpy array to be the type
        # provided if possible.
        np_dt = dtype.as_numpy_dtype if dtype else None
        if np.prod(shape) == 0:
            nparray = np.empty(shape, dtype=np_dt)
        else:
            _AssertCompatible(values, dtype)
            nparray = np.array(values, dtype=np_dt)
            if list(nparray.shape) != _GetDenseDimensions(values):
                raise ValueError("Argument must be a dense tensor: %s" %
                                 values)
        # python/numpy default float type is float64. We prefer float32 instead.
        if (nparray.dtype == np.float64) and dtype is None:
            nparray = nparray.astype(np.float32)
        # python/numpy default int type is int64. We prefer int32 instead.
        elif (nparray.dtype == np.int64) and dtype is None:
            nparray = nparray.astype(np.int32)

    # if dtype is provided, it must be compatible with what numpy
    # conversion says.
    numpy_dtype = types.as_dtype(nparray.dtype)
    if numpy_dtype is None:
        raise TypeError("Unrecognized data type: %s" % nparray.dtype)

    # If dtype was specified and is a quantized type, we convert
    # numpy_dtype back into the quantized version.
    if dtype in [types.qint8, types.quint8, types.qint32]:
        numpy_dtype = dtype

    if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype:
        raise TypeError("Incompatible types: %s vs. %s" %
                        (dtype, nparray.dtype))

    # If shape is not given, get the shape from the numpy array.
    if shape is None:
        shape = nparray.shape
        is_same_size = True
        shape_size = nparray.size
    else:
        shape = [int(dim) for dim in shape]
        shape_size = np.prod(shape)
        is_same_size = shape_size == nparray.size

        if nparray.size > shape_size:
            raise ValueError(
                "Too many elements provided. Needed at most %d, but received %d"
                % (shape_size, nparray.size))

    tensor_proto = tensor_pb2.TensorProto(
        dtype=numpy_dtype.as_datatype_enum,
        tensor_shape=MakeTensorShapeProto(shape))

    if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
        tensor_proto.tensor_content = nparray.tostring()
        return tensor_proto

    # If we were not given values as a numpy array, compute the proto_values
    # from the given values directly, to avoid numpy trimming nulls from the
    # strings. Since values could be a list of strings, or a multi-dimensional
    # list of lists that might or might not correspond to the given shape,
    # we flatten it conservatively.
    if numpy_dtype == types.string and not isinstance(values, np.ndarray):
        proto_values = _FlattenToStrings(values)
        tensor_proto.string_val.extend([str(x) for x in proto_values])
        return tensor_proto

    # TensorFlow expects C order (a.k.a., eigen row major).
    proto_values = nparray.ravel()

    append_fn = GetNumpyAppendFn(proto_values.dtype)
    if append_fn is None:
        raise TypeError("Element type not supported in TensorProto: %s" %
                        numpy_dtype.name)
    append_fn(tensor_proto, proto_values)

    return tensor_proto
Пример #28
0
 def testAllTypesConvertibleToDType(self):
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     self.assertEqual(
         datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum)
Пример #29
0
def convert_to_tensor(value, dtype=None, name=None):
  """Converts the given `value` to a `Tensor`.

  This function converts Python objects of various types to `Tensor`
  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
  and Python scalars. For example:

  ```python
  import numpy as np
  array = np.random.rand((32, 100, 100))

  def my_func(arg):
    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
    return tf.matmul(arg, arg) + arg

  # The following calls are equivalent.
  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
  ```

  This function can be useful when composing a new operation in Python
  (such as `my_func` in the example above). All standard Python op
  constructors apply this function to each of their Tensor-valued
  inputs, which allows those ops to accept numpy arrays, Python lists,
  and scalars in addition to `Tensor` objects.

  Args:
    value: An object whose type has a registered `Tensor` conversion function.
    dtype: Optional element type for the returned tensor. If missing, the
      type is inferred from the type of `value`.
    name: Optional name to use if a new `Tensor` is created.

  Returns:
    A `Tensor` based on `value`.

  Raises:
    TypeError: If no conversion function is registered for `value`.
    RuntimeError: If a registered conversion function returns an invalid value.

  """
  error_prefix = "" if name is None else "%s: " % name
  if dtype is not None:
    dtype = types.as_dtype(dtype)
  for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
    for base_type, conversion_func in funcs_at_priority:
      if isinstance(value, base_type):
        ret = conversion_func(value, dtype=dtype, name=name)
        if not isinstance(ret, Tensor):
          raise RuntimeError(
              "%sConversion function %r for type %s returned non-Tensor: %r"
              % (error_prefix, conversion_func, base_type, ret))
        if dtype and not dtype.is_compatible_with(ret.dtype):
          raise RuntimeError(
              "%sConversion function %r for type %s returned incompatible "
              "dtype: requested = %s, actual = %s"
              % (error_prefix, conversion_func, base_type,
                 dtype.name, ret.dtype.name))
        return ret
  raise TypeError("%sCannot convert %r with type %s to Tensor: "
                  "no conversion function registered."
                  % (error_prefix, value, type(value)))
Пример #30
0
def gradients(ys, xs, grad_ys=None, name="gradients",
              colocate_gradients_with_ops=False,
              gate_gradients=False,
              aggregation_method=None):
  """Constructs symbolic partial derivatives of `ys` w.r.t. x in `xs`.

  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
  is a list of `Tensor`, holding the gradients received by the
  `ys`. The list must be the same length as `ys`.

  `gradients()` adds ops to the graph to output the partial
  derivatives of `ys` with respect to `xs`.  It returns a list of
  `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
  for y in `ys`.

  `grad_ys` is a list of tensors of the same length as `ys` that holds
  the initial gradients for each y in `ys`.  When `grad_ys` is None,
  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
  user can provide their own initial 'grad_ys` to compute the
  derivatives using a different initial gradient for each y (e.g., if
  one wanted to weight the gradient differently for each value in
  each y).

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grad_ys: Optional. A `Tensor` or list of tensors the same size as
      `ys` and holding the gradients computed for each y in `ys`.
    name: Optional name to use for grouping all the gradient ops together.
      defaults to 'gradients'.
    colocate_gradients_with_ops: If True, try colocating gradients with
      the corresponding op.
    gate_gradients: If True, add a tuple around the gradients returned
      for an operations.  This avoids some race conditions.
    aggregation_method: Specifies the method used to combine gradient terms.
      Accepted values are constants defined in the class `AggregationMethod`.

  Returns:
    A list of `sum(dy/dx)` for each x in `xs`.

  Raises:
    LookupError: if one of the operations between `x` and `y` does not
      have a registered gradient function.
    ValueError: if the arguments are invalid.

  """
  ys = _AsList(ys)
  xs = _AsList(xs)
  if grad_ys is None:
    grad_ys = [None] * len(ys)
  else:
    grad_ys = _AsList(grad_ys)
  with ops.op_scope(ys + xs + grad_ys, name, "gradients"):
    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
    xs = ops.convert_n_to_tensor_or_indexed_slices(xs, name="x")
    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)

    # The approach we take here is as follows: Create a list of all ops in the
    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
    # to ensure that when we visit an op the gradients w.r.t its outputs have
    # been collected.  Then aggregate these gradients if needed, call the op's
    # gradient function, and add the generated gradients to the gradients for
    # its input.

    # Initialize the pending count for ops in the connected subgraph from ys
    # to the xs.
    to_ops = [t.op for t in ys]
    from_ops = [t.op for t in xs]
    pending_count, has_control_flow = _PendingCount(
        ops.get_default_graph(), to_ops, from_ops)

    # Iterate over the collected ops.
    #
    # grads: op => list of gradients received on each output endpoint of the
    # op.  The gradients for each endpoint are initially collected as a list.
    # When it is time to call the op's gradient function, for each endpoint we
    # aggregate the list of received gradients into a Add() Operation if there
    # is more than one.
    grads = {}

    # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

    # Initialize queue with to_ops.
    queue = collections.deque()
    # Add the ops in 'to_ops' into the queue.
    to_ops_set = set()
    for op in to_ops:
      if op._id not in to_ops_set:
        to_ops_set.add(op._id)
        queue.append(op)
    # The set of 'from_ops'.
    stop_ops = _StopOps(from_ops, pending_count)
    while queue:
      # generate gradient subgraph for op.
      op = queue.popleft()
      with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
        if has_control_flow:
          control_flow_ops.EnterGradWhileContext(op)
        out_grads = _AggregatedGrads(grads, op, has_control_flow,
                                     aggregation_method)
        grad_fn = None
        if any(out_grads) and op._id not in stop_ops:
          # A grad_fn must be defined, either as a function or as None
          # for ops that do not have gradients.
          try:
            grad_fn = ops.get_gradient_function(op)
          except LookupError:
            raise LookupError(
                "No gradient defined for operation '%s' (op type: %s)" %
                (op.name, op.type))
        if grad_fn and any(out_grads):
          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
          # output, it means that the cost does not depend on output[i],
          # therefore dC/doutput[i] is 0.
          for i, out_grad in enumerate(out_grads):
            if (not out_grad
                and types.as_dtype(op.outputs[i].dtype).base_dtype in (
                    types.float32, types.float64)):
              # Only floating-point outputs get a zero gradient. Gradient
              # functions should ignore the gradient for other outputs.
              out_grads[i] = array_ops.zeros_like(op.outputs[i])
          with ops.name_scope(op.name + "_grad"):
            # pylint: disable=protected-access
            with ops.get_default_graph()._original_op(op):
            # pylint: enable=protected-access
              op_wrapper = op
              if has_control_flow:
                op_wrapper = control_flow_ops.MakeWrapper(op)
              in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
              _VerifyGeneratedGradients(in_grads, op)
              if gate_gradients and len(in_grads) > 1:
                in_grads = control_flow_ops.tuple(in_grads)
          logging.vlog(1, "Gradient for '" + op.name + "'")
          logging.vlog(1, "  in  --> %s",
                       ", ".join([x.name for x in out_grads if x]))
          logging.vlog(1, "  out --> %s",
                       ", ".join([x.name for x in in_grads if x]))
        else:
          # If no grad_fn is defined or none of out_grads is available,
          # just propagates a list of None backwards.
          in_grads = [None] * len(op.inputs)
        for t_in, in_grad in zip(op.inputs, in_grads):
          if in_grad:
            _SetGrad(grads, t_in, in_grad)
        if has_control_flow:
          control_flow_ops.ExitGradWhileContext(op)

      # update pending count for the inputs of op.
      for x in op.inputs:
        pending_count[x.op._id] -= 1
        ready = (pending_count[x.op._id] == 0)
        if has_control_flow and not ready:
          ready = (pending_count[x.op._id] > 0 and
                   control_flow_ops.IsLoopSwitch(x.op))
        if ready:
          queue.append(x.op)
      for x in op.control_inputs:
        pending_count[x._id] -= 1
        if pending_count[x._id] is 0:
          queue.append(x)
  return [_GetGrad(grads, x) for x in xs]
Пример #31
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None):
    """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements'.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map' is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    # Type checks for inputs.
    if not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError('graph_def must be a GraphDef proto.')
    if input_map is None:
        input_map = {}
    else:
        if not (isinstance(input_map, dict) and all(
                isinstance(k, six.string_types) for k in input_map.keys())):
            raise TypeError(
                'input_map must be a dictionary mapping strings to '
                'Tensor objects.')
    if (return_elements is not None
            and not (isinstance(return_elements, (list, tuple)) and all(
                isinstance(x, six.string_types) for x in return_elements))):
        raise TypeError('return_elements must be a list of strings.')

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()

    name_to_op = {}

    if op_dict is None:
        op_dict = op_def_registry.get_registered_ops()

    with ops.op_scope(input_map.values(), name, 'import'):
        g = ops.get_default_graph()

        with ops.name_scope('_inputs'):
            input_map = {
                k: ops.convert_to_tensor(v)
                for k, v in input_map.items()
            }

        # NOTE(mrry): We do this in two passes, because there may be a cycle in
        # `graph_def'.

        # 1. Add operations without their inputs.
        for node in graph_def.node:
            output_types = _OutputTypes(node, op_dict)
            with _MaybeDevice(node.device):
                name_to_op[node.name] = g.create_op(node.op, [],
                                                    output_types,
                                                    name=node.name,
                                                    attrs=node.attr,
                                                    compute_shapes=False)

        # 2. Add inputs to the operations.
        for node in graph_def.node:
            op = name_to_op[node.name]
            input_types = _InputTypes(node, op_dict)

            # NOTE(mrry): We cannot use zip here because control inputs do not appear
            # in the list of input_types.
            for i, input_name in enumerate(
                [_CanonicalInputName(x) for x in node.input]):

                if _IsControlInput(input_name):
                    # (a) Input is a control input that should be taken from an op
                    #     in "graph_def".
                    try:
                        source_op = name_to_op[input_name[1:]]
                    except KeyError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'Control input %r not found in graph_def.' %
                                (input_name, )))
                    # pylint: disable=protected-access
                    op._add_control_input(source_op)
                    # pylint: enable=protected-access

                else:
                    try:
                        input_type = input_types[i]
                    except IndexError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'More inputs specified (%r) than the op expects.'
                                % (input_name, )))

                    if input_name in input_map:
                        # (b) Input should be replaced by a tensor from the caller.
                        source_tensor = input_map[input_name]
                        used_input_keys.add(input_name)

                    else:
                        # (c) Input should be taken from an op in `graph_def'.
                        operation_name, output_index = _ParseTensorName(
                            input_name)
                        try:
                            source_op = name_to_op[operation_name]
                            source_tensor = list(
                                source_op.values())[output_index]
                        except (KeyError, IndexError):
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r not found in graph_def.' %
                                    (input_name, )))

                    try:
                        # pylint: disable=protected-access
                        op._add_input(source_tensor, dtype=input_type)
                        # pylint: enable=protected-access
                    except TypeError as te:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node, 'Input tensor %r %s' %
                                (input_name, te.message)))

            # pylint: disable=protected_access
            if op._input_dtypes != input_types:
                raise ValueError(
                    _InvalidNodeMessage(
                        node, 'Input types mismatch (expected %r but got %r)' %
                        (", ".join(
                            types_lib.as_dtype(x).name
                            for x in input_types), ", ".join(
                                x.name for x in op._input_dtypes))))
            # pylint: enable=protected_access

            # Execute shape inference for this op.
            # NOTE(mrry): If the graph contains a cycle, the full shape information
            # may not be available for this op's inputs.
            ops.set_shapes_for_outputs(op)

        # Treat unused input mappings as an error, because they are likely to be
        # due to a typo.
        unused_input_keys = frozenset(
            input_map.keys()).difference(used_input_keys)
        if unused_input_keys:
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(unused_input_keys))

        if return_elements is None:
            return None
        else:
            ret = []
            for name in return_elements:
                if ':' in name:
                    try:
                        operation_name, output_index = _ParseTensorName(name)
                        ret.append(
                            name_to_op[operation_name].outputs[output_index])
                    except (ValueError, KeyError, IndexError):
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
                else:
                    try:
                        ret.append(name_to_op[name])
                    except KeyError:
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
            return ret
Пример #32
0
def make_tensor_proto(values, dtype=None, shape=None):
    """Create a TensorProto.

  Args:
    values:    Values to put in the TensorProto.
    dtype:     Optional tensor_pb2 DataType value.
    shape:     List of integers representing the dimensions of tensor.

  Returns:
    A TensorProto. Depending on the type, it may contain data in the
    "tensor_content" attribute, which is not directly useful to Python programs.
    To access the values you should convert the proto back to a numpy ndarray
    with tensor_util.MakeNdarray(proto).

  Raises:
    TypeError:  if unsupported types are provided.
    ValueError: if arguments have inappropriate values.

  make_tensor_proto accepts "values" of a python scalar, a python list, a
  numpy ndarray, or a numpy scalar.

  If "values" is a python scalar or a python list, make_tensor_proto
  first convert it to numpy ndarray. If dtype is None, the
  conversion tries its best to infer the right numpy data
  type. Otherwise, the resulting numpy array has a compatible data
  type with the given dtype.

  In either case above, the numpy ndarray (either the caller provided
  or the auto converted) must have the compatible type with dtype.

  make_tensor_proto then converts the numpy array to a tensor proto.

  If "shape" is None, the resulting tensor proto represents the numpy
  array precisely.

  Otherwise, "shape" specifies the tensor's shape and the numpy array
  can not have more elements than what "shape" specifies.

  """
    if dtype:
        dtype = types.as_dtype(dtype)

    # We first convert value to a numpy array or scalar.
    if isinstance(values, (np.ndarray, np.generic)):
        if dtype:
            nparray = values.astype(dtype.as_numpy_dtype)
        else:
            nparray = values
    else:
        if values is None:
            raise ValueError("None values not supported.")
        # if dtype is provided, forces numpy array to be the type
        # provided if possible.
        np_dt = dtype.as_numpy_dtype if dtype else None
        if np.prod(shape) == 0:
            nparray = np.empty(shape, dtype=np_dt)
        else:
            _AssertCompatible(values, dtype)
            nparray = np.array(values, dtype=np_dt)
            if list(nparray.shape) != _GetDenseDimensions(values):
                raise ValueError("Argument must be a dense tensor: %s" % values)
        # python/numpy default float type is float64. We prefer float32 instead.
        if (nparray.dtype == np.float64) and dtype is None:
            nparray = nparray.astype(np.float32)
        # python/numpy default int type is int64. We prefer int32 instead.
        elif (nparray.dtype == np.int64) and dtype is None:
            nparray = nparray.astype(np.int32)

    # if dtype is provided, it must be compatible with what numpy
    # conversion says.
    numpy_dtype = types.as_dtype(nparray.dtype)
    if numpy_dtype is None:
        raise TypeError("Unrecognized data type: %s" % nparray.dtype)

    # If dtype was specified and is a quantized type, we convert
    # numpy_dtype back into the quantized version.
    if dtype in [types.qint8, types.quint8, types.qint32]:
        numpy_dtype = dtype

    if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype:
        raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype))

    # If shape is not given, get the shape from the numpy array.
    if shape is None:
        shape = nparray.shape
        is_same_size = True
        shape_size = nparray.size
    else:
        shape = [int(dim) for dim in shape]
        shape_size = np.prod(shape)
        is_same_size = shape_size == nparray.size

        if nparray.size > shape_size:
            raise ValueError(
                "Too many elements provided. Needed at most %d, but received %d" % (shape_size, nparray.size)
            )

    tensor_proto = tensor_pb2.TensorProto(dtype=numpy_dtype.as_datatype_enum, tensor_shape=MakeTensorShapeProto(shape))

    if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
        tensor_proto.tensor_content = nparray.tostring()
        return tensor_proto

    # If we were not given values as a numpy array, compute the proto_values
    # from the given values directly, to avoid numpy trimming nulls from the
    # strings. Since values could be a list of strings, or a multi-dimensional
    # list of lists that might or might not correspond to the given shape,
    # we flatten it conservatively.
    if numpy_dtype == types.string and not isinstance(values, np.ndarray):
        proto_values = _FlattenToStrings(values)
        tensor_proto.string_val.extend([str(x) for x in proto_values])
        return tensor_proto

    # TensorFlow expects C order (a.k.a., eigen row major).
    proto_values = nparray.ravel()

    append_fn = GetNumpyAppendFn(proto_values.dtype)
    if append_fn is None:
        raise TypeError("Element type not supported in TensorProto: %s" % numpy_dtype.name)
    append_fn(tensor_proto, proto_values)

    return tensor_proto
Пример #33
0
  def get_variable(self, name, shape=None, dtype=types.float32,
                   initializer=None, reuse=None, trainable=True,
                   collections=None):
    """Gets an existing variable with these parameters or create a new one.

    If a variable with the given name is already stored, we return the stored
    variable. Otherwise, we create a new one.

    Set `reuse` to `True` when you only want to reuse existing Variables.
    Set `reuse` to `False` when you only want to create new Variables.
    If `reuse` is `None` (the default), both new and existing variables are
    returned.

    If initializer is `None` (the default), the default initializer passed in
    the constructor is used. If that one is `None` too, we use a new
    `UniformUnitScalingInitializer`.

    Args:
      name: the name of the new or existing variable.
      shape: shape of the new or existing variable.
      dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
      initializer: initializer for the variable.
      reuse: a Boolean or `None`. Controls reuse or creation of variables.
      trainable: If `True` also add the variable to the graph collection
        `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
      collections: List of graph collections keys to add the Variable to.
        Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).

    Returns:
      The created or existing variable.

    Raises:
      ValueError: when creating a new variable and shape is not declared,
        when reusing a variable and specifying a conflicting shape,
        or when violating reuse during variable creation.
    """
    should_check = reuse is not None
    dtype = types.as_dtype(dtype)
    shape = tensor_shape.as_shape(shape)
    if name in self._vars:
      # Here we handle the case when returning an existing variable.
      if should_check and not reuse:
        raise ValueError("Over-sharing: Variable %s already exists, disallowed."
                         " Did you mean to set reuse=True in VarScope?" % name)
      found_var = self._vars[name]
      if not shape.is_compatible_with(found_var.get_shape()):
        raise ValueError("Trying to share variable %s, but specified shape %s"
                         " and found shape %s." % (name, shape,
                                                   found_var.get_shape()))
      if not dtype.is_compatible_with(found_var.dtype):
        dtype_str = dtype.name
        found_type_str = found_var.dtype.name
        raise ValueError("Trying to share variable %s, but specified dtype %s"
                         " and found dtype %s." % (name, dtype_str,
                                                   found_type_str))
      return found_var

    # The code below handles only the case of creating a new variable.
    if should_check and reuse:
      raise ValueError("Under-sharing: Variable %s does not exist, disallowed."
                       " Did you mean to set reuse=None in VarScope?" % name)
    if not shape.is_fully_defined():
      raise ValueError("Shape of a new variable (%s) must be fully defined, "
                       "but instead was %s." % (name, shape))
    if initializer is None:
      initializer = init_ops.uniform_unit_scaling_initializer()
    with ops.name_scope(name + "/Initializer/"):
      init_val = initializer(shape.as_list(), dtype=dtype)
    v = variables.Variable(init_val, name=name, trainable=trainable,
                           collections=collections)
    self._vars[name] = v
    logging.info("Created variable %s with shape %s and init %s", v.name,
                 format(shape), initializer)
    return v
Пример #34
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None):
  """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements'.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map' is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, six.string_types) for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if (return_elements is not None
      and not (isinstance(return_elements, (list, tuple))
               and all(isinstance(x, six.string_types)
                       for x in return_elements))):
    raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  with ops.op_scope(input_map.values(), name, 'import'):
    g = ops.get_default_graph()

    with ops.name_scope('_inputs'):
      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def'.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      output_types = _OutputTypes(node, op_dict)
      with _MaybeDevice(node.device):
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def'.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(
                _InvalidNodeMessage(node, 'Input tensor %r %s'
                                    % (input_name, te.message)))

      # pylint: disable=protected_access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (", ".join(types_lib.as_dtype(x).name for x in input_types),
                   ", ".join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected_access

      # Execute shape inference for this op.
      # NOTE(mrry): If the graph contains a cycle, the full shape information
      # may not be available for this op's inputs.
      ops.set_shapes_for_outputs(op)

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Пример #35
0
def _SingleArgToTypes(node_def, arg_def):
  types = _ArgToTypesNoRef(node_def, arg_def)
  if arg_def.is_ref:
    return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types]
  return types
Пример #36
0
 def testStringConversion(self):
   self.assertIs(types.float32, types.as_dtype("float32"))
   self.assertIs(types.float64, types.as_dtype("float64"))
   self.assertIs(types.int32, types.as_dtype("int32"))
   self.assertIs(types.uint8, types.as_dtype("uint8"))
   self.assertIs(types.int16, types.as_dtype("int16"))
   self.assertIs(types.int8, types.as_dtype("int8"))
   self.assertIs(types.string, types.as_dtype("string"))
   self.assertIs(types.complex64, types.as_dtype("complex64"))
   self.assertIs(types.int64, types.as_dtype("int64"))
   self.assertIs(types.bool, types.as_dtype("bool"))
   self.assertIs(types.qint8, types.as_dtype("qint8"))
   self.assertIs(types.quint8, types.as_dtype("quint8"))
   self.assertIs(types.qint32, types.as_dtype("qint32"))
   self.assertIs(types.bfloat16, types.as_dtype("bfloat16"))
   self.assertIs(types.float32_ref, types.as_dtype("float32_ref"))
   self.assertIs(types.float64_ref, types.as_dtype("float64_ref"))
   self.assertIs(types.int32_ref, types.as_dtype("int32_ref"))
   self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref"))
   self.assertIs(types.int16_ref, types.as_dtype("int16_ref"))
   self.assertIs(types.int8_ref, types.as_dtype("int8_ref"))
   self.assertIs(types.string_ref, types.as_dtype("string_ref"))
   self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref"))
   self.assertIs(types.int64_ref, types.as_dtype("int64_ref"))
   self.assertIs(types.bool_ref, types.as_dtype("bool_ref"))
   self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref"))
   self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref"))
   self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref"))
   self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref"))
   with self.assertRaises(TypeError):
     types.as_dtype("not_a_type")
Пример #37
0
 def testNumpyConversion(self):
   self.assertIs(types.float32, types.as_dtype(np.float32))
   self.assertIs(types.float64, types.as_dtype(np.float64))
   self.assertIs(types.int32, types.as_dtype(np.int32))
   self.assertIs(types.int64, types.as_dtype(np.int64))
   self.assertIs(types.uint8, types.as_dtype(np.uint8))
   self.assertIs(types.int16, types.as_dtype(np.int16))
   self.assertIs(types.int8, types.as_dtype(np.int8))
   self.assertIs(types.complex64, types.as_dtype(np.complex64))
   self.assertIs(types.string, types.as_dtype(np.object))
   self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype))
   self.assertIs(types.bool, types.as_dtype(np.bool))
   with self.assertRaises(TypeError):
     types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
Пример #38
0
 def testInvalid(self):
   with self.assertRaises(TypeError):
     types.DType(types_pb2.DT_INVALID)
   with self.assertRaises(TypeError):
     types.as_dtype(types_pb2.DT_INVALID)
Пример #39
0
def _SingleArgToTypes(node_def, arg_def):
    types = _ArgToTypesNoRef(node_def, arg_def)
    if arg_def.is_ref:
        return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types]
    return types
Пример #40
0
  def get_variable(self, name, shape=None, dtype=types.float32,
                   initializer=None, reuse=None, trainable=True,
                   collections=None):
    """Gets an existing variable with these parameters or create a new one.

    If a variable with the given name is already stored, we return the stored
    variable. Otherwise, we create a new one.

    Set `reuse` to `True` when you only want to reuse existing Variables.
    Set `reuse` to `False` when you only want to create new Variables.
    If `reuse` is `None` (the default), both new and existing variables are
    returned.

    If initializer is `None` (the default), the default initializer passed in
    the constructor is used. If that one is `None` too, we use a new
    `UniformUnitScalingInitializer`.

    Args:
      name: the name of the new or existing variable.
      shape: shape of the new or existing variable.
      dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
      initializer: initializer for the variable.
      reuse: a Boolean or `None`. Controls reuse or creation of variables.
      trainable: If `True` also add the variable to the graph collection
        `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
      collections: List of graph collections keys to add the Variable to.
        Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).

    Returns:
      The created or existing variable.

    Raises:
      ValueError: when creating a new variable and shape is not declared,
        when reusing a variable and specifying a conflicting shape,
        or when violating reuse during variable creation.
    """
    should_check = reuse is not None
    dtype = types.as_dtype(dtype)
    shape = tensor_shape.as_shape(shape)
    if name in self._vars:
      # Here we handle the case when returning an existing variable.
      if should_check and not reuse:
        raise ValueError("Over-sharing: Variable %s already exists, disallowed."
                         " Did you mean to set reuse=True in VarScope?" % name)
      found_var = self._vars[name]
      if not shape.is_compatible_with(found_var.get_shape()):
        raise ValueError("Trying to share variable %s, but specified shape %s"
                         " and found shape %s." % (name, str(shape),
                                                   str(found_var.get_shape())))
      if not dtype.is_compatible_with(found_var.dtype):
        dtype_str = dtype.name
        found_type_str = found_var.dtype.name
        raise ValueError("Trying to share variable %s, but specified dtype %s"
                         " and found dtype %s." % (name, str(dtype_str),
                                                   str(found_type_str)))
      return found_var

    # The code below handles only the case of creating a new variable.
    if should_check and reuse:
      raise ValueError("Under-sharing: Variable %s does not exist, disallowed."
                       " Did you mean to set reuse=None in VarScope?" % name)
    if not shape.is_fully_defined():
      raise ValueError("Shape of a new variable (%s) must be fully defined, "
                       "but instead was %s." % (name, shape))
    if initializer is None:
      initializer = init_ops.uniform_unit_scaling_initializer()
    with ops.name_scope(name + "/Initializer/"):
      init_val = initializer(shape.as_list(), dtype=dtype)
    v = variables.Variable(init_val, name=name, trainable=trainable,
                           collections=collections)
    self._vars[name] = v
    logging.info("Created variable %s with shape %s and init %s", v.name,
                 format(shape), str(initializer))
    return v
Пример #41
0
 def testStringConversion(self):
   self.assertIs(types.float32, types.as_dtype("float32"))
   self.assertIs(types.float64, types.as_dtype("float64"))
   self.assertIs(types.int32, types.as_dtype("int32"))
   self.assertIs(types.uint8, types.as_dtype("uint8"))
   self.assertIs(types.int16, types.as_dtype("int16"))
   self.assertIs(types.int8, types.as_dtype("int8"))
   self.assertIs(types.string, types.as_dtype("string"))
   self.assertIs(types.complex64, types.as_dtype("complex64"))
   self.assertIs(types.int64, types.as_dtype("int64"))
   self.assertIs(types.bool, types.as_dtype("bool"))
   self.assertIs(types.qint8, types.as_dtype("qint8"))
   self.assertIs(types.quint8, types.as_dtype("quint8"))
   self.assertIs(types.qint32, types.as_dtype("qint32"))
   self.assertIs(types.bfloat16, types.as_dtype("bfloat16"))
   self.assertIs(types.float32_ref, types.as_dtype("float32_ref"))
   self.assertIs(types.float64_ref, types.as_dtype("float64_ref"))
   self.assertIs(types.int32_ref, types.as_dtype("int32_ref"))
   self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref"))
   self.assertIs(types.int16_ref, types.as_dtype("int16_ref"))
   self.assertIs(types.int8_ref, types.as_dtype("int8_ref"))
   self.assertIs(types.string_ref, types.as_dtype("string_ref"))
   self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref"))
   self.assertIs(types.int64_ref, types.as_dtype("int64_ref"))
   self.assertIs(types.bool_ref, types.as_dtype("bool_ref"))
   self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref"))
   self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref"))
   self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref"))
   self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref"))
   with self.assertRaises(TypeError):
     types.as_dtype("not_a_type")
Пример #42
0
def MakeNdarray(tensor):
    """Create a numpy ndarray from a tensor.

  Create a numpy ndarray with the same shape and data as the tensor.

  Args:
    tensor: A TensorProto.

  Returns:
    A numpy array with the tensor contents.

  Raises:
    TypeError: if tensor has unsupported type.

  """
    shape = [d.size for d in tensor.tensor_shape.dim]
    num_elements = np.prod(shape)
    tensor_dtype = types.as_dtype(tensor.dtype)
    dtype = tensor_dtype.as_numpy_dtype

    if tensor.tensor_content:
        return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.float32:
        if len(tensor.float_val) == 1:
            return np.repeat(np.array(tensor.float_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.float64:
        if len(tensor.double_val) == 1:
            return np.repeat(np.array(tensor.double_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
    elif tensor_dtype in [
        types.int32,
        types.uint8,
        types.int16,
        types.int8,
        types.qint32,
        types.quint8,
        types.qint8,
        types.bfloat16,
    ]:
        if len(tensor.int_val) == 1:
            return np.repeat(np.array(tensor.int_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.int64:
        if len(tensor.int64_val) == 1:
            return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.string:
        if len(tensor.string_val) == 1:
            return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype), num_elements).reshape(shape)
        else:
            return np.array([str(x) for x in tensor.string_val], dtype=dtype).reshape(shape)
    elif tensor_dtype == types.complex64:
        it = iter(tensor.scomplex_val)
        if len(tensor.scomplex_val) == 2:
            return np.repeat(
                np.array(complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype).reshape(shape)
    elif tensor_dtype == types.bool:
        if len(tensor.bool_val) == 1:
            return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
    else:
        raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
Пример #43
0
 def testNumpyConversion(self):
   self.assertIs(types.float32, types.as_dtype(np.float32))
   self.assertIs(types.float64, types.as_dtype(np.float64))
   self.assertIs(types.int32, types.as_dtype(np.int32))
   self.assertIs(types.int64, types.as_dtype(np.int64))
   self.assertIs(types.uint8, types.as_dtype(np.uint8))
   self.assertIs(types.int16, types.as_dtype(np.int16))
   self.assertIs(types.int8, types.as_dtype(np.int8))
   self.assertIs(types.complex64, types.as_dtype(np.complex64))
   self.assertIs(types.string, types.as_dtype(np.object))
   self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype))
   self.assertIs(types.bool, types.as_dtype(np.bool))
   with self.assertRaises(TypeError):
     types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
Пример #44
0
  def apply_op(self, op_type_name, g=None, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # If none of the inputs are Tensors and your session doesn't have a
       # default graph, you will have to specify the graph.
       op_def_library.apply_op("op", input1=input1, g=g)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      g: The graph context (optional)
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Need to specify g=graph to Op '%s' (could not determine graph due "
          "to: %s)" % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

          try:
            values = ops.convert_n_to_tensor_or_indexed_slices(
                values, name=input_arg.name,
                dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
          except (TypeError, ValueError):
            assert dtype is not None, "Should not fail if dtype is None"
            assert input_arg.number_attr, "Should be number_attr case"
            # What types does the conversion function think values have?
            values = ops.convert_n_to_tensor_or_indexed_slices(values)
            observed = ", ".join(v.dtype.base_dtype.name for v in values)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s that do not match expected type %s." %
                              (prefix, types_lib.as_dtype(dtype).name))
            elif input_arg.type_attr in attrs:
              raise TypeError("%s that do not match type %s inferred from "
                              "earlier arguments." %
                              (prefix, types_lib.as_dtype(dtype).name))
            else:
              raise TypeError("%s that don't all match." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]

          try:
            values = ops.convert_to_tensor(
                values, name=input_arg.name, dtype=dtype)
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, types_lib.as_dtype(input_arg.type).name))
            else:
              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(types_lib.as_dtype(x).name for x in attr_value),
                   ", ".join(types_lib.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, attr_value.s,
                   '", "'.join(attr_def.allowed_values.list.s)))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, x,
                     '", "'.join(attr_def.allowed_values.list.s)))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(t.list.type))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [types_lib.as_dtype(x).as_ref for x in types]
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # Add Op to graph
      if output_structure:
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
        outputs = op.outputs
        return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
                            output_structure)
      else:
        return g.create_op(op_type_name, inputs, output_types, name=scope,
                           input_types=input_types, attrs=attr_protos,
                           op_def=op_def)
Пример #45
0
 def testInvalid(self):
   with self.assertRaises(TypeError):
     types.DType(types_pb2.DT_INVALID)
   with self.assertRaises(TypeError):
     types.as_dtype(types_pb2.DT_INVALID)
Пример #46
0
    def apply_op(self, op_type_name, g=None, name=None, **keywords):
        # pylint: disable=g-doc-args
        """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # If none of the inputs are Tensors and your session doesn't have a
       # default graph, you will have to specify the graph.
       op_def_library.apply_op("op", input1=input1, g=g)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      g: The graph context (optional)
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
        try:
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()),
                                           graph=g)
            # pyline: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Need to specify g=graph to Op '%s' (could not determine graph due "
                "to: %s)" % (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                else:
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                        else:
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype
                                    break

                    try:
                        values = ops.convert_n_to_tensor_or_indexed_slices(
                            values,
                            name=input_arg.name,
                            dtype=types_lib.as_dtype(dtype).base_dtype
                            if dtype else None)
                    except (TypeError, ValueError):
                        assert dtype is not None, "Should not fail if dtype is None"
                        assert input_arg.number_attr, "Should be number_attr case"
                        # What types does the conversion function think values have?
                        values = ops.convert_n_to_tensor_or_indexed_slices(
                            values)
                        observed = ", ".join(v.dtype.base_dtype.name
                                             for v in values)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s that do not match expected type %s." %
                                (prefix, types_lib.as_dtype(dtype).name))
                        elif input_arg.type_attr in attrs:
                            raise TypeError(
                                "%s that do not match type %s inferred from "
                                "earlier arguments." %
                                (prefix, types_lib.as_dtype(dtype).name))
                        else:
                            raise TypeError("%s that don't all match." %
                                            prefix)

                    types = [x.dtype for x in values]
                    inputs.extend(values)
                else:
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]

                    try:
                        values = ops.convert_to_tensor(values,
                                                       name=input_arg.name,
                                                       dtype=dtype)
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                        observed = ops.convert_to_tensor(values).dtype.name
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, types_lib.as_dtype(
                                    input_arg.type).name))
                        else:
                            raise TypeError(
                                "%s type %s of argument '%s'." %
                                (prefix,
                                 types_lib.as_dtype(
                                     attrs[input_arg.type_attr]).name,
                                 inferred_from[input_arg.type_attr]))

                    types = [values.dtype]
                    inputs.append(values)
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                                 attrs[input_arg.number_attr],
                                 inferred_from[input_arg.number_attr]))
                    else:
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                                 num_attr.minimum))
                    # All tensors must have the same base type.
                    if any([bt != base_types[0] for bt in base_types]):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                                input_arg.type_attr]:
                            assert False, "Unreachable"
                    else:
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0], type_attr)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type, _Attr(op_def, input_arg.type_attr))
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    types_lib.as_dtype(x).name
                                    for x in attr_value),
                                 ", ".join(
                                     types_lib.as_dtype(x).name
                                     for x in attrs[input_arg.type_list_attr]),
                                 inferred_from[input_arg.type_list_attr]))
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type,
                                _Attr(op_def, input_arg.type_list_attr))
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                else:
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x.is_ref_dtype for x in types):
                        raise TypeError(
                            "Input '%s' of '%s' Op requires l-value input" %
                            (input_name, op_type_name))
                    input_types.extend(types)
                else:
                    input_types.extend(base_types)

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                    continue
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                else:
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_value.CopyFrom(attr_def.default_value)
                    attr_protos[key] = attr_value
                    continue
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                                 attr_def.minimum))
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                %
                                (key, op_type_name, attr_value.s, '", "'.join(
                                    attr_def.allowed_values.list.s)))
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    % (key, op_type_name, x, '", "'.join(
                                        attr_def.allowed_values.list.s)))
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                                   attr_def.minimum))
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                    attr_value.list.f.extend(
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                    attr_value.list.b.extend(
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                    attr_value.list.type.extend(
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                    attr_value.list.shape.extend(
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                    attr_value.list.tensor.extend(
                        [_MakeTensor(x, key) for x in value])
                else:
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_types = []
            output_structure = []
            for arg in op_def.output_arg:
                types = []
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                    if arg.type_attr:
                        types = [_AttrValue(attr_protos, arg.type_attr).type
                                 ] * n
                    else:
                        types = [arg.type] * n
                    output_structure.append(n)
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                    types = [t.type]
                    output_structure.append(None)
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)
                    types = t.list.type
                    output_structure.append(len(t.list.type))
                else:
                    types = [arg.type]
                    output_structure.append(None)
                if arg.is_ref:
                    types = [types_lib.as_dtype(x).as_ref for x in types]
                output_types.extend(types)

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # Add Op to graph
            if output_structure:
                op = g.create_op(op_type_name,
                                 inputs,
                                 output_types,
                                 name=scope,
                                 input_types=input_types,
                                 attrs=attr_protos,
                                 op_def=op_def)
                outputs = op.outputs
                return _Restructure(
                    ops.convert_n_to_tensor_or_indexed_slices(outputs),
                    output_structure)
            else:
                return g.create_op(op_type_name,
                                   inputs,
                                   output_types,
                                   name=scope,
                                   input_types=input_types,
                                   attrs=attr_protos,
                                   op_def=op_def)
Пример #47
0
 def testAllTypesConvertibleToDType(self):
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     self.assertEqual(
         datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum)