def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. The resulting tensor is populated with values of type `dtype`, as specified by arguments `value` and (optionally) `shape` (see examples below). The argument `value` can be a constant value, or a list of values of type `dtype`. If `value` is a list, then the length of the list must be less than or equal to the number of elements implied by the `shape` argument (if specified). In the case where the list length is less than the number of elements specified by `shape`, the last element in the list will be used to fill the remaining entries. The argument `shape` is optional. If present, it specifies the dimensions of the resulting tensor. If not present, the shape of `value` is used. If the argument `dtype` is not specified, then the type is inferred from the type of `value`. For example: ```python # Constant 1-D Tensor populated with value list. tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7] # Constant 2-D tensor populated with scalar value -1. tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.] [-1. -1. -1.]] ``` `tf.constant` differs from `tf.fill` in a few ways: * `tf.constant` supports arbitrary constants, not just uniform scalar Tensors like `tf.fill`. * `tf.constant` creates a `Const` node in the computation graph with the exact value at graph construction time. On the other hand, `tf.fill` creates an Op in the graph that is expanded at runtime. * Because `tf.constant` only embeds constant values in the graph, it does not support dynamic shapes based on other runtime Tensors, whereas `tf.fill` does. Args: value: A constant value (or list) of output type `dtype`. dtype: The type of the elements of the resulting tensor. shape: Optional dimensions of resulting tensor. name: Optional name for the tensor. verify_shape: Boolean that enables verification of a shape of values. Returns: A Constant Tensor. Raises: TypeError: if shape is incorrectly specified or unsupported. """ ctx = context.context() if ctx.executing_eagerly(): t = convert_to_eager_tensor(value, ctx, dtype) if shape is None: return t shape = tensor_shape.as_shape(shape) if shape == t.shape: return t if verify_shape: raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape), tuple(t.shape))) num_t = t.shape.num_elements() # TODO(josh11b): Implement shape -> eager tensor conversion. if num_t == shape.num_elements(): return _eager_reshape(t, shape.as_list(), ctx) if num_t == 1: if t.dtype == dtypes.bool: # We don't have a Fill kernel for bool dtype on GPU. So we first run # Fill on CPU and then copy to GPU if needed. with ops.device("/device:CPU:0"): x = _eager_fill(shape.as_list(), t.cpu(), ctx) return _eager_identity(x, ctx) else: return _eager_fill(shape.as_list(), t, ctx) raise TypeError("Eager execution of tf.constant with unsupported shape " "(value has %d elements, shape is %s with %d elements)." % (num_t, shape, shape.num_elements())) g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto( value, dtype=dtype, shape=shape, verify_shape=verify_shape)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) const_tensor = g.create_op( "Const", [], [dtype_value.type], attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0] return const_tensor
def __exit__(self, unused_type, unused_value, unused_traceback): # pylint: disable=protected-access if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Within the automatic control dependency context, the default graph" f" cannot change. Upon entry it was {self._graph}, but on exit it" f" changed to {ops.get_default_graph()}") outer_graph = getattr(self._graph, "outer_graph", None) if outer_graph is not None: self._graph._add_control_dependencies = outer_graph._add_control_dependencies else: self._graph._add_control_dependencies = False self._graph.experimental_acd_manager = None # map from resource tensor to the last op which wrote to it last_write_to_resource = {} # map from resource tensor to the list of reads from it since the last # write or since the beginning of the function. reads_since_last_write_to_resource = collections.defaultdict(list) # CollectiveManager manager_ids within a particular function call should not # be needed outside of that function call. So we keep them separate (though # the general idea of the maps is the same, in the future, we'll need to # correctly thread the control output outside). # Map from collective manager scope to the last op which used it collective_manager_scopes_opened = {} collective_manager_scopes_used = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] first_use_for_res = {} resources_by_op = {} # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_write_to_resource). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() if op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS: # This will add it to self._independent_ops, but also mark it with an # attribute. self.run_independently(op) if op in self._independent_ops: ops_which_must_run.add(op) continue # Ensure stateful ops run. # Read-only ops are added to control outputs if the read value is # consumed. This covers the case when the read value is returned from # the function since that goes through a tf.identity in mark_as_return. if ((op_def_registry.get(op.type) is None) or (op_is_stateful(op) and (op.type not in utils.RESOURCE_READ_OPS or any(output.consumers() for output in op.outputs)))): ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": try: collective_manager_scopes_opened[op.get_attr( "_collective_manager_id")] = op except ValueError: pass # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[ 0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run # TODO(mdan): Don't do this. Write a transform to chains instead. # See core/common_runtime/control_flow_deps_to_chains.cc. if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_write_to_resource: last_write_to_resource[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_write_to_resource. for inp, resource_type in _get_resource_inputs(op): is_read = resource_type == ResourceType.READ_ONLY input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized if input_id in last_write_to_resource: if is_building_function or ( last_write_to_resource[input_id]. _control_flow_context is op._control_flow_context): control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) do_record = (self.record_initial_resource_uses and input_id not in first_use_for_res) if is_read: reads_list = reads_since_last_write_to_resource[input_id] reads_list.append(op) if do_record: # Note: this will track the entire list that # reads_since_last_write_to_resource maintains. Updates to it will # and should be tracked, until the first write is encountered. At # that point, reads_since_last_write_to_resource will contain a new # empty list. This logic relies on that behavior. first_use_for_res[input_id] = reads_list else: control_inputs.update( reads_since_last_write_to_resource[input_id]) reads_since_last_write_to_resource[input_id] = [] last_write_to_resource[input_id] = op if do_record: first_use_for_res[input_id] = [op] if self.record_initial_resource_uses and op_is_stateful(op): if resource_inputs: resources_by_op[op] = tuple(resource_inputs) else: if None not in first_use_for_res: first_use_for_res[None] = [op] resources_by_op[op] = (None, ) if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): if None in last_write_to_resource: op._add_control_input(last_write_to_resource[None]) last_write_to_resource[None] = op # Ensure ordering of collective ops manager_ids = collective_manager_ids_from_op(op) for manager_id in manager_ids: if manager_id in collective_manager_scopes_opened: # Chain this function call if the scope was opened. op._add_control_input( collective_manager_scopes_opened[manager_id]) collective_manager_scopes_opened[manager_id] = op else: # If this op is in a scope not created here, create a chain starting # at this op. if manager_id in collective_manager_scopes_used: op._add_control_input( collective_manager_scopes_used[manager_id]) collective_manager_scopes_used[manager_id] = op if control_inputs and not is_building_function: control_inputs = [ c for c in control_inputs if c._control_flow_context is op._control_flow_context ] op._add_control_inputs(control_inputs) # Record the ops which first use resources touched by "ops which must run". if self.record_initial_resource_uses: first_uses_by_output_ops = {} for op in ops_which_must_run: if op not in resources_by_op: # This may happen with Merge/Switch nodes which are special cased # above. continue for r in resources_by_op[op]: if op not in first_uses_by_output_ops: first_uses_by_output_ops[op] = set() first_uses_by_output_ops[op].update(first_use_for_res[r]) # For each "op which must run", set a private attr indicating the ops that # used the same resources it did. for op in first_uses_by_output_ops: others = [ other.name.encode() for other in first_uses_by_output_ops[op] ] l = attr_value_pb2.AttrValue.ListValue(s=others) # TODO(mdan): Is there a way which doesn't use anonymous attrs? op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l)) # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) control_output_op = None for idx, r in enumerate( nest.flatten(list(self._returned_tensors), expand_composites=True)): if self.ops_which_must_run: updated_ops_which_must_run = [] if r.graph.building_function: # There may be many stateful ops in the graph. Adding them as # control inputs to each function output could create excessive # control edges in the graph. Thus we create an intermediate No-op # to chain the control dependencies between stateful ops and # function outputs. if idx == 0: control_output_op = control_flow_ops.no_op() control_output_op._set_attr( "_acd_function_control_output", attr_value_pb2.AttrValue(b=True)) control_output_op._add_control_inputs( self.ops_which_must_run) updated_ops_which_must_run = [control_output_op] else: updated_ops_which_must_run = [ o for o in self.ops_which_must_run if o._control_flow_context is r.op._control_flow_context ] r.op._add_control_inputs(updated_ops_which_must_run) self.collective_manager_ids_used = collective_manager_scopes_used
def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None, use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile and execute output tensor. In the generated graph the compile op feeds into the execute op and no additional compilation is incurred when running the compile op before the execute op. The compile op returns additional information about the compilation but does not return the compiled program. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU backends. Currently, only supports a default placement (computation is placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ del name inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist(), "computation_shape": device_assignment.computation_shape.tolist() } if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError( "tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata(num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # Add identity ops so even unused inputs are "consumed" by the # computation. This is to avoid orphaned TPUReplicatedInput nodes. # TODO(phawkins): consider instead pruning unused TPUReplicatedInput # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource vscope.set_use_resource(True) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) # If the computation returns `None`, make it an empty tuple. if outputs is None: outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) # Append `no_op` here so that fetching any return value of this function # will trigger TPUExecute node. outputs += (control_flow_ops.no_op(), ) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [ o for o in outputs if isinstance(o, ops.Operation) ] output_tensors = [ o for o in outputs if not isinstance(o, ops.Operation) ] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() host_compute_core = context.HostComputeCore() if host_compute_core: attr_value = attr_value_pb2.AttrValue() attr_value.list.s.extend( [compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [ tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity) ] with ops.control_dependencies([metadata]): if use_tpu: compile_status = tpu_ops.tpu_compilation_result() op = compile_status.op attr_value = attr_value_pb2.AttrValue( s=compat.as_bytes(cluster_name)) op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access else: compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ compile_status, [ control_flow_ops.no_op(name="shard_%d" % i) for i in range(num_replicas) ] ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ compile_status, [[ array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity) ] for replica in xrange(num_replicas)] ]
def patch_dtype(input_node, field_name, output_node): if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT): output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF))
def set_attr_int_list(node, key, value): """Set the node's attr which data type is int list. """ list_value = attr_value_pb2.AttrValue.ListValue(i=value) node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
def generate_output_graph(self, input_graph_def, input_node_map, fuse_op_list): output_graph_def = graph_pb2.GraphDef() skip_list = [] skip_node_name = [] float32_type = dtypes.float32.as_datatype_enum for index, node in enumerate(input_graph_def.node): if index in fuse_op_list: input_node = input_node_map[node.input[0]] if input_node.op == 'QuantizeV2': new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndDequantize" for _, value in enumerate(node.input): new_node.input.append(value) dequantize_node = input_graph_def.node[index + 4] frozen_max_node = input_graph_def.node[index + 2] frozen_min_node = input_graph_def.node[index + 1] new_node.name = dequantize_node.name new_node.input.append(frozen_min_node.name) new_node.input.append(frozen_max_node.name) new_node.attr["T1"].CopyFrom(node.attr['T1']) new_node.attr["T2"].CopyFrom(node.attr['T2']) new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) new_node.attr["Toutput"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) skip_list.append(index + 1) skip_list.append(index + 2) skip_list.append(index + 3) skip_list.append(index + 4) output_graph_def.node.extend( [new_node, frozen_max_node, frozen_min_node]) elif input_node.op == "Requantize": new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndDequantize" for _, value in enumerate(node.input): new_node.input.append(value) dequantize_node = input_graph_def.node[index + 4] frozen_max_node = input_graph_def.node[index + 2] frozen_min_node = input_graph_def.node[index + 1] new_node.name = dequantize_node.name skip_list.append(index + 1) skip_list.append(index + 2) skip_list.append(index + 3) skip_list.append(index + 4) new_node.input.append(frozen_min_node.name) new_node.input.append(frozen_max_node.name) new_node.attr["T1"].CopyFrom(node.attr['T1']) new_node.attr["T2"].CopyFrom(node.attr['T2']) new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) new_node.attr["Toutput"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) output_graph_def.node.extend( [new_node, frozen_max_node, frozen_min_node]) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) elif index in skip_list or node.name in skip_node_name: continue else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) return output_graph_def
def _get_defun_inputs(args, names, structure, flat_shapes=None): """Maps python function args to graph-construction inputs. Args: args: A flat list of user-specified arguments. names: A list of strings with user-specified argument names, same length as `args`. May be `None`, in which case a generic name is used. structure: The original argument list or dictionary. flat_shapes: A flat list of values that are either `None` or instances of `TensorShape`. If provided, then length must match that of `nest.flatten(args)`; and locations where `args` are instances of `Tensor` must have a corresponding `TensorShape` in `flat_shapes`. May be `None`, in which case exact shapes are read directly from the args. Returns: Placeholders with the same structure as `structure`. Raises: RuntimeError: if `flat_shapes` is provided, but `len(flat_shapes) != len(nest.flatten(args))`. RuntimeError: if a shape from `flat_shapes` is not None for an argument that is not a `Tensor`, `TensorSpec`, or `ResourceVariable`. """ func_graph = ops.get_default_graph() function_inputs = [] if names is None: names = [None] * len(args) if flat_shapes is None: shapes_iter = itertools.repeat(None) else: len_flat_args = len(nest.flatten(args)) if len_flat_args != len(flat_shapes): raise RuntimeError( "Length of fully flat shapes (%d) must match that of " "flatten(args) (%d). args: %s, flat_shapes: %s" % (len(flat_shapes), len_flat_args, args, flat_shapes)) shapes_iter = iter(flat_shapes) for arg_value, name in zip(args, names): flattened = nest.flatten(arg_value) tensor_specs = [ arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec) ] specified_names = [arg.name for arg in tensor_specs if arg.name] if specified_names and len(specified_names) < len(tensor_specs): raise ValueError( "If specifying TensorSpec names for nested structures, " "either zero or all names have to be specified.") for arg in flattened: # We have a shape entry for each arg, regadless of whether it's a real # Tensor or not. For non-tensor entries it should be None. shape = next(shapes_iter) if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): if isinstance(arg, tensor_spec.TensorSpec) and arg.name: requested_name = arg.name else: requested_name = name placeholder_shape = shape if shape is not None else arg.shape try: placeholder = graph_placeholder(arg.dtype, placeholder_shape, name=requested_name) except ValueError: # Sometimes parameter names are not valid op names, so fall back to # unnamed placeholders. placeholder = graph_placeholder(arg.dtype, placeholder_shape) if name is not None: # Record the requested/user-specified name in case it's different than # the uniquified name, for validation when exporting signatures. placeholder.op._set_attr( # pylint: disable=protected-access "_user_specified_name", attr_value_pb2.AttrValue( s=compat.as_bytes(requested_name))) function_inputs.append(placeholder) elif isinstance(arg, resource_variable_ops.ResourceVariable): # Capture arg variables to create placeholders for them. These will be # removed as captures after the function is traced (since otherwise we'd # just add it back with a new placeholder when the variable was # referenced). placeholder = func_graph.capture(arg.handle, name=name) placeholder.op._set_attr( # pylint: disable=protected-access "_user_specified_name", attr_value_pb2.AttrValue(s=compat.as_bytes(name))) function_inputs.append(arg) else: if shape is not None: raise RuntimeError( "Expected provided shape override to be None for arg that isn't " "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" % (arg, shape, args)) function_inputs.append(arg) return nest.pack_sequence_as(structure, function_inputs)
def replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None): """Builds a graph operator that runs a replicated TPU computation. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: The name of the operator. Returns: A list of lists of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ if name is None: name = "TPUReplicate" inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist(), "computation_shape": device_assignment.computation_shape.tolist() } if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % ( input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str( [i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() with ops.name_scope(name, "replicate"): # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) context = TPUReplicateContext(name=graph.unique_name("cluster")) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( num_replicas=num_replicas, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators # tagged with an attribute _tpu_replicated_input to identify the # replicated inputs. # pylint: disable=protected-access with graph._attr_scope({"_tpu_replicated_input": attr_value_pb2.AttrValue(b=True)}): computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs)] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource vscope.set_use_resource(True) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors finally: context.Exit() # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity)] with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ control_flow_ops.no_op(name="%s_shard_%d" % (name, i)) for i in range(num_replicas) ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ [array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity)] for replica in xrange(num_replicas) ]
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.inputs[1] parallel_iterations = op.get_attr("parallel_iterations") try: num_original_outputs = while_op.get_attr("_num_original_outputs") except: # pylint: disable=bare-except num_original_outputs = len(while_op.outputs) num_intermediates = len(while_op.outputs) - num_original_outputs grads = [ _preprocess_grad(grad, body_out, while_out) # pylint: disable=g-complex-comprehension for grad, body_out, while_out in zip( grads[:num_original_outputs], body_graph.outputs[:num_original_outputs], while_op.outputs[:num_original_outputs]) ] + [None] * num_intermediates # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) # Do not ingore grads wrt extra outputs when computing higher order # derivatives. while_op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=len(while_op.outputs))) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, *unused_args): return counter < forward_loop_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) util.maybe_propagate_compile_time_consts_in_xla(grad_op) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)
def _GetLayerMatch(match_result): """Populates a layer match object containing ops/tensors for folding BNs. Args: match_result: Matched result from graph matcher Returns: layer_op: Matching conv/fc op prior to batch norm BatchNormMatch: _BatchNormMatch containing all required batch norm parameters. """ moving_mean_tensor = None moving_variance_tensor = None bn_decay_mean_tensor = None bn_decay_var_tensor = None batch_to_space_op = None layer_op = match_result.get_op(layer_pattern) layer_tensor = match_result.get_tensor(layer_pattern) bn_id_op = match_result.get_op(batch_norm_identity_pattern) bn_op = match_result.get_op(batch_norm_pattern) if bn_id_op is None: bn_id_op = bn_op batch_epsilon = bn_op.get_attr('epsilon') # In the MatMul case, the output of batch norm is reshaped back into a # 2D tensor, so the output_tensor is the output of the Reshape op. output_tensor = bn_op.outputs[0] if layer_op.type == 'MatMul': output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) # If the matcher didn't match matmul_bn_output_reshape, there will be # another match for this 'MatMul' later, so we can skip this one. if output_reshape_op is None: return None, None output_tensor = output_reshape_op.outputs[0] # Ensure that the output tensor has consumers, otherwise this is a dangling # node and not a match. if not output_tensor.consumers(): return None, None batch_to_space_op = match_result.get_op(batch_to_space_pattern) input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) beta_tensor = match_result.get_tensor(beta_pattern) # FusedBatchNorm in training is different from that in inference. It takes # empty 'mean' and empty 'variance', and produces the mean and the variance # of the batch. Therefore, when is_training is true, mean_tensor and # variance_tensor point to 1st and 2nd (0-based) output of bn_op, # respectively; when is_training is false, they point to bn_op's inputs. is_training = bn_op.get_attr('is_training') if is_training: # FusedBatchNormGrad doesn't compute gradients of the batch_mean and # batch_variance outputs, so we need to substitute our own custom # gradient. # TODO(suharshs, raghuramank): Find a way to avoid needing this hack. # pylint: disable=protected-access bn_op._set_attr( '_gradient_op_type', attr_value_pb2.AttrValue(s=compat.as_bytes('FoldFusedBatchNormGrad'))) # pylint: enable=protected-access mean_tensor = bn_op.outputs[1] # The batch variance used during forward and backward prop is biased, # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average # calculation, the variance is corrected by the term N/N-1 (Bessel's # correction). The variance tensor read from FuseBatchNorm has Bessel's # correction applied, so we undo it here. scope, sep, _ = bn_op.name.rpartition('/') g = ops.get_default_graph() with g.as_default(), g.name_scope(scope + sep): n = math_ops.cast( array_ops.size(layer_tensor) / array_ops.size(mean_tensor), dtypes.float32) variance_tensor = math_ops.multiply( bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction') # TODO(suharshs): Find a way to get rid of this inner match. for mul_match_result in moving_avg_mul_matcher.match_graph(graph): sub_op = mul_match_result.get_op(moving_average_sub_pattern) if sub_op.inputs[1].name == bn_op.outputs[1].name: # During training: Batch Mean is bn_op.outputs[1] moving_mean_tensor = sub_op.inputs[0] bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern) if sub_op.inputs[1].name == bn_op.outputs[2].name or sub_op.inputs[1].name.split("/")[0] + "/FusedBatchNorm:2" == bn_op.outputs[2].name : # During training: Batch Var is bn_op.outputs[2] moving_variance_tensor = sub_op.inputs[0] bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern) else: mean_tensor = match_result.get_tensor(mean_pattern) variance_tensor = match_result.get_tensor(variance_pattern) return layer_op, _BatchNormMatch( layer_op=layer_op, bn_op=bn_op, output_tensor=output_tensor, input_tensor=input_tensor, weight_tensor=weight_tensor, gamma_tensor=gamma_tensor, beta_tensor=beta_tensor, mean_tensor=mean_tensor, variance_tensor=variance_tensor, moving_mean_tensor=moving_mean_tensor, moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, batch_epsilon=batch_epsilon, batch_to_space_op=batch_to_space_op)
def append_postprocessing_op(frozen_graph_def, max_detections, max_classes_per_detection, nms_score_threshold, nms_iou_threshold, num_classes, scale_values, detections_per_class=100, use_regular_nms=False): """Appends postprocessing custom op. Args: frozen_graph_def: Frozen GraphDef for SSD model after freezing the checkpoint max_detections: Maximum number of detections (boxes) to show max_classes_per_detection: Number of classes to display per detection nms_score_threshold: Score threshold used in Non-maximal suppression in post-processing nms_iou_threshold: Intersection-over-union threshold used in Non-maximal suppression in post-processing num_classes: number of classes in SSD detector scale_values: scale values is a dict with following key-value pairs {y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode centersize boxes detections_per_class: In regular NonMaxSuppression, number of anchors used for NonMaxSuppression per class use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of Fast NMS. Returns: transformed_graph_def: Frozen GraphDef with postprocessing custom op appended TFLite_Detection_PostProcess custom op node has four outputs: detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box locations detection_classes: a float32 tensor of shape [1, num_boxes] with class indices detection_scores: a float32 tensor of shape [1, num_boxes] with class scores num_boxes: a float32 tensor of size 1 containing the number of detected boxes """ new_output = frozen_graph_def.node.add() new_output.op = 'TFLite_Detection_PostProcess' new_output.name = 'TFLite_Detection_PostProcess' new_output.attr['_output_quantized'].CopyFrom( attr_value_pb2.AttrValue(b=True)) new_output.attr['_output_types'].list.type.extend([ types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT ]) new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom( attr_value_pb2.AttrValue(b=True)) new_output.attr['max_detections'].CopyFrom( attr_value_pb2.AttrValue(i=max_detections)) new_output.attr['max_classes_per_detection'].CopyFrom( attr_value_pb2.AttrValue(i=max_classes_per_detection)) new_output.attr['nms_score_threshold'].CopyFrom( attr_value_pb2.AttrValue(f=nms_score_threshold.pop())) new_output.attr['nms_iou_threshold'].CopyFrom( attr_value_pb2.AttrValue(f=nms_iou_threshold.pop())) new_output.attr['num_classes'].CopyFrom( attr_value_pb2.AttrValue(i=num_classes)) new_output.attr['y_scale'].CopyFrom( attr_value_pb2.AttrValue(f=scale_values['y_scale'].pop())) new_output.attr['x_scale'].CopyFrom( attr_value_pb2.AttrValue(f=scale_values['x_scale'].pop())) new_output.attr['h_scale'].CopyFrom( attr_value_pb2.AttrValue(f=scale_values['h_scale'].pop())) new_output.attr['w_scale'].CopyFrom( attr_value_pb2.AttrValue(f=scale_values['w_scale'].pop())) new_output.attr['detections_per_class'].CopyFrom( attr_value_pb2.AttrValue(i=detections_per_class)) new_output.attr['use_regular_nms'].CopyFrom( attr_value_pb2.AttrValue(b=use_regular_nms)) new_output.input.extend([ 'raw_outputs/box_encodings', 'raw_outputs/class_predictions', 'anchors' ]) # Transform the graph to append new postprocessing op input_names = [] output_names = ['TFLite_Detection_PostProcess'] transforms = ['strip_unused_nodes'] transformed_graph_def = TransformGraph(frozen_graph_def, input_names, output_names, transforms) return transformed_graph_def
def While(input_, cond, body, name=None, hostmem=None): r"""output = input; While (Cond(output)) { output = Body(output) }. Args: input_: A list of `Tensor` objects. A list of input tensors whose types are T. cond: . A function takes 'input' and returns a tensor. If the tensor is a scalar of non-boolean, the scalar is converted to a boolean according to the following rule: if the scalar is a numerical value, non-zero means True and zero means False; if the scalar is a string, non-empty means True and empty means False. If the tensor is not a scalar, non-emptiness means True and False otherwise. body: . A function takes a list of tensors and returns another list tensors. Both lists have the same types as specified by T. name: A name for the operation (optional). hostmem: A list of integer. If i is in the list, input[i] is a host memory tensor. Raises: ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` have different signatures. Returns: A list of `Tensor` objects. Has the same type as `input`. A list of output tensors whose types are T. """ if cond.captured_inputs: raise ValueError("While op 'cond' argument must be a function " "without implicitly captured inputs.") if cond.declared_input_types != body.declared_input_types: raise ValueError( "While op 'cond' and 'body' signatures do not match. %r vs %r" % (cond.declared_input_types, body.declared_input_types)) if body.captured_inputs: cond_dtypes = list(body.declared_input_types) + [ t.dtype for t in body.captured_inputs ] @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) def CondWrapper(*args): """A wrapper that handles loop-carried captured inputs.""" return cond(*args[:len(body.declared_input_types)]) ret = gen_functional_ops._while(input_ + body.captured_inputs, CondWrapper, _LoopBodyCaptureWrapper(body), name=name) # Slice off the loop-carried captured inputs. ret = ret[:-len(body.captured_inputs)] else: ret = gen_functional_ops._while(input_, cond, body, name=name) if hostmem: input_attr = attr_value_pb2.AttrValue() input_attr.list.i.extend(hostmem) ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access output_attr = attr_value_pb2.AttrValue() output_attr.list.i.extend(hostmem) ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access return ret
def partitioned_call(args, f, tout=None, executing_eagerly=None, config=None, executor_type=None): """Executes a function while respecting device annotations. Currently, only those functions that execute within the same address space can be executed. Args: args: The arguments of the function, including captured inputs. f: The function to execute; an instance of `_DefinedFunction` or `_EagerDefinedFunction`. tout: a list containing the output dtypes enums; if `None`, inferred from the signature of `f`. executing_eagerly: (Optional) A boolean indicating whether the context is executing eagerly. If `None`, fetched from the global context. config: (Optional) A tensorflow::RewriterConfig proto, serialized. If `None`, all optimizations are disabled. Currently only handled for eager defined functions. executor_type: (Optional) A string for the name of the executor to be used in the function call. If not set, or set to an empty string, the default tensorflow executor will be used. Returns: The list of `Tensor`s returned by invoking `f(args)`. If the function does not return anything, then returns `None` if eager execution is enabled, or the `Operation` if not. """ if tout is None: tout = tuple(x.type for x in f.definition.signature.output_arg) if executing_eagerly is None: executing_eagerly = context.executing_eagerly() if config is None: config = _get_disabled_rewriter_config() if executor_type is None: executor_type = "" if executing_eagerly or len(tout): if f.stateful_ops: outputs = gen_functional_ops.stateful_partitioned_call( args=args, Tout=tout, f=f, config=config, executor_type=executor_type) else: outputs = gen_functional_ops.partitioned_call( args=args, Tout=tout, f=f, config=config, executor_type=executor_type) return outputs if outputs else None # The generated binding returns an empty list for functions that don't # return any Tensors, hence the need to use `create_op` directly. args = [ops.internal_convert_to_tensor(x) for x in args] tin_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( type=[x.dtype.as_datatype_enum for x in args])) tout_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(type=tout)) func_attr = attr_value_pb2.AttrValue(func=attr_value_pb2.NameAttrList( name=f.name)) executor_type_attr = attr_value_pb2.AttrValue( s=compat.as_bytes(executor_type)) # When running in graph mode, the graph and function graphs are optimized # (i.e. run through grappler) per the session options, so we can disable any # eager-specific rewriting. rewriter_config = attr_value_pb2.AttrValue( s=_get_disabled_rewriter_config()) graph = ops.get_default_graph() f.add_to_graph(graph) op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" op = graph.create_op(op_name, args, tout, compute_shapes=False, name="PartitionedFunctionCall", attrs={ "Tin": tin_attr, "Tout": tout_attr, "f": func_attr, "config": rewriter_config, "executor_type": executor_type_attr, }) outputs = op.outputs return outputs if outputs else op
class TestFoldConstant(unittest.TestCase): x_node = node_def_pb2.NodeDef() x_node.name = "placeholder" x_node.op = "Placeholder" input0_node = node_def_pb2.NodeDef() input0_node.name = "input0" input0_node.op = "Const" input0_value = np.float32(np.abs(np.random.randn(4, 3, 2))) input0_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input0_value, input0_value.dtype.type, input0_value.shape))) input1_node = node_def_pb2.NodeDef() input1_node.name = "input1" input1_node.op = "Const" input1_value = np.float32(np.abs(np.random.randn(4, 1, 1))) input1_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input1_value, input1_value.dtype.type, input1_value.shape))) add_node = node_def_pb2.NodeDef() add_node.op = "Add" add_node.name = "add" add_node.input.extend([input0_node.name, input1_node.name]) input2_node = node_def_pb2.NodeDef() input2_node.name = "input2" input2_node.op = "Const" input2_value = np.float32(np.abs(np.random.randn(1))) input2_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input2_value, input2_value.dtype.type, input2_value.shape))) input3_node = node_def_pb2.NodeDef() input3_node.name = "input3" input3_node.op = "Const" input3_value = np.float32(np.abs(np.random.randn(1))) input3_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input3_value, input3_value.dtype.type, input3_value.shape))) switch_node = node_def_pb2.NodeDef() switch_node.name = "switch" switch_node.op = "Switch" input4_node = node_def_pb2.NodeDef() input4_node.name = "input4" input4_node.op = "Const" input4_value = np.float32(np.abs(np.random.randn(1))) input4_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input4_value, input4_value.dtype.type, input4_value.shape))) input4_node.input.extend([switch_node.name]) input5_node = node_def_pb2.NodeDef() input5_node.name = "input5" input5_node.op = "Const" input5_value = np.float32(np.abs(np.random.randn(1))) input5_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input5_value, input5_value.dtype.type, input5_value.shape))) input5_node.input.extend([switch_node.name]) cond_end = node_def_pb2.NodeDef() cond_end.name = "cond" cond_end.op = "Add" cond_end.input.extend([input4_node.name, input5_node.name]) mul_node = node_def_pb2.NodeDef() mul_node.op = "Mul" mul_node.name = "mul" mul_node.input.extend([add_node.name, input3_node.name]) sqrt_node = node_def_pb2.NodeDef() sqrt_node.name = "rsqrt" sqrt_node.op = "Rsqrt" sqrt_node.input.extend([mul_node.name]) relu_node = node_def_pb2.NodeDef() relu_node.op = "Relu" relu_node.name = "relu" relu_node.input.extend([sqrt_node.name]) block_node = node_def_pb2.NodeDef() block_node.name = "block_output" block_node.op = "Add" block_node.input.extend([x_node.name, relu_node.name]) res_node = node_def_pb2.NodeDef() res_node.name = "res_add" res_node.op = "Add" res_node.input.extend([sqrt_node.name, input2_node.name]) end_node = node_def_pb2.NodeDef() end_node.name = "end" end_node.op = "Add" end_node.input.extend([block_node.name, res_node.name]) graph_def = graph_pb2.GraphDef() graph_def.node.extend([ x_node, input0_node, input1_node, input2_node, input3_node, add_node, mul_node, sqrt_node, relu_node, block_node, res_node, end_node ]) def test_fold_constant(self): graph = self.graph_def rewriter = GraphFoldConstantOptimizer(graph) new_graph = rewriter.do_transformation() for node in new_graph.node: assert node.name in [ "placeholder", "block_output", "rsqrt_const", "relu", "res_add_const", "end" ] def test_condition_fold_constant(self): graph_def = graph_pb2.GraphDef() graph_def.node.extend([ self.cond_end, self.input4_node, self.input5_node, self.switch_node ]) rewriter = GraphFoldConstantOptimizer(graph_def) new_graph = rewriter.do_transformation() for node in new_graph.node: assert node.name in ["switch", "cond", "input4", "input5"] def test_slice_int_input(self): graph_def = graph_pb2.GraphDef() index0_node = node_def_pb2.NodeDef() index0_node.name = "index0" index0_node.op = "Const" index0_value = np.array(3).astype(np.int32).reshape(()) index0_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( index0_value, index0_value.dtype.type, index0_value.shape))) index1_node = node_def_pb2.NodeDef() index1_node.name = "index1" index1_node.op = "Const" index1_value = np.array(1).astype(np.int32).reshape(()) index1_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( index1_value, index1_value.dtype.type, index1_value.shape))) minus_node = node_def_pb2.NodeDef() minus_node.name = "sub" minus_node.op = "Sub" minus_node.input.extend([index0_node.name, index1_node.name]) graph_def.node.extend([index0_node, index1_node, minus_node]) rewriter = GraphFoldConstantOptimizer(graph_def) new_graph = rewriter.do_transformation() with tf.compat.v1.Session() as sess: tf.compat.v1.import_graph_def(new_graph)
def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. variable_names_whitelist: The set of variable names to convert (by default, all variables are converted). Returns: GraphDef containing a simplified version of the original. """ # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = extract_sub_graph(input_graph_def, output_node_names) found_variables = {} variable_names = [] variable_dict_names = [] for node in inference_graph.node: if node.op == "Variable": variable_name = node.name if (variable_names_whitelist is not None and variable_name not in variable_names_whitelist): continue variable_dict_names.append(variable_name) variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) logging.info("Froze %d variables." % len(returned_variables)) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) print("Converted %d variables to const ops." % how_many_converted) return output_graph_def
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True, back_prop=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, expand_composites=True) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants, expand_composites=False) signature = nest.map_structure( control_flow_ops._shape_invariant_to_type_spec, loop_vars, list(shape_invariants), expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, list(shape_invariants), expand_composites=False) else: signature = nest.map_structure( type_spec.type_spec_from_value, loop_vars, expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, expand_composites=False) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants signature = ( [tensor_spec.TensorSpec.from_tensor(loop_counter), tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + signature) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies def wrapped_cond(loop_counter, maximum_iterations_arg, *args): """Extra `cond` wrapper that can handle the extra counter loop_var.""" # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. pred = cond(*_pack_sequence_as(orig_loop_vars, args)) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) if maximum_iterations is None: return pred else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, pred) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) cond_graph_captures = object_identity.ObjectIdentitySet( cond_graph.external_captures) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph_captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars, expand_composites=True)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True)) num_original_outputs = len(body_graph.outputs) if back_prop and util.output_all_intermediates(): # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): output_shapes = [t.shape for t in body_graph.outputs] orig_loop_vars_range = slice(first_loop_var_index, first_loop_var_index + num_flattened_outputs) output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] cond_stateful_ops = [ op for op in cond_graph.get_operations() if op._is_stateful ] body_stateful_ops = [ op for op in body_graph.get_operations() if op._is_stateful ] if (cond_stateful_ops or body_stateful_ops): op_fn = gen_functional_ops._while else: op_fn = gen_functional_ops.stateless_while outputs = op_fn( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=scope) # This is needed so we do not compute derivative wrt these extra outputs. outputs[0].op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=num_original_outputs)) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs, expand_composites=True) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _init_from_args(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, dtype=None, constraint=None): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: Ignored. Provided for compatibility with tf.Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. @compatibility(eager) When Eager Execution is enabled, variables are never added to collections. It is not implicitly added to the GLOBAL_VARIABLES or TRAINABLE_VARIABLES collections, and the `collections` argument is ignored. @end_compatibility """ if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") self._trainable = trainable if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ ops.GraphKeys.TRAINABLE_VARIABLES ] self._save_slice_info = None self._in_graph_mode = context.in_graph_mode() with ops.control_dependencies(None): with ops.name_scope( name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access handle_name = ops._name_from_scope_name(name) if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. if self._in_graph_mode: attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope( {"_class": attr}): with ops.name_scope("Initializer"), ops.device( None): initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, graph_mode=self._in_graph_mode) self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) self._graph_shape = initial_value.get_shape() else: initial_value = initial_value() with ops.name_scope("Initializer"): initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, graph_mode=False) self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) self._graph_shape = initial_value.get_shape() # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. else: with ops.name_scope("Initializer"): initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) # pylint: disable=protected-access if (self._in_graph_mode and initial_value is not None and initial_value.op._get_control_flow_context() is not None): raise ValueError( "Initializer for variable %s is from inside a control-flow " "construct, such as a loop or conditional. When creating a " "variable inside a loop or conditional, use a lambda as the " "initializer." % name) # pylint: enable=protected-access self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, graph_mode=self._in_graph_mode) self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) self._graph_shape = initial_value.get_shape() self._initial_value = initial_value if self._in_graph_mode else None self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype self._constraint = constraint if self._in_graph_mode: with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): self._initializer_op = ( gen_resource_variable_ops.assign_variable_op( self._handle, self._build_initializer_expr( initial_value), name=n)) with ops.name_scope("Read"), ops.colocate_with( self._handle): # Manually assign reads to the handle's device to avoid log # messages. with ops.device(self._handle_device): value = self._read_variable_op() self._graph_element = value if caching_device is not None: # Variables may be created in a tf.device() or ops.colocate_with() # context. At the same time, users would expect caching device to # be independent of this context, and/or would not expect the # current device context to be merged with the caching device # spec. Therefore we reset the colocation stack before creating # the cached value. Note that resetting the colocation stack will # also reset the device stack. with ops.colocate_with(None, ignore_existing=True): with ops.device(caching_device): self._cached_value = array_ops.identity( value) else: self._cached_value = None else: gen_resource_variable_ops.assign_variable_op( self._handle, initial_value) self._is_initialized_op = None self._initializer_op = None self._graph_element = None if caching_device: with ops.device(caching_device): self._cached_value = self._read_variable_op() else: self._cached_value = None if context.in_graph_mode(): ops.add_to_collections(collections, self) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
def do_transformation(self): """Removes batch normalization ops by folding them into convolutions. Batch normalization during training has multiple dynamic parameters that are updated, but once the graph is finalized these become constants. That means there's an opportunity to reduce the computations down to a scale and addition, rather than the more expensive multiple ops, and even bake the scaling into the convolution weights. This function identifies the typical pattern of batch normalization subgraphs, and performs the transformation to fold the computations down into a simpler form. It currently only spots batch normalization that's performed by the BatchNormWithGlobalNormalization and FusedBatchNorm ops, and will need to be extended in the future to handle the newer style. Returns: Modified graph with BN ops removed, and modified weights. Raises: ValueError: If the graph is badly formed with duplicate node names. """ cur_graph = GraphAnalyzer() cur_graph.graph = self.model graph_info = cur_graph.parse_graph() target_nodes = cur_graph.query_fusion_pattern_nodes( [["Conv2D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2"), ["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3"]]) for node_combination in target_nodes: matched_node = node_combination[:-1] has_add_op = True if len(node_combination[-1]) == 3 else False conv_node = graph_info[Helper.node_name_from_input(matched_node[0])].node weights_node_name = graph_info[Helper.node_name_from_input( matched_node[0])].node.input[1] weights_node = graph_info[Helper.node_name_from_input(weights_node_name)].node bn_node = graph_info[Helper.node_name_from_input(matched_node[-1])].node if weights_node.op != "Const": self.logger.warning("Didn't find expected conv Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (bn_node.name, weights_node_name)) continue weights = Helper.values_from_const(weights_node) if conv_node.op == "Conv2D": channel_count = weights.shape[3] elif conv_node.op == "DepthwiseConv2dNative": channel_count = weights.shape[2] * weights.shape[3] mean_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("mean_op")]) mean_node = graph_info[mean_node_name].node if mean_node.op != "Const": continue mean_value = Helper.values_from_const(mean_node) if has_add_op: bias_node_name = graph_info[Helper.node_name_from_input( matched_node[1])].node.input[1] bias_node = graph_info[Helper.node_name_from_input(bias_node_name)].node if bias_node.op != "Const": continue if mean_value.shape != (channel_count, ): continue mean_value = mean_value - Helper.values_from_const(bias_node) cur_graph.remove_node(bias_node.name) cur_graph.remove_node(matched_node[1]) if mean_value.shape != (channel_count, ): self.logger.warning("Incorrect shape for mean, found %s, expected %s," " for node %s" % (str(mean_value.shape), str( (channel_count, )), conv_node.name)) continue var_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("var_op")]) var_node = graph_info[var_node_name].node if var_node.op != "Const": continue var_value = Helper.values_from_const(var_node) if var_value.shape != (channel_count, ): continue beta_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("beta_op")]) beta_node = graph_info[beta_node_name].node if beta_node.op != "Const": continue beta_value = Helper.values_from_const(beta_node) if beta_value.shape != (channel_count, ): continue gamma_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("gamma_op")]) gamma_node = graph_info[gamma_node_name].node if gamma_node.op != "Const": continue gamma_value = Helper.values_from_const(gamma_node) if gamma_value.shape != (channel_count, ): continue variance_epsilon_value = bn_node.attr[self.EPSILON_ATTR[bn_node.op]].f if self.scale_after_normalization(bn_node): scale_value = ( (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) * gamma_value) else: scale_value = (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) offset_value = (-mean_value * scale_value) + beta_value if conv_node.op == "Conv2D": original_shape =weights.shape tmp_shape = (original_shape[-1], int(weights.size/original_shape[-1])) tmp_order = [weights.ndim - 1] + [i for i in range(weights.ndim - 1)] scaled_weights = np.copy(weights).transpose(tmp_order).ravel().reshape(tmp_shape) reshape_scale = np.array(scale_value).reshape(len(scale_value), 1) scaled_weights = np.multiply( scaled_weights, reshape_scale).transpose().reshape(original_shape) elif conv_node.op == "DepthwiseConv2dNative": scaled_weights = np.copy(weights) it = np.nditer(scaled_weights, flags=["multi_index"], op_flags=["readwrite"]) channel_multiplier = weights.shape[3] while not it.finished: current_scale = scale_value[it.multi_index[2] * channel_multiplier + it.multi_index[3]] it[0] *= current_scale it.iternext() scaled_weights_node = node_def_pb2.NodeDef() scaled_weights_node.op = "Const" scaled_weights_node.name = weights_node_name + "_bn_offset" scaled_weights_node.attr["dtype"].CopyFrom(weights_node.attr["dtype"]) scaled_weights_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( scaled_weights, weights.dtype.type, weights.shape))) cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name) offset_node = node_def_pb2.NodeDef() offset_node.op = "Const" offset_node.name = conv_node.name + "_bn_offset" offset_node.attr["dtype"].CopyFrom(mean_node.attr["dtype"]) offset_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( offset_value, mean_value.dtype.type, offset_value.shape))) bias_add_node = node_def_pb2.NodeDef() bias_add_node.op = "BiasAdd" bias_add_node.name = bn_node.name bias_add_node.attr["T"].CopyFrom(conv_node.attr["T"]) bias_add_node.attr["data_format"].CopyFrom(conv_node.attr["data_format"]) bias_add_node.input.extend([conv_node.name, offset_node.name]) cur_graph.add_node(offset_node, [], [bias_add_node.name]) cur_graph.add_node(bias_add_node, conv_node.name, graph_info[Helper.node_name_from_input(matched_node[-1])].outputs) cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name) cur_graph.remove_node(weights_node_name) cur_graph.remove_node(mean_node_name) cur_graph.remove_node(var_node_name) cur_graph.remove_node(beta_node_name) cur_graph.remove_node(gamma_node_name) return cur_graph.dump_graph()
def AddOp(self, op): """Create op in XLACompileContext and notifies outer context recursively.""" # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: logging.error( 'Operation of type %s (%s) is not supported in XLA. Execution will ' 'fail if this op is used in the graph. ', op.type, op.name) # TODO(ycao): Automatically disable summaries instead of reporting them. if op.type in _UNSUPPORTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( 'Non-resource Variables are not supported inside XLA computations ' '(operator name: %s)' % op.name) if _XLA_COMPILE_ATTR in op.node_def.attr: raise ValueError( 'XLA compiled computations cannot be nested, (operator ' 'name: %s)' % op.name) op._set_attr(_XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) # Remove any control edges from outer control flow contexts. These may cause # mismatched frame errors. An example is when one of op's inputs is # generated in a different While control flow context. (internal_control_inputs, external_control_inputs) = self._RemoveExternalControlEdges(op) if not op.inputs: # Add a control edge from the control pivot to this op. if not internal_control_inputs: # pylint: disable=protected-access op._add_control_input(self._pivot) # pylint: enable=protected-access else: for index in xrange(len(op.inputs)): x = op.inputs[index] real_x = self.AddValue(x) if real_x != x: op._update_input(index, real_x) # pylint: disable=protected-access if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. with ops.control_dependencies(None): self.Enter() external_control_inputs = [ array_ops.identity(x.outputs[0]).op for x in external_control_inputs if x.outputs ] self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access # Mark op's outputs as seen by this context and any outer contexts. output_names = [x.name for x in op.outputs] context = self while context is not None: # pylint: disable=protected-access context._values.update(output_names) context = context._outer_context # pylint: enable=protected-access if self._outer_context: self._outer_context.AddInnerOp(op)
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_tensor_names: A list of the nodes we use as inputs. output_tensor_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. and a map containing the old input names to the new input names Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_tensor_names: if ":" not in name: raise ValueError("Input '%s' appears to refer to a Operation, " "not a Tensor." % name) old2new = {} # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_tensor_names} input_node_names = {name.split(":")[0] for name in input_tensor_names} output_node_names = list( {name.split(":")[0] for name in output_tensor_names}) inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name not in input_node_names: for i in range(len(node.input)): if _append_port(node.input[i]) in input_tensor_names: old_name = _append_port(node.input[i]) not_found.remove(old_name) new_input_name = node.input[i].replace(":", "_") placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = new_input_name if isinstance(placeholder_type_enum, list): input_node_index = input_tensor_names.index(old_name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum[input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) node.input[i] = new_input_name old2new[old_name] = new_input_name + ":0" inputs_replaced_graph_def.node.extend([placeholder_node]) inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError("The following input nodes were not found: %s\n" % not_found) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def, old2new
def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None, use_fp16=False): from tensorflow.python.framework.graph_util_impl import extract_sub_graph from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import tensor_util def patch_dtype(input_node, field_name, output_node): if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT): output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF)) inference_graph = extract_sub_graph(input_graph_def, output_node_names) variable_names = [] variable_dict_names = [] for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or (variable_names_blacklist is not None and variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] if use_fp16 and dtype.type == types_pb2.DT_FLOAT: output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(data.astype('float16'), dtype=types_pb2.DT_HALF, shape=data.shape))) else: output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables): # placeholder nodes # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"])) output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: # mostly op nodes output_node.CopyFrom(input_node) patch_dtype(input_node, 'dtype', output_node) patch_dtype(input_node, 'T', output_node) patch_dtype(input_node, 'DstT', output_node) patch_dtype(input_node, 'SrcT', output_node) patch_dtype(input_node, 'Tparams', output_node) if use_fp16 and ('value' in output_node.attr) and ( output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT): # hard-coded value need to be converted as well output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( output_node.attr['value'].tensor.float_val[0], dtype=types_pb2.DT_HALF))) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) return output_graph_def
def set_attr_dtype(node, key, value): try: node.attr[key].CopyFrom( attr_value_pb2.AttrValue(type=value.as_datatype_enum)) except KeyError: pass
def set_attr_dtype(node, key, value): """Set the attribute data type """ node.attr[key].CopyFrom(attr_value_pb2.AttrValue(type=value.as_datatype_enum))
def set_attr_int_list(node, key, value): list_value = attr_value_pb2.AttrValue.ListValue(i=value) try: node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) except KeyError: pass
def set_attr_bool(node, key, value): """Set the node's attr which data type is bool. """ node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
def set_attr_float(node, key, value): try: node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value)) except KeyError: pass
def AddOp(self, op): # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: logging.error( "Operation of type %s (%s) is not supported on the TPU. " "Execution will fail if this op is used in the graph. " % (op.type, op.name)) if op.type in _UNSUPPORTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " "(operator name: %s)" % op.name) if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op._set_attr(_TPU_REPLICATE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) if self._outside_compilation_cluster: op._set_attr( _OUTSIDE_COMPILATION_ATTR, attr_value_pb2.AttrValue( s=compat.as_bytes(self._outside_compilation_cluster))) if self._num_replicas > 1 or not self._outside_compilation_cluster: # Prevent feeding or fetching anything that is being compiled, # and any replicated outside_compilation Op. op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) # Remove any control edges from outer control flow contexts. These may cause # mismatched frame errors. (internal_control_inputs, external_control_inputs) = self._RemoveExternalControlEdges(op) if not op.inputs: # Add a control edge from the control pivot to this op. if not internal_control_inputs: # pylint: disable=protected-access op._add_control_input(self.GetControlPivot()) # pylint: enable=protected-access else: for index in xrange(len(op.inputs)): x = op.inputs[index] real_x = self.AddValue(x) if real_x != x: op._update_input(index, real_x) # pylint: disable=protected-access if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. with ops.control_dependencies(None): self.Enter() external_control_inputs = [ array_ops.identity(x.outputs[0]).op for x in external_control_inputs if x.outputs ] self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access # Mark op's outputs as seen by this context and any outer contexts. output_names = [x.name for x in op.outputs] context = self while context is not None: # pylint: disable=protected-access context._values.update(output_names) context = context._outer_context # pylint: enable=protected-access if self._outer_context: self._outer_context.AddInnerOp(op)
def _init_from_args( self, initial_value=None, trainable=None, collections=None, caching_device=None, name=None, dtype=None, constraint=None, synchronization=None, aggregation=None, distribute_strategy=None, shape=None, ): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. Defaults to `True`, unless `synchronization` is set to `ON_READ`, in which case it defaults to `False`. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. distribute_strategy: DistributionStrategy under which this variable was created. shape: (optional) The shape of this variable. If None, the shape of `initial_value` will be used. When setting this argument to `tf.TensorShape(None)` (representing an unspecified shape), the variable can be assigned with values of different shapes. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. @compatibility(eager) When Eager Execution is enabled, variables are never added to collections. It is not implicitly added to the `GLOBAL_VARIABLES` or `TRAINABLE_VARIABLES` collections, and the `collections` argument is ignored. @end_compatibility """ ( synchronization, aggregation, trainable, ) = variables.validate_synchronization_aggregation_trainable( synchronization, aggregation, trainable, name) if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if (isinstance(initial_value, ops.Tensor) and hasattr(initial_value, "graph") and initial_value.graph.building_function): raise ValueError( "Tensor-typed variable initializers must either be " "wrapped in an init_scope or callable " "(e.g., `tf.Variable(lambda : " "tf.truncated_normal([10, 40]))`) when building " "functions. Please file a feature request if this " "restriction inconveniences you.") if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ ops.GraphKeys.TRAINABLE_VARIABLES ] with ops.init_scope(): self._in_graph_mode = not context.executing_eagerly() with ops.name_scope( name, "TrainableWrapper", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access handle_name = ops.name_from_scope_name(name) handle_name = handle_name or "TrainableWrapperHandle" if self._in_graph_mode: shared_name = handle_name unique_id = shared_name else: # When in eager mode use a uid for the shared_name, to prevent # accidental sharing. unique_id = "%s_%d" % (handle_name, ops.uid()) shared_name = None # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. device_context_manager = (ops.device if self._in_graph_mode else ops.NullContextmanager) attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), device_context_manager( None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype, ) if shape is None: shape = initial_value.shape handle = resource_variable_ops.eager_safe_variable_handle( initial_value=initial_value, shape=None, # shape, shared_name=shared_name, name=name, graph_mode=self._in_graph_mode, ) # pylint: disable=protected-access if (self._in_graph_mode and initial_value is not None and initial_value.op._get_control_flow_context() is not None): raise ValueError( "Initializer for variable %s is from inside a control-flow " "construct, such as a loop or conditional. When creating a " "variable inside a loop or conditional, use a lambda as the " "initializer." % name) # pylint: enable=protected-access dtype = initial_value.dtype.base_dtype if self._in_graph_mode: with ops.name_scope("IsInitialized"): is_initialized_op = (gen_resource_variable_ops. var_is_initialized_op(handle)) if initial_value is not None: # pylint: disable=g-backslash-continuation with ops.name_scope("Assign") as n, ops.colocate_with( None, ignore_existing=True), ops.device( handle.device): # pylint: disable=protected-access initializer_op = gen_resource_variable_ops.assign_variable_op( handle, variables. _try_guard_against_uninitialized_dependencies( name, initial_value), name=n, ) # pylint: enable=protected-access # pylint: enable=g-backslash-continuation with ops.name_scope("Read"): # Manually assign reads to the handle's device to avoid log # messages. with ops.device(handle.device): with ops.control_dependencies([ gen_resource_variable_ops. assign_variable_op( handle, self.prefetch_values(), name="AssignBeforeInitRead", ) ]): value = gen_resource_variable_ops.read_variable_op( handle, dtype) graph_element = value if caching_device is not None: # Variables may be created in a tf.device() or ops.colocate_with() # context. At the same time, users would expect caching device to # be independent of this context, and/or would not expect the # current device context to be merged with the caching device # spec. Therefore we reset the colocation stack before creating # the cached value. Note that resetting the colocation stack will # also reset the device stack. with ops.colocate_with(None, ignore_existing=True): with ops.device(caching_device): cached_value = array_ops.identity(value) else: cached_value = None else: gen_resource_variable_ops.assign_variable_op( handle, initial_value) is_initialized_op = None initializer_op = None graph_element = None if caching_device: with ops.device(caching_device): with ops.control_dependencies([ gen_resource_variable_ops. assign_variable_op( handle, self.prefetch_values(), name="AssignBeforeInitRead", ) ]): cached_value = ( gen_resource_variable_ops.read_variable_op( handle, dtype)) else: cached_value = None if not context.executing_eagerly(): # Eager variables are only added to collections if they are part of an # eager variable store (otherwise in an interactive session they would # hog memory and cause OOM). This is done in ops/variable_scope.py. ops.add_to_collections(collections, self) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) initial_value = initial_value if self._in_graph_mode else None super(resource_variable_ops.ResourceVariable, self).__init__( trainable=trainable, shape=shape, dtype=dtype, handle=handle, synchronization=synchronization, constraint=constraint, aggregation=aggregation, distribute_strategy=distribute_strategy, name=name, unique_id=unique_id, handle_name=handle_name, graph_element=graph_element, initial_value=initial_value, initializer_op=initializer_op, is_initialized_op=is_initialized_op, cached_value=cached_value, )
def generate_output_graph(input_graph_def, input_node_map, output_node_map, fuse_op_list, fuse_op_deq_list): output_graph_def = graph_pb2.GraphDef() skip_list = [] skip_node_name = [] int8_type = dtypes.qint8.as_datatype_enum uint8_type = dtypes.quint8.as_datatype_enum float32_type = dtypes.float32.as_datatype_enum qint32_type = dtypes.qint32.as_datatype_enum for index, node in enumerate(input_graph_def.node): if index in fuse_op_list: const_node_1 = input_graph_def.node[index + 1] const_node_2 = input_graph_def.node[index + 2] requantize_node = input_graph_def.node[index + 3] new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndRequantize" new_node.name = requantize_node.name for _, value in enumerate(node.input): new_node.input.append(value) new_node.input.append(const_node_1.name) new_node.input.append(const_node_2.name) new_node.attr["Tinput"].CopyFrom(node.attr['Tinput']) new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter']) new_node.attr["strides"].CopyFrom(node.attr['strides']) new_node.attr["padding"].CopyFrom(node.attr['padding']) if input_node_map[new_node.input[0]].op.find("Requantize") != -1: bias_node = input_node_map[new_node.input[2]] last_node = input_node_map[new_node.input[0]] max_input_node = (input_node_map[last_node.input[4][:-2]]) min_input_node = (input_node_map[last_node.input[3][:-2]]) max_filter = input_node_map[new_node.input[6]] min_filter = input_node_map[new_node.input[5]] min_input = (min_input_node.attr['value'].tensor.float_val)[0] max_input = (max_input_node.attr['value'].tensor.float_val)[0] if 'Depthwise' in node.op or "RequantizePerChannel" in [ node.op for node in output_node_map[node.name] ]: channel_size = max_filter.attr[ 'value'].tensor.tensor_shape.dim[0].size max_filter_tensor = tensor_util.MakeNdarray( max_filter.attr['value'].tensor) min_filter_tensor = tensor_util.MakeNdarray( min_filter.attr['value'].tensor) else: channel_size = 1 max_filter_tensor = [] min_filter_tensor = [] max_filter_tensor.append( (max_filter.attr['value'].tensor.float_val)[0]) min_filter_tensor.append( (min_filter.attr['value'].tensor.float_val)[0]) bias_tensor = tensor_util.MakeNdarray( input_node_map[new_node.input[2]].attr['value'].tensor) bias_length = bias_tensor.shape[0] scales = [] for i in range(channel_size): scales.append(255.0 * 127.0 / (max(abs(max_input), abs(min_input)) * max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i])))) int32_bias = [] if channel_size > 1: for i in range(bias_length): int32_bias.append((int)(bias_tensor[i] * scales[i])) else: for i in range(bias_length): int32_bias.append((int)(bias_tensor[i] * scales[0])) bias_node.attr['dtype'].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) bias_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( int32_bias, dtypes.int32, bias_tensor.shape))) bias_node.attr['value'].tensor.dtype = qint32_type skip_node_name.append(bias_node.name) output_graph_def.node.extend([bias_node]) new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) else: new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) if "padding_list" in node.attr: new_node.attr["padding_list"].CopyFrom( node.attr['padding_list']) if "dilations" in node.attr: new_node.attr["dilations"].CopyFrom(node.attr['dilations']) if node.op == "QuantizedConv2D" or node.op == "QuantizedConv2DWithBias": new_node.attr["out_type"].CopyFrom( attr_value_pb2.AttrValue(type=int8_type)) else: new_node.attr["out_type"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) skip_list.append(index + 1) skip_list.append(index + 2) skip_list.append(index + 3) output_graph_def.node.extend( [new_node, const_node_1, const_node_2]) elif index in skip_list or node.name in skip_node_name: continue elif node.op == "Dequantize": new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) new_node.attr["mode"].s = b"SCALED" p_node = input_node_map[new_node.input[0]] pp_node = input_node_map[p_node.name].input[0] if input_node_map[pp_node].op.find("Relu") != -1 or p_node.op in ( "QuantizedAvgPool", "QuantizedMaxPool", "QuantizedConcatV2"): new_node.attr["T"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) elif input_node_map[pp_node].op.find( "QuantizedMatMulWithBias") != -1 and p_node.op.find( "Requantize") != -1: new_node.attr["mode"].s = node.attr["mode"].s new_node.attr["T"].CopyFrom( attr_value_pb2.AttrValue(type=node.attr["T"].type)) else: new_node.attr["T"].CopyFrom( attr_value_pb2.AttrValue(type=int8_type)) output_graph_def.node.extend([new_node]) elif index in fuse_op_deq_list: original_summand_node = input_node_map[ input_graph_def.node[index].input[-1]] sum_const_node_1 = input_graph_def.node[index + 1] sum_const_node_2 = input_graph_def.node[index + 2] sum_requantize_node = input_graph_def.node[index + 3] new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndRequantize" new_node.name = sum_requantize_node.name for _, value in enumerate(node.input[:-1]): new_node.input.append(value) new_node.input.append(sum_const_node_1.name) new_node.input.append(sum_const_node_2.name) new_node.input.append( input_node_map[original_summand_node.name].input[0]) new_node.input.append( input_node_map[original_summand_node.name].input[0] + ":1") new_node.input.append( input_node_map[original_summand_node.name].input[0] + ":2") # skip_list.append(index + 1) # skip_list.append(index + 2) skip_list.append(index + 3) new_node.attr["Tinput"].CopyFrom(node.attr['Tinput']) new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter']) new_node.attr["strides"].CopyFrom(node.attr['strides']) new_node.attr["padding"].CopyFrom(node.attr['padding']) if input_node_map[new_node.input[0]].op.find("Requantize") != -1: bias_node = input_node_map[new_node.input[2]] last_node = input_node_map[new_node.input[0]] max_input_node = (input_node_map[last_node.input[4][:-2]]) min_input_node = (input_node_map[last_node.input[3][:-2]]) max_filter = input_node_map[new_node.input[6]] min_filter = input_node_map[new_node.input[5]] min_input = (min_input_node.attr['value'].tensor.float_val)[0] max_input = (max_input_node.attr['value'].tensor.float_val)[0] if "RequantizePerChannel" in [ node.op for node in output_node_map[node.name] ]: channel_size = max_filter.attr[ 'value'].tensor.tensor_shape.dim[0].size max_filter_tensor = tensor_util.MakeNdarray( max_filter.attr['value'].tensor) min_filter_tensor = tensor_util.MakeNdarray( min_filter.attr['value'].tensor) else: channel_size = 1 max_filter_tensor = [] min_filter_tensor = [] max_filter_tensor.append( (max_filter.attr['value'].tensor.float_val)[0]) min_filter_tensor.append( (min_filter.attr['value'].tensor.float_val)[0]) bias_tensor = (tensor_util.MakeNdarray( input_node_map[new_node.input[2]].attr['value'].tensor)) bias_length = bias_tensor.shape[0] scales = [] for i in range(channel_size): scales.append(255.0 * 127.0 / (max(abs(max_input), abs(min_input)) * max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i])))) int32_bias = [] if channel_size > 1: for i in range(bias_length): int32_bias.append(int(bias_tensor[i] * scales[i])) else: for i in range(bias_length): int32_bias.append(int(bias_tensor[i] * scales[0])) bias_node.attr['dtype'].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) bias_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( int32_bias, dtypes.int32, bias_tensor.shape))) bias_node.attr['value'].tensor.dtype = qint32_type new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) skip_node_name.append(bias_node.name) output_graph_def.node.extend([bias_node]) else: new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) if "padding_list" in node.attr: new_node.attr["padding_list"].CopyFrom( node.attr['padding_list']) if "dilations" in node.attr: new_node.attr["dilations"].CopyFrom(node.attr['dilations']) new_node.attr["out_type"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) summand_op_type = uint8_type if dtypes.as_dtype( original_summand_node.attr["T"].type ) == uint8_type else int8_type if summand_op_type == int8_type: new_node.op = "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" new_node.attr["Tsummand"].CopyFrom( attr_value_pb2.AttrValue(type=summand_op_type)) output_graph_def.node.extend([new_node]) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) return output_graph_def
def apply_op(self, op_type_name, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[ input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append( converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append( "<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError( "%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError("%s that are invalid." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[ input_arg.type_attr] try: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( "Expected %s passed to parameter '%s' of op '%s', got %s of " "type '%s' instead." % (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, repr(values), type(values).__name__)) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name prefix = ( "Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype( attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len( values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[ input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join( dtypes.as_dtype(x).name for x in attr_value), ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(attr_value.s), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend( [_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend( [_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": if isinstance(value, attr_value_pb2.NameAttrList): attr_value.func.CopyFrom(value) elif isinstance(value, compat.bytes_or_text_types): attr_value.func.name = value else: value.add_to_graph(ops.get_default_graph()) attr_value.func.name = value.name else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type ] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(types)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [dtypes.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError( "apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [ val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref ] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) if output_structure: outputs = op.outputs res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) if isinstance(res, list) and not res and op_def.is_stateful: return op else: return res else: return op