def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). """ self._set_scope(kwargs.pop('scope', None)) # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. try: ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) with vs.variable_scope(self._scope, reuse=self.built or self._reuse) as scope: with ops.name_scope(scope.original_name_scope): if not self.built: # Check input assumptions set before layer building, e.g. input rank. self._assert_input_compatibility(inputs) input_list = [ ops.convert_to_tensor(x, name='input') for x in nest.flatten(inputs)] input_shapes = [x.get_shape() for x in input_list] if len(input_shapes) == 1: self.build(input_shapes[0]) else: self.build(input_shapes) if 'scope' in tf_inspect.getargspec(self.call).args: kwargs['scope'] = scope # Check input assumptions set after layer building, e.g. input shape. self._assert_input_compatibility(inputs) outputs = self.call(inputs, *args, **kwargs) # Apply activity regularization. # Note that it should be applied every time the layer creates a new # output, since it is output-specific. if hasattr(self, 'activity_regularizer') and self.activity_regularizer: output_list = _to_list(outputs) for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self.activity_regularizer(output) self.add_loss(activity_regularization) _add_elements_to_collection( activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES) # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) self.built = True return outputs
def get_graph_from_inputs(op_input_list, graph=None): """Returns the appropriate graph to use for the given inputs. 1. If `graph` is provided, we validate that all inputs in `op_input_list` are from the same graph. 2. Otherwise, we attempt to select a graph from the first Operation- or Tensor-valued input in `op_input_list`, and validate that all other such inputs are in the same graph. 3. If the graph was not specified and it could not be inferred from `op_input_list`, we attempt to use the default graph. Args: op_input_list: A list of inputs to an operation, which may include `Tensor`, `Operation`, and other objects that may be converted to a graph element. graph: (Optional) The explicit graph to use. Raises: TypeError: If `op_input_list` is not a list or tuple, or if graph is not a Graph. ValueError: If a graph is explicitly passed and not all inputs are from it, or if the inputs are from multiple graphs, or we could not find a graph and there was no default graph. Returns: The appropriate graph to use for the given inputs. """ # pylint: disable=protected-access return ops._get_graph_from_inputs(op_input_list, graph)
def variable_op_scope(values, name, default_name, initializer=None, regularizer=None): """Returns a context manager for defining an op that creates variables. This context manager validates that the given `values` are from the same graph, ensures that that graph is the default graph, and pushes a name scope and a variable scope. If `name` is not None, it is used as is in the variable scope. If `name` is None, then `default_name` is used. In that case, if the same name has been previously used in the same scope, it will made unique be appending `_N` to it. This is intended to be used when defining generic ops and so reuse is always inherited. For example, to define a new Python op called `my_op_with_vars`: ```python def my_op_with_vars(a, b, name=None): with tf.variable_op_scope([a, b], name, "MyOp") as scope: a = tf.convert_to_tensor(a, name="a") b = tf.convert_to_tensor(b, name="b") c = tf.get_variable('c') # Define some computation that uses `a`, `b`, and `c`. return foo_op(..., name=scope) ``` Args: values: The list of `Tensor` arguments that are passed to the op function. name: The name argument that is passed to the op function, this name is not uniquified in the variable scope. default_name: The default name to use if the `name` argument is `None`, this name will be uniquified. initializer: A default initializer to pass to variable scope. regularizer: default regularizer for variables within this scope. Returns: A context manager for use in defining a Python op. Raises: ValueError: when trying to reuse within a create scope, or create within a reuse scope, or if reuse is not `None` or `True`. TypeError: when the types of some arguments are not appropriate. """ if default_name is None: raise TypeError("default_name cannot be None") g = ops._get_graph_from_inputs(values) # pylint: disable=protected-access with g.as_default(): if name: with variable_scope(name, initializer=initializer, regularizer=regularizer) as vs: yield vs else: with ops.name_scope(default_name) as scope: count = len(default_name.split("/")) scoped_name = "/".join(scope.split("/")[-count - 1:-1]) with _pure_variable_scope(scoped_name, initializer=initializer, regularizer=regularizer) as vs: yield vs
def graph_zeros_like(tensor): """Graph-only version of tf.zeros_like(), for internal use only.""" g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name: tensor = ops.convert_to_tensor(tensor, name="tensor") dtype = tensor.dtype.base_dtype dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum) op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype], attrs={"T": dtype_value}, name=name) result, = op.outputs return result
def _distributed_apply(self, distribution, grads_and_vars, name): """`apply_gradients` using a `DistributionStrategy`.""" reduced_grads = distribution.extended.batch_reduce_to( ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" if isinstance(var, ops.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable.") return self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices) update_op = self._resource_apply_dense(grad, var) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: return update_op update_ops = [] with backend.name_scope(name or self._name): for grad, var in grads_and_vars: scope_name = ("" if ops.executing_eagerly_outside_functions() else "_" + var.op.name) with backend.name_scope("update" + scope_name): update_ops.extend( distribution.extended.update( var, apply_grad_to_update_var, args=(grad,), group=False)) any_symbolic = any(isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i) for i in update_ops) if not context.executing_eagerly() or any_symbolic: # If the current context is graph mode or any of the update ops are # symbolic then the step update should be carried out under a graph # context. (eager updates execute immediately) with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access with ops.control_dependencies(update_ops): return self._iterations.assign_add(1).op return self._iterations.assign_add(1)
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**, the kwarg 'scope' is reserved for use by the Layer. Returns: Output tensor(s). """ scope = kwargs.pop('scope', None) # Define a custom getter to override tf.get_variable when creating layer # variables. The current custom getter is nested by the variable scope. def variable_getter(getter, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, **getter_kwargs): return self._add_variable( name, shape, initializer=initializer, regularizer=regularizer, dtype=dtype, trainable=trainable, variable_getter=functools.partial(getter, **getter_kwargs)) if not self._built and self._scope is None: # If constructed with _scope=None, lazy setting of scope. if self._reuse: self._scope = next(vs.variable_scope( scope if scope is not None else self._base_name).gen) else: self._scope = next(vs.variable_scope( scope, default_name=self._base_name).gen) self._name = self._scope.name # Build (if necessary) and call the layer, inside a variable # scope. with vs.variable_scope(self._scope, reuse=True if self._built else self._reuse, custom_getter=variable_getter) as scope: # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. try: ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access except ValueError as e: raise ValueError("Inputs' and Layer's graphs are not the same: %s" % e) with ops.name_scope(scope.original_name_scope): if not self.built: input_list = [ ops.convert_to_tensor(x, name='input') for x in nest.flatten(inputs)] input_shapes = [x.get_shape() for x in input_list] if len(input_shapes) == 1: self.build(input_shapes[0]) else: self.build(input_shapes) self._built = True outputs = self.call(inputs, *args, **kwargs) # Apply activity regularization. # Note that it should be applied every time the layer creates a new # output, since it is output-specific. if hasattr(self, 'activity_regularizer') and self.activity_regularizer: output_list = _to_list(outputs) for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self.activity_regularizer(output) self._losses.append(activity_regularization) _add_elements_to_collection( activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES) # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs
def apply_op(self, op_type_name, g=None, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Config proto extensions must be provided via the 'ext' keyword argument. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # If none of the inputs are Tensors and your session doesn't have a # default graph, you will have to specify the graph. op_def_library.apply_op("op", input1=input1, g=g) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. g: The graph context (optional) name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Need to specify g=graph to Op '%s' (could not determine graph due " "to: %s)" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break try: values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" # What types does the conversion function think values have? values = ops.convert_n_to_tensor_or_indexed_slices(values) observed = ", ".join(v.dtype.base_dtype.name for v in values) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % (prefix, types_lib.as_dtype(dtype).name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % (prefix, types_lib.as_dtype(dtype).name)) else: raise TypeError("%s that don't all match." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] try: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor(values).dtype.name prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % (prefix, types_lib.as_dtype(input_arg.type).name)) else: raise TypeError( "%s type %s of argument '%s'." % (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join(types_lib.as_dtype(x).name for x in attr_value), ", ".join(types_lib.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError("Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, attr_value.s, '", "'.join(attr_def.allowed_values.list.s))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, x, '", "'.join(attr_def.allowed_values.list.s))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(t.list.type)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [types_lib.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # Add Op to graph if output_structure: op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) outputs = op.outputs return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs), output_structure) else: return g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def)
def apply_op(self, op_type_name, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Config proto extensions must be provided via the 'ext' keyword argument. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, dtype=dtype if dtype else None, as_ref=input_arg.is_ref) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" # What types does the conversion function think values have? values = ops.convert_n_to_tensor_or_indexed_slices( values, as_ref=input_arg.is_ref) observed = ", ".join(v.dtype.base_dtype.name for v in values) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError( "%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] try: values = ops.convert_to_tensor(values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name prefix = ( "Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype( attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len( values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[ input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join( dtypes.as_dtype(x).name for x in attr_value), ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(attr_value.s), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend( [_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend( [_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": if not isinstance(value, compat.bytes_or_text_types): raise TypeError("Expects a string for the func name") attr_value.func.name = value else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type ] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(t.list.type)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [dtypes.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError( "apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # Add Op to graph if output_structure: op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) outputs = op.outputs return _Restructure( ops.convert_n_to_tensor_or_indexed_slices(outputs), output_structure) else: return g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def)
def _apply_op_helper(self, op_type_name, name=None, **keywords): """Implementation of apply_op that returns output_structure, op.""" # print("KST_TEST_ophelper") # print(op_type_name) op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[ input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.internal_convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append( converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append( "<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError( "%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError( "%s that are invalid. Tensors: %s" % (prefix, values)) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[ input_arg.type_attr] try: values = ops.internal_convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( "Expected %s passed to parameter '%s' of op '%s', got %s of " "type '%s' instead. Error: %s" % (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, repr(values), type(values).__name__, err)) except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.internal_convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( "Tried to convert '%s' to a tensor and failed. Error: %s" % (input_name, err)) prefix = ( "Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype( attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len( values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[ input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr( op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join( dtypes.as_dtype(x).name for x in attr_value), ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError(( "'%s' Op requires that input '%s' be a mutable tensor " "(e.g.: a tf.Variable)") % (op_type_name, input_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(attr_value.s), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend( [_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend( [_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": attr_value.func.CopyFrom(_MakeFunc(value, key)) elif attr_def.type == "list(func)": attr_value.list.func.extend( [_MakeFunc(x, key) for x in value]) else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_structure = [] for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) output_structure.append(len(t.list.type)) else: output_structure.append(None) if keywords: raise TypeError( "apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [ val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref ] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph op = g.create_op(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) return output_structure, op_def.is_stateful, op
def _distributed_apply(distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" reduced_grads = distribution.extended.batch_reduce_to( ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" if isinstance(var, ops.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) apply_kwargs = {} if not isinstance(var, de.TrainableWrapper): if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable." ) if "apply_state" in self._sparse_apply_args: apply_kwargs["apply_state"] = apply_state return self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) if "apply_state" in self._dense_apply_args: apply_kwargs["apply_state"] = apply_state update_op = self._resource_apply_dense(grad, var, **apply_kwargs) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: return update_op else: with ops.colocate_with(None, ignore_existing=True): _slots = [ self.get_slot(var, _s) for _s in self.get_slot_names() ] # Add the optimizer slots to restricting list. if var.params.restrict_policy is not None: var.params.restrict_policy._track_optimizer_slots( _slots) with ops.control_dependencies([grad]): _before = [var.read_value() ] + [_s.read_value() for _s in _slots] if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable." ) if "apply_state" in self._sparse_apply_args: apply_kwargs["apply_state"] = apply_state with ops.control_dependencies(_before): _apply_op = self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) with ops.control_dependencies([_apply_op]): _after = control_flow_ops.group( [var.update_op()] + [_s.update_op() for _s in _slots]) return _after if "apply_state" in self._dense_apply_args: apply_kwargs["apply_state"] = apply_state with ops.control_dependencies(_before): update_op = self._resource_apply_dense( grad, var, **apply_kwargs) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: with ops.control_dependencies([update_op]): _after = control_flow_ops.group( [var.update_op()] + [_s.update_op() for _s in _slots]) return _after update_ops = [] with backend.name_scope(name or self._name): for grad, var in grads_and_vars: scope_name = ("update" if ops.executing_eagerly_outside_functions() else "update_" + var.op.name) # Colocate the update with variables to avoid unnecessary communication # delays. See b/136304694. with backend.name_scope( scope_name), distribution.extended.colocate_vars_with( var): update_ops.extend( distribution.extended.update(var, apply_grad_to_update_var, args=(grad, ), group=False)) any_symbolic = any( isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i) for i in update_ops) if not context.executing_eagerly() or any_symbolic: # If the current context is graph mode or any of the update ops are # symbolic then the step update should be carried out under a graph # context. (eager updates execute immediately) with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access with ops.control_dependencies(update_ops): return self._iterations.assign_add(1).op return self._iterations.assign_add(1)
def _current_graph(op_input_list): """Return the graph members of `op_input_list`, or the current graph.""" # pylint: disable=protected-access return ops._get_graph_from_inputs(op_input_list)
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError(f"Unrecognized Op name {op_type_name}") # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( f"Op {op_type_name} is not available in GraphDef version {producer}. " f"It has been removed in version {deprecation_version}. " f"{op_def.deprecation.explanation}.") # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} allowed_list_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) if attr_def.HasField("allowed_values"): allowed_list_attr_map[key] = attr_def.allowed_values.list.type # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError(f"No argument for input {input_name} found in {op_def}") # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( f"Expected list for '{input_name}' argument to '{op_type_name}' " f"Op, not {values}.") # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) all_types = set(v.dtype.base_dtype for v in values) if input_arg.number_attr and len(all_types) > 1: # All types should match. raise TypeError(f"Not all types matched for {input_arg.name} for " f"{op_type_name}. Got {all_types}") except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} that do not match expected type " f"{dtype.name}.") elif input_arg.type_attr in attrs: raise TypeError(f"{prefix} that do not match type {dtype.name} " "inferred from earlier arguments.") else: raise TypeError(f"{prefix} that don't all match.") else: raise TypeError(f"{prefix} that are invalid. Tensors: {values}") types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None allowed_list = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] allowed_list = allowed_list_attr_map.get(input_arg.type_attr) try: # First see if we can get a valid dtype with the default conversion # and see if it matches an allowed dtypes. Some ops like ConcatV2 may # not list allowed dtypes, in which case we should skip this. if dtype is None and allowed_list: inferred = None try: inferred = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref) except TypeError as err: # When converting a python object such as a list of Dimensions, we # need a dtype to be specified, thus tensor conversion may throw # an exception which we will ignore and try again below. pass # If we did not match an allowed dtype, try again with the default # dtype. This could be because we have an empty tensor and thus we # picked the wrong type. if inferred is not None and inferred.dtype in allowed_list: values = inferred else: values = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) else: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( f"Expected {dtypes.as_dtype(dtype).name} passed to parameter " f"'{input_arg.name}' of op '{op_type_name}', got " f"{repr(values)} of type '{type(values).__name__}' instead. " f"Error: {err}") except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( f"Tried to convert '{input_name}' to a tensor and failed. " f"Error: {err}") prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} expected type of " f"{dtypes.as_dtype(input_arg.type).name}.") else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( f"{prefix} type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} must match length " f"{attrs[input_arg.number_attr]} of argument " f"'{inferred_from[input_arg.number_attr]}'.") else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} shorter than minimum length " f"{num_attr.minimum}.") # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError( f"All tensors passed to '{input_name}' of '{op_type_name}' Op " f"must have the same type. Got {base_types} instead.") if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: # If it's in default_type_attr_map, then wait to set it # (in "process remaining attrs", below). if input_arg.type_attr not in default_type_attr_map: raise TypeError( "Don't know how to infer type variable from empty input " f"list passed to input '{input_name}' of '{op_type_name}' " "Op.") else: attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type " f"{dtypes.as_dtype(attr_value).name} that does not match type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: actual_types = ", ".join( dtypes.as_dtype(x).name for x in attr_value) expected_types = ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]) raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type list of " f"{actual_types} that does not match type list {expected_types}" f" of argument '{inferred_from[input_arg.type_list_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError( f"'{op_type_name}' Op requires that input '{input_name}' be a " "mutable tensor (e.g.: a tf.Variable)") input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( f"Should not specify value for inferred attr '{attr.name}' for " f"{op_type_name}.") continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") elif attr.name in default_type_attr_map: attrs[attr.name] = default_type_attr_map[attr.name] inferred_from.setdefault(attr.name, "Default in OpDef") else: raise TypeError(f"No argument found for attr {attr.name} for " f"{op_type_name}") # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] if attr_def.HasField("default_value") and value is None: attr_value = attr_value_pb2.AttrValue() attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue attr_value = value_to_attr_value(value, attr_def.type, key) if attr_def.type.startswith("list("): _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) if attr_def.HasField("allowed_values"): if attr_def.type == "string": _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, op_type_name) elif attr_def.type == "list(string)": for value in attr_value.list.s: _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name) if attr_def.has_minimum and attr_def.type == "int": _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name) if attr_def.type == "type": _SatisfiesTypeConstraint(attr_value.type, attr_def, key) if attr_def.type == "list(type)": for value in attr_value.list.type: _SatisfiesTypeConstraint(value, attr_def, key) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_structure = [] for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr, op_type_name) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name) output_structure.append(len(t.list.type)) else: output_structure.append(None) if keywords: all_keywords = ", ".join(sorted(keywords.keys())) raise TypeError(f"{op_type_name} got unexpected keyword arguments: " f"{all_keywords}.") # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs
def __init__(self, inputs, labels, name: str = None, params: dict = None): """Initialize the model.""" if name in self.__model_names: # TODO: allow? reuse? # raise ValueError("Model name `%s` already exists." % name) pass if name is None: # Generate some random name name = names.get_first_name(gender='female') while name in self.__model_names: name = names.get_first_name(gender='female') self._name = name self.__model_names.add(self._name) # noinspection PyProtectedMember self._graph = ops._get_graph_from_inputs([inputs, labels]) # pylint: disable=protected-access self._graph_context_manager = self._graph.as_default() self._graph_context_manager.__enter__() with tf.variable_scope('input_layer'): self._x = inputs['x'] with tf.variable_scope('labels'): self._labels = labels self._layers = [self._x] self._layer_name_dct = dict() "Dictionary of ([str]layer_name, [int]count)" if params is None: params = dict() # Configurable and directly accessible properties for param, default in [('batch_size', None), ('learning_rate', None)]: # Make sure that flags are privileged from yaml arguments if not tf.app.flags.FLAGS.is_parsed(): value = default else: value = tf.app.flags.FLAGS.get_flag_value(param, default) if value is None: try: value = params[param] except KeyError: raise ValueError("attribute `{}` cannot be of type {}.".format(param, type(None))) self.__setattr__(param, tf.constant(value, tf.float32)) _decay_dct = params.get('learning_rate_decay', None) if isinstance(_decay_dct, list): # yaml will pass list here, need to unpack _decay_dct, = _decay_dct self.learning_rate_decay = _decay_dct is not None "bool" if self.learning_rate_decay: self._decay_steps = tf.constant(_decay_dct.get('decay_steps'), tf.float32) self._decay_rate = tf.constant(_decay_dct.get('decay_rate'), tf.float32) self.learning_rate = tf.train.exponential_decay(learning_rate=self.learning_rate, global_step=tf.train.get_global_step(), decay_steps=self._decay_steps, decay_rate=self._decay_rate, ) tf.summary.scalar(tensor=self.learning_rate, name='learning_rate') self.optimizer = getattr(tf.train, params.get('optimizer', 'AdamOptimizer')) "tf.train.Optimizer"
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). """ scope = kwargs.pop('scope', None) # Define a custom getter to override tf.get_variable when creating layer # variables. The current custom getter is nested by the variable scope. def variable_getter(getter, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, **getter_kwargs): return self._add_variable(name, shape, initializer=initializer, regularizer=regularizer, dtype=dtype, trainable=trainable, variable_getter=functools.partial( getter, **getter_kwargs)) if not self._built and self._scope is None: # If constructed with _scope=None, lazy setting of scope. if self._reuse: self._scope = next( vs.variable_scope( scope if scope is not None else self._base_name).gen) else: self._scope = next( vs.variable_scope(scope, default_name=self._base_name).gen) self._name = self._scope.name # Build (if necessary) and call the layer, inside a variable # scope. with vs.variable_scope(self._scope, reuse=True if self._built else self._reuse, custom_getter=variable_getter) as scope: # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. try: ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access except ValueError as e: raise ValueError( "Inputs' and Layer's graphs are not the same: %s" % e) with ops.name_scope(scope.original_name_scope): if not self.built: input_list = [ ops.convert_to_tensor(x, name='input') for x in nest.flatten(inputs) ] input_shapes = [x.get_shape() for x in input_list] if len(input_shapes) == 1: self.build(input_shapes[0]) else: self.build(input_shapes) self._built = True if 'scope' in tf_inspect.getargspec(self.call).args: kwargs['scope'] = scope outputs = self.call(inputs, *args, **kwargs) # Apply activity regularization. # Note that it should be applied every time the layer creates a new # output, since it is output-specific. if hasattr( self, 'activity_regularizer') and self.activity_regularizer: output_list = _to_list(outputs) for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self.activity_regularizer( output) self._losses.append(activity_regularization) _add_elements_to_collection( activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES) # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" reduced_grads = self._aggregate_gradients(distribution, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" if isinstance(var, ops.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) apply_kwargs = {} if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable." ) if "apply_state" in self._sparse_apply_args: apply_kwargs["apply_state"] = apply_state return self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) if "apply_state" in self._dense_apply_args: apply_kwargs["apply_state"] = apply_state update_op = self._resource_apply_dense(grad, var, **apply_kwargs) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: return update_op eagerly_outside_functions = ops.executing_eagerly_outside_functions() update_ops = [] with ops.name_scope(name or self._name, skip_on_eager=True): for grad, var in grads_and_vars: # Colocate the update with variables to avoid unnecessary communication # delays. See b/136304694. with distribution.extended.colocate_vars_with(var): with ops.name_scope("update" if eagerly_outside_functions else "update_" + var.op.name, skip_on_eager=True): update_ops.extend( distribution.extended.update( var, apply_grad_to_update_var, args=(grad, ), group=False)) any_symbolic = any( isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i) for i in update_ops) if not context.executing_eagerly() or any_symbolic: # If the current context is graph mode or any of the update ops are # symbolic then the step update should be carried out under a graph # context. (eager updates execute immediately) with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access with ops.control_dependencies(update_ops): return self._iterations.assign_add(1).op return self._iterations.assign_add(1)
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). """ self._set_scope(kwargs.pop('scope', None)) # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. try: ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access except ValueError as e: raise ValueError( 'Input graph and Layer graph are not the same: %s' % e) with vs.variable_scope(self._scope, reuse=self.built or self._reuse) as scope: with ops.name_scope(scope.original_name_scope): if not self.built: # Check input assumptions set before layer building, e.g. input rank. self._assert_input_compatibility(inputs) input_list = [ ops.convert_to_tensor(x, name='input') for x in nest.flatten(inputs) ] input_shapes = [x.get_shape() for x in input_list] if len(input_shapes) == 1: self.build(input_shapes[0]) else: self.build(input_shapes) if 'scope' in getargspec(self.call).args: kwargs['scope'] = scope # Check input assumptions set after layer building, e.g. input shape. self._assert_input_compatibility(inputs) outputs = self.call(inputs, *args, **kwargs) # Apply activity regularization. # Note that it should be applied every time the layer creates a new # output, since it is output-specific. if hasattr( self, 'activity_regularizer') and self.activity_regularizer: output_list = _to_list(outputs) for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self.activity_regularizer( output) self.add_loss(activity_regularization) _add_elements_to_collection( activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES) # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) self.built = True return outputs
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). Note: - If the layer's `call` method takes a `scope` keyword argument, this argument will be automatically set to the current variable scope. - If the layer's `call` method takes a `mask` argument (as some Keras layers do), its default value will be set to the mask generated for `inputs` by the previous layer (if `input` did come from a layer that generated a corresponding mask, i.e. if it came from a Keras layer with masking support. Raises: ValueError: if the layer's `call` method returns None (an invalid value). """ scope = kwargs.pop('scope', None) if self._keras_style: if scope is not None: raise ValueError( 'scope argument not allowed when keras style layers are enabled, ' 'but saw: {}'.format(scope)) return super(Layer, self).__call__(inputs, *args, **kwargs) self._set_scope(scope) if not context.executing_eagerly(): try: # Set layer's "graph" at build time self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access graph=self._graph) except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) if self.built: try: # Some classes which inherit from Layer do not use its constructor, so # rather than initializing to None we check for an AttributeError. scope_context_manager = self._always_reuse_variable_scope except AttributeError: # From this point we will always set reuse=True, so create a "final" # variable scope with this setting. We avoid re-creating variable scopes # after this point as an optimization. self._always_reuse_variable_scope = vs.variable_scope( self._scope, reuse=True, auxiliary_name_scope=False) scope_context_manager = self._always_reuse_variable_scope else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) with scope_context_manager as scope: self._current_scope = scope try: call_has_scope_arg = self._call_has_scope_arg except AttributeError: self._call_fn_args = function_utils.fn_args(self.call) self._call_has_scope_arg = 'scope' in self._call_fn_args call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: kwargs['scope'] = scope # Actually call layer outputs = super(Layer, self).__call__(inputs, *args, **kwargs) if not context.executing_eagerly(): # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). Note: - If the layer's `call` method takes a `scope` keyword argument, this argument will be automatically set to the current variable scope. - If the layer's `call` method takes a `mask` argument (as some Keras layers do), its default value will be set to the mask generated for `inputs` by the previous layer (if `input` did come from a layer that generated a corresponding mask, i.e. if it came from a Keras layer with masking support. Raises: ValueError: if the layer's `call` method returns None (an invalid value). """ self._set_scope(kwargs.pop('scope', None)) if not context.executing_eagerly(): try: # Set layer's "graph" at build time self._graph = ops._get_graph_from_inputs( nest.flatten(inputs), # pylint: disable=protected-access graph=self._graph) except ValueError as e: raise ValueError( 'Input graph and Layer graph are not the same: %s' % e) if self.built: try: # Some classes which inherit from Layer do not use its constructor, so # rather than initializing to None we check for an AttributeError. scope_context_manager = self._always_reuse_variable_scope except AttributeError: # From this point we will always set reuse=True, so create a "final" # variable scope with this setting. We avoid re-creating variable scopes # after this point as an optimization. self._always_reuse_variable_scope = vs.variable_scope( self._scope, reuse=True, auxiliary_name_scope=False) scope_context_manager = self._always_reuse_variable_scope else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) with scope_context_manager as scope: self._current_scope = scope try: call_has_scope_arg = self._call_has_scope_arg except AttributeError: self._call_fn_args = function_utils.fn_args(self.call) self._call_has_scope_arg = 'scope' in self._call_fn_args call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: kwargs['scope'] = scope # Actually call layer outputs = super(Layer, self).__call__(inputs, *args, **kwargs) if not context.executing_eagerly(): # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs
def apply_op(self, op_type_name, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError("%s that are invalid." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] try: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor(values, as_ref=input_arg.is_ref).dtype.name prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join(dtypes.as_dtype(x).name for x in attr_value), ", ".join(dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError("Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(attr_value.s), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": if isinstance(value, compat.bytes_or_text_types): attr_value.func.name = value else: value.add_to_graph(ops.get_default_graph()) attr_value.func.name = value.name else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(types)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [dtypes.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) if output_structure: outputs = op.outputs res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) if isinstance(res, list) and not res and op_def.is_stateful: return op else: return res else: return op
def variable_op_scope(values, name_or_scope, default_name, initializer=None, regularizer=None, caching_device=None, reuse=None): """Returns a context manager for defining an op that creates variables. This context manager validates that the given `values` are from the same graph, ensures that that graph is the default graph, and pushes a name scope and a variable scope. If `name_or_scope` is not None, it is used as is in the variable scope. If `scope` is None, then `default_name` is used. In that case, if the same name has been previously used in the same scope, it will made unique be appending `_N` to it. This is intended to be used when defining generic ops and so reuse is always inherited. For example, to define a new Python op called `my_op_with_vars`: ```python def my_op_with_vars(a, b, scope=None): with tf.variable_op_scope([a, b], scope, "MyOp") as scope: a = tf.convert_to_tensor(a, name="a") b = tf.convert_to_tensor(b, name="b") c = tf.get_variable('c') # Define some computation that uses `a`, `b`, and `c`. return foo_op(..., name=scope) ``` Args: values: The list of `Tensor` arguments that are passed to the op function. name_or_scope: The name argument that is passed to the op function, this name_or_scope is not uniquified in the variable scope. default_name: The default name to use if the `name_or_scope` argument is `None`, this name will be uniquified. initializer: The default initializer to pass to variable scope. regularizer: The default regularizer for variables within this scope. caching_device: The default caching device for variables within this scope. reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as well as all sub-scopes; if `None`, we just inherit the parent scope reuse. Returns: A context manager for use in defining a Python op. Raises: ValueError: when trying to reuse within a create scope, or create within a reuse scope, or if reuse is not `None` or `True`. TypeError: when the types of some arguments are not appropriate. """ if default_name is None: raise TypeError("default_name cannot be None") g = ops._get_graph_from_inputs(values) # pylint: disable=protected-access with g.as_default(): if name_or_scope: with variable_scope(name_or_scope, reuse=reuse, initializer=initializer, regularizer=regularizer, caching_device=caching_device) as vs: yield vs else: if reuse: raise ValueError( "reuse=True cannot be used without a name_or_scope") with ops.name_scope(default_name) as scope: count = len(default_name.split("/")) scoped_name = "/".join(scope.split("/")[-count - 1:-1]) with _pure_variable_scope(scoped_name, initializer=initializer, regularizer=regularizer, caching_device=caching_device) as vs: yield vs