def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
  """Creates a constant tensor.

  The resulting tensor is populated with values of type `dtype`, as
  specified by arguments `value` and (optionally) `shape` (see examples
  below).

  The argument `value` can be a constant value, or a list of values of type
  `dtype`. If `value` is a list, then the length of the list must be less
  than or equal to the number of elements implied by the `shape` argument (if
  specified). In the case where the list length is less than the number of
  elements specified by `shape`, the last element in the list will be used
  to fill the remaining entries.

  The argument `shape` is optional. If present, it specifies the dimensions of
  the resulting tensor. If not present, the shape of `value` is used.

  If the argument `dtype` is not specified, then the type is inferred from
  the type of `value`.

  For example:

  ```python
  # Constant 1-D Tensor populated with value list.
  tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]

  # Constant 2-D tensor populated with scalar value -1.
  tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
                                               [-1. -1. -1.]]
  ```

  `tf.constant` differs from `tf.fill` in a few ways:

  *   `tf.constant` supports arbitrary constants, not just uniform scalar
      Tensors like `tf.fill`.
  *   `tf.constant` creates a `Const` node in the computation graph with the
      exact value at graph construction time. On the other hand, `tf.fill`
      creates an Op in the graph that is expanded at runtime.
  *   Because `tf.constant` only embeds constant values in the graph, it does
      not support dynamic shapes based on other runtime Tensors, whereas
      `tf.fill` does.

  Args:
    value:          A constant value (or list) of output type `dtype`.

    dtype:          The type of the elements of the resulting tensor.

    shape:          Optional dimensions of resulting tensor.

    name:           Optional name for the tensor.

    verify_shape:   Boolean that enables verification of a shape of values.

  Returns:
    A Constant Tensor.

  Raises:
    TypeError: if shape is incorrectly specified or unsupported.
  """
  ctx = context.context()
  if ctx.executing_eagerly():
    t = convert_to_eager_tensor(value, ctx, dtype)
    if shape is None:
      return t
    shape = tensor_shape.as_shape(shape)
    if shape == t.shape:
      return t
    if verify_shape:
      raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape),
                                                                tuple(t.shape)))
    num_t = t.shape.num_elements()
    # TODO(josh11b): Implement shape -> eager tensor conversion.
    if num_t == shape.num_elements():
      return _eager_reshape(t, shape.as_list(), ctx)
    if num_t == 1:
      if t.dtype == dtypes.bool:
        # We don't have a Fill kernel for bool dtype on GPU. So we first run
        # Fill on CPU and then copy to GPU if needed.
        with ops.device("/device:CPU:0"):
          x = _eager_fill(shape.as_list(), t.cpu(), ctx)
        return _eager_identity(x, ctx)
      else:
        return _eager_fill(shape.as_list(), t, ctx)
    raise TypeError("Eager execution of tf.constant with unsupported shape "
                    "(value has %d elements, shape is %s with %d elements)." %
                    (num_t, shape, shape.num_elements()))
  g = ops.get_default_graph()
  tensor_value = attr_value_pb2.AttrValue()
  tensor_value.tensor.CopyFrom(
      tensor_util.make_tensor_proto(
          value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
  const_tensor = g.create_op(
      "Const", [], [dtype_value.type],
      attrs={"value": tensor_value,
             "dtype": dtype_value},
      name=name).outputs[0]
  return const_tensor
Example #2
0
    def __exit__(self, unused_type, unused_value, unused_traceback):
        # pylint: disable=protected-access
        if context.executing_eagerly():
            return

        if self._graph is not ops.get_default_graph():
            raise RuntimeError(
                "Within the automatic control dependency context, the default graph"
                f" cannot change. Upon entry it was {self._graph}, but on exit it"
                f" changed to {ops.get_default_graph()}")

        outer_graph = getattr(self._graph, "outer_graph", None)
        if outer_graph is not None:
            self._graph._add_control_dependencies = outer_graph._add_control_dependencies
        else:
            self._graph._add_control_dependencies = False
        self._graph.experimental_acd_manager = None

        # map from resource tensor to the last op which wrote to it
        last_write_to_resource = {}
        # map from resource tensor to the list of reads from it since the last
        # write or since the beginning of the function.
        reads_since_last_write_to_resource = collections.defaultdict(list)
        # CollectiveManager manager_ids within a particular function call should not
        # be needed outside of that function call. So we keep them separate (though
        # the general idea of the maps is the same, in the future, we'll need to
        # correctly thread the control output outside).
        # Map from collective manager scope to the last op which used it
        collective_manager_scopes_opened = {}
        collective_manager_scopes_used = {}
        # set of conditional and loop exits
        ops_which_must_run = set()
        # merge which must depend on ops which use this resource
        merge_for_resource = {}

        new_operations = self._graph.get_operations()[self._n_operations:]
        first_use_for_res = {}
        resources_by_op = {}

        # Ensures that uses of resource tensors get serialized properly and all
        # execute. This is done by keeping a map from resource tensor to the last op
        # in graph-construction order which used it (last_write_to_resource).
        #
        # Conditionals are written in TensorFlow such that every external tensor
        # accessed in the conditional goes through a switch op and every return
        # tensor (it's guaranteed that there will be at least one) goes through a
        # merge op.
        #
        # To handle conditionals, switches are handled in a special way (see
        # comments for _process_switch). Merge nodes created by TF's conditional
        # logic (as opposed to by _process_switch) are forced to run and also get a
        # control dependency added to them to ensure all stateful ops inside their
        # control flow context run.
        #
        # We also ensure that if an op is using a resource output by a switch node
        # (that is, a resource tensor for which there's a value in
        # merge_for_resource) this op will run before the merge for that resource.
        #
        # We try to add control inputs to nodes respecting their control flow
        # contexts to avoid dead nodes propagating everywhere and leading to
        # "retval[0] doesn't have value" errors. If a node gets a control dependency
        # on a dead node (i.e. a note from an untaken control flow branch) that node
        # will be marked as dead unless it's a merge node.
        #
        # TODO(apassos): serialize non-resource-taking stateful ops as well, and
        # test that it works. Support while loops. Support init_scope escaping from
        # this.
        for op in new_operations:
            # TODO(apassos) make this code safely support while loops.
            if control_flow_util.IsInWhileLoop(op):
                continue
            control_inputs = set()

            if op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS:
                # This will add it to self._independent_ops, but also mark it with an
                # attribute.
                self.run_independently(op)

            if op in self._independent_ops:
                ops_which_must_run.add(op)
                continue

            # Ensure stateful ops run.
            # Read-only ops are added to control outputs if the read value is
            # consumed. This covers the case when the read value is returned from
            # the function since that goes through a tf.identity in mark_as_return.
            if ((op_def_registry.get(op.type) is None) or
                (op_is_stateful(op) and (op.type not in utils.RESOURCE_READ_OPS
                                         or any(output.consumers()
                                                for output in op.outputs)))):
                ops_which_must_run.add(op)

            # Make a note of all opened manager_ids.
            if op.type == "NoOp":
                try:
                    collective_manager_scopes_opened[op.get_attr(
                        "_collective_manager_id")] = op
                except ValueError:
                    pass
            # Ignore switches (they're handled separately)
            if op.type == "Switch" and op.inputs[
                    0].dtype == dtypes_module.resource:
                continue
            # Make merges trigger all other computation which must run
            # TODO(mdan): Don't do this. Write a transform to chains instead.
            # See core/common_runtime/control_flow_deps_to_chains.cc.
            if op.type == "Merge":
                for o in ops_which_must_run:
                    op._add_control_input(o)
                    for inp in o.inputs:
                        input_id = ops.tensor_id(inp)
                        if input_id in last_write_to_resource:
                            last_write_to_resource[input_id] = op
                ops_which_must_run = set([op])
                continue

            resource_inputs = set()
            # Check for any resource inputs. If we find any, we update control_inputs
            # and last_write_to_resource.
            for inp, resource_type in _get_resource_inputs(op):
                is_read = resource_type == ResourceType.READ_ONLY
                input_id = ops.tensor_id(inp)

                # If the op receives the same resource tensor twice as an input, we skip
                # to avoid the op getting a control dependency on itself.
                if input_id in resource_inputs:
                    continue

                resource_inputs.add(input_id)
                # Deal with switches, finally.
                if inp.op.type == "Switch":
                    self._process_switch(inp.op, ops_which_must_run,
                                         last_write_to_resource,
                                         merge_for_resource)
                is_building_function = op.graph.building_function
                # Ensure uses of resources are serialized
                if input_id in last_write_to_resource:
                    if is_building_function or (
                            last_write_to_resource[input_id].
                            _control_flow_context is op._control_flow_context):
                        control_inputs.add(last_write_to_resource[input_id])
                # Ensure merges happen after the closing of a cond block
                if input_id in merge_for_resource:
                    merge_for_resource[input_id]._add_control_input(op)

                do_record = (self.record_initial_resource_uses
                             and input_id not in first_use_for_res)

                if is_read:
                    reads_list = reads_since_last_write_to_resource[input_id]
                    reads_list.append(op)

                    if do_record:
                        # Note: this will track the entire list that
                        # reads_since_last_write_to_resource maintains. Updates to it will
                        # and should be tracked, until the first write is encountered. At
                        # that point, reads_since_last_write_to_resource will contain a new
                        # empty list. This logic relies on that behavior.
                        first_use_for_res[input_id] = reads_list

                else:
                    control_inputs.update(
                        reads_since_last_write_to_resource[input_id])
                    reads_since_last_write_to_resource[input_id] = []
                    last_write_to_resource[input_id] = op

                    if do_record:
                        first_use_for_res[input_id] = [op]

            if self.record_initial_resource_uses and op_is_stateful(op):
                if resource_inputs:
                    resources_by_op[op] = tuple(resource_inputs)
                else:
                    if None not in first_use_for_res:
                        first_use_for_res[None] = [op]
                    resources_by_op[op] = (None, )

            if (op_is_stateful(op) and not resource_inputs
                    and op._control_flow_context is None):
                if None in last_write_to_resource:
                    op._add_control_input(last_write_to_resource[None])
                last_write_to_resource[None] = op

            # Ensure ordering of collective ops
            manager_ids = collective_manager_ids_from_op(op)
            for manager_id in manager_ids:
                if manager_id in collective_manager_scopes_opened:
                    # Chain this function call if the scope was opened.
                    op._add_control_input(
                        collective_manager_scopes_opened[manager_id])
                    collective_manager_scopes_opened[manager_id] = op
                else:
                    # If this op is in a scope not created here, create a chain starting
                    # at this op.
                    if manager_id in collective_manager_scopes_used:
                        op._add_control_input(
                            collective_manager_scopes_used[manager_id])
                    collective_manager_scopes_used[manager_id] = op

            if control_inputs and not is_building_function:
                control_inputs = [
                    c for c in control_inputs
                    if c._control_flow_context is op._control_flow_context
                ]

            op._add_control_inputs(control_inputs)

        # Record the ops which first use resources touched by "ops which must run".
        if self.record_initial_resource_uses:
            first_uses_by_output_ops = {}
            for op in ops_which_must_run:
                if op not in resources_by_op:
                    # This may happen with Merge/Switch nodes which are special cased
                    # above.
                    continue
                for r in resources_by_op[op]:
                    if op not in first_uses_by_output_ops:
                        first_uses_by_output_ops[op] = set()
                    first_uses_by_output_ops[op].update(first_use_for_res[r])
            # For each "op which must run", set a private attr indicating the ops that
            # used the same resources it did.
            for op in first_uses_by_output_ops:
                others = [
                    other.name.encode()
                    for other in first_uses_by_output_ops[op]
                ]
                l = attr_value_pb2.AttrValue.ListValue(s=others)
                # TODO(mdan): Is there a way which doesn't use anonymous attrs?
                op._set_attr("_res_first_used_by",
                             attr_value_pb2.AttrValue(list=l))

        # Ensure all ops which must run do run
        self.ops_which_must_run.update(ops_which_must_run)
        control_output_op = None
        for idx, r in enumerate(
                nest.flatten(list(self._returned_tensors),
                             expand_composites=True)):
            if self.ops_which_must_run:
                updated_ops_which_must_run = []
                if r.graph.building_function:
                    # There may be many stateful ops in the graph. Adding them as
                    # control inputs to each function output could create excessive
                    # control edges in the graph. Thus we create an intermediate No-op
                    # to chain the control dependencies between stateful ops and
                    # function outputs.
                    if idx == 0:
                        control_output_op = control_flow_ops.no_op()
                        control_output_op._set_attr(
                            "_acd_function_control_output",
                            attr_value_pb2.AttrValue(b=True))
                        control_output_op._add_control_inputs(
                            self.ops_which_must_run)
                    updated_ops_which_must_run = [control_output_op]
                else:
                    updated_ops_which_must_run = [
                        o for o in self.ops_which_must_run if
                        o._control_flow_context is r.op._control_flow_context
                    ]
                r.op._add_control_inputs(updated_ops_which_must_run)

        self.collective_manager_ids_used = collective_manager_scopes_used
Example #3
0
def split_compile_and_replicate(computation,
                                inputs=None,
                                infeed_queue=None,
                                device_assignment=None,
                                name=None,
                                use_tpu=True):
    """Builds graph operators that runs compilation and replicated computation.

  This is a lower level interface than replicate that returns a separate compile
  and execute output tensor. In the generated graph the compile op feeds into
  the execute op and no additional compilation is incurred when running the
  compile op before the execute op. The compile op returns additional
  information about the compilation but does not return the compiled program.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: (Deprecated) Does nothing.
    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
      backends. Currently, only supports a default placement (computation is
      placed on GPU if one is available, and on CPU if not).
  Returns:
    A list of lists with the first list corresponding to the compile op and the
    second a list of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
    del name
    inputs = [[]] if inputs is None else inputs

    metadata_kwargs = {}
    if device_assignment is not None:
        # Turn the Numpy array into a flattened list so we can pass it as an
        # operator attribute.
        metadata_kwargs = {
            "topology":
            device_assignment.topology.serialized(),
            "device_assignment":
            device_assignment.core_assignment.flatten().tolist(),
            "computation_shape":
            device_assignment.computation_shape.tolist()
        }

    if ((not isinstance(inputs, list))
            or any(not isinstance(inp, (list, tuple)) for inp in inputs)):
        raise TypeError(
            "tpu.replicate() inputs must be a list of lists/tuples")

    num_replicas = len(inputs)

    # No replicas? Nothing to do.
    if num_replicas == 0:
        return []

    # Converts inputs to Tensors.
    inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

    # Verifies that all replicas have matching numbers and types of inputs
    input_types = [x.dtype for x in inputs[0]]
    input_arity = len(input_types)
    for i in range(num_replicas):
        if len(inputs[i]) != input_arity:
            raise ValueError("Replicas must have the same number of inputs. "
                             "Replica 0 had {} inputs, replica {} had {} "
                             "inputs.".format(input_arity, i, len(inputs[i])))

        types = [x.dtype for x in inputs[i]]
        if types != input_types:
            raise ValueError(
                "Replicas must have matching input types. Replica 0 had "
                "input types {}, replica {} had input types {}".format(
                    input_types, i, types))

    arg_error = tpu_function.check_function_argument_count(
        computation, input_arity, infeed_queue)
    if arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s, but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]), arg_error))
        else:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s and %d additional inputs from infeed,"
                " but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]),
                 infeed_queue.number_of_tuple_elements, arg_error))

    graph = ops.get_default_graph()

    # Fan-in: Builds a TPUReplicatedInput node for each input.
    computation_inputs = []
    for i in range(0, input_arity):
        replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
        computation_inputs.append(
            tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

    cluster_name = graph.unique_name("cluster")
    pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
    context = TPUReplicateContext(name=cluster_name,
                                  num_replicas=num_replicas,
                                  pivot=pivot)
    try:
        context.Enter()

        metadata = tpu_ops.tpu_replicate_metadata(num_replicas=num_replicas,
                                                  use_tpu=use_tpu,
                                                  **metadata_kwargs)

        with tpu_function.tpu_shard_context(
                num_replicas), ops.control_dependencies([metadata]):

            # Add identity ops so even unused inputs are "consumed" by the
            # computation. This is to avoid orphaned TPUReplicatedInput nodes.
            # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
            # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
            computation_inputs = [
                array_ops.identity(x, name="replicated_input_{}".format(i))
                for i, x in enumerate(computation_inputs)
            ]

            # If there is an infeed queue, adds the dequeued values to the
            # computation's inputs.
            if infeed_queue is not None:
                infeed_queue.set_number_of_shards(num_replicas)
                for t in infeed_queue.generate_dequeue_op():
                    computation_inputs.append(t)

            # Only resource variables work inside a TPU computation, so turn on
            # resource variables for the computation.
            # TODO(phawkins): consider removing this code. It will
            # be less confusing to clients if they knowingly choose to use resource
            # variables.
            vscope = variable_scope.get_variable_scope()
            saved_use_resource = vscope.use_resource
            vscope.set_use_resource(True)

            outputs = computation(*computation_inputs)

            vscope.set_use_resource(saved_use_resource)

        # If the computation returns `None`, make it an empty tuple.
        if outputs is None:
            outputs = tuple()
        # If the computation only returned one value, makes it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        # Append `no_op` here so that fetching any return value of this function
        # will trigger TPUExecute node.
        outputs += (control_flow_ops.no_op(), )
        try:
            with ops.device(core(0)):
                outputs = [
                    o if isinstance(o, ops.Operation) else
                    ops.convert_to_tensor(o) for o in outputs
                ]
        except Exception as e:
            raise ValueError(
                "TPU function return values must all either be Operations or "
                "convertible to Tensors. Got '%s'" % str(e))

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "TPU functions must return zero-or more Tensor values followed by "
                "zero or more Operations.")
        output_arity = len(output_tensors)

        # Wraps outputs in Identity ops. Otherwise a replicated input copied
        # straight to an output would bypass the replicate(). This would be bad
        # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
        # be rewritten away, leading to a runtime error.
        # TODO(phawkins): extend the rewrite to elide these nodes instead.
        new_output_tensors = []
        for t in output_tensors:
            with ops.device(t.device if t.device else core(0)):
                new_output_tensors.append(array_ops.identity(t))
        output_tensors = new_output_tensors
        context.ExitResult(output_tensors)
    finally:
        context.report_unsupported_operations()
        context.Exit()
        host_compute_core = context.HostComputeCore()

    if host_compute_core:
        attr_value = attr_value_pb2.AttrValue()
        attr_value.list.s.extend(
            [compat.as_bytes(x) for x in host_compute_core])
        metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access

    # Fan-out: Builds a TPUReplicatedOutput node for each output.
    outputs = [
        tpu_ops.tpu_replicated_output(output_tensors[i],
                                      num_replicas,
                                      name="output{}".format(i))
        for i in xrange(output_arity)
    ]

    with ops.control_dependencies([metadata]):
        if use_tpu:
            compile_status = tpu_ops.tpu_compilation_result()
            op = compile_status.op
            attr_value = attr_value_pb2.AttrValue(
                s=compat.as_bytes(cluster_name))
            op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
        else:
            compile_status = control_flow_ops.no_op(name="compilation_status")

    with ops.control_dependencies(output_operations):
        if output_arity == 0:
            # Returns a list of NoOps dependent on the replication Op, indexed by
            # [replica_num].
            return [
                compile_status,
                [
                    control_flow_ops.no_op(name="shard_%d" % i)
                    for i in range(num_replicas)
                ]
            ]
        else:
            # Wraps the outputs in identity operators so the names of any possible
            # `fetch` nodes are preserved by the replication rewrite.
            return [
                compile_status,
                [[
                    array_ops.identity(outputs[out][replica],
                                       name="output_%d_shard_%d" %
                                       (out, replica))
                    for out in xrange(output_arity)
                ] for replica in xrange(num_replicas)]
            ]
Example #4
0
 def patch_dtype(input_node, field_name, output_node):
     if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT):
         output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF))
Example #5
0
 def set_attr_int_list(node, key, value):
     """Set the node's attr which data type is int list.
     """
     list_value = attr_value_pb2.AttrValue.ListValue(i=value)
     node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
Example #6
0
    def generate_output_graph(self, input_graph_def, input_node_map,
                              fuse_op_list):
        output_graph_def = graph_pb2.GraphDef()
        skip_list = []
        skip_node_name = []
        float32_type = dtypes.float32.as_datatype_enum
        for index, node in enumerate(input_graph_def.node):
            if index in fuse_op_list:
                input_node = input_node_map[node.input[0]]
                if input_node.op == 'QuantizeV2':
                    new_node = node_def_pb2.NodeDef()

                    new_node.op = node.op + "AndDequantize"
                    for _, value in enumerate(node.input):
                        new_node.input.append(value)

                    dequantize_node = input_graph_def.node[index + 4]
                    frozen_max_node = input_graph_def.node[index + 2]
                    frozen_min_node = input_graph_def.node[index + 1]

                    new_node.name = dequantize_node.name

                    new_node.input.append(frozen_min_node.name)
                    new_node.input.append(frozen_max_node.name)

                    new_node.attr["T1"].CopyFrom(node.attr['T1'])
                    new_node.attr["T2"].CopyFrom(node.attr['T2'])

                    new_node.attr["Tbias"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))
                    new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))

                    skip_list.append(index + 1)
                    skip_list.append(index + 2)
                    skip_list.append(index + 3)
                    skip_list.append(index + 4)
                    output_graph_def.node.extend(
                        [new_node, frozen_max_node, frozen_min_node])
                elif input_node.op == "Requantize":
                    new_node = node_def_pb2.NodeDef()
                    new_node.op = node.op + "AndDequantize"
                    for _, value in enumerate(node.input):
                        new_node.input.append(value)

                    dequantize_node = input_graph_def.node[index + 4]
                    frozen_max_node = input_graph_def.node[index + 2]
                    frozen_min_node = input_graph_def.node[index + 1]
                    new_node.name = dequantize_node.name
                    skip_list.append(index + 1)
                    skip_list.append(index + 2)
                    skip_list.append(index + 3)
                    skip_list.append(index + 4)
                    new_node.input.append(frozen_min_node.name)
                    new_node.input.append(frozen_max_node.name)

                    new_node.attr["T1"].CopyFrom(node.attr['T1'])
                    new_node.attr["T2"].CopyFrom(node.attr['T2'])

                    new_node.attr["Tbias"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))
                    new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))

                    output_graph_def.node.extend(
                        [new_node, frozen_max_node, frozen_min_node])
                else:
                    new_node = node_def_pb2.NodeDef()
                    new_node.CopyFrom(node)
                    output_graph_def.node.extend([new_node])

            elif index in skip_list or node.name in skip_node_name:
                continue
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(node)
                output_graph_def.node.extend([new_node])
        return output_graph_def
Example #7
0
def _get_defun_inputs(args, names, structure, flat_shapes=None):
    """Maps python function args to graph-construction inputs.

  Args:
    args: A flat list of user-specified arguments.
    names: A list of strings with user-specified argument names, same length as
      `args`. May be `None`, in which case a generic name is used.
    structure: The original argument list or dictionary.
    flat_shapes: A flat list of values that are either `None` or
      instances of `TensorShape`.  If provided, then length must match
      that of `nest.flatten(args)`; and locations where `args` are
      instances of `Tensor` must have a corresponding `TensorShape` in
      `flat_shapes`.  May be `None`, in which case exact shapes are read
      directly from the args.

  Returns:
    Placeholders with the same structure as `structure`.

  Raises:
    RuntimeError: if `flat_shapes` is provided, but
     `len(flat_shapes) != len(nest.flatten(args))`.
    RuntimeError: if a shape from `flat_shapes` is not None
     for an argument that is not a `Tensor`, `TensorSpec`,
     or `ResourceVariable`.
  """
    func_graph = ops.get_default_graph()
    function_inputs = []
    if names is None:
        names = [None] * len(args)
    if flat_shapes is None:
        shapes_iter = itertools.repeat(None)
    else:
        len_flat_args = len(nest.flatten(args))
        if len_flat_args != len(flat_shapes):
            raise RuntimeError(
                "Length of fully flat shapes (%d) must match that of "
                "flatten(args) (%d).  args: %s, flat_shapes: %s" %
                (len(flat_shapes), len_flat_args, args, flat_shapes))
        shapes_iter = iter(flat_shapes)
    for arg_value, name in zip(args, names):
        flattened = nest.flatten(arg_value)
        tensor_specs = [
            arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec)
        ]
        specified_names = [arg.name for arg in tensor_specs if arg.name]
        if specified_names and len(specified_names) < len(tensor_specs):
            raise ValueError(
                "If specifying TensorSpec names for nested structures, "
                "either zero or all names have to be specified.")

        for arg in flattened:
            # We have a shape entry for each arg, regadless of whether it's a real
            # Tensor or not.  For non-tensor entries it should be None.
            shape = next(shapes_iter)
            if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
                if isinstance(arg, tensor_spec.TensorSpec) and arg.name:
                    requested_name = arg.name
                else:
                    requested_name = name
                placeholder_shape = shape if shape is not None else arg.shape
                try:
                    placeholder = graph_placeholder(arg.dtype,
                                                    placeholder_shape,
                                                    name=requested_name)
                except ValueError:
                    # Sometimes parameter names are not valid op names, so fall back to
                    # unnamed placeholders.
                    placeholder = graph_placeholder(arg.dtype,
                                                    placeholder_shape)
                if name is not None:
                    # Record the requested/user-specified name in case it's different than
                    # the uniquified name, for validation when exporting signatures.
                    placeholder.op._set_attr(  # pylint: disable=protected-access
                        "_user_specified_name",
                        attr_value_pb2.AttrValue(
                            s=compat.as_bytes(requested_name)))
                function_inputs.append(placeholder)
            elif isinstance(arg, resource_variable_ops.ResourceVariable):
                # Capture arg variables to create placeholders for them. These will be
                # removed as captures after the function is traced (since otherwise we'd
                # just add it back with a new placeholder when the variable was
                # referenced).
                placeholder = func_graph.capture(arg.handle, name=name)
                placeholder.op._set_attr(  # pylint: disable=protected-access
                    "_user_specified_name",
                    attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
                function_inputs.append(arg)
            else:
                if shape is not None:
                    raise RuntimeError(
                        "Expected provided shape override to be None for arg that isn't "
                        "a Tensor, but saw arg: '%s', shape: '%s'.  args: %s" %
                        (arg, shape, args))
                function_inputs.append(arg)
    return nest.pack_sequence_as(structure, function_inputs)
Example #8
0
def replicate(computation,
              inputs=None,
              infeed_queue=None,
              device_assignment=None,
              name=None):
  """Builds a graph operator that runs a replicated TPU computation.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: The name of the operator.
  Returns:
    A list of lists of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
  if name is None:
    name = "TPUReplicate"
  inputs = [[]] if inputs is None else inputs

  metadata_kwargs = {}
  if device_assignment is not None:
    # Turn the Numpy array into a flattened list so we can pass it as an
    # operator attribute.
    metadata_kwargs = {
        "topology":
            device_assignment.topology.serialized(),
        "device_assignment":
            device_assignment.core_assignment.flatten().tolist(),
        "computation_shape":
            device_assignment.computation_shape.tolist()
    }

  if ((not isinstance(inputs, list)) or
      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")

  num_replicas = len(inputs)

  # No replicas? Nothing to do.
  if num_replicas == 0:
    return []

  # Converts inputs to Tensors.
  inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

  # Verifies that all replicas have matching numbers and types of inputs
  input_types = [x.dtype for x in inputs[0]]
  input_arity = len(input_types)
  for i in range(num_replicas):
    if len(inputs[i]) != input_arity:
      raise ValueError("Replicas must have the same number of inputs. "
                       "Replica 0 had {} inputs, replica {} had {} "
                       "inputs.".format(input_arity, i, len(inputs[i])))

    types = [x.dtype for x in inputs[i]]
    if types != input_types:
      raise ValueError(
          "Replicas must have matching input types. Replica 0 had "
          "input types {}, replica {} had input types {}".format(
              input_types, i, types))

  arg_error = tpu_function.check_function_argument_count(
      computation, input_arity, infeed_queue)
  if arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s, but the computation needs %s" % (
              input_arity, str([i.name for i in inputs[0]]), arg_error))
    else:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s and %d additional inputs from infeed,"
          " but the computation needs %s" % (input_arity, str(
              [i.name
               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
                                             arg_error))

  graph = ops.get_default_graph()

  with ops.name_scope(name, "replicate"):
    # Fan-in: Builds a TPUReplicatedInput node for each input.
    computation_inputs = []
    for i in range(0, input_arity):
      replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
      computation_inputs.append(
          tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

    context = TPUReplicateContext(name=graph.unique_name("cluster"))
    try:
      context.Enter()

      metadata = tpu_ops.tpu_replicate_metadata(
          num_replicas=num_replicas, **metadata_kwargs)

      with tpu_function.tpu_shard_context(
          num_replicas), ops.control_dependencies([metadata]):

        # The EncapsulateTPUComputations rewrite needs to identify the
        # replicated arguments inside each computation. Adds identity operators
        # tagged with an attribute _tpu_replicated_input to identify the
        # replicated inputs.
        # pylint: disable=protected-access
        with graph._attr_scope({"_tpu_replicated_input":
                                attr_value_pb2.AttrValue(b=True)}):
          computation_inputs = [
              array_ops.identity(x, name="replicated_input_{}".format(i))
              for i, x in enumerate(computation_inputs)]
        # pylint: enable=protected-access

        # If there is an infeed queue, adds the dequeued values to the
        # computation's inputs.
        if infeed_queue is not None:
          infeed_queue.set_number_of_shards(num_replicas)
          for t in infeed_queue.generate_dequeue_op():
            computation_inputs.append(t)

        # Only resource variables work inside a TPU computation, so turn on
        # resource variables for the computation.
        # TODO(phawkins): consider removing this code. It will
        # be less confusing to clients if they knowingly choose to use resource
        # variables.
        vscope = variable_scope.get_variable_scope()
        saved_use_resource = vscope.use_resource
        vscope.set_use_resource(True)

        outputs = computation(*computation_inputs)

        vscope.set_use_resource(saved_use_resource)

      # If the computation only returned one value, makes it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)

      try:
        with ops.device(core(0)):
          outputs = [
              o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
              for o in outputs
          ]
      except Exception as e:
        raise ValueError(
            "TPU function return values must all either be Operations or "
            "convertible to Tensors. Got '%s'" % str(e))

      # Separates the returned Operations and Tensors.
      output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
      output_tensors = [o for o in outputs
                        if not isinstance(o, ops.Operation)]

      if outputs != output_tensors + output_operations:
        raise ValueError(
            "TPU functions must return zero-or more Tensor values followed by "
            "zero or more Operations.")
      output_arity = len(output_tensors)

      # Wraps outputs in Identity ops. Otherwise a replicated input copied
      # straight to an output would bypass the replicate(). This would be bad
      # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
      # be rewritten away, leading to a runtime error.
      # TODO(phawkins): extend the rewrite to elide these nodes instead.
      new_output_tensors = []
      for t in output_tensors:
        with ops.device(t.device if t.device else core(0)):
          new_output_tensors.append(array_ops.identity(t))
      output_tensors = new_output_tensors
    finally:
      context.Exit()

    # Fan-out: Builds a TPUReplicatedOutput node for each output.
    outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
                                             name="output{}".format(i))
               for i in xrange(output_arity)]

    with ops.control_dependencies(output_operations):
      if output_arity == 0:
        # Returns a list of NoOps dependent on the replication Op, indexed by
        # [replica_num].
        return [
            control_flow_ops.no_op(name="%s_shard_%d" % (name, i))
            for i in range(num_replicas)
        ]
      else:
        # Wraps the outputs in identity operators so the names of any possible
        # `fetch` nodes are preserved by the replication rewrite.
        return [
            [array_ops.identity(outputs[out][replica],
                                name="output_%d_shard_%d" % (out, replica))
             for out in xrange(output_arity)]
            for replica in xrange(num_replicas)
        ]
Example #9
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  # Note that op is not always the same as while_op because the gradient tape,
  # for eager mode compatibility, forgets information about the proper op. Since
  # the loop cannot run in eager mode, however, we can safely introspect into
  # the graph here.
  while_op = op.outputs[0].op
  cond_graph = _get_graph(while_op, "cond")
  body_graph = _get_graph(while_op, "body")
  orig_num_params = len(body_graph.outputs)

  maximum_iterations = op.inputs[1]
  parallel_iterations = op.get_attr("parallel_iterations")

  try:
    num_original_outputs = while_op.get_attr("_num_original_outputs")
  except:  # pylint: disable=bare-except
    num_original_outputs = len(while_op.outputs)

  num_intermediates = len(while_op.outputs) - num_original_outputs
  grads = [
      _preprocess_grad(grad, body_out, while_out)  # pylint: disable=g-complex-comprehension
      for grad, body_out, while_out in zip(
          grads[:num_original_outputs],
          body_graph.outputs[:num_original_outputs],
          while_op.outputs[:num_original_outputs])
  ] + [None] * num_intermediates

  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, cond_graph, body_graph,
      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

  if body_grad_graph.while_op_needs_rewrite:
    # Modify 'op' to output the intermediate accumulators needed by the grad
    # function.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    cond_graph.name += "_rewritten"
    body_graph.name += "_rewritten"

    new_inputs = body_grad_graph.empty_tensor_lists
    new_outputs = body_graph.outputs[orig_num_params:]

    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
    while_op._set_type_list_attr("T", body_graph.output_types)
    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
    while_op._add_while_inputs(new_inputs)
    while_op._add_outputs([t.dtype for t in new_outputs],
                          [t.shape for t in new_outputs])
    _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

  # Do not ingore grads wrt extra outputs when computing higher order
  # derivatives.
  while_op._set_attr("_num_original_outputs",
                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))

  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                           while_op)
  loop_vars = args + captured_inputs

  # This modifies body_grad_graph.
  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
      grads, body_grad_graph, loop_vars, while_op.inputs)

  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
                *unused_args):
    return counter < forward_loop_iters

  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      parallel_iterations=parallel_iterations,
      name="%s_grad" % while_op.name)
  grad_op = outputs[0].op

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(grad_op)
  util.maybe_propagate_compile_time_consts_in_xla(grad_op)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]
  return _get_structured_grad_output(outputs, grads, body_grad_graph)
  def _GetLayerMatch(match_result):
    """Populates a layer match object containing ops/tensors for folding BNs.

    Args:
      match_result: Matched result from graph matcher

    Returns:
      layer_op: Matching conv/fc op prior to batch norm
      BatchNormMatch: _BatchNormMatch containing all required batch norm
      parameters.
    """
    moving_mean_tensor = None
    moving_variance_tensor = None
    bn_decay_mean_tensor = None
    bn_decay_var_tensor = None
    batch_to_space_op = None
    layer_op = match_result.get_op(layer_pattern)
    layer_tensor = match_result.get_tensor(layer_pattern)
    bn_id_op = match_result.get_op(batch_norm_identity_pattern)
    bn_op = match_result.get_op(batch_norm_pattern)
    if bn_id_op is None:
      bn_id_op = bn_op

    batch_epsilon = bn_op.get_attr('epsilon')

    # In the MatMul case, the output of batch norm is reshaped back into a
    # 2D tensor, so the output_tensor is the output of the Reshape op.
    output_tensor = bn_op.outputs[0]
    if layer_op.type == 'MatMul':
      output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
      # If the matcher didn't match matmul_bn_output_reshape, there will be
      # another match for this 'MatMul' later, so we can skip this one.
      if output_reshape_op is None:
        return None, None
      output_tensor = output_reshape_op.outputs[0]

    # Ensure that the output tensor has consumers, otherwise this is a dangling
    # node and not a match.
    if not output_tensor.consumers():
      return None, None

    batch_to_space_op = match_result.get_op(batch_to_space_pattern)
    input_tensor = match_result.get_tensor(input_pattern)
    weight_tensor = match_result.get_tensor(weight_pattern)
    gamma_tensor = match_result.get_tensor(gamma_pattern)
    beta_tensor = match_result.get_tensor(beta_pattern)
    # FusedBatchNorm in training is different from that in inference. It takes
    # empty 'mean' and empty 'variance', and produces the mean and the variance
    # of the batch. Therefore, when is_training is true, mean_tensor and
    # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
    # respectively; when is_training is false, they point to bn_op's inputs.
    is_training = bn_op.get_attr('is_training')
    if is_training:
      # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
      # batch_variance outputs, so we need to substitute our own custom
      # gradient.
      # TODO(suharshs, raghuramank): Find a way to avoid needing this hack.
      # pylint: disable=protected-access
      bn_op._set_attr(
          '_gradient_op_type',
          attr_value_pb2.AttrValue(s=compat.as_bytes('FoldFusedBatchNormGrad')))
      # pylint: enable=protected-access
      mean_tensor = bn_op.outputs[1]
      # The batch variance used during forward and backward prop is biased,
      # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
      # calculation, the variance is corrected by the term N/N-1 (Bessel's
      # correction). The variance tensor read from FuseBatchNorm has Bessel's
      # correction applied, so we undo it here.
      scope, sep, _ = bn_op.name.rpartition('/')
      g = ops.get_default_graph()
      with g.as_default(), g.name_scope(scope + sep):
        n = math_ops.cast(
            array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
            dtypes.float32)
        variance_tensor = math_ops.multiply(
            bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction')
      # TODO(suharshs): Find a way to get rid of this inner match.
      for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
        sub_op = mul_match_result.get_op(moving_average_sub_pattern)
        if sub_op.inputs[1].name == bn_op.outputs[1].name:
          # During training: Batch Mean is bn_op.outputs[1]
          moving_mean_tensor = sub_op.inputs[0]
          bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
        if sub_op.inputs[1].name == bn_op.outputs[2].name or sub_op.inputs[1].name.split("/")[0] + "/FusedBatchNorm:2" == bn_op.outputs[2].name :
          # During training: Batch Var is bn_op.outputs[2]
          moving_variance_tensor = sub_op.inputs[0]
          bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
    else:
      mean_tensor = match_result.get_tensor(mean_pattern)
      variance_tensor = match_result.get_tensor(variance_pattern)

    return layer_op, _BatchNormMatch(
        layer_op=layer_op,
        bn_op=bn_op,
        output_tensor=output_tensor,
        input_tensor=input_tensor,
        weight_tensor=weight_tensor,
        gamma_tensor=gamma_tensor,
        beta_tensor=beta_tensor,
        mean_tensor=mean_tensor,
        variance_tensor=variance_tensor,
        moving_mean_tensor=moving_mean_tensor,
        moving_variance_tensor=moving_variance_tensor,
        bn_decay_mean_tensor=bn_decay_mean_tensor,
        bn_decay_var_tensor=bn_decay_var_tensor,
        batch_epsilon=batch_epsilon,
        batch_to_space_op=batch_to_space_op)
Example #11
0
def append_postprocessing_op(frozen_graph_def,
                             max_detections,
                             max_classes_per_detection,
                             nms_score_threshold,
                             nms_iou_threshold,
                             num_classes,
                             scale_values,
                             detections_per_class=100,
                             use_regular_nms=False):
    """Appends postprocessing custom op.

  Args:
    frozen_graph_def: Frozen GraphDef for SSD model after freezing the
      checkpoint
    max_detections: Maximum number of detections (boxes) to show
    max_classes_per_detection: Number of classes to display per detection
    nms_score_threshold: Score threshold used in Non-maximal suppression in
      post-processing
    nms_iou_threshold: Intersection-over-union threshold used in Non-maximal
      suppression in post-processing
    num_classes: number of classes in SSD detector
    scale_values: scale values is a dict with following key-value pairs
      {y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
        centersize boxes
    detections_per_class: In regular NonMaxSuppression, number of anchors used
      for NonMaxSuppression per class
    use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
      Fast NMS.

  Returns:
    transformed_graph_def: Frozen GraphDef with postprocessing custom op
    appended
    TFLite_Detection_PostProcess custom op node has four outputs:
    detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
    locations
    detection_classes: a float32 tensor of shape [1, num_boxes]
    with class indices
    detection_scores: a float32 tensor of shape [1, num_boxes]
    with class scores
    num_boxes: a float32 tensor of size 1 containing the number of detected
    boxes
  """
    new_output = frozen_graph_def.node.add()
    new_output.op = 'TFLite_Detection_PostProcess'
    new_output.name = 'TFLite_Detection_PostProcess'
    new_output.attr['_output_quantized'].CopyFrom(
        attr_value_pb2.AttrValue(b=True))
    new_output.attr['_output_types'].list.type.extend([
        types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT,
        types_pb2.DT_FLOAT
    ])
    new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom(
        attr_value_pb2.AttrValue(b=True))
    new_output.attr['max_detections'].CopyFrom(
        attr_value_pb2.AttrValue(i=max_detections))
    new_output.attr['max_classes_per_detection'].CopyFrom(
        attr_value_pb2.AttrValue(i=max_classes_per_detection))
    new_output.attr['nms_score_threshold'].CopyFrom(
        attr_value_pb2.AttrValue(f=nms_score_threshold.pop()))
    new_output.attr['nms_iou_threshold'].CopyFrom(
        attr_value_pb2.AttrValue(f=nms_iou_threshold.pop()))
    new_output.attr['num_classes'].CopyFrom(
        attr_value_pb2.AttrValue(i=num_classes))

    new_output.attr['y_scale'].CopyFrom(
        attr_value_pb2.AttrValue(f=scale_values['y_scale'].pop()))
    new_output.attr['x_scale'].CopyFrom(
        attr_value_pb2.AttrValue(f=scale_values['x_scale'].pop()))
    new_output.attr['h_scale'].CopyFrom(
        attr_value_pb2.AttrValue(f=scale_values['h_scale'].pop()))
    new_output.attr['w_scale'].CopyFrom(
        attr_value_pb2.AttrValue(f=scale_values['w_scale'].pop()))
    new_output.attr['detections_per_class'].CopyFrom(
        attr_value_pb2.AttrValue(i=detections_per_class))
    new_output.attr['use_regular_nms'].CopyFrom(
        attr_value_pb2.AttrValue(b=use_regular_nms))

    new_output.input.extend([
        'raw_outputs/box_encodings', 'raw_outputs/class_predictions', 'anchors'
    ])
    # Transform the graph to append new postprocessing op
    input_names = []
    output_names = ['TFLite_Detection_PostProcess']
    transforms = ['strip_unused_nodes']
    transformed_graph_def = TransformGraph(frozen_graph_def, input_names,
                                           output_names, transforms)
    return transformed_graph_def
Example #12
0
def While(input_, cond, body, name=None, hostmem=None):
    r"""output = input; While (Cond(output)) { output = Body(output) }.

  Args:
    input_: A list of `Tensor` objects.
      A list of input tensors whose types are T.
    cond: . A function takes 'input' and returns a tensor.  If the tensor is
      a scalar of non-boolean, the scalar is converted to a boolean
      according to the following rule: if the scalar is a numerical
      value, non-zero means True and zero means False; if the scalar is
      a string, non-empty means True and empty means False. If the
      tensor is not a scalar, non-emptiness means True and False
      otherwise.
    body: . A function takes a list of tensors and returns another
      list tensors. Both lists have the same types as specified
      by T.
    name: A name for the operation (optional).
    hostmem: A list of integer. If i is in the list, input[i] is a
      host memory tensor.

  Raises:
    ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
      have different signatures.

  Returns:
    A list of `Tensor` objects. Has the same type as `input`.
    A list of output tensors whose types are T.
  """
    if cond.captured_inputs:
        raise ValueError("While op 'cond' argument must be a function "
                         "without implicitly captured inputs.")

    if cond.declared_input_types != body.declared_input_types:
        raise ValueError(
            "While op 'cond' and 'body' signatures do not match. %r vs %r" %
            (cond.declared_input_types, body.declared_input_types))

    if body.captured_inputs:
        cond_dtypes = list(body.declared_input_types) + [
            t.dtype for t in body.captured_inputs
        ]

        @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
        def CondWrapper(*args):
            """A wrapper that handles loop-carried captured inputs."""
            return cond(*args[:len(body.declared_input_types)])

        ret = gen_functional_ops._while(input_ + body.captured_inputs,
                                        CondWrapper,
                                        _LoopBodyCaptureWrapper(body),
                                        name=name)
        # Slice off the loop-carried captured inputs.
        ret = ret[:-len(body.captured_inputs)]
    else:
        ret = gen_functional_ops._while(input_, cond, body, name=name)
    if hostmem:
        input_attr = attr_value_pb2.AttrValue()
        input_attr.list.i.extend(hostmem)
        ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access

        output_attr = attr_value_pb2.AttrValue()
        output_attr.list.i.extend(hostmem)
        ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
    return ret
Example #13
0
def partitioned_call(args,
                     f,
                     tout=None,
                     executing_eagerly=None,
                     config=None,
                     executor_type=None):
    """Executes a function while respecting device annotations.

  Currently, only those functions that execute within the same address space
  can be executed.

  Args:
    args: The arguments of the function, including captured inputs.
    f: The function to execute; an instance of `_DefinedFunction` or
      `_EagerDefinedFunction`.
    tout: a list containing the output dtypes enums; if `None`, inferred from
      the signature of `f`.
    executing_eagerly: (Optional) A boolean indicating whether the context is
      executing eagerly. If `None`, fetched from the global context.
    config: (Optional) A tensorflow::RewriterConfig proto, serialized. If
      `None`, all optimizations are disabled. Currently only handled for eager
      defined functions.
    executor_type: (Optional) A string for the name of the executor to be used
      in the function call. If not set, or set to an empty string, the default
      tensorflow executor will be used.

  Returns:
    The list of `Tensor`s returned by invoking `f(args)`. If the function does
    not return anything, then returns `None` if eager execution is enabled, or
    the `Operation` if not.
  """

    if tout is None:
        tout = tuple(x.type for x in f.definition.signature.output_arg)

    if executing_eagerly is None:
        executing_eagerly = context.executing_eagerly()

    if config is None:
        config = _get_disabled_rewriter_config()

    if executor_type is None:
        executor_type = ""

    if executing_eagerly or len(tout):
        if f.stateful_ops:
            outputs = gen_functional_ops.stateful_partitioned_call(
                args=args,
                Tout=tout,
                f=f,
                config=config,
                executor_type=executor_type)
        else:
            outputs = gen_functional_ops.partitioned_call(
                args=args,
                Tout=tout,
                f=f,
                config=config,
                executor_type=executor_type)
        return outputs if outputs else None

    # The generated binding returns an empty list for functions that don't
    # return any Tensors, hence the need to use `create_op` directly.
    args = [ops.internal_convert_to_tensor(x) for x in args]
    tin_attr = attr_value_pb2.AttrValue(
        list=attr_value_pb2.AttrValue.ListValue(
            type=[x.dtype.as_datatype_enum for x in args]))
    tout_attr = attr_value_pb2.AttrValue(
        list=attr_value_pb2.AttrValue.ListValue(type=tout))
    func_attr = attr_value_pb2.AttrValue(func=attr_value_pb2.NameAttrList(
        name=f.name))
    executor_type_attr = attr_value_pb2.AttrValue(
        s=compat.as_bytes(executor_type))

    # When running in graph mode, the graph and function graphs are optimized
    # (i.e. run through grappler) per the session options, so we can disable any
    # eager-specific rewriting.
    rewriter_config = attr_value_pb2.AttrValue(
        s=_get_disabled_rewriter_config())

    graph = ops.get_default_graph()
    f.add_to_graph(graph)
    op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
    op = graph.create_op(op_name,
                         args,
                         tout,
                         compute_shapes=False,
                         name="PartitionedFunctionCall",
                         attrs={
                             "Tin": tin_attr,
                             "Tout": tout_attr,
                             "f": func_attr,
                             "config": rewriter_config,
                             "executor_type": executor_type_attr,
                         })
    outputs = op.outputs
    return outputs if outputs else op
class TestFoldConstant(unittest.TestCase):
    x_node = node_def_pb2.NodeDef()
    x_node.name = "placeholder"
    x_node.op = "Placeholder"

    input0_node = node_def_pb2.NodeDef()
    input0_node.name = "input0"
    input0_node.op = "Const"
    input0_value = np.float32(np.abs(np.random.randn(4, 3, 2)))
    input0_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input0_value, input0_value.dtype.type, input0_value.shape)))

    input1_node = node_def_pb2.NodeDef()
    input1_node.name = "input1"
    input1_node.op = "Const"
    input1_value = np.float32(np.abs(np.random.randn(4, 1, 1)))
    input1_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input1_value, input1_value.dtype.type, input1_value.shape)))

    add_node = node_def_pb2.NodeDef()
    add_node.op = "Add"
    add_node.name = "add"
    add_node.input.extend([input0_node.name, input1_node.name])

    input2_node = node_def_pb2.NodeDef()
    input2_node.name = "input2"
    input2_node.op = "Const"
    input2_value = np.float32(np.abs(np.random.randn(1)))
    input2_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input2_value, input2_value.dtype.type, input2_value.shape)))

    input3_node = node_def_pb2.NodeDef()
    input3_node.name = "input3"
    input3_node.op = "Const"
    input3_value = np.float32(np.abs(np.random.randn(1)))
    input3_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input3_value, input3_value.dtype.type, input3_value.shape)))

    switch_node = node_def_pb2.NodeDef()
    switch_node.name = "switch"
    switch_node.op = "Switch"

    input4_node = node_def_pb2.NodeDef()
    input4_node.name = "input4"
    input4_node.op = "Const"
    input4_value = np.float32(np.abs(np.random.randn(1)))
    input4_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input4_value, input4_value.dtype.type, input4_value.shape)))
    input4_node.input.extend([switch_node.name])

    input5_node = node_def_pb2.NodeDef()
    input5_node.name = "input5"
    input5_node.op = "Const"
    input5_value = np.float32(np.abs(np.random.randn(1)))
    input5_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input5_value, input5_value.dtype.type, input5_value.shape)))
    input5_node.input.extend([switch_node.name])

    cond_end = node_def_pb2.NodeDef()
    cond_end.name = "cond"
    cond_end.op = "Add"
    cond_end.input.extend([input4_node.name, input5_node.name])

    mul_node = node_def_pb2.NodeDef()
    mul_node.op = "Mul"
    mul_node.name = "mul"
    mul_node.input.extend([add_node.name, input3_node.name])

    sqrt_node = node_def_pb2.NodeDef()
    sqrt_node.name = "rsqrt"
    sqrt_node.op = "Rsqrt"
    sqrt_node.input.extend([mul_node.name])

    relu_node = node_def_pb2.NodeDef()
    relu_node.op = "Relu"
    relu_node.name = "relu"
    relu_node.input.extend([sqrt_node.name])

    block_node = node_def_pb2.NodeDef()
    block_node.name = "block_output"
    block_node.op = "Add"
    block_node.input.extend([x_node.name, relu_node.name])

    res_node = node_def_pb2.NodeDef()
    res_node.name = "res_add"
    res_node.op = "Add"
    res_node.input.extend([sqrt_node.name, input2_node.name])

    end_node = node_def_pb2.NodeDef()
    end_node.name = "end"
    end_node.op = "Add"
    end_node.input.extend([block_node.name, res_node.name])

    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend([
        x_node, input0_node, input1_node, input2_node, input3_node, add_node,
        mul_node, sqrt_node, relu_node, block_node, res_node, end_node
    ])

    def test_fold_constant(self):

        graph = self.graph_def
        rewriter = GraphFoldConstantOptimizer(graph)
        new_graph = rewriter.do_transformation()

        for node in new_graph.node:
            assert node.name in [
                "placeholder", "block_output", "rsqrt_const", "relu",
                "res_add_const", "end"
            ]

    def test_condition_fold_constant(self):
        graph_def = graph_pb2.GraphDef()
        graph_def.node.extend([
            self.cond_end, self.input4_node, self.input5_node, self.switch_node
        ])
        rewriter = GraphFoldConstantOptimizer(graph_def)
        new_graph = rewriter.do_transformation()
        for node in new_graph.node:
            assert node.name in ["switch", "cond", "input4", "input5"]

    def test_slice_int_input(self):
        graph_def = graph_pb2.GraphDef()
        index0_node = node_def_pb2.NodeDef()
        index0_node.name = "index0"
        index0_node.op = "Const"
        index0_value = np.array(3).astype(np.int32).reshape(())
        index0_node.attr["value"].CopyFrom(
            attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                index0_value, index0_value.dtype.type, index0_value.shape)))

        index1_node = node_def_pb2.NodeDef()
        index1_node.name = "index1"
        index1_node.op = "Const"
        index1_value = np.array(1).astype(np.int32).reshape(())
        index1_node.attr["value"].CopyFrom(
            attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                index1_value, index1_value.dtype.type, index1_value.shape)))

        minus_node = node_def_pb2.NodeDef()
        minus_node.name = "sub"
        minus_node.op = "Sub"
        minus_node.input.extend([index0_node.name, index1_node.name])

        graph_def.node.extend([index0_node, index1_node, minus_node])
        rewriter = GraphFoldConstantOptimizer(graph_def)
        new_graph = rewriter.do_transformation()
        with tf.compat.v1.Session() as sess:
            tf.compat.v1.import_graph_def(new_graph)
Example #15
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None):
    """Replaces all the variables in a graph with constants of the same values.

  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.

  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.
    variable_names_whitelist: The set of variable names to convert (by default,
                              all variables are converted).

  Returns:
    GraphDef containing a simplified version of the original.
  """
    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    found_variables = {}
    variable_names = []
    variable_dict_names = []
    for node in inference_graph.node:
        if node.op == "Variable":
            variable_name = node.name
            if (variable_names_whitelist is not None
                    and variable_name not in variable_names_whitelist):
                continue
            variable_dict_names.append(variable_name)
            variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))
    logging.info("Froze %d variables." % len(returned_variables))

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])
    print("Converted %d variables to const ops." % how_many_converted)
    return output_graph_def
Example #16
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True,
               back_prop=True):
  """Like tf.while_loop, except emits a single While op."""
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
      expand_composites=True)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants,
                               expand_composites=False)
    signature = nest.map_structure(
        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
        list(shape_invariants), expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        list(shape_invariants), expand_composites=False)

  else:
    signature = nest.map_structure(
        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        expand_composites=False)
  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")
    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
        maximum_iterations)
    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations_loop_var.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
    signature = (
        [tensor_spec.TensorSpec.from_tensor(loop_counter),
         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
        signature)

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      cond_graph_captures = object_identity.ObjectIdentitySet(
          cond_graph.external_captures)
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph_captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
                                             expand_composites=True))
    # First var is loop counter and second var is maximum_iterations.
    first_loop_var_index = 2
    _check_shapes_compat(
        body_graph.outputs[first_loop_var_index:first_loop_var_index +
                           num_flattened_outputs],
        nest.flatten(
            shape_invariants[first_loop_var_index:first_loop_var_index +
                             len_orig_loop_vars], expand_composites=True),
        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                               len_orig_loop_vars], expand_composites=True))

    num_original_outputs = len(body_graph.outputs)
    if back_prop and util.output_all_intermediates():
      # Export all tensors in the loop body that may be needed for gradient
      # computation. We do this by accumulating the intermediate values in
      # TensorLists.
      intermediate_tensors = _get_intermediates(body_graph)

      for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)
        loop_vars.append(tensor_list)
        with cond_graph.as_default():
          # Add a placeholder to cond_graph's inputs corresponding to the
          # tensor_list.
          cond_graph.capture(tensor_list)
        with body_graph.as_default():
          # Push the intermediate tensor to the tensor list. This captures the
          # `tensor_list` as well.
          appended_tensor_list = list_ops.tensor_list_push_back(
              tensor_list, intermediate_tensor)
          # Add this modified tensor list to the list of outputs.
          body_graph.outputs.append(appended_tensor_list)

    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))
    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

    with ops.control_dependencies(
        list(cond_graph.control_captures) + list(body_graph.control_captures)):
      output_shapes = [t.shape for t in body_graph.outputs]
      orig_loop_vars_range = slice(first_loop_var_index,
                                   first_loop_var_index + num_flattened_outputs)
      output_shapes[orig_loop_vars_range] = nest.flatten(
          shape_invariants, expand_composites=True)[orig_loop_vars_range]

      cond_stateful_ops = [
          op for op in cond_graph.get_operations() if op._is_stateful
      ]
      body_stateful_ops = [
          op for op in body_graph.get_operations() if op._is_stateful
      ]
      if (cond_stateful_ops or body_stateful_ops):
        op_fn = gen_functional_ops._while
      else:
        op_fn = gen_functional_ops.stateless_while

      outputs = op_fn(
          flattened_loop_vars,
          util.create_new_tf_function(cond_graph),
          util.create_new_tf_function(body_graph),
          output_shapes=output_shapes,
          parallel_iterations=parallel_iterations,
          name=scope)
      # This is needed so we do not compute derivative wrt these extra outputs.
      outputs[0].op._set_attr("_num_original_outputs",
                              attr_value_pb2.AttrValue(i=num_original_outputs))

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  outputs = _pack_sequence_as(
      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                              num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs, expand_composites=True)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
Example #17
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None,
                        constraint=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.

    @compatibility(eager)
    When Eager Execution is enabled, variables are never added to collections.
    It is not implicitly added to the GLOBAL_VARIABLES or TRAINABLE_VARIABLES
    collections, and the `collections` argument is ignored.
    @end_compatibility
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        self._trainable = trainable
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        self._in_graph_mode = context.in_graph_mode()
        with ops.control_dependencies(None):
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops._name_from_scope_name(name)
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    if self._in_graph_mode:
                        attr = attr_value_pb2.AttrValue(
                            list=attr_value_pb2.AttrValue.ListValue(
                                s=[compat.as_bytes("loc:@%s" % handle_name)]))
                        with ops.get_default_graph()._attr_scope(
                            {"_class": attr}):
                            with ops.name_scope("Initializer"), ops.device(
                                    None):
                                initial_value = ops.convert_to_tensor(
                                    initial_value(),
                                    name="initial_value",
                                    dtype=dtype)
                            self._handle = _eager_safe_variable_handle(
                                shape=initial_value.get_shape(),
                                dtype=initial_value.dtype.base_dtype,
                                shared_name=handle_name,
                                name=name,
                                graph_mode=self._in_graph_mode)
                            self._handle_device = (
                                self._handle.device if self._in_graph_mode else
                                context.get_default_context().device_name)
                            self._graph_shape = initial_value.get_shape()
                    else:
                        initial_value = initial_value()
                        with ops.name_scope("Initializer"):
                            initial_value = ops.convert_to_tensor(
                                initial_value,
                                name="initial_value",
                                dtype=dtype)
                        self._handle = _eager_safe_variable_handle(
                            shape=initial_value.get_shape(),
                            dtype=initial_value.dtype.base_dtype,
                            shared_name=handle_name,
                            name=name,
                            graph_mode=False)
                        self._handle_device = (
                            self._handle.device if self._in_graph_mode else
                            context.get_default_context().device_name)
                        self._graph_shape = initial_value.get_shape()
                # pylint: enable=protected-access

                # Or get the initial value from a Tensor or Python object.
                else:
                    with ops.name_scope("Initializer"):
                        initial_value = ops.convert_to_tensor(
                            initial_value, name="initial_value", dtype=dtype)
                    # pylint: disable=protected-access
                    if (self._in_graph_mode and initial_value is not None
                            and initial_value.op._get_control_flow_context()
                            is not None):
                        raise ValueError(
                            "Initializer for variable %s is from inside a control-flow "
                            "construct, such as a loop or conditional. When creating a "
                            "variable inside a loop or conditional, use a lambda as the "
                            "initializer." % name)
                    # pylint: enable=protected-access
                    self._handle = _eager_safe_variable_handle(
                        shape=initial_value.get_shape(),
                        dtype=initial_value.dtype.base_dtype,
                        shared_name=handle_name,
                        name=name,
                        graph_mode=self._in_graph_mode)
                    self._handle_device = (
                        self._handle.device if self._in_graph_mode else
                        context.get_default_context().device_name)
                    self._graph_shape = initial_value.get_shape()

                self._initial_value = initial_value if self._in_graph_mode else None
                self._handle_name = handle_name + ":0"
                self._dtype = initial_value.dtype.base_dtype
                self._constraint = constraint

                if self._in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            gen_resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = (
                                gen_resource_variable_ops.assign_variable_op(
                                    self._handle,
                                    self._build_initializer_expr(
                                        initial_value),
                                    name=n))
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle_device):
                            value = self._read_variable_op()
                        self._graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    self._cached_value = array_ops.identity(
                                        value)
                        else:
                            self._cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        self._handle, initial_value)
                    self._is_initialized_op = None
                    self._initializer_op = None
                    self._graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            self._cached_value = self._read_variable_op()
                    else:
                        self._cached_value = None
                if context.in_graph_mode():
                    ops.add_to_collections(collections, self)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
Example #18
0
    def do_transformation(self):
        """Removes batch normalization ops by folding them into convolutions.

        Batch normalization during training has multiple dynamic parameters that are
        updated, but once the graph is finalized these become constants. That means
        there's an opportunity to reduce the computations down to a scale and
        addition, rather than the more expensive multiple ops, and even bake the
        scaling into the convolution weights. This function identifies the typical
        pattern of batch normalization subgraphs, and performs the transformation to
        fold the computations down into a simpler form. It currently only spots batch
        normalization that's performed by the BatchNormWithGlobalNormalization and
        FusedBatchNorm ops, and will need to be extended in the future to handle the
        newer style.

        Returns:
          Modified graph with BN ops removed, and modified weights.

        Raises:
          ValueError: If the graph is badly formed with duplicate node names.
        """
        cur_graph = GraphAnalyzer()
        cur_graph.graph = self.model

        graph_info = cur_graph.parse_graph()
        target_nodes = cur_graph.query_fusion_pattern_nodes(
            [["Conv2D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2"),
             ["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3"]])
        for node_combination in target_nodes:
            matched_node = node_combination[:-1]
            has_add_op = True if len(node_combination[-1]) == 3 else False
            conv_node = graph_info[Helper.node_name_from_input(matched_node[0])].node
            weights_node_name = graph_info[Helper.node_name_from_input(
                matched_node[0])].node.input[1]
            weights_node = graph_info[Helper.node_name_from_input(weights_node_name)].node
            bn_node = graph_info[Helper.node_name_from_input(matched_node[-1])].node

            if weights_node.op != "Const":
                self.logger.warning("Didn't find expected conv Constant input to '%s',"
                                    " found %s instead. Maybe because freeze_graph wasn't"
                                    " run first?" % (bn_node.name, weights_node_name))
                continue
            weights = Helper.values_from_const(weights_node)

            if conv_node.op == "Conv2D":
                channel_count = weights.shape[3]
            elif conv_node.op == "DepthwiseConv2dNative":
                channel_count = weights.shape[2] * weights.shape[3]

            mean_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("mean_op")])
            mean_node = graph_info[mean_node_name].node

            if mean_node.op != "Const":
                continue

            mean_value = Helper.values_from_const(mean_node)

            if has_add_op:
                bias_node_name = graph_info[Helper.node_name_from_input(
                    matched_node[1])].node.input[1]
                bias_node = graph_info[Helper.node_name_from_input(bias_node_name)].node
                if bias_node.op != "Const":
                    continue

                if mean_value.shape != (channel_count, ):
                    continue

                mean_value = mean_value - Helper.values_from_const(bias_node)
                cur_graph.remove_node(bias_node.name)
                cur_graph.remove_node(matched_node[1])

            if mean_value.shape != (channel_count, ):
                self.logger.warning("Incorrect shape for mean, found %s, expected %s,"
                                    " for node %s" % (str(mean_value.shape), str(
                                        (channel_count, )), conv_node.name))
                continue
            var_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("var_op")])
            var_node = graph_info[var_node_name].node
            if var_node.op != "Const":
                continue
            var_value = Helper.values_from_const(var_node)

            if var_value.shape != (channel_count, ):
                continue

            beta_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("beta_op")])
            beta_node = graph_info[beta_node_name].node
            if beta_node.op != "Const":
                continue
            beta_value = Helper.values_from_const(beta_node)

            if beta_value.shape != (channel_count, ):
                continue

            gamma_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("gamma_op")])
            gamma_node = graph_info[gamma_node_name].node

            if gamma_node.op != "Const":
                continue
            gamma_value = Helper.values_from_const(gamma_node)

            if gamma_value.shape != (channel_count, ):
                continue

            variance_epsilon_value = bn_node.attr[self.EPSILON_ATTR[bn_node.op]].f

            if self.scale_after_normalization(bn_node):
                scale_value = (
                    (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) *
                    gamma_value)
            else:
                scale_value = (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value))

            offset_value = (-mean_value * scale_value) + beta_value


            if conv_node.op == "Conv2D":
                original_shape =weights.shape
                tmp_shape = (original_shape[-1], int(weights.size/original_shape[-1]))
                tmp_order = [weights.ndim - 1] + [i for i in range(weights.ndim - 1)]
                scaled_weights = np.copy(weights).transpose(tmp_order).ravel().reshape(tmp_shape)
                reshape_scale = np.array(scale_value).reshape(len(scale_value), 1)
                scaled_weights = np.multiply(
                    scaled_weights, reshape_scale).transpose().reshape(original_shape)
            elif conv_node.op == "DepthwiseConv2dNative":
                scaled_weights = np.copy(weights)
                it = np.nditer(scaled_weights, flags=["multi_index"], op_flags=["readwrite"])
                channel_multiplier = weights.shape[3]
                while not it.finished:
                    current_scale = scale_value[it.multi_index[2] * channel_multiplier +
                                                it.multi_index[3]]
                    it[0] *= current_scale
                    it.iternext()

            scaled_weights_node = node_def_pb2.NodeDef()
            scaled_weights_node.op = "Const"
            scaled_weights_node.name = weights_node_name + "_bn_offset"
            scaled_weights_node.attr["dtype"].CopyFrom(weights_node.attr["dtype"])
            scaled_weights_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    scaled_weights, weights.dtype.type, weights.shape)))
            cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name)

            offset_node = node_def_pb2.NodeDef()
            offset_node.op = "Const"
            offset_node.name = conv_node.name + "_bn_offset"
            offset_node.attr["dtype"].CopyFrom(mean_node.attr["dtype"])
            offset_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    offset_value, mean_value.dtype.type, offset_value.shape)))
            bias_add_node = node_def_pb2.NodeDef()
            bias_add_node.op = "BiasAdd"
            bias_add_node.name = bn_node.name
            bias_add_node.attr["T"].CopyFrom(conv_node.attr["T"])
            bias_add_node.attr["data_format"].CopyFrom(conv_node.attr["data_format"])
            bias_add_node.input.extend([conv_node.name, offset_node.name])

            cur_graph.add_node(offset_node, [], [bias_add_node.name])
            cur_graph.add_node(bias_add_node, conv_node.name,
                               graph_info[Helper.node_name_from_input(matched_node[-1])].outputs)
            cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name)

            cur_graph.remove_node(weights_node_name)
            cur_graph.remove_node(mean_node_name)
            cur_graph.remove_node(var_node_name)
            cur_graph.remove_node(beta_node_name)
            cur_graph.remove_node(gamma_node_name)

        return cur_graph.dump_graph()
Example #19
0
    def AddOp(self, op):
        """Create op in XLACompileContext and notifies outer context recursively."""
        # pylint: disable=protected-access
        if op.type in _BLACKLISTED_OPS:
            logging.error(
                'Operation of type %s (%s) is not supported in XLA. Execution will '
                'fail if this op is used in the graph. ', op.type, op.name)

        # TODO(ycao): Automatically disable summaries instead of reporting them.
        if op.type in _UNSUPPORTED_OPS:
            self._unsupported_ops.append(op)

        if any(x.dtype._is_ref_dtype for x in op.inputs):
            raise NotImplementedError(
                'Non-resource Variables are not supported inside XLA computations '
                '(operator name: %s)' % op.name)

        if _XLA_COMPILE_ATTR in op.node_def.attr:
            raise ValueError(
                'XLA compiled computations cannot be nested, (operator '
                'name: %s)' % op.name)

        op._set_attr(_XLA_COMPILE_ATTR,
                     attr_value_pb2.AttrValue(s=self._name_as_bytes))

        op.graph.prevent_feeding(op)
        op.graph.prevent_fetching(op)

        # Remove any control edges from outer control flow contexts. These may cause
        # mismatched frame errors. An example is when one of op's inputs is
        # generated in a different While control flow context.
        (internal_control_inputs,
         external_control_inputs) = self._RemoveExternalControlEdges(op)

        if not op.inputs:
            # Add a control edge from the control pivot to this op.
            if not internal_control_inputs:
                # pylint: disable=protected-access
                op._add_control_input(self._pivot)
                # pylint: enable=protected-access
        else:
            for index in xrange(len(op.inputs)):
                x = op.inputs[index]
                real_x = self.AddValue(x)
                if real_x != x:
                    op._update_input(index, real_x)  # pylint: disable=protected-access

        if external_control_inputs:
            # Use an identity to pull control inputs as data inputs. Note that we
            # ignore ops which don't have outputs. TODO(phawkins): fix that.
            with ops.control_dependencies(None):
                self.Enter()
                external_control_inputs = [
                    array_ops.identity(x.outputs[0]).op
                    for x in external_control_inputs if x.outputs
                ]
                self.Exit()
            # pylint: disable=protected-access
            op._add_control_inputs(external_control_inputs)
            # pylint: enable=protected-access

        # Mark op's outputs as seen by this context and any outer contexts.
        output_names = [x.name for x in op.outputs]
        context = self
        while context is not None:
            # pylint: disable=protected-access
            context._values.update(output_names)
            context = context._outer_context
            # pylint: enable=protected-access

        if self._outer_context:
            self._outer_context.AddInnerOp(op)
Example #20
0
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_tensor_names: A list of the nodes we use as inputs.
    output_tensor_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed. and a map containing the old input
    names to the new input names

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
    for name in input_tensor_names:
        if ":" not in name:
            raise ValueError("Input '%s' appears to refer to a Operation, "
                             "not a Tensor." % name)

    old2new = {}

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    not_found = {name for name in input_tensor_names}
    input_node_names = {name.split(":")[0] for name in input_tensor_names}
    output_node_names = list(
        {name.split(":")[0]
         for name in output_tensor_names})
    inputs_replaced_graph_def = graph_pb2.GraphDef()
    for node in input_graph_def.node:
        if node.name not in input_node_names:
            for i in range(len(node.input)):
                if _append_port(node.input[i]) in input_tensor_names:
                    old_name = _append_port(node.input[i])
                    not_found.remove(old_name)
                    new_input_name = node.input[i].replace(":", "_")
                    placeholder_node = node_def_pb2.NodeDef()
                    placeholder_node.op = "Placeholder"
                    placeholder_node.name = new_input_name
                    if isinstance(placeholder_type_enum, list):
                        input_node_index = input_tensor_names.index(old_name)
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum[input_node_index]))
                    else:
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum))
                    if "_output_shapes" in node.attr:
                        placeholder_node.attr["_output_shapes"].CopyFrom(
                            node.attr["_output_shapes"])
                    node.input[i] = new_input_name
                    old2new[old_name] = new_input_name + ":0"
                    inputs_replaced_graph_def.node.extend([placeholder_node])
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    if not_found:
        raise KeyError("The following input nodes were not found: %s\n" %
                       not_found)

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def, old2new
Example #21
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None,
                                   use_fp16=False):
    from tensorflow.python.framework.graph_util_impl import extract_sub_graph
    from tensorflow.core.framework import graph_pb2
    from tensorflow.core.framework import node_def_pb2
    from tensorflow.core.framework import attr_value_pb2
    from tensorflow.core.framework import types_pb2
    from tensorflow.python.framework import tensor_util

    def patch_dtype(input_node, field_name, output_node):
        if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT):
            output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF))

    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    variable_names = []
    variable_dict_names = []
    for node in inference_graph.node:
        if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = node.name
            if ((variable_names_whitelist is not None and
                 variable_name not in variable_names_whitelist) or
                    (variable_names_blacklist is not None and
                     variable_name in variable_names_blacklist)):
                continue
            variable_dict_names.append(variable_name)
            if node.op == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]

            if use_fp16 and dtype.type == types_pb2.DT_FLOAT:
                output_node.attr["value"].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(data.astype('float16'),
                                                             dtype=types_pb2.DT_HALF,
                                                             shape=data.shape)))
            else:
                output_node.attr["dtype"].CopyFrom(dtype)
                output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(
                    tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type,
                                                         shape=data.shape)))
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables):
            # placeholder nodes
            # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"]))
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        else:
            # mostly op nodes
            output_node.CopyFrom(input_node)

        patch_dtype(input_node, 'dtype', output_node)
        patch_dtype(input_node, 'T', output_node)
        patch_dtype(input_node, 'DstT', output_node)
        patch_dtype(input_node, 'SrcT', output_node)
        patch_dtype(input_node, 'Tparams', output_node)

        if use_fp16 and ('value' in output_node.attr) and (
                output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT):
            # hard-coded value need to be converted as well
            output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
                tensor=tensor_util.make_tensor_proto(
                    output_node.attr['value'].tensor.float_val[0],
                    dtype=types_pb2.DT_HALF)))

        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)
    return output_graph_def
Example #22
0
def set_attr_dtype(node, key, value):
    try:
        node.attr[key].CopyFrom(
            attr_value_pb2.AttrValue(type=value.as_datatype_enum))
    except KeyError:
        pass
Example #23
0
 def set_attr_dtype(node, key, value):
     """Set the attribute data type
     """
     node.attr[key].CopyFrom(attr_value_pb2.AttrValue(type=value.as_datatype_enum))
Example #24
0
def set_attr_int_list(node, key, value):
    list_value = attr_value_pb2.AttrValue.ListValue(i=value)
    try:
        node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
    except KeyError:
        pass
Example #25
0
 def set_attr_bool(node, key, value):
     """Set the node's attr which data type is bool.
     """
     node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
Example #26
0
def set_attr_float(node, key, value):
    try:
        node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value))
    except KeyError:
        pass
Example #27
0
    def AddOp(self, op):
        # pylint: disable=protected-access
        if op.type in _BLACKLISTED_OPS:
            logging.error(
                "Operation of type %s (%s) is not supported on the TPU. "
                "Execution will fail if this op is used in the graph. " %
                (op.type, op.name))

        if op.type in _UNSUPPORTED_OPS:
            self._unsupported_ops.append(op)

        if any(x.dtype._is_ref_dtype for x in op.inputs):
            raise NotImplementedError(
                "Non-resource Variables are not supported inside TPU computations "
                "(operator name: %s)" % op.name)
        if _TPU_REPLICATE_ATTR in op.node_def.attr:
            raise ValueError("TPU computations cannot be nested")
        op._set_attr(_TPU_REPLICATE_ATTR,
                     attr_value_pb2.AttrValue(s=self._name_as_bytes))
        if self._outside_compilation_cluster:
            op._set_attr(
                _OUTSIDE_COMPILATION_ATTR,
                attr_value_pb2.AttrValue(
                    s=compat.as_bytes(self._outside_compilation_cluster)))
        if self._num_replicas > 1 or not self._outside_compilation_cluster:
            # Prevent feeding or fetching anything that is being compiled,
            # and any replicated outside_compilation Op.
            op.graph.prevent_feeding(op)
            op.graph.prevent_fetching(op)

        # Remove any control edges from outer control flow contexts. These may cause
        # mismatched frame errors.
        (internal_control_inputs,
         external_control_inputs) = self._RemoveExternalControlEdges(op)

        if not op.inputs:
            # Add a control edge from the control pivot to this op.
            if not internal_control_inputs:
                # pylint: disable=protected-access
                op._add_control_input(self.GetControlPivot())
                # pylint: enable=protected-access
        else:
            for index in xrange(len(op.inputs)):
                x = op.inputs[index]
                real_x = self.AddValue(x)
                if real_x != x:
                    op._update_input(index, real_x)  # pylint: disable=protected-access

        if external_control_inputs:
            # Use an identity to pull control inputs as data inputs. Note that we
            # ignore ops which don't have outputs. TODO(phawkins): fix that.
            with ops.control_dependencies(None):
                self.Enter()
                external_control_inputs = [
                    array_ops.identity(x.outputs[0]).op
                    for x in external_control_inputs if x.outputs
                ]
                self.Exit()
            # pylint: disable=protected-access
            op._add_control_inputs(external_control_inputs)
            # pylint: enable=protected-access

        # Mark op's outputs as seen by this context and any outer contexts.
        output_names = [x.name for x in op.outputs]
        context = self
        while context is not None:
            # pylint: disable=protected-access
            context._values.update(output_names)
            context = context._outer_context
            # pylint: enable=protected-access

        if self._outer_context:
            self._outer_context.AddInnerOp(op)
Example #28
0
    def _init_from_args(
        self,
        initial_value=None,
        trainable=None,
        collections=None,
        caching_device=None,
        name=None,
        dtype=None,
        constraint=None,
        synchronization=None,
        aggregation=None,
        distribute_strategy=None,
        shape=None,
    ):
        """Creates a variable.

        Args:
          initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
            which is the initial value for the Variable. The initial value must have
            a shape specified unless `validate_shape` is set to False. Can also be a
            callable with no argument that returns the initial value when called.
            (Note that initializer functions from init_ops.py must first be bound
             to a shape before being used here.)
          trainable: If `True`, the default, also adds the variable to the graph
            collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
            the default list of variables to use by the `Optimizer` classes.
            Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
            which case it defaults to `False`.
          collections: List of graph collections keys. The new variable is added to
            these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
          caching_device: Optional device string or function describing where the
            Variable should be cached for reading.  Defaults to the Variable's
            device.  If not `None`, caches on another device.  Typical use is to
            cache on the device where the Ops using the Variable reside, to
            deduplicate copying through `Switch` and other conditional statements.
          name: Optional name for the variable. Defaults to `'Variable'` and gets
            uniquified automatically.
          dtype: If set, initial_value will be converted to the given type.
            If None, either the datatype will be kept (if initial_value is
           a Tensor) or float32 will be used (if it is a Python object convertible
           to a Tensor).
          constraint: An optional projection function to be applied to the variable
            after being updated by an `Optimizer` (e.g. used to implement norm
            constraints or value constraints for layer weights). The function must
            take as input the unprojected Tensor representing the value of the
            variable and return the Tensor for the projected value
            (which must have the same shape). Constraints are not safe to
            use when doing asynchronous distributed training.
          synchronization: Indicates when a distributed a variable will be
            aggregated. Accepted values are constants defined in the class
            `tf.VariableSynchronization`. By default the synchronization is set to
            `AUTO` and the current `DistributionStrategy` chooses
            when to synchronize.
          aggregation: Indicates how a distributed variable will be aggregated.
            Accepted values are constants defined in the class
            `tf.VariableAggregation`.
          distribute_strategy: DistributionStrategy under which this variable
            was created.
          shape: (optional) The shape of this variable. If None, the shape of
            `initial_value` will be used. When setting this argument to
            `tf.TensorShape(None)` (representing an unspecified shape), the variable
            can be assigned with values of different shapes.

        Raises:
          ValueError: If the initial value is not specified, or does not have a
            shape and `validate_shape` is `True`.

        @compatibility(eager)
        When Eager Execution is enabled, variables are never added to collections.
        It is not implicitly added to the `GLOBAL_VARIABLES` or
        `TRAINABLE_VARIABLES` collections, and the `collections` argument is
        ignored.
        @end_compatibility
        """
        (
            synchronization,
            aggregation,
            trainable,
        ) = variables.validate_synchronization_aggregation_trainable(
            synchronization, aggregation, trainable, name)
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if (isinstance(initial_value, ops.Tensor)
                and hasattr(initial_value, "graph")
                and initial_value.graph.building_function):
            raise ValueError(
                "Tensor-typed variable initializers must either be "
                "wrapped in an init_scope or callable "
                "(e.g., `tf.Variable(lambda : "
                "tf.truncated_normal([10, 40]))`) when building "
                "functions. Please file a feature request if this "
                "restriction inconveniences you.")

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if isinstance(initial_value, trackable.CheckpointInitialValue):
            self._maybe_initialize_trackable()
            self._update_uid = initial_value.checkpoint_position.restore_uid
            initial_value = initial_value.wrapped_value

        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
            with ops.name_scope(
                    name, "TrainableWrapper",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops.name_from_scope_name(name)
                handle_name = handle_name or "TrainableWrapperHandle"
                if self._in_graph_mode:
                    shared_name = handle_name
                    unique_id = shared_name
                else:
                    # When in eager mode use a uid for the shared_name, to prevent
                    # accidental sharing.
                    unique_id = "%s_%d" % (handle_name, ops.uid())
                    shared_name = None
                # Use attr_scope and device(None) to simulate the behavior of
                # colocate_with when the variable we want to colocate with doesn't
                # yet exist.
                device_context_manager = (ops.device if self._in_graph_mode
                                          else ops.NullContextmanager)
                attr = attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(
                        s=[compat.as_bytes("loc:@%s" % handle_name)]))
                with ops.get_default_graph()._attr_scope({"_class": attr}):
                    with ops.name_scope("Initializer"), device_context_manager(
                            None):
                        initial_value = ops.convert_to_tensor(
                            initial_value() if init_from_fn else initial_value,
                            name="initial_value",
                            dtype=dtype,
                        )
                    if shape is None:
                        shape = initial_value.shape
                    handle = resource_variable_ops.eager_safe_variable_handle(
                        initial_value=initial_value,
                        shape=None,  # shape,
                        shared_name=shared_name,
                        name=name,
                        graph_mode=self._in_graph_mode,
                    )
                # pylint: disable=protected-access
                if (self._in_graph_mode and initial_value is not None
                        and initial_value.op._get_control_flow_context()
                        is not None):
                    raise ValueError(
                        "Initializer for variable %s is from inside a control-flow "
                        "construct, such as a loop or conditional. When creating a "
                        "variable inside a loop or conditional, use a lambda as the "
                        "initializer." % name)
                # pylint: enable=protected-access
                dtype = initial_value.dtype.base_dtype

                if self._in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        is_initialized_op = (gen_resource_variable_ops.
                                             var_is_initialized_op(handle))
                    if initial_value is not None:
                        # pylint: disable=g-backslash-continuation
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                None, ignore_existing=True), ops.device(
                                    handle.device):
                            # pylint: disable=protected-access
                            initializer_op = gen_resource_variable_ops.assign_variable_op(
                                handle,
                                variables.
                                _try_guard_against_uninitialized_dependencies(
                                    name, initial_value),
                                name=n,
                            )
                            # pylint: enable=protected-access
                        # pylint: enable=g-backslash-continuation
                    with ops.name_scope("Read"):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(handle.device):
                            with ops.control_dependencies([
                                    gen_resource_variable_ops.
                                    assign_variable_op(
                                        handle,
                                        self.prefetch_values(),
                                        name="AssignBeforeInitRead",
                                    )
                            ]):
                                value = gen_resource_variable_ops.read_variable_op(
                                    handle, dtype)
                        graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    cached_value = array_ops.identity(value)
                        else:
                            cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        handle, initial_value)
                    is_initialized_op = None
                    initializer_op = None
                    graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            with ops.control_dependencies([
                                    gen_resource_variable_ops.
                                    assign_variable_op(
                                        handle,
                                        self.prefetch_values(),
                                        name="AssignBeforeInitRead",
                                    )
                            ]):
                                cached_value = (
                                    gen_resource_variable_ops.read_variable_op(
                                        handle, dtype))
                    else:
                        cached_value = None
                if not context.executing_eagerly():
                    # Eager variables are only added to collections if they are part of an
                    # eager variable store (otherwise in an interactive session they would
                    # hog memory and cause OOM). This is done in ops/variable_scope.py.
                    ops.add_to_collections(collections, self)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
            initial_value = initial_value if self._in_graph_mode else None
            super(resource_variable_ops.ResourceVariable, self).__init__(
                trainable=trainable,
                shape=shape,
                dtype=dtype,
                handle=handle,
                synchronization=synchronization,
                constraint=constraint,
                aggregation=aggregation,
                distribute_strategy=distribute_strategy,
                name=name,
                unique_id=unique_id,
                handle_name=handle_name,
                graph_element=graph_element,
                initial_value=initial_value,
                initializer_op=initializer_op,
                is_initialized_op=is_initialized_op,
                cached_value=cached_value,
            )
def generate_output_graph(input_graph_def, input_node_map, output_node_map,
                          fuse_op_list, fuse_op_deq_list):
    output_graph_def = graph_pb2.GraphDef()
    skip_list = []
    skip_node_name = []
    int8_type = dtypes.qint8.as_datatype_enum
    uint8_type = dtypes.quint8.as_datatype_enum
    float32_type = dtypes.float32.as_datatype_enum
    qint32_type = dtypes.qint32.as_datatype_enum
    for index, node in enumerate(input_graph_def.node):
        if index in fuse_op_list:
            const_node_1 = input_graph_def.node[index + 1]
            const_node_2 = input_graph_def.node[index + 2]
            requantize_node = input_graph_def.node[index + 3]
            new_node = node_def_pb2.NodeDef()

            new_node.op = node.op + "AndRequantize"
            new_node.name = requantize_node.name
            for _, value in enumerate(node.input):
                new_node.input.append(value)

            new_node.input.append(const_node_1.name)
            new_node.input.append(const_node_2.name)

            new_node.attr["Tinput"].CopyFrom(node.attr['Tinput'])
            new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter'])
            new_node.attr["strides"].CopyFrom(node.attr['strides'])
            new_node.attr["padding"].CopyFrom(node.attr['padding'])
            if input_node_map[new_node.input[0]].op.find("Requantize") != -1:
                bias_node = input_node_map[new_node.input[2]]
                last_node = input_node_map[new_node.input[0]]
                max_input_node = (input_node_map[last_node.input[4][:-2]])
                min_input_node = (input_node_map[last_node.input[3][:-2]])
                max_filter = input_node_map[new_node.input[6]]
                min_filter = input_node_map[new_node.input[5]]

                min_input = (min_input_node.attr['value'].tensor.float_val)[0]
                max_input = (max_input_node.attr['value'].tensor.float_val)[0]
                if 'Depthwise' in node.op or "RequantizePerChannel" in [
                        node.op for node in output_node_map[node.name]
                ]:

                    channel_size = max_filter.attr[
                        'value'].tensor.tensor_shape.dim[0].size
                    max_filter_tensor = tensor_util.MakeNdarray(
                        max_filter.attr['value'].tensor)
                    min_filter_tensor = tensor_util.MakeNdarray(
                        min_filter.attr['value'].tensor)
                else:

                    channel_size = 1
                    max_filter_tensor = []
                    min_filter_tensor = []
                    max_filter_tensor.append(
                        (max_filter.attr['value'].tensor.float_val)[0])
                    min_filter_tensor.append(
                        (min_filter.attr['value'].tensor.float_val)[0])

                bias_tensor = tensor_util.MakeNdarray(
                    input_node_map[new_node.input[2]].attr['value'].tensor)
                bias_length = bias_tensor.shape[0]
                scales = []
                for i in range(channel_size):
                    scales.append(255.0 * 127.0 /
                                  (max(abs(max_input), abs(min_input)) *
                                   max(abs(max_filter_tensor[i]),
                                       abs(min_filter_tensor[i]))))
                int32_bias = []
                if channel_size > 1:
                    for i in range(bias_length):
                        int32_bias.append((int)(bias_tensor[i] * scales[i]))
                else:
                    for i in range(bias_length):
                        int32_bias.append((int)(bias_tensor[i] * scales[0]))
                bias_node.attr['dtype'].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
                bias_node.attr['value'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(
                            int32_bias, dtypes.int32, bias_tensor.shape)))

                bias_node.attr['value'].tensor.dtype = qint32_type
                skip_node_name.append(bias_node.name)
                output_graph_def.node.extend([bias_node])
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
            else:
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=float32_type))

            if "padding_list" in node.attr:
                new_node.attr["padding_list"].CopyFrom(
                    node.attr['padding_list'])
            if "dilations" in node.attr:
                new_node.attr["dilations"].CopyFrom(node.attr['dilations'])

            if node.op == "QuantizedConv2D" or node.op == "QuantizedConv2DWithBias":
                new_node.attr["out_type"].CopyFrom(
                    attr_value_pb2.AttrValue(type=int8_type))
            else:
                new_node.attr["out_type"].CopyFrom(
                    attr_value_pb2.AttrValue(type=uint8_type))

            skip_list.append(index + 1)
            skip_list.append(index + 2)
            skip_list.append(index + 3)
            output_graph_def.node.extend(
                [new_node, const_node_1, const_node_2])
        elif index in skip_list or node.name in skip_node_name:
            continue
        elif node.op == "Dequantize":
            new_node = node_def_pb2.NodeDef()
            new_node.CopyFrom(node)
            new_node.attr["mode"].s = b"SCALED"
            p_node = input_node_map[new_node.input[0]]
            pp_node = input_node_map[p_node.name].input[0]
            if input_node_map[pp_node].op.find("Relu") != -1 or p_node.op in (
                    "QuantizedAvgPool", "QuantizedMaxPool",
                    "QuantizedConcatV2"):
                new_node.attr["T"].CopyFrom(
                    attr_value_pb2.AttrValue(type=uint8_type))
            elif input_node_map[pp_node].op.find(
                    "QuantizedMatMulWithBias") != -1 and p_node.op.find(
                        "Requantize") != -1:
                new_node.attr["mode"].s = node.attr["mode"].s
                new_node.attr["T"].CopyFrom(
                    attr_value_pb2.AttrValue(type=node.attr["T"].type))
            else:
                new_node.attr["T"].CopyFrom(
                    attr_value_pb2.AttrValue(type=int8_type))
            output_graph_def.node.extend([new_node])
        elif index in fuse_op_deq_list:
            original_summand_node = input_node_map[
                input_graph_def.node[index].input[-1]]
            sum_const_node_1 = input_graph_def.node[index + 1]
            sum_const_node_2 = input_graph_def.node[index + 2]
            sum_requantize_node = input_graph_def.node[index + 3]

            new_node = node_def_pb2.NodeDef()

            new_node.op = node.op + "AndRequantize"
            new_node.name = sum_requantize_node.name
            for _, value in enumerate(node.input[:-1]):
                new_node.input.append(value)
            new_node.input.append(sum_const_node_1.name)
            new_node.input.append(sum_const_node_2.name)
            new_node.input.append(
                input_node_map[original_summand_node.name].input[0])
            new_node.input.append(
                input_node_map[original_summand_node.name].input[0] + ":1")
            new_node.input.append(
                input_node_map[original_summand_node.name].input[0] + ":2")

            # skip_list.append(index + 1)
            # skip_list.append(index + 2)
            skip_list.append(index + 3)

            new_node.attr["Tinput"].CopyFrom(node.attr['Tinput'])
            new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter'])
            new_node.attr["strides"].CopyFrom(node.attr['strides'])
            new_node.attr["padding"].CopyFrom(node.attr['padding'])
            if input_node_map[new_node.input[0]].op.find("Requantize") != -1:

                bias_node = input_node_map[new_node.input[2]]
                last_node = input_node_map[new_node.input[0]]
                max_input_node = (input_node_map[last_node.input[4][:-2]])
                min_input_node = (input_node_map[last_node.input[3][:-2]])
                max_filter = input_node_map[new_node.input[6]]
                min_filter = input_node_map[new_node.input[5]]

                min_input = (min_input_node.attr['value'].tensor.float_val)[0]
                max_input = (max_input_node.attr['value'].tensor.float_val)[0]

                if "RequantizePerChannel" in [
                        node.op for node in output_node_map[node.name]
                ]:
                    channel_size = max_filter.attr[
                        'value'].tensor.tensor_shape.dim[0].size
                    max_filter_tensor = tensor_util.MakeNdarray(
                        max_filter.attr['value'].tensor)
                    min_filter_tensor = tensor_util.MakeNdarray(
                        min_filter.attr['value'].tensor)
                else:
                    channel_size = 1
                    max_filter_tensor = []
                    min_filter_tensor = []
                    max_filter_tensor.append(
                        (max_filter.attr['value'].tensor.float_val)[0])
                    min_filter_tensor.append(
                        (min_filter.attr['value'].tensor.float_val)[0])

                bias_tensor = (tensor_util.MakeNdarray(
                    input_node_map[new_node.input[2]].attr['value'].tensor))
                bias_length = bias_tensor.shape[0]
                scales = []
                for i in range(channel_size):
                    scales.append(255.0 * 127.0 /
                                  (max(abs(max_input), abs(min_input)) *
                                   max(abs(max_filter_tensor[i]),
                                       abs(min_filter_tensor[i]))))
                int32_bias = []
                if channel_size > 1:
                    for i in range(bias_length):
                        int32_bias.append(int(bias_tensor[i] * scales[i]))
                else:
                    for i in range(bias_length):
                        int32_bias.append(int(bias_tensor[i] * scales[0]))
                bias_node.attr['dtype'].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
                bias_node.attr['value'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(
                            int32_bias, dtypes.int32, bias_tensor.shape)))
                bias_node.attr['value'].tensor.dtype = qint32_type
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
                skip_node_name.append(bias_node.name)
                output_graph_def.node.extend([bias_node])
            else:
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=float32_type))

            if "padding_list" in node.attr:
                new_node.attr["padding_list"].CopyFrom(
                    node.attr['padding_list'])
            if "dilations" in node.attr:
                new_node.attr["dilations"].CopyFrom(node.attr['dilations'])

            new_node.attr["out_type"].CopyFrom(
                attr_value_pb2.AttrValue(type=uint8_type))

            summand_op_type = uint8_type if dtypes.as_dtype(
                original_summand_node.attr["T"].type
            ) == uint8_type else int8_type

            if summand_op_type == int8_type:
                new_node.op = "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"

            new_node.attr["Tsummand"].CopyFrom(
                attr_value_pb2.AttrValue(type=summand_op_type))
            output_graph_def.node.extend([new_node])
        else:
            new_node = node_def_pb2.NodeDef()
            new_node.CopyFrom(node)
            output_graph_def.node.extend([new_node])
    return output_graph_def
Example #30
0
    def apply_op(self, op_type_name, name=None, **keywords):
        # pylint: disable=g-doc-args
        """Add a node invoking a registered Op to a graph.

    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
        try:
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
            # pyline: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Cannot determine graph for Op '%s' due to: %s" %
                (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Check for deprecation
        deprecation_version = op_def.deprecation.version
        if deprecation_version:
            producer = g.graph_def_versions.producer
            if producer >= deprecation_version:
                raise NotImplementedError(
                    ("Op %s is not available in GraphDef version %d. "
                     "It has been removed in version %d. %s.") %
                    (op_type_name, producer, deprecation_version,
                     op_def.deprecation.explanation))

        # Fill in the list of default types for all "type" attrs.  This
        # will be used to choose a preferred dtype to convert to in the
        # absence of input type information.
        #
        # TODO(b/31302892): Currently the defaults don't work in the right
        # way if you have two inputs, one of whose type resolution depends
        # on the other.  Handling this will require restructuring this code
        # significantly.
        default_type_attr_map = {}
        for attr_def in op_def.attr:
            if attr_def.type != "type":
                continue
            key = attr_def.name
            if attr_def.HasField("default_value"):
                default_type_attr_map[key] = dtypes.as_dtype(
                    attr_def.default_value.type)

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                else:
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                        else:
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype
                                    break

                        # dtype still not found, prefer using the default dtype
                        # from the attr.
                        if dtype is None and input_arg.type_attr in default_type_attr_map:
                            default_dtype = default_type_attr_map[
                                input_arg.type_attr]

                    try:
                        if not input_arg.is_ref and dtype:
                            dtype = dtypes.as_dtype(dtype).base_dtype
                        values = ops.convert_n_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype if dtype else None,
                            preferred_dtype=default_dtype,
                            as_ref=input_arg.is_ref)
                        if input_arg.number_attr and len(
                                set(v.dtype.base_dtype for v in values)) > 1:
                            raise TypeError()  # All types should match.
                    except (TypeError, ValueError):
                        # What types does the conversion function think values have?
                        observed_types = []
                        for value in values:
                            try:
                                converted_value = ops.convert_to_tensor(
                                    value, as_ref=input_arg.is_ref)
                                observed_types.append(
                                    converted_value.dtype.base_dtype.name)
                            except (TypeError, ValueError):
                                observed_types.append(
                                    "<NOT CONVERTIBLE TO TENSOR>")
                        observed = ", ".join(observed_types)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.number_attr:
                            if input_arg.type != types_pb2.DT_INVALID:
                                raise TypeError(
                                    "%s that do not match expected type %s." %
                                    (prefix, dtype.name))
                            elif input_arg.type_attr in attrs:
                                raise TypeError(
                                    "%s that do not match type %s inferred from "
                                    "earlier arguments." %
                                    (prefix, dtype.name))
                            else:
                                raise TypeError("%s that don't all match." %
                                                prefix)
                        else:
                            raise TypeError("%s that are invalid." % prefix)

                    types = [x.dtype for x in values]
                    inputs.extend(values)
                else:
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    elif input_arg.type_attr in default_type_attr_map:
                        # The dtype could not be inferred solely from the inputs,
                        # so we prefer the attr's default, so code that adds a new attr
                        # with a default is backwards compatible.
                        default_dtype = default_type_attr_map[
                            input_arg.type_attr]

                    try:
                        values = ops.convert_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype,
                            as_ref=input_arg.is_ref,
                            preferred_dtype=default_dtype)
                    except TypeError as err:
                        if dtype is None:
                            raise err
                        else:
                            raise TypeError(
                                "Expected %s passed to parameter '%s' of op '%s', got %s of "
                                "type '%s' instead." %
                                (dtypes.as_dtype(dtype).name,
                                 input_arg.name, op_type_name, repr(values),
                                 type(values).__name__))
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                        observed = ops.convert_to_tensor(
                            values, as_ref=input_arg.is_ref).dtype.name
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, dtypes.as_dtype(input_arg.type).name))
                        else:
                            # Update the maps with the default, if needed.
                            k = input_arg.type_attr
                            if k in default_type_attr_map:
                                if k not in attrs:
                                    attrs[k] = default_type_attr_map[k]
                                    if k not in inferred_from:
                                        inferred_from[k] = "Default in OpDef"

                            raise TypeError(
                                "%s type %s of argument '%s'." %
                                (prefix,
                                 dtypes.as_dtype(
                                     attrs[input_arg.type_attr]).name,
                                 inferred_from[input_arg.type_attr]))

                    types = [values.dtype]
                    inputs.append(values)
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                                 attrs[input_arg.number_attr],
                                 inferred_from[input_arg.number_attr]))
                    else:
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                                 num_attr.minimum))
                    # All tensors must have the same base type.
                    if any([bt != base_types[0] for bt in base_types]):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                                input_arg.type_attr]:
                            assert False, "Unreachable"
                    else:
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0], type_attr)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type, _Attr(op_def, input_arg.type_attr))
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attr_value),
                                 ", ".join(
                                     dtypes.as_dtype(x).name
                                     for x in attrs[input_arg.type_list_attr]),
                                 inferred_from[input_arg.type_list_attr]))
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type,
                                _Attr(op_def, input_arg.type_list_attr))
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                else:
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x.is_ref_dtype for x in types):
                        raise TypeError(
                            "Input '%s' of '%s' Op requires l-value input" %
                            (input_name, op_type_name))
                    input_types.extend(types)
                else:
                    input_types.extend(base_types)

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                    continue
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                else:
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_value.CopyFrom(attr_def.default_value)
                    attr_protos[key] = attr_value
                    continue
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                                 attr_def.minimum))
                    attr_value.list.SetInParent()
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name,
                                   compat.as_text(attr_value.s), '", "'.join(
                                       map(compat.as_text,
                                           attr_def.allowed_values.list.s))))
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    %
                                    (key, op_type_name, compat.as_text(x),
                                     '", "'.join(
                                         map(compat.as_text,
                                             attr_def.allowed_values.list.s))))
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                                   attr_def.minimum))
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                    attr_value.list.f.extend(
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                    attr_value.list.b.extend(
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                    attr_value.list.type.extend(
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                    attr_value.list.shape.extend(
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                    attr_value.list.tensor.extend(
                        [_MakeTensor(x, key) for x in value])
                elif attr_def.type == "func":
                    if isinstance(value, attr_value_pb2.NameAttrList):
                        attr_value.func.CopyFrom(value)
                    elif isinstance(value, compat.bytes_or_text_types):
                        attr_value.func.name = value
                    else:
                        value.add_to_graph(ops.get_default_graph())
                        attr_value.func.name = value.name
                else:
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_types = []
            output_structure = []
            for arg in op_def.output_arg:
                types = []
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                    if arg.type_attr:
                        types = [_AttrValue(attr_protos, arg.type_attr).type
                                 ] * n
                    else:
                        types = [arg.type] * n
                    output_structure.append(n)
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                    types = [t.type]
                    output_structure.append(None)
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)
                    types = t.list.type
                    output_structure.append(len(types))
                else:
                    types = [arg.type]
                    output_structure.append(None)
                if arg.is_ref:
                    types = [dtypes.as_dtype(x).as_ref for x in types]
                output_types.extend(types)

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # NOTE(mrry): We add an explicit colocation constraint between
            # the newly created op and any of its reference-typed inputs.
            must_colocate_inputs = [
                val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
            ]
            with _MaybeColocateWith(must_colocate_inputs):
                # Add Op to graph
                op = g.create_op(op_type_name,
                                 inputs,
                                 output_types,
                                 name=scope,
                                 input_types=input_types,
                                 attrs=attr_protos,
                                 op_def=op_def)
                if output_structure:
                    outputs = op.outputs
                    res = _Restructure(ops.convert_n_to_tensor(outputs),
                                       output_structure)
                    if isinstance(res,
                                  list) and not res and op_def.is_stateful:
                        return op
                    else:
                        return res
                else:
                    return op