def extract(cls, node): pb = node.parameters mapping_rule = {'context': list()} tag = find_next_tag(pb) if tag == '<LeftContext>': read_placeholder(pb, 1) l_context = read_binary_integer32_token(pb) tag = find_next_tag(pb) if tag != '<RightContext>': raise Error( 'Unknown token {} in SpliceComponent node {}'.format( tag, node.id)) read_placeholder(pb, 1) r_context = read_binary_integer32_token(pb) for i in range(-l_context, r_context + 1): mapping_rule['context'].append(i) elif tag == '<Context>': collect_until_whitespace(pb) mapping_rule['context'] = read_binary_vector(pb, False, dtype=np.int32) else: raise Error('Unknown token {} in SpliceComponent node {}'.format( tag, node.id)) tag = find_next_tag(pb) if tag == '<ConstComponentDim>': read_placeholder(pb, 1) const_dim = read_binary_integer32_token(pb) mapping_rule['const_dim'] = const_dim Splice.update_node_stat(node, mapping_rule) return cls.enabled
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if pair_node.has_default: return if node.in_port(0).get_source() is not None: input_node_out_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_node_in_ports = pair_node.out_port(0).get_destinations() else: input_node_out_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_node_in_ports = node.out_port(0).get_destinations() in_shape = input_node_out_port.data.get_shape().copy() node_id = node.id node_name = node.name node_t = node.t splice = Splice(graph, {'name': node_name, 'id': node_id, 'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node() splice.in_port(0).connect(input_node_out_port) # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0 crop = Crop(graph, {'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([max(0, in_shape[1] * node_t)]), 'dim': int64_array([in_shape[1]])}).create_node() splice.out_port(0).connect(crop.in_port(0)) splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]])) outs = input_node_out_port.get_destinations() for in_port in outs: out_ = in_port.node if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice': crop_input = Crop(graph, {'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([-min(0, in_shape[1] * node_t)]), 'dim': int64_array([in_shape[1]])}).create_node() splice.out_port(0).connect(crop_input.in_port(0)) in_port.disconnect() crop_input.out_port(0).connect(in_port) crop_input.out_port(0).data.set_shape(in_shape) for dest_port in out_node_in_ports: dest_port.connect(crop.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)