Ejemplo n.º 1
0
def _replace_batch_norm_with_bias_add(g: graph.Graph,
                                      match_info: Dict[str, node.Node],
                                      offset: np.ndarray):
    """
  Replace the fused batch normalization node in the graph with a BiasAdd
  node that applies the offset from the original normalization.
  Then remove the batch normalization node and its input constants.

  Args:
    match_info: Should contain ops under the following keys:
      "batch_norm" ==> fused batch normalization op
      "conv" ==> Convolution or matmul op that feeds into the batch
        normalization
      "mean", "variance", "beta", "gamma" ==> Const nodes containing
        normalization parameters
    offset: Offset that the batch norm node applies at inference time
  """
    batch_norm_node = match_info["batch_norm"]
    orig_inputs = batch_norm_node.inputs
    conv_node = match_info["conv"] if "conv" in match_info else match_info[
        "conv0"]
    data_format = conv_node.get_attr("data_format") if conv_node.has_attr(
        "data_format") else None

    # TODO(frreiss): Support non-32-bit offsets
    bias_offset_node = util.make_const(g,
                                       batch_norm_node.name + "_offset",
                                       np.float32(offset),
                                       uniquify_name=True)
    bias_add_node = g.add_node(batch_norm_node.name + "_bias_add",
                               "BiasAdd",
                               uniquify_name=True)
    if data_format is not None:
        bias_add_node.add_attr("data_format", data_format)
    bias_add_node.add_attr("T", batch_norm_node.get_attr("T"))
    bias_add_node.set_inputs([batch_norm_node.inputs[0], bias_offset_node])
    bias_add_node.set_outputs_from_pairs([(batch_norm_node.output(0).dtype,
                                           batch_norm_node.output(0).shape)])

    # Splice the batch norm op out of the graph and replace with a newly
    # created BiasAdd node.
    # Note that the batch norm node has a bunch of other outputs that aren't
    # used in inference.
    reroute.reroute_ts(bias_add_node.output(0), batch_norm_node.output(0))
    g.remove_node_by_name(batch_norm_node.name)

    # Original rewrite gave the name of the batch norm node to the BiasAdd.
    # Recreate that behavior here, including putting the node in the
    # collections that the original node was a member of.
    g.rename_node(bias_add_node.name, batch_norm_node.name)
    for collection_name in batch_norm_node.collection_names:
        bias_add_node.add_to_collection(collection_name)

    # Remove the input constants if they are no longer used.
    # Input 0 is the value to be normalized, and inputs 1-4 are the consts that
    # hold normalization parameters.
    for ix in range(1, 5):
        in_tensor = orig_inputs[ix]
        if len(in_tensor.consumers()) == 0:
            g.remove_node_by_name(in_tensor.node.name)
Ejemplo n.º 2
0
    def handle_relu6(relu6_op: node.Node, scale: np.ndarray):
        """
    Additional rewrite logic that replaces a ReLU6 op with a ReLU plus scaled
    minumum.

    Args:
      relu6_op: Original Relu6
      scale: Scale factor pulled from the batch normalization
    """
        # ReLU6 op: min(max(features, 0), 6). Add min() component to graph.
        target_np_type = relu6_op.output(0).dtype.as_numpy_dtype
        min_values = (6. / scale).astype(target_np_type)
        min_node = util.make_simple_binary_op(
            g, relu6_op.name + "/min", "Minimum", relu6_op.output(0),
            util.make_const(g, relu6_op.name + "/min/const",
                            min_values).output(0))
        reroute.reroute_ts(min_node.output(0),
                           relu6_op.output(0),
                           cannot_modify=[min_node])
        relu6_op.change_op_type("Relu")