Example #1
0
 def extract(cls, node):
     axis = onnx_attr(node, 'axis', 'i', default=0, dst_type=np.int64)
     size_splits = onnx_attr(node, 'split', 'ints', default=None, dst_type=int64_array)
     if size_splits is None:
         AttributedSplit.update_node_stat(node, {
             'axis': axis,
             'num_splits': onnx_get_num_outputs(node),
         })
     else:
         AttributedVariadicSplit.update_node_stat(node, {
             'axis': axis,
             'size_splits': size_splits,
         })
     return cls.enabled
Example #2
0
    def test_splitv_zero(self):
        graph = build_graph(self.nodes, self.edges,
                            {
                                'split_input_data': {'shape': int64_array([2, 12, 25, 30])},
                                'split_op': {'axis': np.array(2), 'split_lengths': np.array([2, 13, 10, 0]),
                                             'out_ports_count': 4},
                            }
                            )
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        AttributedVariadicSplit.infer(node)

        self.assertTrue(len(node.out_edges()) == 3)
        self.assertTrue(np.all(node.split_lengths == np.array([2, 13, 10])))
Example #3
0
    def test_splitv_zero_not_last(self):
        graph = build_graph(self.nodes, self.edges,
                            {
                                'split_input_data': {'shape': int64_array([2, 12, 25, 30])},
                                'split_op': {'axis': np.array(2), 'split_lengths': np.array([2, 13, 0, 10]),
                                             'out_ports_count': 4},
                            }
                            )
        node = Node(graph, 'split_op')

        # extractor should do it
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)
        node.out_port(2).get_connection().set_source(node.out_port(3))

        AttributedVariadicSplit.infer(node)

        self.assertTrue(node.out_port(3).disconnected())
        self.assertTrue(np.all(node.split_lengths == np.array([2, 13, 10])))
Example #4
0
def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
    """
    Load ParallelComponent of the Kaldi model.
    ParallelComponent contains parallel nested networks.
    VariadicSplit is inserted before nested networks.
    Outputs of nested networks concatenate with layer Concat.

    :param file_descr: descriptor of the model file
    :param graph: graph with the topology.
    :param prev_layer_id: id of the input layers for parallel component layer
    :return: id of the concat layer - last layer of the parallel component layers
    """
    nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
    log.debug(
        'Model contains parallel component with {} nested networks'.format(
            nnet_count))

    split_points = []
    outputs = []
    inputs = []

    for i in range(nnet_count):
        read_token_value(file_descr, b'<NestedNnet>')
        collect_until_token(file_descr, b'<Nnet>')
        g = Graph()
        load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))

        # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
        # 1st axis contains input_size of the nested subnetwork
        # we split input from the main network to subnetworks
        input_node = Node(g, 'Parameter')
        split_points.append(input_node['shape'][1])
        g.remove_node(input_node.id)

        mapping = {
            node: graph.unique_id(node)
            for node in g.nodes(data=False) if node in graph
        }
        g = nx.relabel_nodes(g, mapping)
        for val in mapping.values():
            g.node[val]['name'] = val
        graph.add_nodes_from(g.nodes(data=True))
        graph.add_edges_from(g.edges(data=True))
        sorted_nodes = tuple(nx.topological_sort(g))

        outputs.append(Node(graph, sorted_nodes[-1]))
        inputs.append(Node(graph, sorted_nodes[0]))

    split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
    attrs = {
        'out_ports_count': nnet_count,
        'size_splits': split_points,
        'axis': 1,
        'name': split_id
    }
    variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
    prev_layer_node = Node(graph, prev_layer_id)
    prev_layer_node.add_output_port(0)
    graph.create_edge(
        prev_layer_node, variadic_split_node, 0, 0,
        create_edge_attrs(prev_layer_id, variadic_split_node.id,
                          prev_layer_id))

    concat_id = graph.unique_id(prefix='Concat')
    graph.add_node(concat_id, parameters=None, op='concat', kind='op')
    concat_node = Node(graph, concat_id)

    # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
    # and each subnetwork's output to concat_node
    for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
        output_node.add_output_port(0)
        concat_node.add_input_port(i)
        graph.create_edge(
            output_node, concat_node, 0, i,
            create_edge_attrs(output_node.id, concat_id, output_node.id, i, 0))
        graph.create_edge(
            variadic_split_node, input_node, i, 0,
            create_edge_attrs(variadic_split_node.id, input_node.id,
                              variadic_split_node.id, 0, i))
    return concat_id
Example #5
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        # check if we have dropout
        input_port = node.in_port(0)
        if node.has_and_set('use_dropout'):
            split_dropout = AttributedVariadicSplit(
                graph, {
                    'name': node_name + '/split_dropout',
                    'size_splits': int64_array([-1, 1, 1, 1]),
                    'axis': int64_array(1)
                }).create_node()
            input_port.get_connection().set_destination(
                split_dropout.in_port(0))
            input_port = split_dropout.out_port(0)
            i_drop_scale = split_dropout.out_port(1)
            f_drop_scale = split_dropout.out_port(2)
            o_drop_scale = split_dropout.out_port(3)

        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node = create_op_with_const_inputs(
            graph, Split, {1: np.int64(1)}, {
                'name': node_name + '/split_lstm_input',
                'num_splits': 5
            })
        input_port.get_connection().set_destination(split_node.in_port(0))

        i_part = split_node.out_port(0)
        f_part = split_node.out_port(1)
        c_part = split_node.out_port(2)
        o_part = split_node.out_port(3)
        ct_1 = split_node.out_port(4)

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {
            'name': node_name + '/i_scaleshift',
            'bias_term': False
        }
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        ct_1.connect(i_scale.in_port(0))

        sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
        i_part.connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {
            'name': node_name + '/i_sigmoid'
        }).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_i = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_i'
            }).create_node()
            mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0))
            mul_dropout_i.in_port(1).connect(i_drop_scale)
            i_sigmoid = mul_dropout_i

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {
            'name': node_name + '/f_scaleshift',
            'bias_term': False
        }
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        ct_1.connect(f_scale.in_port(0))

        sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
        f_part.connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {
            'name': node_name + '/f_sigmoid'
        }).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_f = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_f'
            }).create_node()
            mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0))
            mul_dropout_f.in_port(1).connect(f_drop_scale)
            f_sigmoid = mul_dropout_f

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
        c_part.connect(c_tanh.in_port(0))

        prod_i_c_tanh = Mul(graph, {
            'name': node_name + '/prod_i_c_tanh_'
        }).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Mul(graph, {
            'name': node_name + '/prod_f_ct_1_'
        }).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        ct_1.connect(prod_f_ct_1.in_port(1))

        sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {
            'name': node_name + '/o_scaleshift',
            'bias_term': False
        }
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
        o_part.connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {
            'name': node_name + '/o_sigmoid'
        }).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_o = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_o'
            }).create_node()
            mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0))
            mul_dropout_o.in_port(1).connect(o_drop_scale)
            o_sigmoid = mul_dropout_o

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Mul(graph, {
            'name': node_name + '/prod_o_c_t_tanh_'
        }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {
            'name': node_name + '/concat_c_m'
        }).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]