コード例 #1
ファイル: base.py プロジェクト: AlbertXiebnu/tensorflow
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
      Output tensor(s).
    self._set_scope(kwargs.pop('scope', None))

    # Ensure the Layer, if being reused, is working with inputs from
    # the same graph as where it was created.
      ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
    except ValueError as e:
      raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    with vs.variable_scope(self._scope,
                           reuse=self.built or self._reuse) as scope:
      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          # Check input assumptions set before layer building, e.g. input rank.
          input_list = [
              ops.convert_to_tensor(x, name='input')
              for x in nest.flatten(inputs)]
          input_shapes = [x.get_shape() for x in input_list]
          if len(input_shapes) == 1:
        if 'scope' in tf_inspect.getargspec(self.call).args:
          kwargs['scope'] = scope
        # Check input assumptions set after layer building, e.g. input shape.
        outputs = self.call(inputs, *args, **kwargs)

        # Apply activity regularization.
        # Note that it should be applied every time the layer creates a new
        # output, since it is output-specific.
        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
          output_list = _to_list(outputs)
          for output in output_list:
            with ops.name_scope('ActivityRegularizer'):
              activity_regularization = self.activity_regularizer(output)
                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)

    # Update global default collections.
    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    self.built = True
    return outputs
コード例 #2
ファイル: ops.py プロジェクト: 1000sprites/tensorflow
def get_graph_from_inputs(op_input_list, graph=None):
  """Returns the appropriate graph to use for the given inputs.

  1. If `graph` is provided, we validate that all inputs in `op_input_list` are
     from the same graph.
  2. Otherwise, we attempt to select a graph from the first Operation- or
     Tensor-valued input in `op_input_list`, and validate that all other
     such inputs are in the same graph.
  3. If the graph was not specified and it could not be inferred from
     `op_input_list`, we attempt to use the default graph.

    op_input_list: A list of inputs to an operation, which may include `Tensor`,
      `Operation`, and other objects that may be converted to a graph element.
    graph: (Optional) The explicit graph to use.

    TypeError: If `op_input_list` is not a list or tuple, or if graph is not a
    ValueError: If a graph is explicitly passed and not all inputs are from it,
      or if the inputs are from multiple graphs, or we could not find a graph
      and there was no default graph.

    The appropriate graph to use for the given inputs.
  # pylint: disable=protected-access
  return ops._get_graph_from_inputs(op_input_list, graph)
コード例 #3
ファイル: variable_scope.py プロジェクト: 13331151/tensorflow
def variable_op_scope(values, name, default_name, initializer=None,
  """Returns a context manager for defining an op that creates variables.

  This context manager validates that the given `values` are from the
  same graph, ensures that that graph is the default graph, and pushes a
  name scope and a variable scope.

  If `name` is not None, it is used as is in the variable scope. If `name`
  is None, then `default_name` is used.  In that case, if the same name has been
  previously used in the same scope, it will made unique be appending `_N` to

  This is intended to be used when defining generic ops and so reuse is always

  For example, to define a new Python op called `my_op_with_vars`:

  def my_op_with_vars(a, b, name=None):
    with tf.variable_op_scope([a, b], name, "MyOp") as scope:
      a = tf.convert_to_tensor(a, name="a")
      b = tf.convert_to_tensor(b, name="b")
      c = tf.get_variable('c')
      # Define some computation that uses `a`, `b`, and `c`.
      return foo_op(..., name=scope)

    values: The list of `Tensor` arguments that are passed to the op function.
    name: The name argument that is passed to the op function, this name is not
      uniquified in the variable scope.
    default_name: The default name to use if the `name` argument is `None`, this
      name will be uniquified.
    initializer: A default initializer to pass to variable scope.
    regularizer: default regularizer for variables within this scope.

    A context manager for use in defining a Python op.

    ValueError: when trying to reuse within a create scope, or create within
      a reuse scope, or if reuse is not `None` or `True`.
    TypeError: when the types of some arguments are not appropriate.
  if default_name is None:
    raise TypeError("default_name cannot be None")
  g = ops._get_graph_from_inputs(values)  # pylint: disable=protected-access
  with g.as_default():
    if name:
      with variable_scope(name, initializer=initializer,
                          regularizer=regularizer) as vs:
        yield vs
      with ops.name_scope(default_name) as scope:
        count = len(default_name.split("/"))
        scoped_name = "/".join(scope.split("/")[-count - 1:-1])
        with _pure_variable_scope(scoped_name, initializer=initializer,
                                  regularizer=regularizer) as vs:
          yield vs
コード例 #4
ファイル: graph_only_ops.py プロジェクト: Wajih-O/tensorflow
def graph_zeros_like(tensor):
  """Graph-only version of tf.zeros_like(), for internal use only."""
  g = ops._get_graph_from_inputs([tensor])  # pylint: disable=protected-access
  with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
    tensor = ops.convert_to_tensor(tensor, name="tensor")
    dtype = tensor.dtype.base_dtype
    dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum)
    op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
                     attrs={"T": dtype_value}, name=name)
  result, = op.outputs
  return result
コード例 #5
ファイル: optimizer_v2.py プロジェクト: aritratony/tensorflow
  def _distributed_apply(self, distribution, grads_and_vars, name):
    """`apply_gradients` using a `DistributionStrategy`."""
    reduced_grads = distribution.extended.batch_reduce_to(
        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
    var_list = [v for _, v in grads_and_vars]
    grads_and_vars = zip(reduced_grads, var_list)

    def apply_grad_to_update_var(var, grad):
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
      if isinstance(grad, ops.IndexedSlices):
        if var.constraint is not None:
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
        return self._resource_apply_sparse_duplicate_indices(
            grad.values, var, grad.indices)
      update_op = self._resource_apply_dense(grad, var)
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
        return update_op

    update_ops = []
    with backend.name_scope(name or self._name):
      for grad, var in grads_and_vars:
        scope_name = ("" if ops.executing_eagerly_outside_functions() else
                      "_" + var.op.name)
        with backend.name_scope("update" + scope_name):
                  var, apply_grad_to_update_var, args=(grad,), group=False))

      any_symbolic = any(isinstance(i, ops.Operation) or
                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
      if not context.executing_eagerly() or any_symbolic:
        # If the current context is graph mode or any of the update ops are
        # symbolic then the step update should be carried out under a graph
        # context. (eager updates execute immediately)
        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
          with ops.control_dependencies(update_ops):
            return self._iterations.assign_add(1).op

      return self._iterations.assign_add(1)
コード例 #6
ファイル: base.py プロジェクト: LUTAN/tensorflow
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**, the kwarg 'scope' is reserved for use by the Layer.

      Output tensor(s).
    scope = kwargs.pop('scope', None)

    # Define a custom getter to override tf.get_variable when creating layer
    # variables. The current custom getter is nested by the variable scope.
    def variable_getter(getter, name, shape, dtype=None, initializer=None,
                        regularizer=None, trainable=True, **getter_kwargs):
      return self._add_variable(
          name, shape, initializer=initializer, regularizer=regularizer,
          dtype=dtype, trainable=trainable,
          variable_getter=functools.partial(getter, **getter_kwargs))

    if not self._built and self._scope is None:
      # If constructed with _scope=None, lazy setting of scope.
      if self._reuse:
        self._scope = next(vs.variable_scope(
            scope if scope is not None else self._base_name).gen)
        self._scope = next(vs.variable_scope(
            scope, default_name=self._base_name).gen)
      self._name = self._scope.name

    # Build (if necessary) and call the layer, inside a variable
    # scope.
    with vs.variable_scope(self._scope,
                           reuse=True if self._built else self._reuse,
                           custom_getter=variable_getter) as scope:
      # Ensure the Layer, if being reused, is working with inputs from
      # the same graph as where it was created.
        ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
      except ValueError as e:
        raise ValueError("Inputs' and Layer's graphs are not the same: %s" % e)

      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          input_list = [
              ops.convert_to_tensor(x, name='input')
              for x in nest.flatten(inputs)]
          input_shapes = [x.get_shape() for x in input_list]
          if len(input_shapes) == 1:
          self._built = True
        outputs = self.call(inputs, *args, **kwargs)

        # Apply activity regularization.
        # Note that it should be applied every time the layer creates a new
        # output, since it is output-specific.
        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
          output_list = _to_list(outputs)
          for output in output_list:
            with ops.name_scope('ActivityRegularizer'):
              activity_regularization = self.activity_regularizer(output)
                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)

    # Update global default collections.
    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
コード例 #7
  def apply_op(self, op_type_name, g=None, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # If none of the inputs are Tensors and your session doesn't have a
       # default graph, you will have to specify the graph.
       op_def_library.apply_op("op", input1=input1, g=g)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

      op_type_name: string. Must match the name field of a registered Op.
      g: The graph context (optional)
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Need to specify g=graph to Op '%s' (could not determine graph due "
          "to: %s)" % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype

            values = ops.convert_n_to_tensor_or_indexed_slices(
                values, name=input_arg.name,
                dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
          except (TypeError, ValueError):
            assert dtype is not None, "Should not fail if dtype is None"
            assert input_arg.number_attr, "Should be number_attr case"
            # What types does the conversion function think values have?
            values = ops.convert_n_to_tensor_or_indexed_slices(values)
            observed = ", ".join(v.dtype.base_dtype.name for v in values)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s that do not match expected type %s." %
                              (prefix, types_lib.as_dtype(dtype).name))
            elif input_arg.type_attr in attrs:
              raise TypeError("%s that do not match type %s inferred from "
                              "earlier arguments." %
                              (prefix, types_lib.as_dtype(dtype).name))
              raise TypeError("%s that don't all match." % prefix)

          types = [x.dtype for x in values]
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]

            values = ops.convert_to_tensor(
                values, name=input_arg.name, dtype=dtype)
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, types_lib.as_dtype(input_arg.type).name))
              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,

          types = [values.dtype]
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
            for base_type in base_types:
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(types_lib.as_dtype(x).name for x in attr_value),
                   ", ".join(types_lib.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
            for base_type in base_types:
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_protos[key] = attr_value
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, attr_value.s,
                   '", "'.join(attr_def.allowed_values.list.s)))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, x,
                     '", "'.join(attr_def.allowed_values.list.s)))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
              [_MakeTensor(x, key) for x in value])
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
            types = [arg.type] * n
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          types = [arg.type]
        if arg.is_ref:
          types = [types_lib.as_dtype(x).as_ref for x in types]

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # Add Op to graph
      if output_structure:
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
        outputs = op.outputs
        return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
        return g.create_op(op_type_name, inputs, output_types, name=scope,
                           input_types=input_types, attrs=attr_protos,
コード例 #8
    def apply_op(self, op_type_name, name=None, **keywords):
        # pylint: disable=g-doc-args
        """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
            # pyline: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Cannot determine graph for Op '%s' due to: %s" %
                (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype

                        if not input_arg.is_ref and dtype:
                            dtype = dtypes.as_dtype(dtype).base_dtype
                        values = ops.convert_n_to_tensor_or_indexed_slices(
                            dtype=dtype if dtype else None,
                    except (TypeError, ValueError):
                        assert dtype is not None, "Should not fail if dtype is None"
                        assert input_arg.number_attr, "Should be number_attr case"
                        # What types does the conversion function think values have?
                        values = ops.convert_n_to_tensor_or_indexed_slices(
                            values, as_ref=input_arg.is_ref)
                        observed = ", ".join(v.dtype.base_dtype.name
                                             for v in values)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s that do not match expected type %s." %
                                (prefix, dtype.name))
                        elif input_arg.type_attr in attrs:
                            raise TypeError(
                                "%s that do not match type %s inferred from "
                                "earlier arguments." % (prefix, dtype.name))
                            raise TypeError("%s that don't all match." %

                    types = [x.dtype for x in values]
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                        values = ops.convert_to_tensor(values,
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                        observed = ops.convert_to_tensor(
                            values, as_ref=input_arg.is_ref).dtype.name
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, dtypes.as_dtype(input_arg.type).name))
                            raise TypeError(
                                "%s type %s of argument '%s'." %

                    types = [values.dtype]
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                    # All tensors must have the same base type.
                    if any([bt != base_types[0] for bt in base_types]):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                            assert False, "Unreachable"
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0], type_attr)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                        for base_type in base_types:
                                base_type, _Attr(op_def, input_arg.type_attr))
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    for x in attr_value),
                                 ", ".join(
                                     for x in attrs[input_arg.type_list_attr]),
                        for base_type in base_types:
                                _Attr(op_def, input_arg.type_list_attr))
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x.is_ref_dtype for x in types):
                        raise TypeError(
                            "Input '%s' of '%s' Op requires l-value input" %
                            (input_name, op_type_name))

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_protos[key] = attr_value
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name,
                                   compat.as_text(attr_value.s), '", "'.join(
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    (key, op_type_name, compat.as_text(x),
                                     '", "'.join(
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                        [_MakeTensor(x, key) for x in value])
                elif attr_def.type == "func":
                    if not isinstance(value, compat.bytes_or_text_types):
                        raise TypeError("Expects a string for the func name")
                    attr_value.func.name = value
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_types = []
            output_structure = []
            for arg in op_def.output_arg:
                types = []
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                    if arg.type_attr:
                        types = [_AttrValue(attr_protos, arg.type_attr).type
                                 ] * n
                        types = [arg.type] * n
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                    types = [t.type]
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)
                    types = t.list.type
                    types = [arg.type]
                if arg.is_ref:
                    types = [dtypes.as_dtype(x).as_ref for x in types]

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # Add Op to graph
            if output_structure:
                op = g.create_op(op_type_name,
                outputs = op.outputs
                return _Restructure(
                return g.create_op(op_type_name,
コード例 #9
    def _apply_op_helper(self, op_type_name, name=None, **keywords):
        """Implementation of apply_op that returns output_structure, op."""
        #    print("KST_TEST_ophelper")
        #    print(op_type_name)
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
            # pylint: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Cannot determine graph for Op '%s' due to: %s" %
                (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Check for deprecation
        deprecation_version = op_def.deprecation.version
        if deprecation_version:
            producer = g.graph_def_versions.producer
            if producer >= deprecation_version:
                raise NotImplementedError(
                    ("Op %s is not available in GraphDef version %d. "
                     "It has been removed in version %d. %s.") %
                    (op_type_name, producer, deprecation_version,

        # Fill in the list of default types for all "type" attrs.  This
        # will be used to choose a preferred dtype to convert to in the
        # absence of input type information.
        # TODO(b/31302892): Currently the defaults don't work in the right
        # way if you have two inputs, one of whose type resolution depends
        # on the other.  Handling this will require restructuring this code
        # significantly.
        default_type_attr_map = {}
        for attr_def in op_def.attr:
            if attr_def.type != "type":
            key = attr_def.name
            if attr_def.HasField("default_value"):
                default_type_attr_map[key] = dtypes.as_dtype(

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype

                        # dtype still not found, prefer using the default dtype
                        # from the attr.
                        if dtype is None and input_arg.type_attr in default_type_attr_map:
                            default_dtype = default_type_attr_map[

                        if not input_arg.is_ref and dtype:
                            dtype = dtypes.as_dtype(dtype).base_dtype
                        values = ops.internal_convert_n_to_tensor(
                            dtype=dtype if dtype else None,
                        if input_arg.number_attr and len(
                                set(v.dtype.base_dtype for v in values)) > 1:
                            raise TypeError()  # All types should match.
                    except (TypeError, ValueError):
                        # What types does the conversion function think values have?
                        observed_types = []
                        for value in values:
                                converted_value = ops.internal_convert_to_tensor(
                                    value, as_ref=input_arg.is_ref)
                            except (TypeError, ValueError):
                                    "<NOT CONVERTIBLE TO TENSOR>")
                        observed = ", ".join(observed_types)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.number_attr:
                            if input_arg.type != types_pb2.DT_INVALID:
                                raise TypeError(
                                    "%s that do not match expected type %s." %
                                    (prefix, dtype.name))
                            elif input_arg.type_attr in attrs:
                                raise TypeError(
                                    "%s that do not match type %s inferred from "
                                    "earlier arguments." %
                                    (prefix, dtype.name))
                                raise TypeError("%s that don't all match." %
                            raise TypeError(
                                "%s that are invalid. Tensors: %s" %
                                (prefix, values))

                    types = [x.dtype for x in values]
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    elif input_arg.type_attr in default_type_attr_map:
                        # The dtype could not be inferred solely from the inputs,
                        # so we prefer the attr's default, so code that adds a new attr
                        # with a default is backwards compatible.
                        default_dtype = default_type_attr_map[

                        values = ops.internal_convert_to_tensor(
                    except TypeError as err:
                        if dtype is None:
                            raise err
                            raise TypeError(
                                "Expected %s passed to parameter '%s' of op '%s', got %s of "
                                "type '%s' instead. Error: %s" %
                                 input_arg.name, op_type_name, repr(values),
                                 type(values).__name__, err))
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                            observed = ops.internal_convert_to_tensor(
                                values, as_ref=input_arg.is_ref).dtype.name
                        except ValueError as err:
                            raise ValueError(
                                "Tried to convert '%s' to a tensor and failed. Error: %s"
                                % (input_name, err))
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, dtypes.as_dtype(input_arg.type).name))
                            # Update the maps with the default, if needed.
                            k = input_arg.type_attr
                            if k in default_type_attr_map:
                                if k not in attrs:
                                    attrs[k] = default_type_attr_map[k]
                                    if k not in inferred_from:
                                        inferred_from[k] = "Default in OpDef"

                            raise TypeError(
                                "%s type %s of argument '%s'." %

                    types = [values.dtype]
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                    # All tensors must have the same base type.
                    if any(bt != base_types[0] for bt in base_types):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                            assert False, "Unreachable"
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                        for base_type in base_types:
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    for x in attr_value),
                                 ", ".join(
                                     for x in attrs[input_arg.type_list_attr]),
                        for base_type in base_types:
                                _Attr(op_def, input_arg.type_list_attr),
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
                        raise TypeError((
                            "'%s' Op requires that input '%s' be a mutable tensor "
                            "(e.g.: a tf.Variable)") %
                                        (op_type_name, input_name))

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_protos[key] = attr_value
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name,
                                   compat.as_text(attr_value.s), '", "'.join(
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    (key, op_type_name, compat.as_text(x),
                                     '", "'.join(
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                        [_MakeTensor(x, key) for x in value])
                elif attr_def.type == "func":
                    attr_value.func.CopyFrom(_MakeFunc(value, key))
                elif attr_def.type == "list(func)":
                        [_MakeFunc(x, key) for x in value])
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_structure = []
            for arg in op_def.output_arg:
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # NOTE(mrry): We add an explicit colocation constraint between
            # the newly created op and any of its reference-typed inputs.
            must_colocate_inputs = [
                val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
            with _MaybeColocateWith(must_colocate_inputs):
                # Add Op to graph
                op = g.create_op(op_type_name,
            return output_structure, op_def.is_stateful, op
コード例 #10
    def _distributed_apply(distribution, grads_and_vars, name, apply_state):
        """`apply_gradients` using a `DistributionStrategy`."""
        reduced_grads = distribution.extended.batch_reduce_to(
            ds_reduce_util.ReduceOp.SUM, grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = zip(reduced_grads, var_list)

        def apply_grad_to_update_var(var, grad):
            """Apply gradient to variable."""
            if isinstance(var, ops.Tensor):
                raise NotImplementedError("Trying to update a Tensor ", var)

            apply_kwargs = {}
            if not isinstance(var, de.TrainableWrapper):
                if isinstance(grad, ops.IndexedSlices):
                    if var.constraint is not None:
                        raise RuntimeError(
                            "Cannot use a constraint function on a sparse variable."
                    if "apply_state" in self._sparse_apply_args:
                        apply_kwargs["apply_state"] = apply_state
                    return self._resource_apply_sparse_duplicate_indices(
                        grad.values, var, grad.indices, **apply_kwargs)

                if "apply_state" in self._dense_apply_args:
                    apply_kwargs["apply_state"] = apply_state
                update_op = self._resource_apply_dense(grad, var,
                if var.constraint is not None:
                    with ops.control_dependencies([update_op]):
                        return var.assign(var.constraint(var))
                    return update_op
                with ops.colocate_with(None, ignore_existing=True):
                    _slots = [
                        self.get_slot(var, _s) for _s in self.get_slot_names()
                    # Add the optimizer slots to restricting list.
                    if var.params.restrict_policy is not None:

                    with ops.control_dependencies([grad]):
                        _before = [var.read_value()
                                   ] + [_s.read_value() for _s in _slots]
                    if isinstance(grad, ops.IndexedSlices):
                        if var.constraint is not None:
                            raise RuntimeError(
                                "Cannot use a constraint function on a sparse variable."
                        if "apply_state" in self._sparse_apply_args:
                            apply_kwargs["apply_state"] = apply_state
                        with ops.control_dependencies(_before):
                            _apply_op = self._resource_apply_sparse_duplicate_indices(
                                grad.values, var, grad.indices, **apply_kwargs)
                        with ops.control_dependencies([_apply_op]):
                            _after = control_flow_ops.group(
                                [var.update_op()] +
                                [_s.update_op() for _s in _slots])
                            return _after

                    if "apply_state" in self._dense_apply_args:
                        apply_kwargs["apply_state"] = apply_state
                    with ops.control_dependencies(_before):
                        update_op = self._resource_apply_dense(
                            grad, var, **apply_kwargs)
                    if var.constraint is not None:
                        with ops.control_dependencies([update_op]):
                            return var.assign(var.constraint(var))
                        with ops.control_dependencies([update_op]):
                            _after = control_flow_ops.group(
                                [var.update_op()] +
                                [_s.update_op() for _s in _slots])
                        return _after

        update_ops = []
        with backend.name_scope(name or self._name):
            for grad, var in grads_and_vars:
                scope_name = ("update"
                              if ops.executing_eagerly_outside_functions() else
                              "update_" + var.op.name)
                # Colocate the update with variables to avoid unnecessary communication
                # delays. See b/136304694.
                with backend.name_scope(
                        scope_name), distribution.extended.colocate_vars_with(
                                                     args=(grad, ),

            any_symbolic = any(
                isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i)
                for i in update_ops)
            if not context.executing_eagerly() or any_symbolic:
                # If the current context is graph mode or any of the update ops are
                # symbolic then the step update should be carried out under a graph
                # context. (eager updates execute immediately)
                with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
                    with ops.control_dependencies(update_ops):
                        return self._iterations.assign_add(1).op

            return self._iterations.assign_add(1)
コード例 #11
def _current_graph(op_input_list):
    """Return the graph members of `op_input_list`, or the current graph."""
    # pylint: disable=protected-access
    return ops._get_graph_from_inputs(op_input_list)
コード例 #12
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
  """Implementation of apply_op that returns output_structure, op."""
  op_def = op_def_registry.get(op_type_name)
  if op_def is None:
    raise RuntimeError(f"Unrecognized Op name {op_type_name}")

  # Determine the graph context.
    # Need to flatten all the arguments into a list.
    # pylint: disable=protected-access
    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    # pylint: enable=protected-access
  except AssertionError as e:
    raise RuntimeError(
        f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")

  # Default name if not specified.
  if name is None:
    name = op_type_name

  # Check for deprecation
  deprecation_version = op_def.deprecation.version
  if deprecation_version:
    producer = g.graph_def_versions.producer
    if producer >= deprecation_version:
      raise NotImplementedError(
          f"Op {op_type_name} is not available in GraphDef version {producer}. "
          f"It has been removed in version {deprecation_version}. "

  # Fill in the list of default types for all "type" attrs.  This
  # will be used to choose a preferred dtype to convert to in the
  # absence of input type information.
  # TODO(b/31302892): Currently the defaults don't work in the right
  # way if you have two inputs, one of whose type resolution depends
  # on the other.  Handling this will require restructuring this code
  # significantly.
  default_type_attr_map = {}
  allowed_list_attr_map = {}
  for attr_def in op_def.attr:
    if attr_def.type != "type":
    key = attr_def.name
    if attr_def.HasField("default_value"):
      default_type_attr_map[key] = dtypes.as_dtype(
    if attr_def.HasField("allowed_values"):
      allowed_list_attr_map[key] = attr_def.allowed_values.list.type

  # Requires that op_def has passed validation (using the C++
  # ValidateOpDef() from ../framework/op_def_util.h).
  attrs = {}
  inputs = []
  input_types = []
  with g.as_default(), ops.name_scope(name) as scope:

    # Perform input type inference
    inferred_from = {}
    for input_arg in op_def.input_arg:
      input_name = input_arg.name
      if input_name in keywords:
        values = keywords.pop(input_name)
      elif input_name + "_" in keywords:
        # Handle the case where the name is a keyword or built-in
        # for Python so we use the name + _ instead.
        input_name += "_"
        values = keywords.pop(input_name)
        raise TypeError(f"No argument for input {input_name} found in {op_def}")

      # Goals:
      # * Convert values to Tensors if it contains constants.
      # * Verify that values is a list if that matches the input_arg's
      #   type.
      # * If the input_arg's type is determined by attrs, either set
      #   those attrs and validate those attr values are legal (if
      #   they have not yet been set) or validate the input matches
      #   the type indicated by the attrs (if they have already been
      #   inferred via an earlier input).
      # * If the input_arg has an explicit type, make sure the input
      #   conforms.

      if _IsListParameter(input_arg):
        if not _IsListValue(values):
          raise TypeError(
              f"Expected list for '{input_name}' argument to '{op_type_name}' "
              f"Op, not {values}.")
        # In cases where we expect all elements of the list to have the
        # same dtype, try to cast non-Tensor elements to that type.
        dtype = None
        default_dtype = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.number_attr:
          if input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
            for t in values:
              if isinstance(t, ops.Tensor):
                dtype = t.dtype

          # dtype still not found, prefer using the default dtype
          # from the attr.
          if dtype is None and input_arg.type_attr in default_type_attr_map:
            default_dtype = default_type_attr_map[input_arg.type_attr]

          if not input_arg.is_ref and dtype:
            dtype = dtypes.as_dtype(dtype).base_dtype
          values = ops.internal_convert_n_to_tensor(
              dtype=dtype if dtype else None,
          all_types = set(v.dtype.base_dtype for v in values)
          if input_arg.number_attr and len(all_types) > 1:
            # All types should match.
            raise TypeError(f"Not all types matched for {input_arg.name} for "
                            f"{op_type_name}. Got {all_types}")
        except (TypeError, ValueError):
          # What types does the conversion function think values have?
          observed_types = []
          for value in values:
              converted_value = ops.convert_to_tensor(
                  value, as_ref=input_arg.is_ref)
            except (TypeError, ValueError):
              observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
          observed = ", ".join(observed_types)

          prefix = (
              "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
              (input_name, op_type_name, observed))
          if input_arg.number_attr:
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError(f"{prefix} that do not match expected type "
            elif input_arg.type_attr in attrs:
              raise TypeError(f"{prefix} that do not match type {dtype.name} "
                              "inferred from earlier arguments.")
              raise TypeError(f"{prefix} that don't all match.")
            raise TypeError(f"{prefix} that are invalid. Tensors: {values}")

        types = [x.dtype for x in values]
        # In cases where we have an expected type, try to convert non-Tensor
        # arguments to that type.
        dtype = None
        default_dtype = None
        allowed_list = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.type_attr in attrs:
          dtype = attrs[input_arg.type_attr]
        elif input_arg.type_attr in default_type_attr_map:
          # The dtype could not be inferred solely from the inputs,
          # so we prefer the attr's default, so code that adds a new attr
          # with a default is backwards compatible.
          default_dtype = default_type_attr_map[input_arg.type_attr]
          allowed_list = allowed_list_attr_map.get(input_arg.type_attr)

          # First see if we can get a valid dtype with the default conversion
          # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
          # not list allowed dtypes, in which case we should skip this.
          if dtype is None and allowed_list:
            inferred = None
              inferred = ops.convert_to_tensor(
                  values, name=input_arg.name, as_ref=input_arg.is_ref)
            except TypeError as err:
              # When converting a python object such as a list of Dimensions, we
              # need a dtype to be specified, thus tensor conversion may throw
              # an exception which we will ignore and try again below.

            # If we did not match an allowed dtype, try again with the default
            # dtype. This could be because we have an empty tensor and thus we
            # picked the wrong type.
            if inferred is not None and inferred.dtype in allowed_list:
              values = inferred
              values = ops.convert_to_tensor(
            values = ops.convert_to_tensor(
        except TypeError as err:
          if dtype is None:
            raise err
            raise TypeError(
                f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
                f"'{input_arg.name}' of op '{op_type_name}', got "
                f"{repr(values)} of type '{type(values).__name__}' instead. "
                f"Error: {err}")
        except ValueError:
          # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(
                values, as_ref=input_arg.is_ref).dtype.name
          except ValueError as err:
            raise ValueError(
                f"Tried to convert '{input_name}' to a tensor and failed. "
                f"Error: {err}")
          prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                    (input_name, op_type_name, observed))
          if input_arg.type != types_pb2.DT_INVALID:
            raise TypeError(f"{prefix} expected type of "
            # Update the maps with the default, if needed.
            k = input_arg.type_attr
            if k in default_type_attr_map:
              if k not in attrs:
                attrs[k] = default_type_attr_map[k]
                if k not in inferred_from:
                  inferred_from[k] = "Default in OpDef"

            raise TypeError(
                f"{prefix} type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")

        types = [values.dtype]
      base_types = [x.base_dtype for x in types]

      if input_arg.number_attr:
        # <number-attr> * <type> or <number-attr> * <type-attr>
        if input_arg.number_attr in attrs:
          if len(values) != attrs[input_arg.number_attr]:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} must match length "
                f"{attrs[input_arg.number_attr]} of argument "
          attrs[input_arg.number_attr] = len(values)
          inferred_from[input_arg.number_attr] = input_name
          num_attr = _Attr(op_def, input_arg.number_attr)
          if num_attr.has_minimum and len(values) < num_attr.minimum:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} shorter than minimum length "
        # All tensors must have the same base type.
        if any(bt != base_types[0] for bt in base_types):
          raise TypeError(
              f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
              f"must have the same type. Got {base_types} instead.")
        if input_arg.type != types_pb2.DT_INVALID:
          # <number-attr> * <type> case
          if base_types and base_types[0] != input_arg.type:
            assert False, "Unreachable"
        elif input_arg.type_attr in attrs:
          # <number-attr> * <type-attr> case, where <type-attr> already
          # has an inferred value.
          if base_types and base_types[0] != attrs[input_arg.type_attr]:
            assert False, "Unreachable"
          # <number-attr> * <type-attr> case, where we are now setting
          # the <type-attr> based on this input
          if not base_types:
            # If it's in default_type_attr_map, then wait to set it
            # (in "process remaining attrs", below).
            if input_arg.type_attr not in default_type_attr_map:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  f"list passed to input '{input_name}' of '{op_type_name}' "
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr,
      elif input_arg.type_attr:
        # <type-attr>
        attr_value = base_types[0]
        if input_arg.type_attr in attrs:
          if attrs[input_arg.type_attr] != attr_value:
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type "
                f"{dtypes.as_dtype(attr_value).name} that does not match type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")
          for base_type in base_types:
                                     _Attr(op_def, input_arg.type_attr),
          attrs[input_arg.type_attr] = attr_value
          inferred_from[input_arg.type_attr] = input_name
      elif input_arg.type_list_attr:
        # <type-list-attr>
        attr_value = base_types
        if input_arg.type_list_attr in attrs:
          if attrs[input_arg.type_list_attr] != attr_value:
            actual_types = ", ".join(
                dtypes.as_dtype(x).name for x in attr_value)
            expected_types = ", ".join(
                for x in attrs[input_arg.type_list_attr])
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type list of "
                f"{actual_types} that does not match type list {expected_types}"
                f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
          for base_type in base_types:
                                     _Attr(op_def, input_arg.type_list_attr),
          attrs[input_arg.type_list_attr] = attr_value
          inferred_from[input_arg.type_list_attr] = input_name
        # single Tensor with specified type
        if base_types[0] != input_arg.type:
          assert False, "Unreachable"

      if input_arg.is_ref:
        if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
          raise TypeError(
              f"'{op_type_name}' Op requires that input '{input_name}' be a "
              "mutable tensor (e.g.: a tf.Variable)")

    # Process remaining attrs
    for attr in op_def.attr:
      # Skip attrs that have already had their values inferred
      if attr.name in attrs:
        if attr.name in keywords:
          raise TypeError(
              f"Should not specify value for inferred attr '{attr.name}' for "
      if attr.name in keywords:
        attrs[attr.name] = keywords.pop(attr.name)
      elif attr.name + "_" in keywords:
        # Attrs whose names match Python keywords have an extra '_'
        # appended, so we must check for that as well.
        attrs[attr.name] = keywords.pop(attr.name + "_")
      elif attr.name in default_type_attr_map:
        attrs[attr.name] = default_type_attr_map[attr.name]
        inferred_from.setdefault(attr.name, "Default in OpDef")
        raise TypeError(f"No argument found for attr {attr.name} for "

    # Convert attr values to AttrValue protos.
    attr_protos = {}
    for attr_def in op_def.attr:
      key = attr_def.name
      value = attrs[key]

      if attr_def.HasField("default_value") and value is None:
        attr_value = attr_value_pb2.AttrValue()
        attr_protos[key] = attr_value

      attr_value = value_to_attr_value(value, attr_def.type, key)
      if attr_def.type.startswith("list("):
        _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
      if attr_def.HasField("allowed_values"):
        if attr_def.type == "string":
          _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
        elif attr_def.type == "list(string)":
          for value in attr_value.list.s:
            _SatisfiesAllowedStringsConstraint(value, attr_def, key,
      if attr_def.has_minimum and attr_def.type == "int":
        _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
      if attr_def.type == "type":
        _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
      if attr_def.type == "list(type)":
        for value in attr_value.list.type:
          _SatisfiesTypeConstraint(value, attr_def, key)

      attr_protos[key] = attr_value
    del attrs  # attrs is no longer authoritative, use attr_protos instead

    # Determine output types (possibly using attrs)
    output_structure = []
    for arg in op_def.output_arg:
      if arg.number_attr:
        n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i
      elif arg.type_attr:
        t = _AttrValue(attr_protos, arg.type_attr, op_type_name)
      elif arg.type_list_attr:
        t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name)

    if keywords:
      all_keywords = ", ".join(sorted(keywords.keys()))
      raise TypeError(f"{op_type_name} got unexpected keyword arguments: "

    # NOTE(mrry): We add an explicit colocation constraint between
    # the newly created op and any of its reference-typed inputs.
    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                            if arg.is_ref]
    with _MaybeColocateWith(must_colocate_inputs):
      # Add Op to graph
      # pylint: disable=protected-access
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
                                 name=scope, input_types=input_types,
                                 attrs=attr_protos, op_def=op_def)

    # `outputs` is returned as a separate return value so that the output
    # tensors can the `op` per se can be decoupled so that the
    # `op_callbacks` can function properly. See framework/op_callbacks.py
    # for more details.
    outputs = op.outputs
    # Conditionally invoke tfdbg v2's op callback(s).
    if op_callbacks.should_invoke_op_callbacks():
      callback_outputs = op_callbacks.invoke_op_callbacks(
          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
          op_name=op.name, graph=g)
      if callback_outputs is not None:
        outputs = callback_outputs

    return output_structure, op_def.is_stateful, op, outputs
コード例 #13
ファイル: model.py プロジェクト: CermakM/intect
    def __init__(self,
                 name: str = None,
                 params: dict = None):
        """Initialize the model."""

        if name in self.__model_names:
            # TODO: allow? reuse?
            # raise ValueError("Model name `%s` already exists." % name)

        if name is None:
            # Generate some random name
            name = names.get_first_name(gender='female')
            while name in self.__model_names:
                name = names.get_first_name(gender='female')

        self._name = name

        # noinspection PyProtectedMember
        self._graph = ops._get_graph_from_inputs([inputs, labels])  # pylint: disable=protected-access
        self._graph_context_manager = self._graph.as_default()

        with tf.variable_scope('input_layer'):
            self._x = inputs['x']

        with tf.variable_scope('labels'):
            self._labels = labels

        self._layers = [self._x]
        self._layer_name_dct = dict()
        "Dictionary of ([str]layer_name, [int]count)"

        if params is None:
            params = dict()

        # Configurable and directly accessible properties
        for param, default in [('batch_size', None), ('learning_rate', None)]:
            # Make sure that flags are privileged from yaml arguments
            if not tf.app.flags.FLAGS.is_parsed():
                value = default
                value = tf.app.flags.FLAGS.get_flag_value(param, default)
            if value is None:
                    value = params[param]
                except KeyError:
                    raise ValueError("attribute `{}` cannot be of type {}.".format(param, type(None)))

            self.__setattr__(param, tf.constant(value, tf.float32))

        _decay_dct = params.get('learning_rate_decay', None)
        if isinstance(_decay_dct, list):  # yaml will pass list here, need to unpack
            _decay_dct, = _decay_dct

        self.learning_rate_decay = _decay_dct is not None

        if self.learning_rate_decay:
            self._decay_steps = tf.constant(_decay_dct.get('decay_steps'), tf.float32)
            self._decay_rate = tf.constant(_decay_dct.get('decay_rate'), tf.float32)

            self.learning_rate = tf.train.exponential_decay(learning_rate=self.learning_rate,

        tf.summary.scalar(tensor=self.learning_rate, name='learning_rate')

        self.optimizer = getattr(tf.train, params.get('optimizer', 'AdamOptimizer'))
コード例 #14
ファイル: base.py プロジェクト: woodhaha/tensorflow-1
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
      Output tensor(s).
        scope = kwargs.pop('scope', None)

        # Define a custom getter to override tf.get_variable when creating layer
        # variables. The current custom getter is nested by the variable scope.
        def variable_getter(getter,
            return self._add_variable(name,
                                          getter, **getter_kwargs))

        if not self._built and self._scope is None:
            # If constructed with _scope=None, lazy setting of scope.
            if self._reuse:
                self._scope = next(
                        scope if scope is not None else self._base_name).gen)
                self._scope = next(
                    vs.variable_scope(scope, default_name=self._base_name).gen)
            self._name = self._scope.name

        # Build (if necessary) and call the layer, inside a variable
        # scope.
        with vs.variable_scope(self._scope,
                               reuse=True if self._built else self._reuse,
                               custom_getter=variable_getter) as scope:
            # Ensure the Layer, if being reused, is working with inputs from
            # the same graph as where it was created.
                                           graph=self.graph)  # pylint: disable=protected-access
            except ValueError as e:
                raise ValueError(
                    "Inputs' and Layer's graphs are not the same: %s" % e)

            with ops.name_scope(scope.original_name_scope):
                if not self.built:
                    input_list = [
                        ops.convert_to_tensor(x, name='input')
                        for x in nest.flatten(inputs)
                    input_shapes = [x.get_shape() for x in input_list]
                    if len(input_shapes) == 1:
                    self._built = True
                if 'scope' in tf_inspect.getargspec(self.call).args:
                    kwargs['scope'] = scope
                outputs = self.call(inputs, *args, **kwargs)

                # Apply activity regularization.
                # Note that it should be applied every time the layer creates a new
                # output, since it is output-specific.
                if hasattr(
                        'activity_regularizer') and self.activity_regularizer:
                    output_list = _to_list(outputs)
                    for output in output_list:
                        with ops.name_scope('ActivityRegularizer'):
                            activity_regularization = self.activity_regularizer(

        # Update global default collections.
        _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        return outputs
コード例 #15
    def _distributed_apply(self, distribution, grads_and_vars, name,
        """`apply_gradients` using a `DistributionStrategy`."""
        reduced_grads = self._aggregate_gradients(distribution, grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = zip(reduced_grads, var_list)

        def apply_grad_to_update_var(var, grad):
            """Apply gradient to variable."""
            if isinstance(var, ops.Tensor):
                raise NotImplementedError("Trying to update a Tensor ", var)

            apply_kwargs = {}
            if isinstance(grad, ops.IndexedSlices):
                if var.constraint is not None:
                    raise RuntimeError(
                        "Cannot use a constraint function on a sparse variable."
                if "apply_state" in self._sparse_apply_args:
                    apply_kwargs["apply_state"] = apply_state
                return self._resource_apply_sparse_duplicate_indices(
                    grad.values, var, grad.indices, **apply_kwargs)

            if "apply_state" in self._dense_apply_args:
                apply_kwargs["apply_state"] = apply_state
            update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
            if var.constraint is not None:
                with ops.control_dependencies([update_op]):
                    return var.assign(var.constraint(var))
                return update_op

        eagerly_outside_functions = ops.executing_eagerly_outside_functions()
        update_ops = []
        with ops.name_scope(name or self._name, skip_on_eager=True):
            for grad, var in grads_and_vars:
                # Colocate the update with variables to avoid unnecessary communication
                # delays. See b/136304694.
                with distribution.extended.colocate_vars_with(var):
                    with ops.name_scope("update" if eagerly_outside_functions
                                        else "update_" + var.op.name,
                                args=(grad, ),

            any_symbolic = any(
                isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i)
                for i in update_ops)
            if not context.executing_eagerly() or any_symbolic:
                # If the current context is graph mode or any of the update ops are
                # symbolic then the step update should be carried out under a graph
                # context. (eager updates execute immediately)
                with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
                    with ops.control_dependencies(update_ops):
                        return self._iterations.assign_add(1).op

            return self._iterations.assign_add(1)
コード例 #16
ファイル: base.py プロジェクト: troyddddd/project
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
      Output tensor(s).
        self._set_scope(kwargs.pop('scope', None))

        # Ensure the Layer, if being reused, is working with inputs from
        # the same graph as where it was created.
            ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
        except ValueError as e:
            raise ValueError(
                'Input graph and Layer graph are not the same: %s' % e)

        with vs.variable_scope(self._scope, reuse=self.built
                               or self._reuse) as scope:
            with ops.name_scope(scope.original_name_scope):
                if not self.built:
                    # Check input assumptions set before layer building, e.g. input rank.
                    input_list = [
                        ops.convert_to_tensor(x, name='input')
                        for x in nest.flatten(inputs)
                    input_shapes = [x.get_shape() for x in input_list]
                    if len(input_shapes) == 1:
                if 'scope' in getargspec(self.call).args:
                    kwargs['scope'] = scope
                # Check input assumptions set after layer building, e.g. input shape.
                outputs = self.call(inputs, *args, **kwargs)

                # Apply activity regularization.
                # Note that it should be applied every time the layer creates a new
                # output, since it is output-specific.
                if hasattr(
                        'activity_regularizer') and self.activity_regularizer:
                    output_list = _to_list(outputs)
                    for output in output_list:
                        with ops.name_scope('ActivityRegularizer'):
                            activity_regularization = self.activity_regularizer(

        # Update global default collections.
        _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        self.built = True
        return outputs
コード例 #17
ファイル: base.py プロジェクト: adit-chandra/tensorflow
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

      Output tensor(s).

      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

      ValueError: if the layer's `call` method returns None (an invalid value).
    scope = kwargs.pop('scope', None)

    if self._keras_style:
      if scope is not None:
        raise ValueError(
            'scope argument not allowed when keras style layers are enabled, '
            'but saw: {}'.format(scope))
      return super(Layer, self).__call__(inputs, *args, **kwargs)


    if not context.executing_eagerly():
        # Set layer's "graph" at build time
        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    if self.built:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True, auxiliary_name_scope=False)
        scope_context_manager = self._always_reuse_variable_scope
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse, auxiliary_name_scope=False)

    with scope_context_manager as scope:
      self._current_scope = scope

        call_has_scope_arg = self._call_has_scope_arg
      except AttributeError:
        self._call_fn_args = function_utils.fn_args(self.call)
        self._call_has_scope_arg = 'scope' in self._call_fn_args
        call_has_scope_arg = self._call_has_scope_arg
      if call_has_scope_arg:
        kwargs['scope'] = scope

      # Actually call layer
      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

    if not context.executing_eagerly():
      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
コード例 #18
ファイル: base.py プロジェクト: ashuven63/tf_audio
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

      Output tensor(s).

      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

      ValueError: if the layer's `call` method returns None (an invalid value).
        self._set_scope(kwargs.pop('scope', None))

        if not context.executing_eagerly():
                # Set layer's "graph" at build time
                self._graph = ops._get_graph_from_inputs(
                    nest.flatten(inputs),  # pylint: disable=protected-access
            except ValueError as e:
                raise ValueError(
                    'Input graph and Layer graph are not the same: %s' % e)

        if self.built:
                # Some classes which inherit from Layer do not use its constructor, so
                # rather than initializing to None we check for an AttributeError.
                scope_context_manager = self._always_reuse_variable_scope
            except AttributeError:
                # From this point we will always set reuse=True, so create a "final"
                # variable scope with this setting. We avoid re-creating variable scopes
                # after this point as an optimization.
                self._always_reuse_variable_scope = vs.variable_scope(
                    self._scope, reuse=True, auxiliary_name_scope=False)
                scope_context_manager = self._always_reuse_variable_scope
            scope_context_manager = vs.variable_scope(
                self._scope, reuse=self._reuse, auxiliary_name_scope=False)

        with scope_context_manager as scope:
            self._current_scope = scope

                call_has_scope_arg = self._call_has_scope_arg
            except AttributeError:
                self._call_fn_args = function_utils.fn_args(self.call)
                self._call_has_scope_arg = 'scope' in self._call_fn_args
                call_has_scope_arg = self._call_has_scope_arg
            if call_has_scope_arg:
                kwargs['scope'] = scope

            # Actually call layer
            outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

        if not context.executing_eagerly():
            # Update global default collections.
            _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        return outputs
コード例 #19
  def apply_op(self, op_type_name, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Cannot determine graph for Op '%s' due to: %s"
          % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
      producer = g.graph_def_versions.producer
      if producer >= deprecation_version:
        raise NotImplementedError(
            ("Op %s is not available in GraphDef version %d. "
             "It has been removed in version %d. %s.") %
            (op_type_name, producer, deprecation_version,

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
      if attr_def.type != "type":
      key = attr_def.name
      if attr_def.HasField("default_value"):
        default_type_attr_map[key] = dtypes.as_dtype(

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype

            # dtype still not found, prefer using the default dtype
            # from the attr.
            if dtype is None and input_arg.type_attr in default_type_attr_map:
              default_dtype = default_type_attr_map[input_arg.type_attr]

            if not input_arg.is_ref and dtype:
              dtype = dtypes.as_dtype(dtype).base_dtype
            values = ops.convert_n_to_tensor(
                dtype=dtype if dtype else None,
            if input_arg.number_attr and len(
                set(v.dtype.base_dtype for v in values)) > 1:
              raise TypeError()  # All types should match.
          except (TypeError, ValueError):
            # What types does the conversion function think values have?
            observed_types = []
            for value in values:
                converted_value = ops.convert_to_tensor(
                    value, as_ref=input_arg.is_ref)
              except (TypeError, ValueError):
                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
            observed = ", ".join(observed_types)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.number_attr:
              if input_arg.type != types_pb2.DT_INVALID:
                raise TypeError("%s that do not match expected type %s." %
                                (prefix, dtype.name))
              elif input_arg.type_attr in attrs:
                raise TypeError("%s that do not match type %s inferred from "
                                "earlier arguments." %
                                (prefix, dtype.name))
                raise TypeError("%s that don't all match." % prefix)
              raise TypeError("%s that are invalid." % prefix)

          types = [x.dtype for x in values]
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          elif input_arg.type_attr in default_type_attr_map:
            # The dtype could not be inferred solely from the inputs,
            # so we prefer the attr's default, so code that adds a new attr
            # with a default is backwards compatible.
            default_dtype = default_type_attr_map[input_arg.type_attr]

            values = ops.convert_to_tensor(
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values,
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, dtypes.as_dtype(input_arg.type).name))
              # Update the maps with the default, if needed.
              k = input_arg.type_attr
              if k in default_type_attr_map:
                if k not in attrs:
                  attrs[k] = default_type_attr_map[k]
                  if k not in inferred_from:
                    inferred_from[k] = "Default in OpDef"

              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,

          types = [values.dtype]
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
            for base_type in base_types:
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
                   ", ".join(dtypes.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
            for base_type in base_types:
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_protos[key] = attr_value
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, compat.as_text(attr_value.s),
                   '", "'.join(map(compat.as_text,
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, compat.as_text(x),
                     '", "'.join(map(compat.as_text,
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
              [_MakeTensor(x, key) for x in value])
        elif attr_def.type == "func":
          if isinstance(value, compat.bytes_or_text_types):
            attr_value.func.name = value
            attr_value.func.name = value.name
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
            types = [arg.type] * n
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          types = [arg.type]
        if arg.is_ref:
          types = [dtypes.as_dtype(x).as_ref for x in types]

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # NOTE(mrry): We add an explicit colocation constraint between
      # the newly created op and any of its reference-typed inputs.
      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                              if arg.is_ref]
      with _MaybeColocateWith(must_colocate_inputs):
        # Add Op to graph
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
        if output_structure:
          outputs = op.outputs
          res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
          if isinstance(res, list) and not res and op_def.is_stateful:
            return op
            return res
          return op
コード例 #20
ファイル: variable_scope.py プロジェクト: u4lr451/tensorflow
def variable_op_scope(values,
    """Returns a context manager for defining an op that creates variables.

  This context manager validates that the given `values` are from the
  same graph, ensures that that graph is the default graph, and pushes a
  name scope and a variable scope.

  If `name_or_scope` is not None, it is used as is in the variable scope. If
  `scope` is None, then `default_name` is used.  In that case, if the same name
  has been previously used in the same scope, it will made unique be appending
  `_N` to it.

  This is intended to be used when defining generic ops and so reuse is always

  For example, to define a new Python op called `my_op_with_vars`:

  def my_op_with_vars(a, b, scope=None):
    with tf.variable_op_scope([a, b], scope, "MyOp") as scope:
      a = tf.convert_to_tensor(a, name="a")
      b = tf.convert_to_tensor(b, name="b")
      c = tf.get_variable('c')
      # Define some computation that uses `a`, `b`, and `c`.
      return foo_op(..., name=scope)

    values: The list of `Tensor` arguments that are passed to the op function.
    name_or_scope: The name argument that is passed to the op function,
      this name_or_scope is not uniquified in the variable scope.
    default_name: The default name to use if the `name_or_scope` argument is
      `None`, this name will be uniquified.
    initializer: The  default initializer to pass to variable scope.
    regularizer: The default regularizer for variables within this scope.
    caching_device: The default caching device for variables within this scope.
    reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as
      well as all sub-scopes; if `None`, we just inherit the parent scope reuse.

    A context manager for use in defining a Python op.

    ValueError: when trying to reuse within a create scope, or create within
      a reuse scope, or if reuse is not `None` or `True`.
    TypeError: when the types of some arguments are not appropriate.
    if default_name is None:
        raise TypeError("default_name cannot be None")
    g = ops._get_graph_from_inputs(values)  # pylint: disable=protected-access
    with g.as_default():
        if name_or_scope:
            with variable_scope(name_or_scope,
                                caching_device=caching_device) as vs:
                yield vs
            if reuse:
                raise ValueError(
                    "reuse=True cannot be used without a name_or_scope")
            with ops.name_scope(default_name) as scope:
                count = len(default_name.split("/"))
                scoped_name = "/".join(scope.split("/")[-count - 1:-1])
                with _pure_variable_scope(scoped_name,
                                          caching_device=caching_device) as vs:
                    yield vs