def _replace_batch_norm_with_bias_add(g: graph.Graph, match_info: Dict[str, node.Node], offset: np.ndarray): """ Replace the fused batch normalization node in the graph with a BiasAdd node that applies the offset from the original normalization. Then remove the batch normalization node and its input constants. Args: match_info: Should contain ops under the following keys: "batch_norm" ==> fused batch normalization op "conv" ==> Convolution or matmul op that feeds into the batch normalization "mean", "variance", "beta", "gamma" ==> Const nodes containing normalization parameters offset: Offset that the batch norm node applies at inference time """ batch_norm_node = match_info["batch_norm"] orig_inputs = batch_norm_node.inputs conv_node = match_info["conv"] if "conv" in match_info else match_info[ "conv0"] data_format = conv_node.get_attr("data_format") if conv_node.has_attr( "data_format") else None # TODO(frreiss): Support non-32-bit offsets bias_offset_node = util.make_const(g, batch_norm_node.name + "_offset", np.float32(offset), uniquify_name=True) bias_add_node = g.add_node(batch_norm_node.name + "_bias_add", "BiasAdd", uniquify_name=True) if data_format is not None: bias_add_node.add_attr("data_format", data_format) bias_add_node.add_attr("T", batch_norm_node.get_attr("T")) bias_add_node.set_inputs([batch_norm_node.inputs[0], bias_offset_node]) bias_add_node.set_outputs_from_pairs([(batch_norm_node.output(0).dtype, batch_norm_node.output(0).shape)]) # Splice the batch norm op out of the graph and replace with a newly # created BiasAdd node. # Note that the batch norm node has a bunch of other outputs that aren't # used in inference. reroute.reroute_ts(bias_add_node.output(0), batch_norm_node.output(0)) g.remove_node_by_name(batch_norm_node.name) # Original rewrite gave the name of the batch norm node to the BiasAdd. # Recreate that behavior here, including putting the node in the # collections that the original node was a member of. g.rename_node(bias_add_node.name, batch_norm_node.name) for collection_name in batch_norm_node.collection_names: bias_add_node.add_to_collection(collection_name) # Remove the input constants if they are no longer used. # Input 0 is the value to be normalized, and inputs 1-4 are the consts that # hold normalization parameters. for ix in range(1, 5): in_tensor = orig_inputs[ix] if len(in_tensor.consumers()) == 0: g.remove_node_by_name(in_tensor.node.name)
def handle_relu6(relu6_op: node.Node, scale: np.ndarray): """ Additional rewrite logic that replaces a ReLU6 op with a ReLU plus scaled minumum. Args: relu6_op: Original Relu6 scale: Scale factor pulled from the batch normalization """ # ReLU6 op: min(max(features, 0), 6). Add min() component to graph. target_np_type = relu6_op.output(0).dtype.as_numpy_dtype min_values = (6. / scale).astype(target_np_type) min_node = util.make_simple_binary_op( g, relu6_op.name + "/min", "Minimum", relu6_op.output(0), util.make_const(g, relu6_op.name + "/min/const", min_values).output(0)) reroute.reroute_ts(min_node.output(0), relu6_op.output(0), cannot_modify=[min_node]) relu6_op.change_op_type("Relu")