def replace_op(self, graph: nx.MultiDiGraph, node: Node): in_node_0 = node.in_node(0) in_node_1 = node.in_node(1) in_node_2 = node.in_node(2) ss = ScaleShiftOp(graph, {'name': node.id + "/ScaleShift_", 'axis': 0}) scale_shift = ss.create_node(inputs=[in_node_1, in_node_0]) el = Add(graph, {'name': node.id + "/Add_"}) el_node = el.create_node(inputs=[scale_shift, in_node_2]) return [el_node.id]
def replace_op(self, graph: Graph, node: Node): prefix = node.name + '/InstanceNormalization' mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon)) mul = Mul(graph, dict(name=prefix + '/Mul', axis=1)) add = Add(graph, dict(name=prefix + '/Add', axis=1)) new_subgraph = add.create_node([ mul.create_node( [mvn.create_node([node.in_node(0)]), node.in_node(1)]), node.in_node(2) ]) return [new_subgraph.id]
def replace_op(self, graph: nx.MultiDiGraph, node: Node): in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) scalar_value_op = Const(graph, dict(value=node.scalar, shape=node.scalar.shape, symbol_dict={'name': node.id + '/const'})) add_op = Add(graph, dict(name=node.id + '/add_', symbol_dict={'name': node.id + '/add_'})) add_node = add_op.create_node(inputs=[in_node, scalar_value_op.create_node()]) for out_node in out_nodes: edge_attrs = graph.get_edge_data(node.id, out_node.id)[0] graph.remove_edge(node.id, out_node.id) graph.add_edges_from([(add_node.id, out_node.id, edge_attrs)]) return [add_node.id]
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): # This replacer replace ImageScalar operation to Mul->Add sequence # Also it check that weights and biases are good op = match['op'] # Check that weights and biases are not useless has_bias, has_weights = True, True if all([x == 1 for x in np.nditer(op.scale)]): has_weights = False if all([x == 0 for x in np.nditer(op.bias)]): has_bias = False # Get all outputs for op node out_nodes = [node for node in op.out_nodes().values()] assert len(op.in_nodes()) == 1 last_node = op.in_node() # Create Mul & Add nodes if has_weights: mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)) mul_op = Mul(graph, dict(name=op.id + '/mul_')) last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()]) if has_bias: add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)) add_op = Add(graph, dict(name=op.id + '/add_')) last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()]) # Move edges from ImageScaler to last_node (Mul or Add) for out_node in out_nodes: edge_attrs = graph.get_edge_data(op.id, out_node.id)[0] graph.remove_edge(op.id, out_node.id) graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)]) # Disconnect ImageScalar node graph.remove_edge(op.in_node().id, op.id)
def convert_scale_shift_to_mul_add(graph: Graph): nodes = graph.get_op_nodes(op='ScaleShift') for node in nodes: if node.soft_get('can_be_fused') is False: continue ports_count = len(node.in_ports()) input_port = node.in_port(0) scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None output_port = node.out_port(0) has_biases = True has_weights = True # We don't need zero biases if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])): has_biases = False # We don't need weights with ones if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])): has_weights = False mul_op = Mul(graph, dict(name=node.name + "/Mul_")) add_op = Add(graph, dict(name=node.name + "/Add_")) # Expand dims for current layout broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0 # In case if we have constant weights/biases we have to broadcast them according to graph layout # otherwise we insert Reshape with broadcast dim attribute. def broadcast_value(port): value = np.array(port.data.get_value()) for idx in range(broadcast_dims_cnt): value = np.expand_dims(value, axis=-1) port.data.set_value(value) def broadcast_with_reshape(port): input_shape = input_port.data.get_shape() reshape_dims = np.zeros(len(input_shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 data_shape = port.data.get_shape() for i in range(node.axis, node.axis + len(data_shape)): reshape_dims[i] = data_shape[i - node.axis] for i in range(node.axis + len(data_shape), len(input_shape)): reshape_dims[i] = 1 reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node() port.get_connection().set_destination(reshape.in_port(0)) reshape.out_port(0).connect(port) if has_weights and scale_port.data.get_value() is not None: broadcast_value(scale_port) elif has_weights: broadcast_with_reshape(scale_port) if has_biases and shift_port.data.get_value() is not None: broadcast_value(shift_port) elif has_biases: broadcast_with_reshape(shift_port) if has_biases and has_weights: # Connect input->mul->out->add->out add_node = add_op.create_node() mul_node = mul_op.create_node() # Connect Mul operation with inputs input_port.get_connection().set_destination(mul_node.in_port(0)) scale_port.get_connection().set_destination(mul_node.in_port(1)) # Connect Add operation with inputs mul_node.out_port(0).connect(add_node.in_port(0)) shift_port.get_connection().set_destination(add_node.in_port(1)) output_port.get_connection().set_source(add_node.out_port(0)) elif has_weights: # Connect input->mul->out mul_node = mul_op.create_node() # Connect Mul operation with inputs input_port.get_connection().set_destination(mul_node.in_port(0)) scale_port.get_connection().set_destination(mul_node.in_port(1)) output_port.get_connection().set_source(mul_node.out_port(0)) elif has_biases: # Connect input->add->out add_node = add_op.create_node() # Connect Add operation with inputs input_port.get_connection().set_destination(add_node.in_port(0)) shift_port.get_connection().set_destination(add_node.in_port(1)) output_port.get_connection().set_source(add_node.out_port(0)) else: # Connect input->out producer_port = input_port.get_source() input_port.disconnect() output_port.get_connection().set_source(producer_port)