def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name)) if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name)) mvn = MVN(graph, dict( name=fbn.name + '/MVN_', eps=fbn.eps, eps_mode='outside_sqrt', normalize_variance=1 )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['softmax'] if 'temperature' in node and node['temperature'] != 1.0: in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) temperature = np.array([1.0 / node.temperature]) scalar_value_op = Const(graph, dict(value=temperature, shape=temperature.shape, symbol_dict={'name': node.id + '/const'})) mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()]) edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0] graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format( input.op, input.name)) if input.id != match['mean'].in_node( 0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format( input.op, input.name)) MVN = Op.get_op_class_by_name('MVN') mvn = MVN( graph, dict(name=fbn.name + '/MVN_', eps=fbn.eps, required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3])) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def convert_scale_shift_to_mul_add(graph: Graph): nodes = graph.get_op_nodes(op='ScaleShift') for node in nodes: if node.soft_get('can_be_fused') is False: continue ports_count = len(node.in_ports()) input_port = node.in_port(0) scale_port = node.in_port(1) if ports_count > 1 and not node.in_port( 1).disconnected() else None shift_port = node.in_port(2) if ports_count > 2 and not node.in_port( 2).disconnected() else None output_port = node.out_port(0) has_biases = True has_weights = True # We don't need zero biases if shift_port is None or ( shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])): has_biases = False # We don't need weights with ones if scale_port is None or ( scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])): has_weights = False mul_op = Mul(graph, dict(name=node.name + "/Mul_")) add_op = Add(graph, dict(name=node.name + "/Add_")) # Expand dims for current layout broadcast_dims_cnt = len(input_port.data.get_shape( )) - 2 if graph.graph['layout'] == 'NCHW' else 0 # In case if we have constant weights/biases we have to broadcast them according to graph layout # otherwise we insert Reshape with broadcast dim attribute. def broadcast_value(port): value = np.array(port.data.get_value()) for idx in range(broadcast_dims_cnt): value = np.expand_dims(value, axis=-1) port.data.set_value(value) def broadcast_with_reshape(port): input_shape = input_port.data.get_shape() reshape_dims = np.zeros(len(input_shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 data_shape = port.data.get_shape() for i in range(node.axis, node.axis + len(data_shape)): reshape_dims[i] = data_shape[i - node.axis] for i in range(node.axis + len(data_shape), len(input_shape)): reshape_dims[i] = 1 reshape = create_op_node_with_second_input( graph, Reshape, reshape_dims, dict(name=port.node.name + "/Broadcast_")) port.get_connection().set_destination(reshape.in_port(0)) reshape.out_port(0).connect(port) if has_weights and scale_port.data.get_value() is not None: broadcast_value(scale_port) elif has_weights: broadcast_with_reshape(scale_port) if has_biases and shift_port.data.get_value() is not None: broadcast_value(shift_port) elif has_biases: broadcast_with_reshape(shift_port) if has_biases and has_weights: # Connect input->mul->out->add->out add_node = add_op.create_node() mul_node = mul_op.create_node() # Connect Mul operation with inputs input_port.get_connection().set_destination(mul_node.in_port(0)) scale_port.get_connection().set_destination(mul_node.in_port(1)) # Connect Add operation with inputs mul_node.out_port(0).connect(add_node.in_port(0)) shift_port.get_connection().set_destination(add_node.in_port(1)) output_port.get_connection().set_source(add_node.out_port(0)) elif has_weights: # Connect input->mul->out mul_node = mul_op.create_node() # Connect Mul operation with inputs input_port.get_connection().set_destination(mul_node.in_port(0)) scale_port.get_connection().set_destination(mul_node.in_port(1)) output_port.get_connection().set_source(mul_node.out_port(0)) elif has_biases: # Connect input->add->out add_node = add_op.create_node() # Connect Add operation with inputs input_port.get_connection().set_destination(add_node.in_port(0)) shift_port.get_connection().set_destination(add_node.in_port(1)) output_port.get_connection().set_source(add_node.out_port(0)) else: # Connect input->out producer_port = input_port.get_source() input_port.disconnect() output_port.get_connection().set_source(producer_port)
def _fuse_linear_sequence(graph: Graph, start_node: Node): """ This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add). :param graph: :param start_node: The first operation of the sequence """ fnodes = [start_node] while True: node = fnodes[-1] destinations = node.out_port(0).get_destinations() if len(destinations) != 1: break dst_node = destinations[0].node if dst_node.soft_get('op') in [ 'Mul', 'Add' ] and get_value_in_port(dst_node) is not None and dst_node.soft_get( 'can_be_fused') is True: fnodes.append(dst_node) else: break if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul' and fnodes[1].op == 'Add'): return False input_shape = get_tensor_in_port(start_node).data.get_shape() init_dims_cnt = len( input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1 mul = np.ones([1 for x in range(init_dims_cnt)]) add = np.zeros([1 for x in range(init_dims_cnt)]) first_mul_name = None first_add_name = None for node in fnodes: const_port_value = get_value_in_port(node).data.get_value() if node.op == 'Mul': if first_mul_name is None: first_mul_name = node.name mul = mul * const_port_value add = add * const_port_value elif node.op == 'Add': if first_add_name is None: first_add_name = node.name add = add + const_port_value # If mul is scalar we broadcast it to biases shape if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1: mul = np.array([mul[0] for x in range(add.shape[0])]) assert (np.array_equal( get_tensor_in_port(fnodes[0]).data.get_shape(), fnodes[-1].out_port(0).data.get_shape())) mul_op = Mul(graph, dict(name='{}/Fused_Mul_'.format(first_mul_name or ''))) add_op = Add(graph, dict(name='{}/Fused_Add_'.format(first_add_name or ''))) in_port = get_tensor_in_port(fnodes[0]) out_port = fnodes[-1].out_port(0) """ Four cases considered below: 1. Mul and Add have valid values (mul value != 1 and add value != 0) 2. Only Mul has valid values, so we add only Mul node 3. Only Add has valid values, so we add only Add node 4. When Mul and Add has not valid values we just merge two data nodes """ if any([x != 0 for x in np.nditer(add)]) and any([x != 1 for x in np.nditer(mul)]): # Const\ Const\ # ----->Mul------>Add--> mul_const = Const(graph, dict(name="data_mul_", value=np.array(mul))).create_node() add_const = Const(graph, dict(name="data_add_", value=np.array(add))).create_node() mul_node = mul_op.create_node() add_node = add_op.create_node() in_port.get_connection().set_destination(mul_node.in_port(0)) mul_const.out_port(0).connect(mul_node.in_port(1)) mul_node.out_port(0).connect(add_node.in_port(0)) add_const.out_port(0).connect(add_node.in_port(1)) out_port.get_connection().set_source(add_node.out_port(0)) elif any([x != 1 for x in np.nditer(mul)]): # Const\ # ----->Mul--> mul_const = Const(graph, dict(name="data_mul_", value=np.array(mul))).create_node() mul_node = mul_op.create_node() in_port.get_connection().set_destination(mul_node.in_port(0)) mul_const.out_port(0).connect(mul_node.in_port(1)) out_port.get_connection().set_source(mul_node.out_port(0)) elif any([x != 0 for x in np.nditer(add)]): # Const\ # ----->Add--> add_const = Const(graph, dict(name="data_add_", value=np.array(add))).create_node() add_node = add_op.create_node() in_port.get_connection().set_destination(add_node.in_port(0)) add_const.out_port(0).connect(add_node.in_port(1)) out_port.get_connection().set_source(add_node.out_port(0)) else: source_node = in_port.get_source() in_port.disconnect() out_port.get_connection().set_source(source_node) # Remove fused nodes for node in fnodes: graph.remove_node(node.id) log.debug('Fused {} operations'.format(len(fnodes))) return True