Esempio n. 1
0
    def action(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        mul_node = match_info["mul"]
        conv_node = match_info["conv"]
        weights_node = match_info["weights"]
        mul_values_node = match_info["mul_values"]

        # Cast to 64-bit float to avoid losing precision
        scale = np.float64(mul_values_node.get_attr("value"))

        # If there is another direct consumer of the output of the convolution,
        # skip the rewrite.
        if len(conv_node.outputs[0].consumers()) > 1:
            return False

        _add_scale_to_conv_weights(conv_node, weights_node, scale)

        # Cut the Mul node out of the graph
        reroute.reroute_ts(mul_node.inputs[0], mul_node.outputs[0])
        g.remove_node_by_name(mul_node.name, False)

        # Const might still be in use; check before removing it.
        if len(mul_values_node.outputs[0].consumers()) == 0:
            g.remove_node_by_name(mul_values_node.name, False)

        # Original rewrite gave the name of the Mul node to the Conv2D. Recreate
        # that behavior here, including putting the node in the collections that
        # the Mul node was a member of.
        g.rename_node(conv_node.name, mul_node.name)
        conv_node.remove_from_collections()
        for collection_name in mul_node.collection_names:
            conv_node.add_to_collection(collection_name)
        return True
Esempio n. 2
0
def _replace_batch_norm_with_bias_add(g: graph.Graph,
                                      match_info: Dict[str, node.Node],
                                      offset: np.ndarray):
    """
  Replace the fused batch normalization node in the graph with a BiasAdd
  node that applies the offset from the original normalization.
  Then remove the batch normalization node and its input constants.

  Args:
    match_info: Should contain ops under the following keys:
      "batch_norm" ==> fused batch normalization op
      "conv" ==> Convolution or matmul op that feeds into the batch
        normalization
      "mean", "variance", "beta", "gamma" ==> Const nodes containing
        normalization parameters
    offset: Offset that the batch norm node applies at inference time
  """
    batch_norm_node = match_info["batch_norm"]
    orig_inputs = batch_norm_node.inputs
    conv_node = match_info["conv"] if "conv" in match_info else match_info[
        "conv0"]
    data_format = conv_node.get_attr("data_format") if conv_node.has_attr(
        "data_format") else None

    # TODO(frreiss): Support non-32-bit offsets
    bias_offset_node = util.make_const(g,
                                       batch_norm_node.name + "_offset",
                                       np.float32(offset),
                                       uniquify_name=True)
    bias_add_node = g.add_node(batch_norm_node.name + "_bias_add",
                               "BiasAdd",
                               uniquify_name=True)
    if data_format is not None:
        bias_add_node.add_attr("data_format", data_format)
    bias_add_node.add_attr("T", batch_norm_node.get_attr("T"))
    bias_add_node.set_inputs([batch_norm_node.inputs[0], bias_offset_node])
    bias_add_node.set_outputs_from_pairs([(batch_norm_node.output(0).dtype,
                                           batch_norm_node.output(0).shape)])

    # Splice the batch norm op out of the graph and replace with a newly
    # created BiasAdd node.
    # Note that the batch norm node has a bunch of other outputs that aren't
    # used in inference.
    reroute.reroute_ts(bias_add_node.output(0), batch_norm_node.output(0))
    g.remove_node_by_name(batch_norm_node.name)

    # Original rewrite gave the name of the batch norm node to the BiasAdd.
    # Recreate that behavior here, including putting the node in the
    # collections that the original node was a member of.
    g.rename_node(bias_add_node.name, batch_norm_node.name)
    for collection_name in batch_norm_node.collection_names:
        bias_add_node.add_to_collection(collection_name)

    # Remove the input constants if they are no longer used.
    # Input 0 is the value to be normalized, and inputs 1-4 are the consts that
    # hold normalization parameters.
    for ix in range(1, 5):
        in_tensor = orig_inputs[ix]
        if len(in_tensor.consumers()) == 0:
            g.remove_node_by_name(in_tensor.node.name)
Esempio n. 3
0
    def handle_relu6(relu6_op: node.Node, scale: np.ndarray):
        """
    Additional rewrite logic that replaces a ReLU6 op with a ReLU plus scaled
    minumum.

    Args:
      relu6_op: Original Relu6
      scale: Scale factor pulled from the batch normalization
    """
        # ReLU6 op: min(max(features, 0), 6). Add min() component to graph.
        target_np_type = relu6_op.output(0).dtype.as_numpy_dtype
        min_values = (6. / scale).astype(target_np_type)
        min_node = util.make_simple_binary_op(
            g, relu6_op.name + "/min", "Minimum", relu6_op.output(0),
            util.make_const(g, relu6_op.name + "/min/const",
                            min_values).output(0))
        reroute.reroute_ts(min_node.output(0),
                           relu6_op.output(0),
                           cannot_modify=[min_node])
        relu6_op.change_op_type("Relu")
Esempio n. 4
0
def bypass(sgv):
    """Bypass the given subgraph by connecting its inputs to its outputs.

  Args:
    sgv: the subgraph view to be bypassed. This argument is converted to a
      subgraph using the same rules than the function subgraph.make_view.
      Note that sgv is modified in place.
  Returns:
    A tuple `(sgv, detached_inputs)` where:
      `sgv` is a new subgraph view of the bypassed subgraph;
      `detached_inputs` is a list of the created input placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
    sgv = subgraph.make_view(sgv)
    sgv_inputs = list(sgv.inputs)
    sgv, detached_inputs = detach_inputs(sgv)
    reroute.reroute_ts(sgv_inputs, sgv.outputs)
    return sgv, detached_inputs
Esempio n. 5
0
    def action_1(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        conv_node = match_info["conv"]
        add_node = match_info["add"]
        mul_node = match_info["mul"]
        weights_node = match_info["weights"]
        mul_values_node = match_info["mul_values"]
        add_values_node = match_info["add_values"]

        # If there is another direct consumer of anything we're about to
        # modify, skip the rewrite.
        for n in (add_node, mul_node, weights_node, add_values_node):
            if len(n.output(0).consumers()) > 1:
                return False

        # Scale the weights to compensate for unscaled inputs.
        scale = np.float64(mul_values_node.get_attr("value"))
        _scale_weights(weights_node, scale, [compute_input_dim(conv_node)])

        # Divide the additive factor to compensate for the multiplication being
        # pulled above the Add.
        add_values = add_values_node.get_attr("value")
        new_add_values = add_values.astype(np.float64) / scale
        add_values_node.replace_attr("value",
                                     new_add_values.astype(add_values.dtype))

        # Cut the Mul node out of the graph
        reroute.reroute_ts(mul_node.inputs[0], mul_node.outputs[0])
        g.remove_node_by_name(mul_node.name, False)

        # Const might still be in use; check before removing it.
        if len(mul_values_node.outputs[0].consumers()) == 0:
            g.remove_node_by_name(mul_values_node.name, False)

        if "relu" in match_info and match_info["relu"].op_type == "Relu6":
            handle_relu6(match_info["relu"], scale)

        return True