Пример #1
0
    def replace_op(self, graph: Graph, node: Node):
        reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0,
                                   'name': node.name + '/reciprocal_'}).create_node()
        mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
        node.in_port(0).get_connection().set_destination(mul.in_port(1))
        reciprocal.out_port(0).connect(mul.in_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [mul.id]
    def replace_op(self, graph: Graph, node: Node):
        ss_node = Split(graph,
                        attrs={
                            'name': 'Split_eltwise_' + node.name,
                            'num_split': node['num_inputs']
                        }).create_node()

        inp = node.get_inputs()
        in_node = inp[0][0]
        edge_attrs = inp[0][1]
        graph.add_edge(in_node, ss_node.id, **edge_attrs)
        if ss_node.num_split == 2:
            eltwise_node = Eltwise(graph,
                                   attrs={
                                       'name': 'Eltwise_' + node.name,
                                       'operation': node['operation']
                                   }).create_node()
        elif ss_node.num_split > 2:
            eltwise_node = EltwiseN(graph,
                                    attrs={
                                        'name': 'Eltwise_' + node.name,
                                        'operation': node['operation']
                                    }).create_node()
        else:
            raise Error('Error on replacing Kaldi eltwise')
        for i in range(ss_node.num_split):
            ss_node.add_output_port(i)
            ss_node.out_port(i).get_connection().set_destination(
                eltwise_node.in_port(i))
        return [eltwise_node.id]
    def replace_op(self, graph: Graph, node: Node):
        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
        split_node = Split(graph, {
            'name': graph.unique_id(prefix='Split_lstm_input_'),
            'num_splits': 5
        }).create_node()
        node.in_port(0).get_connection().set_destination(split_node.in_port(0))
        split_node.in_port(1).connect(split_node_axis.out_port(0))

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {
            'name': graph.unique_id(prefix='i_scaleshift'),
            'bias_term': False
        }
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        split_node.out_port(4).connect(i_scale.in_port(0))

        sum_i_c = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_i_c_'),
            'operation': 'sum'
        }).create_node()
        split_node.out_port(0).connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {
            'name': graph.unique_id(prefix='f_scaleshift'),
            'bias_term': False
        }
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        split_node.out_port(4).connect(f_scale.in_port(0))

        sum_f_c = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_f_c_'),
            'operation': 'sum'
        }).create_node()
        split_node.out_port(1).connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
        split_node.out_port(2).connect(c_tanh.in_port(0))

        prod_i_c_tanh = Eltwise(
            graph, {
                'name': graph.unique_id(prefix='prod_i_c_tanh_'),
                'operation': 'mul'
            }).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Eltwise(graph, {
            'name': graph.unique_id(prefix='prod_f_ct_1_'),
            'operation': 'mul'
        }).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

        sum_f_i = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_f_i_'),
            'operation': 'sum'
        }).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {
            'name': graph.unique_id(prefix='o_scaleshift'),
            'bias_term': False
        }
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_o_c_'),
            'operation': 'sum'
        }).create_node()
        split_node.out_port(3).connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Eltwise(
            graph, {
                'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
                'operation': 'mul'
            }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {
            'name': graph.unique_id(prefix='Concat_c_m')
        }).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]