Ejemplo n.º 1
0
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
    Returns:
      Output tensor(s).
    """
    self._set_scope(kwargs.pop('scope', None))

    # Ensure the Layer, if being reused, is working with inputs from
    # the same graph as where it was created.
    try:
      ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
    except ValueError as e:
      raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    with vs.variable_scope(self._scope,
                           reuse=self.built or self._reuse) as scope:
      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          # Check input assumptions set before layer building, e.g. input rank.
          self._assert_input_compatibility(inputs)
          input_list = [
              ops.convert_to_tensor(x, name='input')
              for x in nest.flatten(inputs)]
          input_shapes = [x.get_shape() for x in input_list]
          if len(input_shapes) == 1:
            self.build(input_shapes[0])
          else:
            self.build(input_shapes)
        if 'scope' in tf_inspect.getargspec(self.call).args:
          kwargs['scope'] = scope
        # Check input assumptions set after layer building, e.g. input shape.
        self._assert_input_compatibility(inputs)
        outputs = self.call(inputs, *args, **kwargs)

        # Apply activity regularization.
        # Note that it should be applied every time the layer creates a new
        # output, since it is output-specific.
        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
          output_list = _to_list(outputs)
          for output in output_list:
            with ops.name_scope('ActivityRegularizer'):
              activity_regularization = self.activity_regularizer(output)
            self.add_loss(activity_regularization)
            _add_elements_to_collection(
                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)

    # Update global default collections.
    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    self.built = True
    return outputs
Ejemplo n.º 2
0
def get_graph_from_inputs(op_input_list, graph=None):
  """Returns the appropriate graph to use for the given inputs.

  1. If `graph` is provided, we validate that all inputs in `op_input_list` are
     from the same graph.
  2. Otherwise, we attempt to select a graph from the first Operation- or
     Tensor-valued input in `op_input_list`, and validate that all other
     such inputs are in the same graph.
  3. If the graph was not specified and it could not be inferred from
     `op_input_list`, we attempt to use the default graph.

  Args:
    op_input_list: A list of inputs to an operation, which may include `Tensor`,
      `Operation`, and other objects that may be converted to a graph element.
    graph: (Optional) The explicit graph to use.

  Raises:
    TypeError: If `op_input_list` is not a list or tuple, or if graph is not a
      Graph.
    ValueError: If a graph is explicitly passed and not all inputs are from it,
      or if the inputs are from multiple graphs, or we could not find a graph
      and there was no default graph.

  Returns:
    The appropriate graph to use for the given inputs.
  """
  # pylint: disable=protected-access
  return ops._get_graph_from_inputs(op_input_list, graph)
Ejemplo n.º 3
0
def variable_op_scope(values, name, default_name, initializer=None,
                      regularizer=None):
  """Returns a context manager for defining an op that creates variables.

  This context manager validates that the given `values` are from the
  same graph, ensures that that graph is the default graph, and pushes a
  name scope and a variable scope.

  If `name` is not None, it is used as is in the variable scope. If `name`
  is None, then `default_name` is used.  In that case, if the same name has been
  previously used in the same scope, it will made unique be appending `_N` to
  it.

  This is intended to be used when defining generic ops and so reuse is always
  inherited.

  For example, to define a new Python op called `my_op_with_vars`:

  ```python
  def my_op_with_vars(a, b, name=None):
    with tf.variable_op_scope([a, b], name, "MyOp") as scope:
      a = tf.convert_to_tensor(a, name="a")
      b = tf.convert_to_tensor(b, name="b")
      c = tf.get_variable('c')
      # Define some computation that uses `a`, `b`, and `c`.
      return foo_op(..., name=scope)
  ```

  Args:
    values: The list of `Tensor` arguments that are passed to the op function.
    name: The name argument that is passed to the op function, this name is not
      uniquified in the variable scope.
    default_name: The default name to use if the `name` argument is `None`, this
      name will be uniquified.
    initializer: A default initializer to pass to variable scope.
    regularizer: default regularizer for variables within this scope.

  Returns:
    A context manager for use in defining a Python op.

  Raises:
    ValueError: when trying to reuse within a create scope, or create within
      a reuse scope, or if reuse is not `None` or `True`.
    TypeError: when the types of some arguments are not appropriate.
  """
  if default_name is None:
    raise TypeError("default_name cannot be None")
  g = ops._get_graph_from_inputs(values)  # pylint: disable=protected-access
  with g.as_default():
    if name:
      with variable_scope(name, initializer=initializer,
                          regularizer=regularizer) as vs:
        yield vs
    else:
      with ops.name_scope(default_name) as scope:
        count = len(default_name.split("/"))
        scoped_name = "/".join(scope.split("/")[-count - 1:-1])
        with _pure_variable_scope(scoped_name, initializer=initializer,
                                  regularizer=regularizer) as vs:
          yield vs
Ejemplo n.º 4
0
def graph_zeros_like(tensor):
  """Graph-only version of tf.zeros_like(), for internal use only."""
  g = ops._get_graph_from_inputs([tensor])  # pylint: disable=protected-access
  with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
    tensor = ops.convert_to_tensor(tensor, name="tensor")
    dtype = tensor.dtype.base_dtype
    dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum)
    op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
                     attrs={"T": dtype_value}, name=name)
  result, = op.outputs
  return result
Ejemplo n.º 5
0
  def _distributed_apply(self, distribution, grads_and_vars, name):
    """`apply_gradients` using a `DistributionStrategy`."""
    reduced_grads = distribution.extended.batch_reduce_to(
        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
    var_list = [v for _, v in grads_and_vars]
    grads_and_vars = zip(reduced_grads, var_list)

    def apply_grad_to_update_var(var, grad):
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
      if isinstance(grad, ops.IndexedSlices):
        if var.constraint is not None:
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
        return self._resource_apply_sparse_duplicate_indices(
            grad.values, var, grad.indices)
      update_op = self._resource_apply_dense(grad, var)
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
      else:
        return update_op

    update_ops = []
    with backend.name_scope(name or self._name):
      for grad, var in grads_and_vars:
        scope_name = ("" if ops.executing_eagerly_outside_functions() else
                      "_" + var.op.name)
        with backend.name_scope("update" + scope_name):
          update_ops.extend(
              distribution.extended.update(
                  var, apply_grad_to_update_var, args=(grad,), group=False))

      any_symbolic = any(isinstance(i, ops.Operation) or
                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
      if not context.executing_eagerly() or any_symbolic:
        # If the current context is graph mode or any of the update ops are
        # symbolic then the step update should be carried out under a graph
        # context. (eager updates execute immediately)
        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
          with ops.control_dependencies(update_ops):
            return self._iterations.assign_add(1).op

      return self._iterations.assign_add(1)
Ejemplo n.º 6
0
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**, the kwarg 'scope' is reserved for use by the Layer.

    Returns:
      Output tensor(s).
    """
    scope = kwargs.pop('scope', None)

    # Define a custom getter to override tf.get_variable when creating layer
    # variables. The current custom getter is nested by the variable scope.
    def variable_getter(getter, name, shape, dtype=None, initializer=None,
                        regularizer=None, trainable=True, **getter_kwargs):
      return self._add_variable(
          name, shape, initializer=initializer, regularizer=regularizer,
          dtype=dtype, trainable=trainable,
          variable_getter=functools.partial(getter, **getter_kwargs))

    if not self._built and self._scope is None:
      # If constructed with _scope=None, lazy setting of scope.
      if self._reuse:
        self._scope = next(vs.variable_scope(
            scope if scope is not None else self._base_name).gen)
      else:
        self._scope = next(vs.variable_scope(
            scope, default_name=self._base_name).gen)
      self._name = self._scope.name

    # Build (if necessary) and call the layer, inside a variable
    # scope.
    with vs.variable_scope(self._scope,
                           reuse=True if self._built else self._reuse,
                           custom_getter=variable_getter) as scope:
      # Ensure the Layer, if being reused, is working with inputs from
      # the same graph as where it was created.
      try:
        ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
      except ValueError as e:
        raise ValueError("Inputs' and Layer's graphs are not the same: %s" % e)

      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          input_list = [
              ops.convert_to_tensor(x, name='input')
              for x in nest.flatten(inputs)]
          input_shapes = [x.get_shape() for x in input_list]
          if len(input_shapes) == 1:
            self.build(input_shapes[0])
          else:
            self.build(input_shapes)
          self._built = True
        outputs = self.call(inputs, *args, **kwargs)

        # Apply activity regularization.
        # Note that it should be applied every time the layer creates a new
        # output, since it is output-specific.
        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
          output_list = _to_list(outputs)
          for output in output_list:
            with ops.name_scope('ActivityRegularizer'):
              activity_regularization = self.activity_regularizer(output)
            self._losses.append(activity_regularization)
            _add_elements_to_collection(
                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)

    # Update global default collections.
    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
Ejemplo n.º 7
0
  def apply_op(self, op_type_name, g=None, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # If none of the inputs are Tensors and your session doesn't have a
       # default graph, you will have to specify the graph.
       op_def_library.apply_op("op", input1=input1, g=g)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      g: The graph context (optional)
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Need to specify g=graph to Op '%s' (could not determine graph due "
          "to: %s)" % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

          try:
            values = ops.convert_n_to_tensor_or_indexed_slices(
                values, name=input_arg.name,
                dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
          except (TypeError, ValueError):
            assert dtype is not None, "Should not fail if dtype is None"
            assert input_arg.number_attr, "Should be number_attr case"
            # What types does the conversion function think values have?
            values = ops.convert_n_to_tensor_or_indexed_slices(values)
            observed = ", ".join(v.dtype.base_dtype.name for v in values)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s that do not match expected type %s." %
                              (prefix, types_lib.as_dtype(dtype).name))
            elif input_arg.type_attr in attrs:
              raise TypeError("%s that do not match type %s inferred from "
                              "earlier arguments." %
                              (prefix, types_lib.as_dtype(dtype).name))
            else:
              raise TypeError("%s that don't all match." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]

          try:
            values = ops.convert_to_tensor(
                values, name=input_arg.name, dtype=dtype)
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, types_lib.as_dtype(input_arg.type).name))
            else:
              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(types_lib.as_dtype(x).name for x in attr_value),
                   ", ".join(types_lib.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, attr_value.s,
                   '", "'.join(attr_def.allowed_values.list.s)))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, x,
                     '", "'.join(attr_def.allowed_values.list.s)))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(t.list.type))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [types_lib.as_dtype(x).as_ref for x in types]
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # Add Op to graph
      if output_structure:
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
        outputs = op.outputs
        return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
                            output_structure)
      else:
        return g.create_op(op_type_name, inputs, output_types, name=scope,
                           input_types=input_types, attrs=attr_protos,
                           op_def=op_def)
Ejemplo n.º 8
0
    def apply_op(self, op_type_name, name=None, **keywords):
        # pylint: disable=g-doc-args
        """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
        try:
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
            # pyline: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Cannot determine graph for Op '%s' due to: %s" %
                (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                else:
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                        else:
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype
                                    break

                    try:
                        if not input_arg.is_ref and dtype:
                            dtype = dtypes.as_dtype(dtype).base_dtype
                        values = ops.convert_n_to_tensor_or_indexed_slices(
                            values,
                            name=input_arg.name,
                            dtype=dtype if dtype else None,
                            as_ref=input_arg.is_ref)
                    except (TypeError, ValueError):
                        assert dtype is not None, "Should not fail if dtype is None"
                        assert input_arg.number_attr, "Should be number_attr case"
                        # What types does the conversion function think values have?
                        values = ops.convert_n_to_tensor_or_indexed_slices(
                            values, as_ref=input_arg.is_ref)
                        observed = ", ".join(v.dtype.base_dtype.name
                                             for v in values)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s that do not match expected type %s." %
                                (prefix, dtype.name))
                        elif input_arg.type_attr in attrs:
                            raise TypeError(
                                "%s that do not match type %s inferred from "
                                "earlier arguments." % (prefix, dtype.name))
                        else:
                            raise TypeError("%s that don't all match." %
                                            prefix)

                    types = [x.dtype for x in values]
                    inputs.extend(values)
                else:
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    try:
                        values = ops.convert_to_tensor(values,
                                                       name=input_arg.name,
                                                       dtype=dtype,
                                                       as_ref=input_arg.is_ref)
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                        observed = ops.convert_to_tensor(
                            values, as_ref=input_arg.is_ref).dtype.name
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, dtypes.as_dtype(input_arg.type).name))
                        else:
                            raise TypeError(
                                "%s type %s of argument '%s'." %
                                (prefix,
                                 dtypes.as_dtype(
                                     attrs[input_arg.type_attr]).name,
                                 inferred_from[input_arg.type_attr]))

                    types = [values.dtype]
                    inputs.append(values)
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                                 attrs[input_arg.number_attr],
                                 inferred_from[input_arg.number_attr]))
                    else:
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                                 num_attr.minimum))
                    # All tensors must have the same base type.
                    if any([bt != base_types[0] for bt in base_types]):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                                input_arg.type_attr]:
                            assert False, "Unreachable"
                    else:
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0], type_attr)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type, _Attr(op_def, input_arg.type_attr))
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attr_value),
                                 ", ".join(
                                     dtypes.as_dtype(x).name
                                     for x in attrs[input_arg.type_list_attr]),
                                 inferred_from[input_arg.type_list_attr]))
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type,
                                _Attr(op_def, input_arg.type_list_attr))
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                else:
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x.is_ref_dtype for x in types):
                        raise TypeError(
                            "Input '%s' of '%s' Op requires l-value input" %
                            (input_name, op_type_name))
                    input_types.extend(types)
                else:
                    input_types.extend(base_types)

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                    continue
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                else:
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_value.CopyFrom(attr_def.default_value)
                    attr_protos[key] = attr_value
                    continue
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                                 attr_def.minimum))
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name,
                                   compat.as_text(attr_value.s), '", "'.join(
                                       map(compat.as_text,
                                           attr_def.allowed_values.list.s))))
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    %
                                    (key, op_type_name, compat.as_text(x),
                                     '", "'.join(
                                         map(compat.as_text,
                                             attr_def.allowed_values.list.s))))
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                                   attr_def.minimum))
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                    attr_value.list.f.extend(
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                    attr_value.list.b.extend(
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                    attr_value.list.type.extend(
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                    attr_value.list.shape.extend(
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                    attr_value.list.tensor.extend(
                        [_MakeTensor(x, key) for x in value])
                elif attr_def.type == "func":
                    if not isinstance(value, compat.bytes_or_text_types):
                        raise TypeError("Expects a string for the func name")
                    attr_value.func.name = value
                else:
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_types = []
            output_structure = []
            for arg in op_def.output_arg:
                types = []
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                    if arg.type_attr:
                        types = [_AttrValue(attr_protos, arg.type_attr).type
                                 ] * n
                    else:
                        types = [arg.type] * n
                    output_structure.append(n)
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                    types = [t.type]
                    output_structure.append(None)
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)
                    types = t.list.type
                    output_structure.append(len(t.list.type))
                else:
                    types = [arg.type]
                    output_structure.append(None)
                if arg.is_ref:
                    types = [dtypes.as_dtype(x).as_ref for x in types]
                output_types.extend(types)

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # Add Op to graph
            if output_structure:
                op = g.create_op(op_type_name,
                                 inputs,
                                 output_types,
                                 name=scope,
                                 input_types=input_types,
                                 attrs=attr_protos,
                                 op_def=op_def)
                outputs = op.outputs
                return _Restructure(
                    ops.convert_n_to_tensor_or_indexed_slices(outputs),
                    output_structure)
            else:
                return g.create_op(op_type_name,
                                   inputs,
                                   output_types,
                                   name=scope,
                                   input_types=input_types,
                                   attrs=attr_protos,
                                   op_def=op_def)
Ejemplo n.º 9
0
    def _apply_op_helper(self, op_type_name, name=None, **keywords):
        """Implementation of apply_op that returns output_structure, op."""
        #    print("KST_TEST_ophelper")
        #    print(op_type_name)
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
        try:
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
            # pylint: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Cannot determine graph for Op '%s' due to: %s" %
                (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Check for deprecation
        deprecation_version = op_def.deprecation.version
        if deprecation_version:
            producer = g.graph_def_versions.producer
            if producer >= deprecation_version:
                raise NotImplementedError(
                    ("Op %s is not available in GraphDef version %d. "
                     "It has been removed in version %d. %s.") %
                    (op_type_name, producer, deprecation_version,
                     op_def.deprecation.explanation))

        # Fill in the list of default types for all "type" attrs.  This
        # will be used to choose a preferred dtype to convert to in the
        # absence of input type information.
        #
        # TODO(b/31302892): Currently the defaults don't work in the right
        # way if you have two inputs, one of whose type resolution depends
        # on the other.  Handling this will require restructuring this code
        # significantly.
        default_type_attr_map = {}
        for attr_def in op_def.attr:
            if attr_def.type != "type":
                continue
            key = attr_def.name
            if attr_def.HasField("default_value"):
                default_type_attr_map[key] = dtypes.as_dtype(
                    attr_def.default_value.type)

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                else:
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                        else:
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype
                                    break

                        # dtype still not found, prefer using the default dtype
                        # from the attr.
                        if dtype is None and input_arg.type_attr in default_type_attr_map:
                            default_dtype = default_type_attr_map[
                                input_arg.type_attr]

                    try:
                        if not input_arg.is_ref and dtype:
                            dtype = dtypes.as_dtype(dtype).base_dtype
                        values = ops.internal_convert_n_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype if dtype else None,
                            preferred_dtype=default_dtype,
                            as_ref=input_arg.is_ref)
                        if input_arg.number_attr and len(
                                set(v.dtype.base_dtype for v in values)) > 1:
                            raise TypeError()  # All types should match.
                    except (TypeError, ValueError):
                        # What types does the conversion function think values have?
                        observed_types = []
                        for value in values:
                            try:
                                converted_value = ops.internal_convert_to_tensor(
                                    value, as_ref=input_arg.is_ref)
                                observed_types.append(
                                    converted_value.dtype.base_dtype.name)
                            except (TypeError, ValueError):
                                observed_types.append(
                                    "<NOT CONVERTIBLE TO TENSOR>")
                        observed = ", ".join(observed_types)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.number_attr:
                            if input_arg.type != types_pb2.DT_INVALID:
                                raise TypeError(
                                    "%s that do not match expected type %s." %
                                    (prefix, dtype.name))
                            elif input_arg.type_attr in attrs:
                                raise TypeError(
                                    "%s that do not match type %s inferred from "
                                    "earlier arguments." %
                                    (prefix, dtype.name))
                            else:
                                raise TypeError("%s that don't all match." %
                                                prefix)
                        else:
                            raise TypeError(
                                "%s that are invalid. Tensors: %s" %
                                (prefix, values))

                    types = [x.dtype for x in values]
                    inputs.extend(values)
                else:
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    elif input_arg.type_attr in default_type_attr_map:
                        # The dtype could not be inferred solely from the inputs,
                        # so we prefer the attr's default, so code that adds a new attr
                        # with a default is backwards compatible.
                        default_dtype = default_type_attr_map[
                            input_arg.type_attr]

                    try:
                        values = ops.internal_convert_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype,
                            as_ref=input_arg.is_ref,
                            preferred_dtype=default_dtype)
                    except TypeError as err:
                        if dtype is None:
                            raise err
                        else:
                            raise TypeError(
                                "Expected %s passed to parameter '%s' of op '%s', got %s of "
                                "type '%s' instead. Error: %s" %
                                (dtypes.as_dtype(dtype).name,
                                 input_arg.name, op_type_name, repr(values),
                                 type(values).__name__, err))
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                        try:
                            observed = ops.internal_convert_to_tensor(
                                values, as_ref=input_arg.is_ref).dtype.name
                        except ValueError as err:
                            raise ValueError(
                                "Tried to convert '%s' to a tensor and failed. Error: %s"
                                % (input_name, err))
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, dtypes.as_dtype(input_arg.type).name))
                        else:
                            # Update the maps with the default, if needed.
                            k = input_arg.type_attr
                            if k in default_type_attr_map:
                                if k not in attrs:
                                    attrs[k] = default_type_attr_map[k]
                                    if k not in inferred_from:
                                        inferred_from[k] = "Default in OpDef"

                            raise TypeError(
                                "%s type %s of argument '%s'." %
                                (prefix,
                                 dtypes.as_dtype(
                                     attrs[input_arg.type_attr]).name,
                                 inferred_from[input_arg.type_attr]))

                    types = [values.dtype]
                    inputs.append(values)
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                                 attrs[input_arg.number_attr],
                                 inferred_from[input_arg.number_attr]))
                    else:
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                                 num_attr.minimum))
                    # All tensors must have the same base type.
                    if any(bt != base_types[0] for bt in base_types):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                                input_arg.type_attr]:
                            assert False, "Unreachable"
                    else:
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0],
                                                 type_attr,
                                                 param_name=input_name)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(base_type,
                                                     _Attr(
                                                         op_def,
                                                         input_arg.type_attr),
                                                     param_name=input_name)
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attr_value),
                                 ", ".join(
                                     dtypes.as_dtype(x).name
                                     for x in attrs[input_arg.type_list_attr]),
                                 inferred_from[input_arg.type_list_attr]))
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type,
                                _Attr(op_def, input_arg.type_list_attr),
                                param_name=input_name)
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                else:
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
                        raise TypeError((
                            "'%s' Op requires that input '%s' be a mutable tensor "
                            "(e.g.: a tf.Variable)") %
                                        (op_type_name, input_name))
                    input_types.extend(types)
                else:
                    input_types.extend(base_types)

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                    continue
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                else:
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_value.CopyFrom(attr_def.default_value)
                    attr_protos[key] = attr_value
                    continue
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                                 attr_def.minimum))
                    attr_value.list.SetInParent()
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name,
                                   compat.as_text(attr_value.s), '", "'.join(
                                       map(compat.as_text,
                                           attr_def.allowed_values.list.s))))
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    %
                                    (key, op_type_name, compat.as_text(x),
                                     '", "'.join(
                                         map(compat.as_text,
                                             attr_def.allowed_values.list.s))))
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                                   attr_def.minimum))
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                    attr_value.list.f.extend(
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                    attr_value.list.b.extend(
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                    attr_value.list.type.extend(
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                    attr_value.list.shape.extend(
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                    attr_value.list.tensor.extend(
                        [_MakeTensor(x, key) for x in value])
                elif attr_def.type == "func":
                    attr_value.func.CopyFrom(_MakeFunc(value, key))
                elif attr_def.type == "list(func)":
                    attr_value.list.func.extend(
                        [_MakeFunc(x, key) for x in value])
                else:
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_structure = []
            for arg in op_def.output_arg:
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                    output_structure.append(n)
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                    output_structure.append(None)
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)
                    output_structure.append(len(t.list.type))
                else:
                    output_structure.append(None)

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # NOTE(mrry): We add an explicit colocation constraint between
            # the newly created op and any of its reference-typed inputs.
            must_colocate_inputs = [
                val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
            ]
            with _MaybeColocateWith(must_colocate_inputs):
                # Add Op to graph
                op = g.create_op(op_type_name,
                                 inputs,
                                 dtypes=None,
                                 name=scope,
                                 input_types=input_types,
                                 attrs=attr_protos,
                                 op_def=op_def)
            return output_structure, op_def.is_stateful, op
    def _distributed_apply(distribution, grads_and_vars, name, apply_state):
        """`apply_gradients` using a `DistributionStrategy`."""
        reduced_grads = distribution.extended.batch_reduce_to(
            ds_reduce_util.ReduceOp.SUM, grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = zip(reduced_grads, var_list)

        def apply_grad_to_update_var(var, grad):
            """Apply gradient to variable."""
            if isinstance(var, ops.Tensor):
                raise NotImplementedError("Trying to update a Tensor ", var)

            apply_kwargs = {}
            if not isinstance(var, de.TrainableWrapper):
                if isinstance(grad, ops.IndexedSlices):
                    if var.constraint is not None:
                        raise RuntimeError(
                            "Cannot use a constraint function on a sparse variable."
                        )
                    if "apply_state" in self._sparse_apply_args:
                        apply_kwargs["apply_state"] = apply_state
                    return self._resource_apply_sparse_duplicate_indices(
                        grad.values, var, grad.indices, **apply_kwargs)

                if "apply_state" in self._dense_apply_args:
                    apply_kwargs["apply_state"] = apply_state
                update_op = self._resource_apply_dense(grad, var,
                                                       **apply_kwargs)
                if var.constraint is not None:
                    with ops.control_dependencies([update_op]):
                        return var.assign(var.constraint(var))
                else:
                    return update_op
            else:
                with ops.colocate_with(None, ignore_existing=True):
                    _slots = [
                        self.get_slot(var, _s) for _s in self.get_slot_names()
                    ]
                    # Add the optimizer slots to restricting list.
                    if var.params.restrict_policy is not None:
                        var.params.restrict_policy._track_optimizer_slots(
                            _slots)

                    with ops.control_dependencies([grad]):
                        _before = [var.read_value()
                                   ] + [_s.read_value() for _s in _slots]
                    if isinstance(grad, ops.IndexedSlices):
                        if var.constraint is not None:
                            raise RuntimeError(
                                "Cannot use a constraint function on a sparse variable."
                            )
                        if "apply_state" in self._sparse_apply_args:
                            apply_kwargs["apply_state"] = apply_state
                        with ops.control_dependencies(_before):
                            _apply_op = self._resource_apply_sparse_duplicate_indices(
                                grad.values, var, grad.indices, **apply_kwargs)
                        with ops.control_dependencies([_apply_op]):
                            _after = control_flow_ops.group(
                                [var.update_op()] +
                                [_s.update_op() for _s in _slots])
                            return _after

                    if "apply_state" in self._dense_apply_args:
                        apply_kwargs["apply_state"] = apply_state
                    with ops.control_dependencies(_before):
                        update_op = self._resource_apply_dense(
                            grad, var, **apply_kwargs)
                    if var.constraint is not None:
                        with ops.control_dependencies([update_op]):
                            return var.assign(var.constraint(var))
                    else:
                        with ops.control_dependencies([update_op]):
                            _after = control_flow_ops.group(
                                [var.update_op()] +
                                [_s.update_op() for _s in _slots])
                        return _after

        update_ops = []
        with backend.name_scope(name or self._name):
            for grad, var in grads_and_vars:
                scope_name = ("update"
                              if ops.executing_eagerly_outside_functions() else
                              "update_" + var.op.name)
                # Colocate the update with variables to avoid unnecessary communication
                # delays. See b/136304694.
                with backend.name_scope(
                        scope_name), distribution.extended.colocate_vars_with(
                            var):
                    update_ops.extend(
                        distribution.extended.update(var,
                                                     apply_grad_to_update_var,
                                                     args=(grad, ),
                                                     group=False))

            any_symbolic = any(
                isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i)
                for i in update_ops)
            if not context.executing_eagerly() or any_symbolic:
                # If the current context is graph mode or any of the update ops are
                # symbolic then the step update should be carried out under a graph
                # context. (eager updates execute immediately)
                with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
                    with ops.control_dependencies(update_ops):
                        return self._iterations.assign_add(1).op

            return self._iterations.assign_add(1)
Ejemplo n.º 11
0
def _current_graph(op_input_list):
    """Return the graph members of `op_input_list`, or the current graph."""
    # pylint: disable=protected-access
    return ops._get_graph_from_inputs(op_input_list)
Ejemplo n.º 12
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
  """Implementation of apply_op that returns output_structure, op."""
  op_def = op_def_registry.get(op_type_name)
  if op_def is None:
    raise RuntimeError(f"Unrecognized Op name {op_type_name}")

  # Determine the graph context.
  try:
    # Need to flatten all the arguments into a list.
    # pylint: disable=protected-access
    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    # pylint: enable=protected-access
  except AssertionError as e:
    raise RuntimeError(
        f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")

  # Default name if not specified.
  if name is None:
    name = op_type_name

  # Check for deprecation
  deprecation_version = op_def.deprecation.version
  if deprecation_version:
    producer = g.graph_def_versions.producer
    if producer >= deprecation_version:
      raise NotImplementedError(
          f"Op {op_type_name} is not available in GraphDef version {producer}. "
          f"It has been removed in version {deprecation_version}. "
          f"{op_def.deprecation.explanation}.")

  # Fill in the list of default types for all "type" attrs.  This
  # will be used to choose a preferred dtype to convert to in the
  # absence of input type information.
  #
  # TODO(b/31302892): Currently the defaults don't work in the right
  # way if you have two inputs, one of whose type resolution depends
  # on the other.  Handling this will require restructuring this code
  # significantly.
  default_type_attr_map = {}
  allowed_list_attr_map = {}
  for attr_def in op_def.attr:
    if attr_def.type != "type":
      continue
    key = attr_def.name
    if attr_def.HasField("default_value"):
      default_type_attr_map[key] = dtypes.as_dtype(
          attr_def.default_value.type)
    if attr_def.HasField("allowed_values"):
      allowed_list_attr_map[key] = attr_def.allowed_values.list.type

  # Requires that op_def has passed validation (using the C++
  # ValidateOpDef() from ../framework/op_def_util.h).
  attrs = {}
  inputs = []
  input_types = []
  with g.as_default(), ops.name_scope(name) as scope:

    # Perform input type inference
    inferred_from = {}
    for input_arg in op_def.input_arg:
      input_name = input_arg.name
      if input_name in keywords:
        values = keywords.pop(input_name)
      elif input_name + "_" in keywords:
        # Handle the case where the name is a keyword or built-in
        # for Python so we use the name + _ instead.
        input_name += "_"
        values = keywords.pop(input_name)
      else:
        raise TypeError(f"No argument for input {input_name} found in {op_def}")

      # Goals:
      # * Convert values to Tensors if it contains constants.
      # * Verify that values is a list if that matches the input_arg's
      #   type.
      # * If the input_arg's type is determined by attrs, either set
      #   those attrs and validate those attr values are legal (if
      #   they have not yet been set) or validate the input matches
      #   the type indicated by the attrs (if they have already been
      #   inferred via an earlier input).
      # * If the input_arg has an explicit type, make sure the input
      #   conforms.

      if _IsListParameter(input_arg):
        if not _IsListValue(values):
          raise TypeError(
              f"Expected list for '{input_name}' argument to '{op_type_name}' "
              f"Op, not {values}.")
        # In cases where we expect all elements of the list to have the
        # same dtype, try to cast non-Tensor elements to that type.
        dtype = None
        default_dtype = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.number_attr:
          if input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          else:
            for t in values:
              if isinstance(t, ops.Tensor):
                dtype = t.dtype
                break

          # dtype still not found, prefer using the default dtype
          # from the attr.
          if dtype is None and input_arg.type_attr in default_type_attr_map:
            default_dtype = default_type_attr_map[input_arg.type_attr]

        try:
          if not input_arg.is_ref and dtype:
            dtype = dtypes.as_dtype(dtype).base_dtype
          values = ops.internal_convert_n_to_tensor(
              values,
              name=input_arg.name,
              dtype=dtype if dtype else None,
              preferred_dtype=default_dtype,
              as_ref=input_arg.is_ref)
          all_types = set(v.dtype.base_dtype for v in values)
          if input_arg.number_attr and len(all_types) > 1:
            # All types should match.
            raise TypeError(f"Not all types matched for {input_arg.name} for "
                            f"{op_type_name}. Got {all_types}")
        except (TypeError, ValueError):
          # What types does the conversion function think values have?
          observed_types = []
          for value in values:
            try:
              converted_value = ops.convert_to_tensor(
                  value, as_ref=input_arg.is_ref)
              observed_types.append(converted_value.dtype.base_dtype.name)
            except (TypeError, ValueError):
              observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
          observed = ", ".join(observed_types)

          prefix = (
              "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
              (input_name, op_type_name, observed))
          if input_arg.number_attr:
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError(f"{prefix} that do not match expected type "
                              f"{dtype.name}.")
            elif input_arg.type_attr in attrs:
              raise TypeError(f"{prefix} that do not match type {dtype.name} "
                              "inferred from earlier arguments.")
            else:
              raise TypeError(f"{prefix} that don't all match.")
          else:
            raise TypeError(f"{prefix} that are invalid. Tensors: {values}")

        types = [x.dtype for x in values]
        inputs.extend(values)
      else:
        # In cases where we have an expected type, try to convert non-Tensor
        # arguments to that type.
        dtype = None
        default_dtype = None
        allowed_list = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.type_attr in attrs:
          dtype = attrs[input_arg.type_attr]
        elif input_arg.type_attr in default_type_attr_map:
          # The dtype could not be inferred solely from the inputs,
          # so we prefer the attr's default, so code that adds a new attr
          # with a default is backwards compatible.
          default_dtype = default_type_attr_map[input_arg.type_attr]
          allowed_list = allowed_list_attr_map.get(input_arg.type_attr)

        try:
          # First see if we can get a valid dtype with the default conversion
          # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
          # not list allowed dtypes, in which case we should skip this.
          if dtype is None and allowed_list:
            inferred = None
            try:
              inferred = ops.convert_to_tensor(
                  values, name=input_arg.name, as_ref=input_arg.is_ref)
            except TypeError as err:
              # When converting a python object such as a list of Dimensions, we
              # need a dtype to be specified, thus tensor conversion may throw
              # an exception which we will ignore and try again below.
              pass

            # If we did not match an allowed dtype, try again with the default
            # dtype. This could be because we have an empty tensor and thus we
            # picked the wrong type.
            if inferred is not None and inferred.dtype in allowed_list:
              values = inferred
            else:
              values = ops.convert_to_tensor(
                  values,
                  name=input_arg.name,
                  as_ref=input_arg.is_ref,
                  preferred_dtype=default_dtype)
          else:
            values = ops.convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
        except TypeError as err:
          if dtype is None:
            raise err
          else:
            raise TypeError(
                f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
                f"'{input_arg.name}' of op '{op_type_name}', got "
                f"{repr(values)} of type '{type(values).__name__}' instead. "
                f"Error: {err}")
        except ValueError:
          # What type does convert_to_tensor think it has?
          try:
            observed = ops.convert_to_tensor(
                values, as_ref=input_arg.is_ref).dtype.name
          except ValueError as err:
            raise ValueError(
                f"Tried to convert '{input_name}' to a tensor and failed. "
                f"Error: {err}")
          prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                    (input_name, op_type_name, observed))
          if input_arg.type != types_pb2.DT_INVALID:
            raise TypeError(f"{prefix} expected type of "
                            f"{dtypes.as_dtype(input_arg.type).name}.")
          else:
            # Update the maps with the default, if needed.
            k = input_arg.type_attr
            if k in default_type_attr_map:
              if k not in attrs:
                attrs[k] = default_type_attr_map[k]
                if k not in inferred_from:
                  inferred_from[k] = "Default in OpDef"

            raise TypeError(
                f"{prefix} type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")

        types = [values.dtype]
        inputs.append(values)
      base_types = [x.base_dtype for x in types]

      if input_arg.number_attr:
        # <number-attr> * <type> or <number-attr> * <type-attr>
        if input_arg.number_attr in attrs:
          if len(values) != attrs[input_arg.number_attr]:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} must match length "
                f"{attrs[input_arg.number_attr]} of argument "
                f"'{inferred_from[input_arg.number_attr]}'.")
        else:
          attrs[input_arg.number_attr] = len(values)
          inferred_from[input_arg.number_attr] = input_name
          num_attr = _Attr(op_def, input_arg.number_attr)
          if num_attr.has_minimum and len(values) < num_attr.minimum:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} shorter than minimum length "
                f"{num_attr.minimum}.")
        # All tensors must have the same base type.
        if any(bt != base_types[0] for bt in base_types):
          raise TypeError(
              f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
              f"must have the same type. Got {base_types} instead.")
        if input_arg.type != types_pb2.DT_INVALID:
          # <number-attr> * <type> case
          if base_types and base_types[0] != input_arg.type:
            assert False, "Unreachable"
        elif input_arg.type_attr in attrs:
          # <number-attr> * <type-attr> case, where <type-attr> already
          # has an inferred value.
          if base_types and base_types[0] != attrs[input_arg.type_attr]:
            assert False, "Unreachable"
        else:
          # <number-attr> * <type-attr> case, where we are now setting
          # the <type-attr> based on this input
          if not base_types:
            # If it's in default_type_attr_map, then wait to set it
            # (in "process remaining attrs", below).
            if input_arg.type_attr not in default_type_attr_map:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  f"list passed to input '{input_name}' of '{op_type_name}' "
                  "Op.")
          else:
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr,
                                     param_name=input_name)
      elif input_arg.type_attr:
        # <type-attr>
        attr_value = base_types[0]
        if input_arg.type_attr in attrs:
          if attrs[input_arg.type_attr] != attr_value:
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type "
                f"{dtypes.as_dtype(attr_value).name} that does not match type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")
        else:
          for base_type in base_types:
            _SatisfiesTypeConstraint(base_type,
                                     _Attr(op_def, input_arg.type_attr),
                                     param_name=input_name)
          attrs[input_arg.type_attr] = attr_value
          inferred_from[input_arg.type_attr] = input_name
      elif input_arg.type_list_attr:
        # <type-list-attr>
        attr_value = base_types
        if input_arg.type_list_attr in attrs:
          if attrs[input_arg.type_list_attr] != attr_value:
            actual_types = ", ".join(
                dtypes.as_dtype(x).name for x in attr_value)
            expected_types = ", ".join(
                dtypes.as_dtype(x).name
                for x in attrs[input_arg.type_list_attr])
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type list of "
                f"{actual_types} that does not match type list {expected_types}"
                f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
        else:
          for base_type in base_types:
            _SatisfiesTypeConstraint(base_type,
                                     _Attr(op_def, input_arg.type_list_attr),
                                     param_name=input_name)
          attrs[input_arg.type_list_attr] = attr_value
          inferred_from[input_arg.type_list_attr] = input_name
      else:
        # single Tensor with specified type
        if base_types[0] != input_arg.type:
          assert False, "Unreachable"

      if input_arg.is_ref:
        if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
          raise TypeError(
              f"'{op_type_name}' Op requires that input '{input_name}' be a "
              "mutable tensor (e.g.: a tf.Variable)")
        input_types.extend(types)
      else:
        input_types.extend(base_types)

    # Process remaining attrs
    for attr in op_def.attr:
      # Skip attrs that have already had their values inferred
      if attr.name in attrs:
        if attr.name in keywords:
          raise TypeError(
              f"Should not specify value for inferred attr '{attr.name}' for "
              f"{op_type_name}.")
        continue
      if attr.name in keywords:
        attrs[attr.name] = keywords.pop(attr.name)
      elif attr.name + "_" in keywords:
        # Attrs whose names match Python keywords have an extra '_'
        # appended, so we must check for that as well.
        attrs[attr.name] = keywords.pop(attr.name + "_")
      elif attr.name in default_type_attr_map:
        attrs[attr.name] = default_type_attr_map[attr.name]
        inferred_from.setdefault(attr.name, "Default in OpDef")
      else:
        raise TypeError(f"No argument found for attr {attr.name} for "
                        f"{op_type_name}")

    # Convert attr values to AttrValue protos.
    attr_protos = {}
    for attr_def in op_def.attr:
      key = attr_def.name
      value = attrs[key]

      if attr_def.HasField("default_value") and value is None:
        attr_value = attr_value_pb2.AttrValue()
        attr_value.CopyFrom(attr_def.default_value)
        attr_protos[key] = attr_value
        continue

      attr_value = value_to_attr_value(value, attr_def.type, key)
      if attr_def.type.startswith("list("):
        _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
      if attr_def.HasField("allowed_values"):
        if attr_def.type == "string":
          _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
                                             op_type_name)
        elif attr_def.type == "list(string)":
          for value in attr_value.list.s:
            _SatisfiesAllowedStringsConstraint(value, attr_def, key,
                                               op_type_name)
      if attr_def.has_minimum and attr_def.type == "int":
        _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
                                       op_type_name)
      if attr_def.type == "type":
        _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
      if attr_def.type == "list(type)":
        for value in attr_value.list.type:
          _SatisfiesTypeConstraint(value, attr_def, key)

      attr_protos[key] = attr_value
    del attrs  # attrs is no longer authoritative, use attr_protos instead

    # Determine output types (possibly using attrs)
    output_structure = []
    for arg in op_def.output_arg:
      if arg.number_attr:
        n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i
        output_structure.append(n)
      elif arg.type_attr:
        t = _AttrValue(attr_protos, arg.type_attr, op_type_name)
        output_structure.append(None)
      elif arg.type_list_attr:
        t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name)
        output_structure.append(len(t.list.type))
      else:
        output_structure.append(None)

    if keywords:
      all_keywords = ", ".join(sorted(keywords.keys()))
      raise TypeError(f"{op_type_name} got unexpected keyword arguments: "
                      f"{all_keywords}.")

    # NOTE(mrry): We add an explicit colocation constraint between
    # the newly created op and any of its reference-typed inputs.
    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                            if arg.is_ref]
    with _MaybeColocateWith(must_colocate_inputs):
      # Add Op to graph
      # pylint: disable=protected-access
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
                                 name=scope, input_types=input_types,
                                 attrs=attr_protos, op_def=op_def)

    # `outputs` is returned as a separate return value so that the output
    # tensors can the `op` per se can be decoupled so that the
    # `op_callbacks` can function properly. See framework/op_callbacks.py
    # for more details.
    outputs = op.outputs
    # Conditionally invoke tfdbg v2's op callback(s).
    if op_callbacks.should_invoke_op_callbacks():
      callback_outputs = op_callbacks.invoke_op_callbacks(
          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
          op_name=op.name, graph=g)
      if callback_outputs is not None:
        outputs = callback_outputs

    return output_structure, op_def.is_stateful, op, outputs
Ejemplo n.º 13
0
    def __init__(self,
                 inputs,
                 labels,
                 name: str = None,
                 params: dict = None):
        """Initialize the model."""

        if name in self.__model_names:
            # TODO: allow? reuse?
            # raise ValueError("Model name `%s` already exists." % name)
            pass

        if name is None:
            # Generate some random name
            name = names.get_first_name(gender='female')
            while name in self.__model_names:
                name = names.get_first_name(gender='female')

        self._name = name
        self.__model_names.add(self._name)

        # noinspection PyProtectedMember
        self._graph = ops._get_graph_from_inputs([inputs, labels])  # pylint: disable=protected-access
        self._graph_context_manager = self._graph.as_default()
        self._graph_context_manager.__enter__()

        with tf.variable_scope('input_layer'):
            self._x = inputs['x']

        with tf.variable_scope('labels'):
            self._labels = labels

        self._layers = [self._x]
        self._layer_name_dct = dict()
        "Dictionary of ([str]layer_name, [int]count)"

        if params is None:
            params = dict()

        # Configurable and directly accessible properties
        for param, default in [('batch_size', None), ('learning_rate', None)]:
            # Make sure that flags are privileged from yaml arguments
            if not tf.app.flags.FLAGS.is_parsed():
                value = default
            else:
                value = tf.app.flags.FLAGS.get_flag_value(param, default)
            if value is None:
                try:
                    value = params[param]
                except KeyError:
                    raise ValueError("attribute `{}` cannot be of type {}.".format(param, type(None)))

            self.__setattr__(param, tf.constant(value, tf.float32))

        _decay_dct = params.get('learning_rate_decay', None)
        if isinstance(_decay_dct, list):  # yaml will pass list here, need to unpack
            _decay_dct, = _decay_dct

        self.learning_rate_decay = _decay_dct is not None
        "bool"

        if self.learning_rate_decay:
            self._decay_steps = tf.constant(_decay_dct.get('decay_steps'), tf.float32)
            self._decay_rate = tf.constant(_decay_dct.get('decay_rate'), tf.float32)

            self.learning_rate = tf.train.exponential_decay(learning_rate=self.learning_rate,
                                                            global_step=tf.train.get_global_step(),
                                                            decay_steps=self._decay_steps,
                                                            decay_rate=self._decay_rate,
                                                            )

        tf.summary.scalar(tensor=self.learning_rate, name='learning_rate')

        self.optimizer = getattr(tf.train, params.get('optimizer', 'AdamOptimizer'))
        "tf.train.Optimizer"
Ejemplo n.º 14
0
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
    Returns:
      Output tensor(s).
    """
        scope = kwargs.pop('scope', None)

        # Define a custom getter to override tf.get_variable when creating layer
        # variables. The current custom getter is nested by the variable scope.
        def variable_getter(getter,
                            name,
                            shape,
                            dtype=None,
                            initializer=None,
                            regularizer=None,
                            trainable=True,
                            **getter_kwargs):
            return self._add_variable(name,
                                      shape,
                                      initializer=initializer,
                                      regularizer=regularizer,
                                      dtype=dtype,
                                      trainable=trainable,
                                      variable_getter=functools.partial(
                                          getter, **getter_kwargs))

        if not self._built and self._scope is None:
            # If constructed with _scope=None, lazy setting of scope.
            if self._reuse:
                self._scope = next(
                    vs.variable_scope(
                        scope if scope is not None else self._base_name).gen)
            else:
                self._scope = next(
                    vs.variable_scope(scope, default_name=self._base_name).gen)
            self._name = self._scope.name

        # Build (if necessary) and call the layer, inside a variable
        # scope.
        with vs.variable_scope(self._scope,
                               reuse=True if self._built else self._reuse,
                               custom_getter=variable_getter) as scope:
            # Ensure the Layer, if being reused, is working with inputs from
            # the same graph as where it was created.
            try:
                ops._get_graph_from_inputs(nest.flatten(inputs),
                                           graph=self.graph)  # pylint: disable=protected-access
            except ValueError as e:
                raise ValueError(
                    "Inputs' and Layer's graphs are not the same: %s" % e)

            with ops.name_scope(scope.original_name_scope):
                if not self.built:
                    input_list = [
                        ops.convert_to_tensor(x, name='input')
                        for x in nest.flatten(inputs)
                    ]
                    input_shapes = [x.get_shape() for x in input_list]
                    if len(input_shapes) == 1:
                        self.build(input_shapes[0])
                    else:
                        self.build(input_shapes)
                    self._built = True
                if 'scope' in tf_inspect.getargspec(self.call).args:
                    kwargs['scope'] = scope
                outputs = self.call(inputs, *args, **kwargs)

                # Apply activity regularization.
                # Note that it should be applied every time the layer creates a new
                # output, since it is output-specific.
                if hasattr(
                        self,
                        'activity_regularizer') and self.activity_regularizer:
                    output_list = _to_list(outputs)
                    for output in output_list:
                        with ops.name_scope('ActivityRegularizer'):
                            activity_regularization = self.activity_regularizer(
                                output)
                        self._losses.append(activity_regularization)
                        _add_elements_to_collection(
                            activity_regularization,
                            ops.GraphKeys.REGULARIZATION_LOSSES)

        # Update global default collections.
        _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        return outputs
Ejemplo n.º 15
0
    def _distributed_apply(self, distribution, grads_and_vars, name,
                           apply_state):
        """`apply_gradients` using a `DistributionStrategy`."""
        reduced_grads = self._aggregate_gradients(distribution, grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = zip(reduced_grads, var_list)

        def apply_grad_to_update_var(var, grad):
            """Apply gradient to variable."""
            if isinstance(var, ops.Tensor):
                raise NotImplementedError("Trying to update a Tensor ", var)

            apply_kwargs = {}
            if isinstance(grad, ops.IndexedSlices):
                if var.constraint is not None:
                    raise RuntimeError(
                        "Cannot use a constraint function on a sparse variable."
                    )
                if "apply_state" in self._sparse_apply_args:
                    apply_kwargs["apply_state"] = apply_state
                return self._resource_apply_sparse_duplicate_indices(
                    grad.values, var, grad.indices, **apply_kwargs)

            if "apply_state" in self._dense_apply_args:
                apply_kwargs["apply_state"] = apply_state
            update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
            if var.constraint is not None:
                with ops.control_dependencies([update_op]):
                    return var.assign(var.constraint(var))
            else:
                return update_op

        eagerly_outside_functions = ops.executing_eagerly_outside_functions()
        update_ops = []
        with ops.name_scope(name or self._name, skip_on_eager=True):
            for grad, var in grads_and_vars:
                # Colocate the update with variables to avoid unnecessary communication
                # delays. See b/136304694.
                with distribution.extended.colocate_vars_with(var):
                    with ops.name_scope("update" if eagerly_outside_functions
                                        else "update_" + var.op.name,
                                        skip_on_eager=True):
                        update_ops.extend(
                            distribution.extended.update(
                                var,
                                apply_grad_to_update_var,
                                args=(grad, ),
                                group=False))

            any_symbolic = any(
                isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i)
                for i in update_ops)
            if not context.executing_eagerly() or any_symbolic:
                # If the current context is graph mode or any of the update ops are
                # symbolic then the step update should be carried out under a graph
                # context. (eager updates execute immediately)
                with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
                    with ops.control_dependencies(update_ops):
                        return self._iterations.assign_add(1).op

            return self._iterations.assign_add(1)
Ejemplo n.º 16
0
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
    Returns:
      Output tensor(s).
    """
        self._set_scope(kwargs.pop('scope', None))

        # Ensure the Layer, if being reused, is working with inputs from
        # the same graph as where it was created.
        try:
            ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
        except ValueError as e:
            raise ValueError(
                'Input graph and Layer graph are not the same: %s' % e)

        with vs.variable_scope(self._scope, reuse=self.built
                               or self._reuse) as scope:
            with ops.name_scope(scope.original_name_scope):
                if not self.built:
                    # Check input assumptions set before layer building, e.g. input rank.
                    self._assert_input_compatibility(inputs)
                    input_list = [
                        ops.convert_to_tensor(x, name='input')
                        for x in nest.flatten(inputs)
                    ]
                    input_shapes = [x.get_shape() for x in input_list]
                    if len(input_shapes) == 1:
                        self.build(input_shapes[0])
                    else:
                        self.build(input_shapes)
                if 'scope' in getargspec(self.call).args:
                    kwargs['scope'] = scope
                # Check input assumptions set after layer building, e.g. input shape.
                self._assert_input_compatibility(inputs)
                outputs = self.call(inputs, *args, **kwargs)

                # Apply activity regularization.
                # Note that it should be applied every time the layer creates a new
                # output, since it is output-specific.
                if hasattr(
                        self,
                        'activity_regularizer') and self.activity_regularizer:
                    output_list = _to_list(outputs)
                    for output in output_list:
                        with ops.name_scope('ActivityRegularizer'):
                            activity_regularization = self.activity_regularizer(
                                output)
                        self.add_loss(activity_regularization)
                        _add_elements_to_collection(
                            activity_regularization,
                            ops.GraphKeys.REGULARIZATION_LOSSES)

        # Update global default collections.
        _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        self.built = True
        return outputs
Ejemplo n.º 17
0
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
    scope = kwargs.pop('scope', None)

    if self._keras_style:
      if scope is not None:
        raise ValueError(
            'scope argument not allowed when keras style layers are enabled, '
            'but saw: {}'.format(scope))
      return super(Layer, self).__call__(inputs, *args, **kwargs)

    self._set_scope(scope)

    if not context.executing_eagerly():
      try:
        # Set layer's "graph" at build time
        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
                                                 graph=self._graph)
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    if self.built:
      try:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True, auxiliary_name_scope=False)
        scope_context_manager = self._always_reuse_variable_scope
    else:
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse, auxiliary_name_scope=False)

    with scope_context_manager as scope:
      self._current_scope = scope

      try:
        call_has_scope_arg = self._call_has_scope_arg
      except AttributeError:
        self._call_fn_args = function_utils.fn_args(self.call)
        self._call_has_scope_arg = 'scope' in self._call_fn_args
        call_has_scope_arg = self._call_has_scope_arg
      if call_has_scope_arg:
        kwargs['scope'] = scope

      # Actually call layer
      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

    if not context.executing_eagerly():
      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
Ejemplo n.º 18
0
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
        self._set_scope(kwargs.pop('scope', None))

        if not context.executing_eagerly():
            try:
                # Set layer's "graph" at build time
                self._graph = ops._get_graph_from_inputs(
                    nest.flatten(inputs),  # pylint: disable=protected-access
                    graph=self._graph)
            except ValueError as e:
                raise ValueError(
                    'Input graph and Layer graph are not the same: %s' % e)

        if self.built:
            try:
                # Some classes which inherit from Layer do not use its constructor, so
                # rather than initializing to None we check for an AttributeError.
                scope_context_manager = self._always_reuse_variable_scope
            except AttributeError:
                # From this point we will always set reuse=True, so create a "final"
                # variable scope with this setting. We avoid re-creating variable scopes
                # after this point as an optimization.
                self._always_reuse_variable_scope = vs.variable_scope(
                    self._scope, reuse=True, auxiliary_name_scope=False)
                scope_context_manager = self._always_reuse_variable_scope
        else:
            scope_context_manager = vs.variable_scope(
                self._scope, reuse=self._reuse, auxiliary_name_scope=False)

        with scope_context_manager as scope:
            self._current_scope = scope

            try:
                call_has_scope_arg = self._call_has_scope_arg
            except AttributeError:
                self._call_fn_args = function_utils.fn_args(self.call)
                self._call_has_scope_arg = 'scope' in self._call_fn_args
                call_has_scope_arg = self._call_has_scope_arg
            if call_has_scope_arg:
                kwargs['scope'] = scope

            # Actually call layer
            outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

        if not context.executing_eagerly():
            # Update global default collections.
            _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        return outputs
Ejemplo n.º 19
0
  def apply_op(self, op_type_name, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Cannot determine graph for Op '%s' due to: %s"
          % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
      producer = g.graph_def_versions.producer
      if producer >= deprecation_version:
        raise NotImplementedError(
            ("Op %s is not available in GraphDef version %d. "
             "It has been removed in version %d. %s.") %
            (op_type_name, producer, deprecation_version,
             op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
      if attr_def.type != "type":
        continue
      key = attr_def.name
      if attr_def.HasField("default_value"):
        default_type_attr_map[key] = dtypes.as_dtype(
            attr_def.default_value.type)

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

            # dtype still not found, prefer using the default dtype
            # from the attr.
            if dtype is None and input_arg.type_attr in default_type_attr_map:
              default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            if not input_arg.is_ref and dtype:
              dtype = dtypes.as_dtype(dtype).base_dtype
            values = ops.convert_n_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype if dtype else None,
                preferred_dtype=default_dtype,
                as_ref=input_arg.is_ref)
            if input_arg.number_attr and len(
                set(v.dtype.base_dtype for v in values)) > 1:
              raise TypeError()  # All types should match.
          except (TypeError, ValueError):
            # What types does the conversion function think values have?
            observed_types = []
            for value in values:
              try:
                converted_value = ops.convert_to_tensor(
                    value, as_ref=input_arg.is_ref)
                observed_types.append(converted_value.dtype.base_dtype.name)
              except (TypeError, ValueError):
                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
            observed = ", ".join(observed_types)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.number_attr:
              if input_arg.type != types_pb2.DT_INVALID:
                raise TypeError("%s that do not match expected type %s." %
                                (prefix, dtype.name))
              elif input_arg.type_attr in attrs:
                raise TypeError("%s that do not match type %s inferred from "
                                "earlier arguments." %
                                (prefix, dtype.name))
              else:
                raise TypeError("%s that don't all match." % prefix)
            else:
              raise TypeError("%s that are invalid." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          elif input_arg.type_attr in default_type_attr_map:
            # The dtype could not be inferred solely from the inputs,
            # so we prefer the attr's default, so code that adds a new attr
            # with a default is backwards compatible.
            default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            values = ops.convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values,
                                             as_ref=input_arg.is_ref).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, dtypes.as_dtype(input_arg.type).name))
            else:
              # Update the maps with the default, if needed.
              k = input_arg.type_attr
              if k in default_type_attr_map:
                if k not in attrs:
                  attrs[k] = default_type_attr_map[k]
                  if k not in inferred_from:
                    inferred_from[k] = "Default in OpDef"

              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
                   ", ".join(dtypes.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
          attr_value.list.SetInParent()
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, compat.as_text(attr_value.s),
                   '", "'.join(map(compat.as_text,
                                   attr_def.allowed_values.list.s))))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, compat.as_text(x),
                     '", "'.join(map(compat.as_text,
                                     attr_def.allowed_values.list.s))))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        elif attr_def.type == "func":
          if isinstance(value, compat.bytes_or_text_types):
            attr_value.func.name = value
          else:
            value.add_to_graph(ops.get_default_graph())
            attr_value.func.name = value.name
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(types))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [dtypes.as_dtype(x).as_ref for x in types]
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # NOTE(mrry): We add an explicit colocation constraint between
      # the newly created op and any of its reference-typed inputs.
      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                              if arg.is_ref]
      with _MaybeColocateWith(must_colocate_inputs):
        # Add Op to graph
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
        if output_structure:
          outputs = op.outputs
          res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
          if isinstance(res, list) and not res and op_def.is_stateful:
            return op
          else:
            return res
        else:
          return op
Ejemplo n.º 20
0
def variable_op_scope(values,
                      name_or_scope,
                      default_name,
                      initializer=None,
                      regularizer=None,
                      caching_device=None,
                      reuse=None):
    """Returns a context manager for defining an op that creates variables.

  This context manager validates that the given `values` are from the
  same graph, ensures that that graph is the default graph, and pushes a
  name scope and a variable scope.

  If `name_or_scope` is not None, it is used as is in the variable scope. If
  `scope` is None, then `default_name` is used.  In that case, if the same name
  has been previously used in the same scope, it will made unique be appending
  `_N` to it.

  This is intended to be used when defining generic ops and so reuse is always
  inherited.

  For example, to define a new Python op called `my_op_with_vars`:

  ```python
  def my_op_with_vars(a, b, scope=None):
    with tf.variable_op_scope([a, b], scope, "MyOp") as scope:
      a = tf.convert_to_tensor(a, name="a")
      b = tf.convert_to_tensor(b, name="b")
      c = tf.get_variable('c')
      # Define some computation that uses `a`, `b`, and `c`.
      return foo_op(..., name=scope)
  ```

  Args:
    values: The list of `Tensor` arguments that are passed to the op function.
    name_or_scope: The name argument that is passed to the op function,
      this name_or_scope is not uniquified in the variable scope.
    default_name: The default name to use if the `name_or_scope` argument is
      `None`, this name will be uniquified.
    initializer: The  default initializer to pass to variable scope.
    regularizer: The default regularizer for variables within this scope.
    caching_device: The default caching device for variables within this scope.
    reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as
      well as all sub-scopes; if `None`, we just inherit the parent scope reuse.


  Returns:
    A context manager for use in defining a Python op.

  Raises:
    ValueError: when trying to reuse within a create scope, or create within
      a reuse scope, or if reuse is not `None` or `True`.
    TypeError: when the types of some arguments are not appropriate.
  """
    if default_name is None:
        raise TypeError("default_name cannot be None")
    g = ops._get_graph_from_inputs(values)  # pylint: disable=protected-access
    with g.as_default():
        if name_or_scope:
            with variable_scope(name_or_scope,
                                reuse=reuse,
                                initializer=initializer,
                                regularizer=regularizer,
                                caching_device=caching_device) as vs:
                yield vs
        else:
            if reuse:
                raise ValueError(
                    "reuse=True cannot be used without a name_or_scope")
            with ops.name_scope(default_name) as scope:
                count = len(default_name.split("/"))
                scoped_name = "/".join(scope.split("/")[-count - 1:-1])
                with _pure_variable_scope(scoped_name,
                                          initializer=initializer,
                                          regularizer=regularizer,
                                          caching_device=caching_device) as vs:
                    yield vs