def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def): """Tensorflow does not support Prelu op, and the grappler remap optimizer will not fuse the prelu op with _FusedConv2D op. This method searches for the pattern and fuse the (_FusedConv2D||FusedDepthwiseConv2dNative + Prelu) nodes into a single _FusedConv2D||FusedDepthwiseConv2dNative op with activation information. Args: input_graph_def: A GraphDef containing a model. Returns: Modified graph with Prelu ops fused with _FusedConv2D or FusedDepthwiseConv2dNative as activation function Raises: ValueError: If the graph is badly formed with duplicate node names. """ input_node_map = {} nodes_to_skip = {} inputs_to_remove = [] for node in input_graph_def.node: if node.name not in input_node_map: input_node_map[node.name] = node else: raise ValueError('Duplicate node names detected for ', node.name) for node in input_graph_def.node: if node.op != 'Prelu': continue fused_op = graph_rewrite_util.node_from_map( input_node_map, node.input[0]) if (not fused_op or (fused_op.op != '_FusedConv2D' and fused_op.op != '_FusedMatMul' and fused_op.op != 'FusedDepthwiseConv2dNative') or len(fused_op.attr['fused_ops'].list.s) > 1): continue alpha_tensor_name = node.input[1] fused_op.input.extend([alpha_tensor_name]) fused_op.attr['fused_ops'].list.s.extend([b'Prelu']) fused_op.attr['num_args'].i = fused_op.attr['num_args'].i + 1 node.op = 'Identity' node.input[:] = [node.input[0]] nodes_to_skip[node.name] = True inputs_to_remove.append(node) return graph_rewrite_util.cleanup_graph_def( input_graph_def, nodes_to_skip, inputs_to_remove)
def _find_contraction_with_bias(node, node_map): if node.op != 'BiasAdd': return False # Input to the BiasAdd must be a DepthwiseConv2dNative. if not node.input: return False conv2d_node = graph_rewrite_util.node_from_map(node_map, node.input[0]) if conv2d_node.op != 'DepthwiseConv2dNative': return False return {'contraction': conv2d_node, 'bias': node, 'activation': None}
def _find_contraction_with_activation(node, node_map): if not _is_supported_activation(node): return False # And input to the activation node must match ContractionWithBiasAdd pattern. if len(node.input) != 1: return False conv2d_node = graph_rewrite_util.node_from_map(node_map, node.input[0]) if conv2d_node.op != 'DepthwiseConv2dNative': return False return {'contraction': conv2d_node, 'bias': None, 'activation': node}
def _find_contraction_with_bias_and_activation(node, node_map): if not _is_supported_activation(node): return False # And input to the activation node must match ContractionWithBiasAdd pattern. if len(node.input) != 1: return False bias_add = graph_rewrite_util.node_from_map(node_map, node.input[0]) match = _find_contraction_with_bias(bias_add, node_map) if not match: return False match['activation'] = node return match
def fuse_ops_for_prelu(input_graph_def): """Modifies the provided graph by fusing a set of ops into a single Prelu op. The formula of PReLU is: f(x) = alpha * x for x < 0, f(x) = x for x >= 0. `x` is the input, and `alpha` is a trainable tensor which can be broadcasted to the shape of `x`. There's no native PRelu op in TensorFlow, so Keras generates the following structure which does the equivalent calculation: f(x) = Relu(x) + (-alpha * Relu(-x)) Practically, alpha is always a constant in the inference graph, and grappler can have other graph transformations which fold the activation functions to other ops. Therefore, we're looking for the structure: f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu)) Args: input_graph_def: A GraphDef containing a model. Returns: Modified graph with Prelu ops generated, and modified weights. Raises: ValueError: If the graph is badly formed with duplicate node names. """ input_node_map = {} for node in input_graph_def.node: if node.name not in input_node_map: input_node_map[node.name] = node else: raise ValueError('Duplicate node names detected for ', node.name) nodes_to_skip = {} inputs_to_remove = [] updated_alpha = [] for node in input_graph_def.node: if (node.op not in ('Add', 'AddV2') or len(node.input) != 2): continue relu_input_op = graph_rewrite_util.node_from_map( input_node_map, node.input[0]) if (not relu_input_op or relu_input_op.op != 'Relu'): continue mul_op = graph_rewrite_util.node_from_map(input_node_map, node.input[1]) if (not mul_op or mul_op.op != 'Mul'): continue neg_alpha_op = None for name in mul_op.input: op = graph_rewrite_util.node_from_map(input_node_map, name) if op.op == 'Const': neg_alpha_op = op break if not neg_alpha_op: continue alpha_tensor_name = neg_alpha_op.name _create_alpha_node(neg_alpha_op, updated_alpha) relu_neg_input_op = None for name in mul_op.input: op = graph_rewrite_util.node_from_map(input_node_map, name) if op.op == 'Relu': relu_neg_input_op = op break if (not relu_neg_input_op or len(relu_neg_input_op.input) != 1 or relu_neg_input_op.op != 'Relu'): continue # This detects a Neg op followed by a separated Relu op. neg_input_op = graph_rewrite_util.node_from_map( input_node_map, relu_neg_input_op.input[0]) if (not neg_input_op or len(neg_input_op.input) != 1 or neg_input_op.op != 'Neg'): continue final_input_op = neg_input_op if relu_input_op.input[0] != final_input_op.input[0]: continue relu_input_op.op = 'Prelu' relu_input_op.input.extend([alpha_tensor_name]) # Remove the T attr that is defined in Relu op, since our custom Prelu op # definition does not have that. del relu_input_op.attr['T'] node.op = 'Identity' del node.input[:] node.input.append(relu_input_op.name) nodes_to_skip[mul_op.name] = True nodes_to_skip[relu_neg_input_op.name] = True nodes_to_skip[neg_input_op.name] = True nodes_to_skip[node.name] = True inputs_to_remove.append(node) return graph_rewrite_util.cleanup_graph_def( input_graph_def, nodes_to_skip, inputs_to_remove)
def fold_batch_norms(input_graph_def): """Removes batch normalization ops by folding them into convolutions. Batch normalization during training has multiple dynamic parameters that are updated, but once the graph is finalized these become constants. That means there's an opportunity to reduce the computations down to a scale and addition, rather than the more expensive multiple ops, and even bake the scaling into the convolution weights. This function identifies the typical pattern of batch normalization subgraphs, and performs the transformation to fold the computations down into a simpler form. It currently only supports batch normalization that's performed by the BatchNormWithGlobalNormalization FusedBatchNorm and FusedBatchNormV3 ops, and will need to be extended in the future to handle the newer style. Args: input_graph_def: A GraphDef containing a model. Returns: Modified graph with BN ops removed, and modified weights. Raises: ValueError: If the graph is badly formed with duplicate node names. """ input_node_map = {} for node in input_graph_def.node: if node.name not in input_node_map: input_node_map[node.name] = node else: raise ValueError("Duplicate node names detected for ", node.name) nodes_to_skip = {} new_ops = [] for node in input_graph_def.node: if (node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3")): continue bias = None conv_op = graph_rewrite_util.node_from_map( input_node_map, node.input[INPUT_ORDER[node.op].index("conv_op")]) # There might be an Add/BiasAdd op between the conv and the batchnorm, # which we can fold into the mean param of the batchnorm. if conv_op.op in ['BiasAdd', 'Add', 'AddV2']: add_op = conv_op # Follow the first input of the add to get to the conv. conv_op = graph_rewrite_util.node_from_map( input_node_map, add_op.input[0]) bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1]) if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]: # Follow the second input of the add to get to the conv. conv_op = graph_rewrite_util.node_from_map( input_node_map, add_op.input[1]) bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0]) if bias and bias.op != 'Const': tf_logging.warning("The bias %s after the conv %s was not a constant. " "Maybe because freeze_graph wasn't " "run first?" % (bias.name, conv_op.name)) continue if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]: tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative" " input to '%s'" % node.name) continue weights_op = graph_rewrite_util.node_from_map( input_node_map, conv_op.input[1]) if weights_op.op != "Const": tf_logging.warning("Didn't find expected conv Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (conv_op.name, weights_op)) continue weights = graph_rewrite_util.values_from_const(weights_op) if conv_op.op == "Conv2D": channel_count = weights.shape[3] elif conv_op.op == "DepthwiseConv2dNative": channel_count = weights.shape[2] * weights.shape[3] mean_op = graph_rewrite_util.node_from_map( input_node_map, node.input[INPUT_ORDER[node.op].index("mean_op")]) if mean_op.op != "Const": tf_logging.warning("Didn't find expected mean Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, mean_op)) continue mean_value = graph_rewrite_util.values_from_const(mean_op) if bias is not None: # Adjust the mean of the batchnorm based on the add op in-between the conv # and the batchnorm. mean_value = mean_value - graph_rewrite_util.values_from_const(bias) if mean_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for mean, found %s, expected %s," " for node %s" % (str(mean_value.shape), str( (channel_count,)), node.name)) continue var_op = graph_rewrite_util.node_from_map( input_node_map, node.input[INPUT_ORDER[node.op].index("var_op")]) if var_op.op != "Const": tf_logging.warning("Didn't find expected var Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, var_op)) continue var_value = graph_rewrite_util.values_from_const(var_op) if var_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for var, found %s, expected %s," " for node %s" % (str(var_value.shape), str( (channel_count,)), node.name)) continue beta_op = graph_rewrite_util.node_from_map( input_node_map, node.input[INPUT_ORDER[node.op].index("beta_op")]) if beta_op.op != "Const": tf_logging.warning("Didn't find expected beta Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, beta_op)) continue beta_value = graph_rewrite_util.values_from_const(beta_op) if beta_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for beta, found %s, expected %s," " for node %s" % (str(beta_value.shape), str( (channel_count,)), node.name)) continue gamma_op = graph_rewrite_util.node_from_map( input_node_map, node.input[INPUT_ORDER[node.op].index("gamma_op")]) if gamma_op.op != "Const": tf_logging.warning("Didn't find expected gamma Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, gamma_op)) continue gamma_value = graph_rewrite_util.values_from_const(gamma_op) if gamma_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for gamma, found %s, expected %s," " for node %s" % (str(gamma_value.shape), str( (channel_count,)), node.name)) continue variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f nodes_to_skip[node.name] = True nodes_to_skip[weights_op.name] = True nodes_to_skip[conv_op.name] = True if bias is not None: nodes_to_skip[add_op.name] = True if scale_after_normalization(node): scale_value = ( (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) * gamma_value) else: scale_value = ( 1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) offset_value = (-mean_value * scale_value) + beta_value scaled_weights = np.copy(weights) it = np.nditer( scaled_weights, flags=["multi_index"], op_flags=["readwrite"]) if conv_op.op == "Conv2D": while not it.finished: current_scale = scale_value[it.multi_index[3]] it[0] *= current_scale it.iternext() elif conv_op.op == "DepthwiseConv2dNative": channel_multiplier = weights.shape[3] while not it.finished: current_scale = scale_value[it.multi_index[2] * channel_multiplier + it.multi_index[3]] it[0] *= current_scale it.iternext() scaled_weights_op = node_def_pb2.NodeDef() scaled_weights_op.op = "Const" scaled_weights_op.name = conv_op.name + '_weights' scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"]) scaled_weights_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( scaled_weights, weights.dtype.type, weights.shape))) # Replace the weights node with scaled weights node for i, weights_node in enumerate(conv_op.input): if weights_node == weights_op.name: conv_op.input[i] = scaled_weights_op.name new_conv_op = node_def_pb2.NodeDef() new_conv_op.CopyFrom(conv_op) offset_op = node_def_pb2.NodeDef() offset_op.op = "Const" offset_op.name = conv_op.name + "_bn_offset" offset_op.attr["dtype"].CopyFrom(mean_op.attr["dtype"]) offset_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( offset_value, mean_value.dtype.type, offset_value.shape))) bias_add_op = node_def_pb2.NodeDef() bias_add_op.op = "BiasAdd" bias_add_op.name = node.name bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"]) bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"]) bias_add_op.input.extend([new_conv_op.name, offset_op.name]) new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op]) result_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in nodes_to_skip: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) retained_input = [] for input_node in new_node.input: if not input_node.startswith('^') or input_node[1:] not in nodes_to_skip: retained_input.append(input_node) new_node.input[:] = retained_input result_graph_def.node.extend([new_node]) result_graph_def.node.extend(new_ops) result_graph_def.versions.CopyFrom(input_graph_def.versions) return result_graph_def