def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): node = match['op'] if not node.has_valid('bias') or (node.has_valid('bias') and node.bias == 1): return # Calculate scale value & create Const op scale_value = np.array(1. / (pow(node.bias, node.beta))) node.alpha /= node.bias const_node = Const(graph, dict(value=scale_value, shape=scale_value.shape)) # Get all outputs for LRN layer out_nodes = [node for node in node.out_nodes().values()] # Create Mul node with inputs mul_node = Mul(graph, dict(name=node.id + "/Mul_")) mnode = mul_node.create_node(inputs=[node, const_node.create_node()]) # Move edges from LRN to Mul node for out_node in out_nodes: edge_attrs = graph.get_edge_data(node.id, out_node.id)[0] graph.remove_edge(node.id, out_node.id) graph.add_edges_from([(mnode.id, out_node.id, edge_attrs)])
def replace_op(self, graph: nx.MultiDiGraph, node: Node): mul_op = Mul( graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node( inputs=[node.in_node(0), node.in_node(1)]) replace_node(node, mul_node) return [mul_node.id]
def replace_op(self, graph: Graph, node: Node): prefix = node.name + '/InstanceNormalization' mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon)) mul = Mul(graph, dict(name=prefix + '/Mul', axis=1)) add = Add(graph, dict(name=prefix + '/Add', axis=1)) new_subgraph = add.create_node([ mul.create_node( [mvn.create_node([node.in_node(0)]), node.in_node(1)]), node.in_node(2) ]) return [new_subgraph.id]
def replace_op(self, graph: Graph, node: Node): in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) scalar_value_op = Const(graph, dict(value=node.scalar, shape=node.scalar.shape, symbol_dict={'name': node.id + '/const'})) mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()]) for out_node in out_nodes: edge_attrs = graph.get_edge_data(node.id, out_node.id)[0] graph.remove_edge(node.id, out_node.id) graph.add_edges_from([(mul_node.id, out_node.id, edge_attrs)]) return [mul_node.id]
def replace_sub_graph(self, graph: Graph, match: dict): node = match['softmax'] if 'temperature' in node and node['temperature'] != 1.0: in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) temperature = np.array([1.0 / node.temperature]) scalar_value_op = Const( graph, dict(value=temperature, shape=temperature.shape, symbol_dict={'name': node.id + '/const'})) mul_op = Mul( graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node( inputs=[in_node, scalar_value_op.create_node()]) edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0] graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): # This replacer replace ImageScalar operation to Mul->Add sequence # Also it check that weights and biases are good op = match['op'] # Check that weights and biases are not useless has_bias, has_weights = True, True if all([x == 1 for x in np.nditer(op.scale)]): has_weights = False if all([x == 0 for x in np.nditer(op.bias)]): has_bias = False # Get all outputs for op node out_nodes = [node for node in op.out_nodes().values()] assert len(op.in_nodes()) == 1 last_node = op.in_node() # Create Mul & Add nodes if has_weights: mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)) mul_op = Mul(graph, dict(name=op.id + '/mul_')) last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()]) if has_bias: add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)) add_op = Add(graph, dict(name=op.id + '/add_')) last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()]) # Move edges from ImageScaler to last_node (Mul or Add) for out_node in out_nodes: edge_attrs = graph.get_edge_data(op.id, out_node.id)[0] graph.remove_edge(op.id, out_node.id) graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)]) # Disconnect ImageScalar node graph.remove_edge(op.in_node().id, op.id)
def convert_scale_shift_to_mul_add(graph: Graph): nodes = graph.get_op_nodes(op='ScaleShift') for node in nodes: if node.soft_get('can_be_fused') is False: continue ports_count = len(node.in_ports()) input_port = node.in_port(0) scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None output_port = node.out_port(0) has_biases = True has_weights = True # We don't need zero biases if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])): has_biases = False # We don't need weights with ones if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])): has_weights = False mul_op = Mul(graph, dict(name=node.name + "/Mul_")) add_op = Add(graph, dict(name=node.name + "/Add_")) # Expand dims for current layout broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0 # In case if we have constant weights/biases we have to broadcast them according to graph layout # otherwise we insert Reshape with broadcast dim attribute. def broadcast_value(port): value = np.array(port.data.get_value()) for idx in range(broadcast_dims_cnt): value = np.expand_dims(value, axis=-1) port.data.set_value(value) def broadcast_with_reshape(port): input_shape = input_port.data.get_shape() reshape_dims = np.zeros(len(input_shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 data_shape = port.data.get_shape() for i in range(node.axis, node.axis + len(data_shape)): reshape_dims[i] = data_shape[i - node.axis] for i in range(node.axis + len(data_shape), len(input_shape)): reshape_dims[i] = 1 reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node() port.get_connection().set_destination(reshape.in_port(0)) reshape.out_port(0).connect(port) if has_weights and scale_port.data.get_value() is not None: broadcast_value(scale_port) elif has_weights: broadcast_with_reshape(scale_port) if has_biases and shift_port.data.get_value() is not None: broadcast_value(shift_port) elif has_biases: broadcast_with_reshape(shift_port) if has_biases and has_weights: # Connect input->mul->out->add->out add_node = add_op.create_node() mul_node = mul_op.create_node() # Connect Mul operation with inputs input_port.get_connection().set_destination(mul_node.in_port(0)) scale_port.get_connection().set_destination(mul_node.in_port(1)) # Connect Add operation with inputs mul_node.out_port(0).connect(add_node.in_port(0)) shift_port.get_connection().set_destination(add_node.in_port(1)) output_port.get_connection().set_source(add_node.out_port(0)) elif has_weights: # Connect input->mul->out mul_node = mul_op.create_node() # Connect Mul operation with inputs input_port.get_connection().set_destination(mul_node.in_port(0)) scale_port.get_connection().set_destination(mul_node.in_port(1)) output_port.get_connection().set_source(mul_node.out_port(0)) elif has_biases: # Connect input->add->out add_node = add_op.create_node() # Connect Add operation with inputs input_port.get_connection().set_destination(add_node.in_port(0)) shift_port.get_connection().set_destination(add_node.in_port(1)) output_port.get_connection().set_source(add_node.out_port(0)) else: # Connect input->out producer_port = input_port.get_source() input_port.disconnect() output_port.get_connection().set_source(producer_port)