Ejemplo n.º 1
0
    def extract(node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        axis = attrs.int("axis", 1)
        num_outputs = attrs.int("num_outputs", 0)

        node_attrs = {'axis': axis, 'num_split': num_outputs}

        # update the attributes of the node
        Split.update_node_stat(node, node_attrs)
        return __class__.enabled
Ejemplo n.º 2
0
    def split_data(self, data: Node):
        """ Split data node into two part along 0 axis """
        assert len(data.shape) == 3
        assert data.shape[0] == 2

        output_data = [
            Op._create_data_node(
                data.graph,
                name=data.name +
                '/SplittedBiLSTM/{}'.format(['forward', 'reverse'][i]))
            for i in [0, 1]
        ]
        split_op = Split(
            data.graph,
            dict(name=data.name + '/DecomposedBiLSTM_0', axis=0, num_split=2))
        return split_op.create_node_with_data([data], data_nodes=output_data)
    def replace_op(self, graph: Graph, node: Node):
        ss_node = Split(graph,
                        attrs={
                            'name': 'Split_eltwise_' + node.name,
                            'num_split': node['num_inputs']
                        }).create_node()

        inp = node.get_inputs()
        in_node = inp[0][0]
        edge_attrs = inp[0][1]
        graph.add_edge(in_node, ss_node.id, **edge_attrs)
        if ss_node.num_split == 2:
            eltwise_node = Eltwise(graph,
                                   attrs={
                                       'name': 'Eltwise_' + node.name,
                                       'operation': node['operation']
                                   }).create_node()
        elif ss_node.num_split > 2:
            eltwise_node = EltwiseN(graph,
                                    attrs={
                                        'name': 'Eltwise_' + node.name,
                                        'operation': node['operation']
                                    }).create_node()
        else:
            raise Error('Error on replacing Kaldi eltwise')
        for i in range(ss_node.num_split):
            ss_node.add_output_port(i)
            ss_node.out_port(i).get_connection().set_destination(
                eltwise_node.in_port(i))
        return [eltwise_node.id]
Ejemplo n.º 4
0
    def replace_op(self, graph: Graph, node: Node):
        input_node = node.in_node()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'num_output': node.gifo_x_weights_shape[0],
            'bias_term': True
        }

        embed_input(fc_layer_after_input_attrs, 1, 'weights',
                    node.gifo_x_weights)
        embed_input(fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
        fc_layer_after_input = InnerProduct(
            graph, fc_layer_after_input_attrs).create_node([input_node])

        prev_lstm_output = Memory(
            graph, {
                'name': 'prev_memory_output',
                'id': memory_pair_input,
                'index': 1,
                'size': 2,
                'shape': np.array([node.gifo_r_weights_shape[1]],
                                  dtype=np.int64)
            }).create_node()

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'num_output': node.gifo_r_weights_shape[0],
            'bias_term': False
        }

        embed_input(fc_layer_from_prev_state_attrs, 1, 'weights',
                    node.gifo_r_weights)
        fc_layer_from_prev_state = InnerProduct(
            graph,
            fc_layer_from_prev_state_attrs).create_node([prev_lstm_output])

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise',
        }).create_node([fc_layer_from_prev_state, fc_layer_after_input])

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input = Split(
            graph, {
                'name': 'join_input_split',
                'axis': 1,
                'num_split': 4,
                'out_ports_count': 4,
            }).create_node([join_input_prev_state_sum])

        prev_lstm_state = Memory(
            graph, {
                'name':
                'prev_memory_state',
                'id':
                memory_pair_output,
                'index':
                1,
                'size':
                2,
                'shape':
                np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
            }).create_node()

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        embed_input(state_input_scaleshift_attrs, 1, 'weights',
                    node.input_gate_weights)
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        embed_input(state_forget_scaleshift_attrs, 1, 'weights',
                    node.forget_gate_weights)
        state_forget_scaleshift = ScaleShiftOp(
            graph,
            state_forget_scaleshift_attrs).create_node([prev_lstm_state])

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise',
            }).create_node([(split_joined_input, 1), state_input_scaleshift])
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node([(split_joined_input, 2), state_forget_scaleshift])

        # Split -> Tanh
        remember_tahn = Tanh(graph, {
            'name': 'remember_tahnv'
        }).create_node([(split_joined_input, 0)])

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node([join_prev_lstm_input_joined_input_sum])

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node([join_prev_lstm_input_joined_forget_sum])

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul',
        }).create_node([forget_sigmoid, prev_lstm_state])

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul',
            }).create_node([remember_tahn, remember_sigmoid])

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum',
        }).create_node(
            [join_forget_prev_state_mul, join_remember_candidates_mul])

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = Clamp(
            graph, {
                'name': 'join_forget_clamp',
                'max': node.clip_value,
                'min': -node.clip_value
            }).create_node([join_forget_remember_sum])
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Memory(
            graph, {
                'name':
                'next_lstm_state',
                'id':
                memory_pair_output,
                'index':
                0,
                'size':
                2,
                'shape':
                np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
            }).create_node([join_forget_clamp])
        Result(graph, {
            'name': 'next_lstm_state_out'
        }).create_node([next_lstm_state])

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node([join_forget_clamp])

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        embed_input(clamp_scaleshift_attrs, 1, 'weights',
                    node.output_gate_weights)
        clamp_scaleshift = ScaleShiftOp(
            graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node([(split_joined_input, 3), clamp_scaleshift])

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node([join_next_lstm_input_joined_input_sum])

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node([state_filtered_tahn, output_sigmoid])

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'num_output': node.projection_weights_shape[0],
            'bias_term': False
        }
        embed_input(fc_output_attrs, 1, 'weights', node.projection_weights)
        fc_output = InnerProduct(graph, fc_output_attrs).create_node(
            [joined_output_mul])

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Memory(
            graph, {
                'name': 'next_lstm_output',
                'id': memory_pair_input,
                'index': 0,
                'size': 2,
                'shape': np.array([node.gifo_r_weights_shape[1]],
                                  dtype=np.int64)
            }).create_node([fc_output])
        Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node([next_lstm_output])

        return [fc_output.id]
Ejemplo n.º 5
0
    def replace_op(self, graph: Graph, node: Node):
        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node = Split(graph, {'name': graph.unique_id(prefix='Split_lstm_input_'),
                                   'num_split': 5}).create_node()
        node.in_port(0).get_connection().set_destination(split_node.in_port(0))
        for i in range(5):
            split_node.add_output_port(i)

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {'name': graph.unique_id(prefix='i_scaleshift'),
                         'bias_term': False}
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        split_node.out_port(4).connect(i_scale.in_port(0))

        sum_i_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum'}).create_node()
        split_node.out_port(0).connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {'name': graph.unique_id(prefix='f_scaleshift'),
                         'bias_term': False}
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        split_node.out_port(4).connect(f_scale.in_port(0))

        sum_f_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum'}).create_node()
        split_node.out_port(1).connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
        split_node.out_port(2).connect(c_tanh.in_port(0))

        prod_i_c_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_i_c_tanh_'),
                                        'operation': 'mul'}).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Eltwise(graph, {'name': graph.unique_id(prefix='prod_f_ct_1_'),
                                      'operation': 'mul'}).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

        sum_f_i = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {'name': graph.unique_id(prefix='o_scaleshift'),
                         'bias_term': False}
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum'}).create_node()
        split_node.out_port(3).connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
                                          'operation': 'mul'}).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {'name': graph.unique_id(prefix='Concat_c_m')}).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]