def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node, mean: np.ndarray, variance: np.ndarray, can_be_fused=True): """ This is common function for TF, Caffe and MXNet It creates Mul->Add->Mul->Add subgraph """ shape = tinput.shape # Create first Mul & Add operations mul1_node = Mul(graph, dict(name="Mul1_", can_be_fused=can_be_fused)) add1_node = Add(graph, dict(name="Add1_", can_be_fused=can_be_fused)) mul1_data = Op.create_input_data_node(graph, "data_mul_", np.array(mean)) add1_data = Op.create_input_data_node(graph, "data_add_", np.array(variance)) # Broadcast const from scalar # We can broadcast only when const.value is scalar if gamma.shape[0] != gamma.value.shape[0]: gamma.value.resize(gamma.shape) gamma.value.fill(gamma.value[0]) # Create second Mul & Add mul2_node = Mul(graph, dict(name="Mul2_", can_be_fused=can_be_fused)) add2_node = Add(graph, dict(name="Add2_", can_be_fused=can_be_fused)) add2_node.create_node_with_data( inputs=[mul2_node.create_node_with_data( inputs=[add1_node.create_node_with_data( inputs=[mul1_node.create_node_with_data(inputs=[tinput, mul1_data]), add1_data]), gamma]), beta], data_nodes=toutput)
def apply_mean_value(graph: nx.MultiDiGraph, input_node: Node, node_mean_scale_values: dict): if 'mean' in node_mean_scale_values and node_mean_scale_values[ 'mean'] is not None: if all([x == 0 for x in node_mean_scale_values['mean']]): return out_node = input_node.out_node() if not input_node.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format( input_node.id)) input_shape = input_node.shape # Create Add node graph.remove_edge(input_node.id, out_node.id) value = np.array(node_mean_scale_values['mean']) * (-1) add_node = Add(graph, dict(name="Add_")) add_data = Op.create_input_data_node(graph, "data_add_", np.array(value)) Op.expand_node_shape( add_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)) add_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape}) add_node.create_node_with_data(inputs=[add_input, add_data], data_nodes=out_node)
def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict): # Data nodes tinput = match['input'] toutput = match['output'] mean = match['mean'] variance = match['variance'] # Op node bn_node = match['batch_norm'] # Disconnect data nodes from graph.remove_edge(tinput.node, bn_node.node) graph.remove_edge(mean.node, bn_node.node) graph.remove_edge(variance.node, bn_node.node) graph.remove_edge(bn_node.node, toutput.node) scale = 1. / np.sqrt(variance.value + bn_node.epsilon) shift = (mean.value * (-1)) * scale mean.value = np.array(scale) variance.value = np.array(shift) # Expand dims for current layout broadcast_dims_cnt = len( tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0 # Update values and shapes with new shape Op.expand_node_shape(mean, broadcast_dims_cnt) Op.expand_node_shape(variance, broadcast_dims_cnt) can_be_fused = False if not bn_node.soft_get('can_be_fused') else True mul_node = Mul(graph, dict(name="Mul_", can_be_fused=can_be_fused)) add_node = Add(graph, dict(name="Add_", can_be_fused=can_be_fused)) # Connect input->mul->add add_node.create_node_with_data(inputs=[ mul_node.create_node_with_data(inputs=[tinput, mean]), variance ], data_nodes=toutput)
def _fuse_linear_sequence(graph: nx.MultiDiGraph, start_node: Node): """ This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add). :param graph: :param start_node: The first operation of the sequence """ fnodes = [start_node] while True: node = fnodes[-1] data_node = node.out_node() if (len(data_node.out_nodes()) != 1): break if (data_node.out_node().op in ['Mul', 'Add']) and get_value_id( data_node.out_node()) is not None and data_node.out_node( ).soft_get('can_be_fused') == True: fnodes.append(data_node.out_node()) else: break if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul' and fnodes[1].op == 'Add'): return False input_shape = start_node.in_node(get_tensor_id(start_node)).shape init_dims_cnt = len( input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1 mul = np.ones([1 for x in range(init_dims_cnt)]) add = np.zeros([1 for x in range(init_dims_cnt)]) first_mul_name = None first_add_name = None for idx in range(len(fnodes)): node = fnodes[idx] const_node = get_value_id(node) if node.op == 'Mul': if first_mul_name is None: first_mul_name = node.name mul = mul * node.in_node(const_node).value add = add * node.in_node(const_node).value elif node.op == 'Add': if first_add_name is None: first_add_name = node.name add = add + node.in_node(const_node).value # If mul is scalar we broadcast it to biases shape if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1: mul = np.array([mul[0] for x in range(add.shape[0])]) assert (np.array_equal(fnodes[0].in_node(get_tensor_id(fnodes[0])).shape, fnodes[-1].out_node().shape)) mul_node = Mul( graph, dict(name=first_mul_name + '/Fused_Mul_' if first_mul_name is not None else '')) add_node = Add( graph, dict(name=first_add_name + '/Fused_Add_' if first_add_name is not None else '')) in_node = fnodes[0].in_node(get_tensor_id(fnodes[0])) out_node = fnodes[-1].out_node() graph.remove_edge(in_node.id, fnodes[0].id) graph.remove_edge(fnodes[-1].id, out_node.id) # Remove deleted subgraph for node in fnodes: for tmp_node in node.in_nodes().values(): # Remove node only if it has one consumer (for case with shared weights) if len(tmp_node.out_nodes()) == 1: graph.remove_node(tmp_node.id) for tmp_node in node.out_nodes().values(): graph.remove_node(tmp_node.id) graph.remove_node(node.id) """ Four cases considered below: 1. Mul and Add have valid values (mul value != 1 and add value != 0) 2. Only Mul has valid values, so we add only Mul node 3. Only Add has valid values, so we add only Add node 4. When Mul and Add has not valid values we just merge two data nodes """ if any([x != 0 for x in np.nditer(add)]) and any([x != 1 for x in np.nditer(mul)]): data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul)) data_add = Op.create_input_data_node(graph, "data_add_", np.array(add)) add_node.create_node_with_data(inputs=[ mul_node.create_node_with_data([in_node, data_mul]), data_add ], data_nodes=out_node) elif any([x != 1 for x in np.nditer(mul)]): data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul)) mul_node.create_node_with_data(inputs=[in_node, data_mul], data_nodes=out_node) elif any([x != 0 for x in np.nditer(add)]): data_add = Op.create_input_data_node(graph, "data_add_", np.array(add)) add_node.create_node_with_data(inputs=[in_node, data_add], data_nodes=out_node) else: merge_data_nodes(graph, out_node, in_node) graph.remove_node(in_node.id) log.debug('Fused {} operations'.format(len(fnodes))) return True
def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph): nodes = [ Node(graph, node) for node in graph.nodes() if Node(graph, node).soft_get('op') == 'ScaleShift' ] for node in nodes: if node.soft_get('can_be_fused') is False: continue has_biases = True has_weights = True # We don't need zero biases if len(node.in_nodes()) < 3 or all( [x == 0 for x in node.in_node(2).value]): has_biases = False input_node = node.in_node(0) scale_node = node.in_node(1) shift_node = node.in_node(2) if has_biases else None output_node = node.out_node() if scale_node.has_valid("value") and all( [x == 1 for x in scale_node.value]): has_weights = False mul_node = Mul(graph, dict(name=node.name + "/Mul_")) add_node = Add(graph, dict(name=node.name + "/Add_")) # Disconnect ScaleShift node graph.remove_edge(input_node.id, node.id) graph.remove_edge(node.id, output_node.id) # Expand dims for current layout broadcast_dims_cnt = len( input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0 if scale_node.has_valid("value"): Op.expand_node_shape(scale_node, broadcast_dims_cnt) else: # insert reshape to make shapes similar reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 for i in range(node.axis, node.axis + len(scale_node.shape)): reshape_dims[i] = scale_node.shape[i - node.axis] for i in range(node.axis + len(scale_node.shape), len(input_node.shape)): reshape_dims[i] = 1 reshape = Reshape( graph, dict(name=scale_node.name + "/Broadcast_", dim=reshape_dims)) scale_node = reshape.create_node_with_data(inputs=[scale_node]) Op.expand_node_shape(shift_node, broadcast_dims_cnt) # Connect input->mul->out->add->out if has_biases: add_node.create_node_with_data(inputs=[ mul_node.create_node_with_data( inputs=[input_node, scale_node]), shift_node ], data_nodes=output_node) elif has_weights: mul_node.create_node_with_data(inputs=[input_node, scale_node], data_nodes=output_node) else: merge_data_nodes(graph, input_node, output_node) graph.remove_node(output_node.id)