def apply_mean_value(graph: Graph, input_node: Node, node_mean_scale_values: dict): if 'mean' in node_mean_scale_values and node_mean_scale_values[ 'mean'] is not None: if all([x == 0 for x in node_mean_scale_values['mean']]): return out_node = input_node.out_node() if not input_node.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format( input_node.id)) input_shape = input_node.shape # Create Add node graph.remove_edge(input_node.id, out_node.id) value = np.array(node_mean_scale_values['mean']) * (-1) add_node = Add(graph, dict(name="Add_")) add_data = Op.create_input_data_node(graph, "data_add_", np.array(value)) Op.expand_node_shape(add_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)) add_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape}) add_node.create_node_with_data(inputs=[add_input, add_data], data_nodes=out_node)
def apply_scale(graph: Graph, input_node: Node, node_mean_scale_values: dict): if 'scale' in node_mean_scale_values and node_mean_scale_values[ 'scale'] is not None: if all([x == 1 for x in node_mean_scale_values['scale']]): return out_node = input_node.out_node() if not input_node.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format( input_node.id)) input_shape = input_node.shape # Create Mul node value = 1 / np.array(node_mean_scale_values['scale']) graph.remove_edge(input_node.id, out_node.id) mul_node = Mul(graph, dict(name="Mul_")) mul_data = Op.create_input_data_node(graph, "data_mul_", np.array(value)) Op.expand_node_shape(mul_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)) mul_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape}) mul_node.create_node_with_data(inputs=[mul_input, mul_data], data_nodes=out_node)
def _scale_input_action_mul(graph: nx.MultiDiGraph, match: dict, scale: float): assert (len(match['placeholder'].out_nodes())) tinput = match['placeholder'] if not tinput.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format(tinput.id)) input_shape = tinput.shape toutput = match['data'] # Create Mul node value = np.array([1 / scale]) # Disconnect input with data node graph.remove_edge(tinput.id, toutput.id) # Create Mul node mul_node = Mul(graph, dict(name="Mul1_")) mul_data = Op.create_input_data_node(graph, "data_mul_scale_", np.array(value)) Op.expand_node_shape( mul_data, len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0) mul_input = Op.create_data_node(graph, tinput, {'shape': toutput.shape}) mul_node.create_node_with_data(inputs=[mul_input, mul_data], data_nodes=toutput)
def convert_batch_norm(graph: nx.MultiDiGraph): """ This function finds FusedBatchNorm layer (or BatchNorm for MXNet) and replaces with Mul->Add->Mul->Add sequence. """ for n in list(graph.nodes()): node = Node(graph, n) if node.has_valid('op') and (node.op == 'FusedBatchNorm' or node.op == 'BatchNorm' or node.op == 'BatchNormalization'): toutput = node.out_node() tinput = node.in_node(0) if any([ node.in_node(i).value is None for i in range(1, len(node.in_nodes())) ]): log.warning( 'Cannot translate FusedBatchNorm {} node with non-constant weights' .format( node.name if node.has_valid('name') else '<UNKNOWN>')) continue const = node.in_node(1) beta = node.in_node(2) mean = node.in_node(3) variance = node.in_node(4) eps = node.eps if node.has_valid('fix_gamma') and node.fix_gamma: const.value.fill(1.) can_be_fused = False if not node.soft_get('can_be_fused') else True # Remove edges from FusedBN node graph.remove_edge(tinput.id, node.id) graph.remove_edge(beta.id, node.id) graph.remove_edge(const.id, node.id) graph.remove_edge(mean.id, node.id) graph.remove_edge(variance.id, node.id) graph.remove_edge(node.id, toutput.id) scale = 1. / np.sqrt(variance.value + eps) shift = (mean.value * (-1)) * scale # Expand dims for current layout broadcast_dims_cnt = len( tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0 # Update values and shapes with new shape Op.expand_node_shape(const, broadcast_dims_cnt) Op.expand_node_shape(beta, broadcast_dims_cnt) for idx in range(broadcast_dims_cnt): scale = np.expand_dims(scale, axis=-1) shift = np.expand_dims(shift, axis=-1) _fused_batch_norm_decomposition(graph, tinput, toutput, const, beta, scale, shift, can_be_fused)
def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict): # Data nodes tinput = match['input'] toutput = match['output'] mean = match['mean'] variance = match['variance'] # Op node bn_node = match['batch_norm'] # Disconnect data nodes from graph.remove_edge(tinput.node, bn_node.node) graph.remove_edge(mean.node, bn_node.node) graph.remove_edge(variance.node, bn_node.node) graph.remove_edge(bn_node.node, toutput.node) scale = 1. / np.sqrt(variance.value + bn_node.epsilon) shift = (mean.value * (-1)) * scale mean.value = np.array(scale) variance.value = np.array(shift) # Expand dims for current layout broadcast_dims_cnt = len( tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0 # Update values and shapes with new shape Op.expand_node_shape(mean, broadcast_dims_cnt) Op.expand_node_shape(variance, broadcast_dims_cnt) can_be_fused = False if not bn_node.soft_get('can_be_fused') else True mul_node = Mul(graph, dict(name="Mul_", can_be_fused=can_be_fused)) add_node = Add(graph, dict(name="Add_", can_be_fused=can_be_fused)) # Connect input->mul->add add_node.create_node_with_data(inputs=[ mul_node.create_node_with_data(inputs=[tinput, mean]), variance ], data_nodes=toutput)
def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph): nodes = [ Node(graph, node) for node in graph.nodes() if Node(graph, node).soft_get('op') == 'ScaleShift' ] for node in nodes: if node.soft_get('can_be_fused') is False: continue has_biases = True has_weights = True # We don't need zero biases if len(node.in_nodes()) < 3 or all( [x == 0 for x in node.in_node(2).value]): has_biases = False input_node = node.in_node(0) scale_node = node.in_node(1) shift_node = node.in_node(2) if has_biases else None output_node = node.out_node() if scale_node.has_valid("value") and all( [x == 1 for x in scale_node.value]): has_weights = False mul_node = Mul(graph, dict(name=node.name + "/Mul_")) add_node = Add(graph, dict(name=node.name + "/Add_")) # Disconnect ScaleShift node graph.remove_edge(input_node.id, node.id) graph.remove_edge(node.id, output_node.id) # Expand dims for current layout broadcast_dims_cnt = len( input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0 if scale_node.has_valid("value"): Op.expand_node_shape(scale_node, broadcast_dims_cnt) else: # insert reshape to make shapes similar reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 for i in range(node.axis, node.axis + len(scale_node.shape)): reshape_dims[i] = scale_node.shape[i - node.axis] for i in range(node.axis + len(scale_node.shape), len(input_node.shape)): reshape_dims[i] = 1 reshape = Reshape( graph, dict(name=scale_node.name + "/Broadcast_", dim=reshape_dims)) scale_node = reshape.create_node_with_data(inputs=[scale_node]) Op.expand_node_shape(shift_node, broadcast_dims_cnt) # Connect input->mul->out->add->out if has_biases: add_node.create_node_with_data(inputs=[ mul_node.create_node_with_data( inputs=[input_node, scale_node]), shift_node ], data_nodes=output_node) elif has_weights: mul_node.create_node_with_data(inputs=[input_node, scale_node], data_nodes=output_node) else: merge_data_nodes(graph, input_node, output_node) graph.remove_node(output_node.id)