Exemplo n.º 1
0
def fold_batch_norms(g):
    # type: (graph.Graph) -> None
    """
  Python port of the Graph Transform Tool rewrite by the same name.

  Identifies instances of the pattern `Conv2D => Mul` and folds the
  multiplication into the convolution's filter coefficients. This pattern
  occurs as a result of `Conv2D => BatchNorm` turning into
  `Conv2D => Mul => Add` when a multi-op batch normalization is used.

  Also covers the related cases when the `Conv2D` is replaced with a `MatMul`
  or a `DepthwiseConv2D`
  """
    pattern = TreeExpr(op="Mul",
                       alias="mul",
                       inputs=(TreeExpr(
                           op="Conv2D|MatMul|DepthwiseConv2dNative",
                           alias="conv",
                           inputs=(TreeExpr(),
                                   TreeExpr(op="Const", alias="weights"))),
                               TreeExpr(op="Const", alias="mul_values")))

    def action(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        mul_node = match_info["mul"]
        conv_node = match_info["conv"]
        weights_node = match_info["weights"]
        mul_values_node = match_info["mul_values"]

        # Cast to 64-bit float to avoid losing precision
        scale = np.float64(mul_values_node.get_attr("value"))

        # If there is another direct consumer of the output of the convolution,
        # skip the rewrite.
        if len(conv_node.outputs[0].consumers()) > 1:
            return False

        _add_scale_to_conv_weights(conv_node, weights_node, scale)

        # Cut the Mul node out of the graph
        reroute.reroute_ts(mul_node.inputs[0], mul_node.outputs[0])
        g.remove_node_by_name(mul_node.name, False)

        # Const might still be in use; check before removing it.
        if len(mul_values_node.outputs[0].consumers()) == 0:
            g.remove_node_by_name(mul_values_node.name, False)

        # Original rewrite gave the name of the Mul node to the Conv2D. Recreate
        # that behavior here, including putting the node in the collections that
        # the Mul node was a member of.
        g.rename_node(conv_node.name, mul_node.name)
        conv_node.remove_from_collections()
        for collection_name in mul_node.collection_names:
            conv_node.add_to_collection(collection_name)
        return True

    _fixed_point_apply(pattern, action, g)
Exemplo n.º 2
0
def _fixed_point_apply(pattern: TreeExpr,
                       action: Callable[[graph.Graph, Dict[str, node.Node]],
                                        bool], g: graph.Graph):
    """
  Repeatedly apply a pattern-action rule until the graph stops changing.

  Args:
    pattern: Expression that selects a portion of the graph for modification
    action: Rule (as a Callable) that optionally modifies the graph. Returns
      True if modifications occurred and False otherwise.
  """
    keep_going = True
    while keep_going:
        keep_going = False
        # Each iteration walks through all the nodes of the graph to avoid O(n^2)
        # behavior
        nodes_before = g.nodes
        for n in nodes_before:
            if n.graph is None:
                # Node has been removed from the graph.
                continue
            match_info = pattern.eval_from(n)
            if match_info is not None:
                # Found a structural match rooted at the current node. Perform action.
                change_happened = action(g, match_info)
                if change_happened:
                    keep_going = True
Exemplo n.º 3
0
def fold_batch_norms_up(g):
    # type: (graph.Graph) -> None
    """
  Identifies instances of the pattern
  ```
     Mul => Add => (optional ReLU/ReLU6) => [Conv2D|MatMul|DepthwiseConv2d]
  ```
  and the equivalent pattern
  ```
    FusedBatchNorm => (optional ReLU/ReLU6) => [Conv2D|MatMul|DepthwiseConv2d]
  ```
  Then fuses the multiplication into the convolution's filter coefficients
  and applies a correction to the Add op to compensate for add happening
  before multiply.

  If the nonlinearity is a ReLU6, replaces it with
  ```
    ReLU => Min(6 / multiplier from batch norm)
  """
    def compute_input_dim(n  #type: node.Node
                          ):
        if n.op_type == "Conv2D" or n.op_type == "DepthwiseConv2dNative":
            return 2
        elif n.op_type == "MatMul":
            return 0
        else:
            raise ValueError("Unexpected op type {}".format(n.op_type))

    pattern_1 = (TreeExpr(op="Conv2D|MatMul|DepthwiseConv2dNative",
                          alias="conv",
                          inputs=(TreeExpr(
                              op="Relu|Relu6",
                              alias="relu",
                              optional=True,
                              inputs=(TreeExpr(
                                  op="Add",
                                  alias="add",
                                  inputs=(TreeExpr(
                                      op="Mul",
                                      alias="mul",
                                      inputs=(TreeExpr(),
                                              TreeExpr(op="Const",
                                                       alias="mul_values"))),
                                          TreeExpr(op="Const",
                                                   alias="add_values"))))),
                                  TreeExpr(op="Const", alias="weights"))))

    def handle_relu6(relu6_op, scale):
        # type: (node.Node, np.ndarray) -> None
        """
    Additional rewrite logic that replaces a ReLU6 op with a ReLU plus scaled
    minumum.

    Args:
      relu6_op: Original Relu6
      scale: Scale factor pulled from the batch normalization
    """
        # ReLU6 op: min(max(features, 0), 6). Add min() component to graph.
        target_np_type = relu6_op.output(0).dtype.as_numpy_dtype
        min_values = (6. / scale).astype(target_np_type)
        min_node = util.make_simple_binary_op(
            g, relu6_op.name + "/min", "Minimum", relu6_op.output(0),
            util.make_const(g, relu6_op.name + "/min/const",
                            min_values).output(0))
        reroute.reroute_ts(min_node.output(0),
                           relu6_op.output(0),
                           cannot_modify=[min_node])
        relu6_op.change_op_type("Relu")

    def action_1(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        conv_node = match_info["conv"]
        add_node = match_info["add"]
        mul_node = match_info["mul"]
        weights_node = match_info["weights"]
        mul_values_node = match_info["mul_values"]
        add_values_node = match_info["add_values"]

        # If there is another direct consumer of anything we're about to
        # modify, skip the rewrite.
        for n in (add_node, mul_node, weights_node, add_values_node):
            if len(n.output(0).consumers()) > 1:
                return False

        # Scale the weights to compensate for unscaled inputs.
        scale = np.float64(mul_values_node.get_attr("value"))
        _scale_weights(weights_node, scale, [compute_input_dim(conv_node)])

        # Divide the additive factor to compensate for the multiplication being
        # pulled above the Add.
        add_values = add_values_node.get_attr("value")
        new_add_values = add_values.astype(np.float64) / scale
        add_values_node.replace_attr("value",
                                     new_add_values.astype(add_values.dtype))

        # Cut the Mul node out of the graph
        reroute.reroute_ts(mul_node.inputs[0], mul_node.outputs[0])
        g.remove_node_by_name(mul_node.name, False)

        # Const might still be in use; check before removing it.
        if len(mul_values_node.outputs[0].consumers()) == 0:
            g.remove_node_by_name(mul_values_node.name, False)

        if "relu" in match_info and match_info["relu"].op_type == "Relu6":
            handle_relu6(match_info["relu"], scale)

        return True

    _fixed_point_apply(pattern_1, action_1, g)

    pattern_2 = (TreeExpr(
        op="Conv2D|MatMul|DepthwiseConv2dNative",
        alias="conv",
        inputs=(TreeExpr(op="Relu|Relu6",
                         alias="relu",
                         optional=True,
                         inputs=(TreeExpr(op="FusedBatchNorm",
                                          alias="batch_norm",
                                          inputs=(TreeExpr(),
                                                  TreeExpr(op="Const"),
                                                  TreeExpr(op="Const"),
                                                  TreeExpr(op="Const"),
                                                  TreeExpr(op="Const"))), )),
                TreeExpr(op="Const", alias="weights"))))

    def action_2(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        conv_node = match_info["conv"]
        batch_norm_node = match_info["batch_norm"]
        weights_node = match_info["weights"]

        # If there is another direct consumer of anything we're about to
        # modify, skip the rewrite.
        for n in (batch_norm_node, weights_node):
            if len(n.output(0).consumers()) > 1:
                return False

        scale, offset = _get_scale_and_offset(match_info)

        # Scale the weights to compensate for unscaled inputs.
        _scale_weights(weights_node, scale, [compute_input_dim(conv_node)])

        # Divide the additive factor to compensate for the multiplication being
        # pulled above the fused batch norm's embedded addition.
        offset /= scale
        _replace_batch_norm_with_bias_add(g, match_info, offset)

        if "relu" in match_info and match_info["relu"].op_type == "Relu6":
            handle_relu6(match_info["relu"], scale)

        return True

    _fixed_point_apply(pattern_2, action_2, g)
Exemplo n.º 4
0
def fold_old_batch_norms(g):
    # type: (graph.Graph) -> None
    """
  Python port of the Graph Transform Tool rewrite by the same name.

  This rewrite looks for instances of the pattern `Conv2D => [batch norm]`,
  where [batch norm] is a fused batch normalization operator.

  The rewrite also covers instances of `DepthwiseConv2D => [batch norm]` when
  the channel multiplier of the DepthwiseConv2D op is 1.

  The TF documentation says that this rewrite is only for graphs produced by
  legacy code, but this is not true. As of January 2019, the most recent
  version of TensorFlow produces fused batch normalization operators by default.

  Specifically, legacy code uses the `BatchNormWithGlobalNormalization` op,
  while new code uses the `FusedBatchNorm` op.

  In addition to covering the basic `Conv2D => [batch norm]` pattern,
  the rewrite also covers the cases where some postprocessing nodes exist
  between the `Conv2D` and the `[batch norm]` parts. As a result, the rewrite
  proceeds in three passes.
  """
    # Perform three passes to cover three different types of subgraph.
    # PASS 1: Simple Conv2D => [batch norm] pattern.
    pattern_1 = TreeExpr(op="BatchNormWithGlobalNormalization|FusedBatchNorm",
                         alias="batch_norm",
                         inputs=(
                             TreeExpr(op="Conv2D|DepthwiseConv2dNative",
                                      alias="conv",
                                      inputs=(TreeExpr(),
                                              TreeExpr(op="Const",
                                                       alias="weights"))),
                             TreeExpr(op="Const"),
                             TreeExpr(op="Const"),
                             TreeExpr(op="Const"),
                             TreeExpr(op="Const"),
                         ))

    def action_1(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        conv_node = match_info["conv"]
        weights_node = match_info["weights"]

        # If there is another direct consumer of the output of the convolution,
        # skip the rewrite.
        if len(conv_node.outputs[0].consumers()) > 1:
            return False

        scale, offset = _get_scale_and_offset(match_info)
        _add_scale_to_conv_weights(conv_node, weights_node, scale)
        _replace_batch_norm_with_bias_add(g, match_info, offset)
        return True

    _fixed_point_apply(pattern_1, action_1, g)

    # PASS 2: Conv2D|DepthwiseConv2D => BatchToSpaceND => [batch norm]
    pattern_2 = TreeExpr(
        op="BatchNormWithGlobalNormalization|FusedBatchNorm",
        alias="batch_norm",
        inputs=(
            TreeExpr(op="BatchToSpaceND",
                     alias="batch_to_space",
                     inputs=(TreeExpr(op="Conv2D|DepthwiseConv2dNative",
                                      alias="conv",
                                      inputs=(TreeExpr(),
                                              TreeExpr(op="Const",
                                                       alias="weights"))))),
            TreeExpr(op="Const"),
            TreeExpr(op="Const"),
            TreeExpr(op="Const"),
            TreeExpr(op="Const"),
        ))

    def action_2(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        conv_node = match_info["conv"]
        weights_node = match_info["weights"]

        # If there is another direct consumer of the output of the convolution,
        # the BatchToSpaceND, or the convolution weights, skip the rewrite
        for n in (conv_node, weights_node, match_info["batch_to_space"]):
            if len(n.output(0).consumers()) > 1:
                return False

        scale, offset = _get_scale_and_offset(match_info)
        _add_scale_to_conv_weights(conv_node, weights_node, scale)
        _replace_batch_norm_with_bias_add(g, match_info, offset)
        return True

    _fixed_point_apply(pattern_2, action_2, g)

    # PASS 3: Two Conv2D's -> Concat -> [batch norm]
    pattern_3 = TreeExpr(
        op="BatchNormWithGlobalNormalization|FusedBatchNorm",
        alias="batch_norm",
        inputs=(
            TreeExpr(op="ConcatV2|Concat",
                     alias="concat",
                     inputs=(TreeExpr(op="Conv2D",
                                      alias="conv0",
                                      inputs=(TreeExpr(),
                                              TreeExpr(op="Const",
                                                       alias="weights0"))),
                             TreeExpr(op="Conv2D",
                                      alias="conv1",
                                      inputs=(TreeExpr(),
                                              TreeExpr(op="Const",
                                                       alias="weights1"))),
                             TreeExpr(op="Const", alias="axis"))),
            TreeExpr(op="Const"),
            TreeExpr(op="Const"),
            TreeExpr(op="Const"),
            TreeExpr(op="Const"),
        ))

    def action_3(_, match_info):
        # type: (Any, Dict[str, node.Node]) -> bool
        # If there is another direct consumer of anything between a conv and the
        # final output, skip the rewrite
        if len(match_info["conv0"].outputs[0].consumers()) > 1:
            return False
        if len(match_info["conv1"].outputs[0].consumers()) > 1:
            return False
        if len(match_info["concat"].outputs[0].consumers()) > 1:
            return False

        conv0_node = match_info["conv0"]
        conv1_node = match_info["conv1"]
        weights0_node = match_info["weights0"]
        weights1_node = match_info["weights1"]

        scale, offset = _get_scale_and_offset(match_info)

        axis = match_info["axis"].get_attr("value")
        if axis == 3:
            # Concatenating along channel axis ==> Need to split scale and offset
            split_cols = weights0_node.get_attr("value").shape[3]
            scale_0, offset_0 = scale[:split_cols], offset[:split_cols]
            scale_1, offset_1 = scale[split_cols:], offset[split_cols:]
        else:
            # Concatenating along axis other than channel ==> Scale every channel
            scale_0, offset_0 = scale, offset
            scale_1, offset_1 = scale, offset

        _add_scale_to_conv_weights(conv0_node, weights0_node, scale_0)
        _add_scale_to_conv_weights(conv1_node, weights1_node, scale_1)

        _replace_batch_norm_with_bias_add(g, match_info, offset)
        return True

    _fixed_point_apply(pattern_3, action_3, g)