def extract(cls, node): axis = onnx_attr(node, 'axis', 'i', default=0, dst_type=np.int64) size_splits = onnx_attr(node, 'split', 'ints', default=None, dst_type=int64_array) if size_splits is None: AttributedSplit.update_node_stat(node, { 'axis': axis, 'num_splits': onnx_get_num_outputs(node), }) else: AttributedVariadicSplit.update_node_stat(node, { 'axis': axis, 'size_splits': size_splits, }) return cls.enabled
def test_splitv_zero(self): graph = build_graph(self.nodes, self.edges, { 'split_input_data': {'shape': int64_array([2, 12, 25, 30])}, 'split_op': {'axis': np.array(2), 'split_lengths': np.array([2, 13, 10, 0]), 'out_ports_count': 4}, } ) node = Node(graph, 'split_op') for p in range(len(node.out_edges()), node.out_ports_count): node.add_output_port(p) AttributedVariadicSplit.infer(node) self.assertTrue(len(node.out_edges()) == 3) self.assertTrue(np.all(node.split_lengths == np.array([2, 13, 10])))
def test_splitv_zero_not_last(self): graph = build_graph(self.nodes, self.edges, { 'split_input_data': {'shape': int64_array([2, 12, 25, 30])}, 'split_op': {'axis': np.array(2), 'split_lengths': np.array([2, 13, 0, 10]), 'out_ports_count': 4}, } ) node = Node(graph, 'split_op') # extractor should do it for p in range(len(node.out_edges()), node.out_ports_count): node.add_output_port(p) node.out_port(2).get_connection().set_source(node.out_port(3)) AttributedVariadicSplit.infer(node) self.assertTrue(node.out_port(3).disconnected()) self.assertTrue(np.all(node.split_lengths == np.array([2, 13, 10])))
def load_parallel_component(file_descr, graph: Graph, prev_layer_id): """ Load ParallelComponent of the Kaldi model. ParallelComponent contains parallel nested networks. VariadicSplit is inserted before nested networks. Outputs of nested networks concatenate with layer Concat. :param file_descr: descriptor of the model file :param graph: graph with the topology. :param prev_layer_id: id of the input layers for parallel component layer :return: id of the concat layer - last layer of the parallel component layers """ nnet_count = read_token_value(file_descr, b'<NestedNnetCount>') log.debug( 'Model contains parallel component with {} nested networks'.format( nnet_count)) split_points = [] outputs = [] inputs = [] for i in range(nnet_count): read_token_value(file_descr, b'<NestedNnet>') collect_until_token(file_descr, b'<Nnet>') g = Graph() load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i)) # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis # 1st axis contains input_size of the nested subnetwork # we split input from the main network to subnetworks input_node = Node(g, 'Parameter') split_points.append(input_node['shape'][1]) g.remove_node(input_node.id) mapping = { node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph } g = nx.relabel_nodes(g, mapping) for val in mapping.values(): g.node[val]['name'] = val graph.add_nodes_from(g.nodes(data=True)) graph.add_edges_from(g.edges(data=True)) sorted_nodes = tuple(nx.topological_sort(g)) outputs.append(Node(graph, sorted_nodes[-1])) inputs.append(Node(graph, sorted_nodes[0])) split_id = graph.unique_id(prefix='NestedNets/VariadicSplit') attrs = { 'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id } variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node() prev_layer_node = Node(graph, prev_layer_id) prev_layer_node.add_output_port(0) graph.create_edge( prev_layer_node, variadic_split_node, 0, 0, create_edge_attrs(prev_layer_id, variadic_split_node.id, prev_layer_id)) concat_id = graph.unique_id(prefix='Concat') graph.add_node(concat_id, parameters=None, op='concat', kind='op') concat_node = Node(graph, concat_id) # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent # and each subnetwork's output to concat_node for i, (input_node, output_node) in enumerate(zip(inputs, outputs)): output_node.add_output_port(0) concat_node.add_input_port(i) graph.create_edge( output_node, concat_node, 0, i, create_edge_attrs(output_node.id, concat_id, output_node.id, i, 0)) graph.create_edge( variadic_split_node, input_node, i, 0, create_edge_attrs(variadic_split_node.id, input_node.id, variadic_split_node.id, 0, i)) return concat_id
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # check if we have dropout input_port = node.in_port(0) if node.has_and_set('use_dropout'): split_dropout = AttributedVariadicSplit( graph, { 'name': node_name + '/split_dropout', 'size_splits': int64_array([-1, 1, 1, 1]), 'axis': int64_array(1) }).create_node() input_port.get_connection().set_destination( split_dropout.in_port(0)) input_port = split_dropout.out_port(0) i_drop_scale = split_dropout.out_port(1) f_drop_scale = split_dropout.out_port(2) o_drop_scale = split_dropout.out_port(3) # split input to (i_part, f_part, c_part, o_part, ct_1) split_node = create_op_with_const_inputs( graph, Split, {1: np.int64(1)}, { 'name': node_name + '/split_lstm_input', 'num_splits': 5 }) input_port.get_connection().set_destination(split_node.in_port(0)) i_part = split_node.out_port(0) f_part = split_node.out_port(1) c_part = split_node.out_port(2) o_part = split_node.out_port(3) ct_1 = split_node.out_port(4) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = { 'name': node_name + '/i_scaleshift', 'bias_term': False } i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) ct_1.connect(i_scale.in_port(0)) sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node() i_part.connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, { 'name': node_name + '/i_sigmoid' }).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) if node['use_dropout']: mul_dropout_i = Mul(graph, { 'name': split_node.soft_get('name', split_node.id) + '/mul_i' }).create_node() mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0)) mul_dropout_i.in_port(1).connect(i_drop_scale) i_sigmoid = mul_dropout_i # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = { 'name': node_name + '/f_scaleshift', 'bias_term': False } f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) ct_1.connect(f_scale.in_port(0)) sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node() f_part.connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, { 'name': node_name + '/f_sigmoid' }).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) if node['use_dropout']: mul_dropout_f = Mul(graph, { 'name': split_node.soft_get('name', split_node.id) + '/mul_f' }).create_node() mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0)) mul_dropout_f.in_port(1).connect(f_drop_scale) f_sigmoid = mul_dropout_f # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node() c_part.connect(c_tanh.in_port(0)) prod_i_c_tanh = Mul(graph, { 'name': node_name + '/prod_i_c_tanh_' }).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Mul(graph, { 'name': node_name + '/prod_f_ct_1_' }).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) ct_1.connect(prod_f_ct_1.in_port(1)) sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = { 'name': node_name + '/o_scaleshift', 'bias_term': False } o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node() o_part.connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, { 'name': node_name + '/o_sigmoid' }).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) if node['use_dropout']: mul_dropout_o = Mul(graph, { 'name': split_node.soft_get('name', split_node.id) + '/mul_o' }).create_node() mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0)) mul_dropout_o.in_port(1).connect(o_drop_scale) o_sigmoid = mul_dropout_o # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Mul(graph, { 'name': node_name + '/prod_o_c_t_tanh_' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, { 'name': node_name + '/concat_c_m' }).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]