def action(_, match_info): # type: (Any, Dict[str, node.Node]) -> bool mul_node = match_info["mul"] conv_node = match_info["conv"] weights_node = match_info["weights"] mul_values_node = match_info["mul_values"] # Cast to 64-bit float to avoid losing precision scale = np.float64(mul_values_node.get_attr("value")) # If there is another direct consumer of the output of the convolution, # skip the rewrite. if len(conv_node.outputs[0].consumers()) > 1: return False _add_scale_to_conv_weights(conv_node, weights_node, scale) # Cut the Mul node out of the graph reroute.reroute_ts(mul_node.inputs[0], mul_node.outputs[0]) g.remove_node_by_name(mul_node.name, False) # Const might still be in use; check before removing it. if len(mul_values_node.outputs[0].consumers()) == 0: g.remove_node_by_name(mul_values_node.name, False) # Original rewrite gave the name of the Mul node to the Conv2D. Recreate # that behavior here, including putting the node in the collections that # the Mul node was a member of. g.rename_node(conv_node.name, mul_node.name) conv_node.remove_from_collections() for collection_name in mul_node.collection_names: conv_node.add_to_collection(collection_name) return True
def _replace_batch_norm_with_bias_add(g: graph.Graph, match_info: Dict[str, node.Node], offset: np.ndarray): """ Replace the fused batch normalization node in the graph with a BiasAdd node that applies the offset from the original normalization. Then remove the batch normalization node and its input constants. Args: match_info: Should contain ops under the following keys: "batch_norm" ==> fused batch normalization op "conv" ==> Convolution or matmul op that feeds into the batch normalization "mean", "variance", "beta", "gamma" ==> Const nodes containing normalization parameters offset: Offset that the batch norm node applies at inference time """ batch_norm_node = match_info["batch_norm"] orig_inputs = batch_norm_node.inputs conv_node = match_info["conv"] if "conv" in match_info else match_info[ "conv0"] data_format = conv_node.get_attr("data_format") if conv_node.has_attr( "data_format") else None # TODO(frreiss): Support non-32-bit offsets bias_offset_node = util.make_const(g, batch_norm_node.name + "_offset", np.float32(offset), uniquify_name=True) bias_add_node = g.add_node(batch_norm_node.name + "_bias_add", "BiasAdd", uniquify_name=True) if data_format is not None: bias_add_node.add_attr("data_format", data_format) bias_add_node.add_attr("T", batch_norm_node.get_attr("T")) bias_add_node.set_inputs([batch_norm_node.inputs[0], bias_offset_node]) bias_add_node.set_outputs_from_pairs([(batch_norm_node.output(0).dtype, batch_norm_node.output(0).shape)]) # Splice the batch norm op out of the graph and replace with a newly # created BiasAdd node. # Note that the batch norm node has a bunch of other outputs that aren't # used in inference. reroute.reroute_ts(bias_add_node.output(0), batch_norm_node.output(0)) g.remove_node_by_name(batch_norm_node.name) # Original rewrite gave the name of the batch norm node to the BiasAdd. # Recreate that behavior here, including putting the node in the # collections that the original node was a member of. g.rename_node(bias_add_node.name, batch_norm_node.name) for collection_name in batch_norm_node.collection_names: bias_add_node.add_to_collection(collection_name) # Remove the input constants if they are no longer used. # Input 0 is the value to be normalized, and inputs 1-4 are the consts that # hold normalization parameters. for ix in range(1, 5): in_tensor = orig_inputs[ix] if len(in_tensor.consumers()) == 0: g.remove_node_by_name(in_tensor.node.name)
def handle_relu6(relu6_op: node.Node, scale: np.ndarray): """ Additional rewrite logic that replaces a ReLU6 op with a ReLU plus scaled minumum. Args: relu6_op: Original Relu6 scale: Scale factor pulled from the batch normalization """ # ReLU6 op: min(max(features, 0), 6). Add min() component to graph. target_np_type = relu6_op.output(0).dtype.as_numpy_dtype min_values = (6. / scale).astype(target_np_type) min_node = util.make_simple_binary_op( g, relu6_op.name + "/min", "Minimum", relu6_op.output(0), util.make_const(g, relu6_op.name + "/min/const", min_values).output(0)) reroute.reroute_ts(min_node.output(0), relu6_op.output(0), cannot_modify=[min_node]) relu6_op.change_op_type("Relu")
def bypass(sgv): """Bypass the given subgraph by connecting its inputs to its outputs. Args: sgv: the subgraph view to be bypassed. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. Note that sgv is modified in place. Returns: A tuple `(sgv, detached_inputs)` where: `sgv` is a new subgraph view of the bypassed subgraph; `detached_inputs` is a list of the created input placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers sgv = subgraph.make_view(sgv) sgv_inputs = list(sgv.inputs) sgv, detached_inputs = detach_inputs(sgv) reroute.reroute_ts(sgv_inputs, sgv.outputs) return sgv, detached_inputs
def action_1(_, match_info): # type: (Any, Dict[str, node.Node]) -> bool conv_node = match_info["conv"] add_node = match_info["add"] mul_node = match_info["mul"] weights_node = match_info["weights"] mul_values_node = match_info["mul_values"] add_values_node = match_info["add_values"] # If there is another direct consumer of anything we're about to # modify, skip the rewrite. for n in (add_node, mul_node, weights_node, add_values_node): if len(n.output(0).consumers()) > 1: return False # Scale the weights to compensate for unscaled inputs. scale = np.float64(mul_values_node.get_attr("value")) _scale_weights(weights_node, scale, [compute_input_dim(conv_node)]) # Divide the additive factor to compensate for the multiplication being # pulled above the Add. add_values = add_values_node.get_attr("value") new_add_values = add_values.astype(np.float64) / scale add_values_node.replace_attr("value", new_add_values.astype(add_values.dtype)) # Cut the Mul node out of the graph reroute.reroute_ts(mul_node.inputs[0], mul_node.outputs[0]) g.remove_node_by_name(mul_node.name, False) # Const might still be in use; check before removing it. if len(mul_values_node.outputs[0].consumers()) == 0: g.remove_node_by_name(mul_values_node.name, False) if "relu" in match_info and match_info["relu"].op_type == "Relu6": handle_relu6(match_info["relu"], scale) return True