def replace_op(self, graph: Graph, node: Node):
        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
        split_node = Split(graph, {
            'name': 'Split_lstm_input_',
            'num_splits': 5
        }).create_node()
        node.in_port(0).get_connection().set_destination(split_node.in_port(0))
        split_node.in_port(1).connect(split_node_axis.out_port(0))

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {'name': 'i_scaleshift', 'bias_term': False}
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        split_node.out_port(4).connect(i_scale.in_port(0))

        sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
        split_node.out_port(0).connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {'name': 'f_scaleshift', 'bias_term': False}
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        split_node.out_port(4).connect(f_scale.in_port(0))

        sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
        split_node.out_port(1).connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
        split_node.out_port(2).connect(c_tanh.in_port(0))

        prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

        sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {'name': 'o_scaleshift', 'bias_term': False}
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
        split_node.out_port(3).connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Mul(graph, {
            'name': 'prod_o_c_t_tanh_'
        }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]
Example #2
0
    def replace_op(self, graph: Graph, node: Node):
        input_node = node.in_node()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'num_output': node.gifo_x_weights_shape[0],
            'bias_term': True
        }

        embed_input(fc_layer_after_input_attrs, 1, 'weights',
                    node.gifo_x_weights)
        embed_input(fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
        fc_layer_after_input = InnerProduct(
            graph, fc_layer_after_input_attrs).create_node([input_node])

        prev_lstm_output = Memory(
            graph, {
                'name': 'prev_memory_output',
                'id': memory_pair_input,
                'index': 1,
                'size': 2,
                'shape': np.array([node.gifo_r_weights_shape[1]],
                                  dtype=np.int64)
            }).create_node()

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'num_output': node.gifo_r_weights_shape[0],
            'bias_term': False
        }

        embed_input(fc_layer_from_prev_state_attrs, 1, 'weights',
                    node.gifo_r_weights)
        fc_layer_from_prev_state = InnerProduct(
            graph,
            fc_layer_from_prev_state_attrs).create_node([prev_lstm_output])

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise',
        }).create_node([fc_layer_from_prev_state, fc_layer_after_input])

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input = Split(
            graph, {
                'name': 'join_input_split',
                'axis': 1,
                'num_split': 4,
                'out_ports_count': 4,
            }).create_node([join_input_prev_state_sum])

        prev_lstm_state = Memory(
            graph, {
                'name':
                'prev_memory_state',
                'id':
                memory_pair_output,
                'index':
                1,
                'size':
                2,
                'shape':
                np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
            }).create_node()

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        embed_input(state_input_scaleshift_attrs, 1, 'weights',
                    node.input_gate_weights)
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        embed_input(state_forget_scaleshift_attrs, 1, 'weights',
                    node.forget_gate_weights)
        state_forget_scaleshift = ScaleShiftOp(
            graph,
            state_forget_scaleshift_attrs).create_node([prev_lstm_state])

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise',
            }).create_node([(split_joined_input, 1), state_input_scaleshift])
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node([(split_joined_input, 2), state_forget_scaleshift])

        # Split -> Tanh
        remember_tahn = Tanh(graph, {
            'name': 'remember_tahnv'
        }).create_node([(split_joined_input, 0)])

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node([join_prev_lstm_input_joined_input_sum])

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node([join_prev_lstm_input_joined_forget_sum])

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul',
        }).create_node([forget_sigmoid, prev_lstm_state])

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul',
            }).create_node([remember_tahn, remember_sigmoid])

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum',
        }).create_node(
            [join_forget_prev_state_mul, join_remember_candidates_mul])

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = Clamp(
            graph, {
                'name': 'join_forget_clamp',
                'max': node.clip_value,
                'min': -node.clip_value
            }).create_node([join_forget_remember_sum])
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Memory(
            graph, {
                'name':
                'next_lstm_state',
                'id':
                memory_pair_output,
                'index':
                0,
                'size':
                2,
                'shape':
                np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
            }).create_node([join_forget_clamp])
        Result(graph, {
            'name': 'next_lstm_state_out'
        }).create_node([next_lstm_state])

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node([join_forget_clamp])

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        embed_input(clamp_scaleshift_attrs, 1, 'weights',
                    node.output_gate_weights)
        clamp_scaleshift = ScaleShiftOp(
            graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node([(split_joined_input, 3), clamp_scaleshift])

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node([join_next_lstm_input_joined_input_sum])

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node([state_filtered_tahn, output_sigmoid])

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'num_output': node.projection_weights_shape[0],
            'bias_term': False
        }
        embed_input(fc_output_attrs, 1, 'weights', node.projection_weights)
        fc_output = InnerProduct(graph, fc_output_attrs).create_node(
            [joined_output_mul])

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Memory(
            graph, {
                'name': 'next_lstm_output',
                'id': memory_pair_input,
                'index': 0,
                'size': 2,
                'shape': np.array([node.gifo_r_weights_shape[1]],
                                  dtype=np.int64)
            }).create_node([fc_output])
        Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node([next_lstm_output])

        return [fc_output.id]
    def replace_op(self, graph: Graph, node: Node):
        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node = Split(graph, {
            'name': graph.unique_id(prefix='Split_lstm_input_'),
            'num_split': 5
        }).create_node()
        node.in_port(0).get_connection().set_destination(split_node.in_port(0))
        for i in range(5):
            split_node.add_output_port(i)

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {
            'name': graph.unique_id(prefix='i_scaleshift'),
            'bias_term': False
        }
        embed_input(i_scale_attrs, 1, 'weights', node.i_weights)
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        split_node.out_port(4).connect(i_scale.in_port(0))

        sum_i_c = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_i_c_'),
            'operation': 'sum'
        }).create_node()
        split_node.out_port(0).connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {
            'name': graph.unique_id(prefix='f_scaleshift'),
            'bias_term': False
        }
        embed_input(f_scale_attrs, 1, 'weights', node.f_weights)
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        split_node.out_port(4).connect(f_scale.in_port(0))

        sum_f_c = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_f_c_'),
            'operation': 'sum'
        }).create_node()
        split_node.out_port(1).connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
        split_node.out_port(2).connect(c_tanh.in_port(0))

        prod_i_c_tanh = Eltwise(
            graph, {
                'name': graph.unique_id(prefix='prod_i_c_tanh_'),
                'operation': 'mul'
            }).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Eltwise(graph, {
            'name': graph.unique_id(prefix='prod_f_ct_1_'),
            'operation': 'mul'
        }).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

        sum_f_i = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_f_i_'),
            'operation': 'sum'
        }).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {
            'name': graph.unique_id(prefix='o_scaleshift'),
            'bias_term': False
        }
        embed_input(o_scale_attrs, 1, 'weights', node.o_weights)
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Eltwise(graph, {
            'name': graph.unique_id(prefix='sum_o_c_'),
            'operation': 'sum'
        }).create_node()
        split_node.out_port(3).connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Eltwise(
            graph, {
                'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
                'operation': 'mul'
            }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {
            'name': graph.unique_id(prefix='Concat_c_m')
        }).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]
Example #4
0
    def replace_op(self, graph: Graph, node: Node):
        input_out_port = node.in_port(0).get_source()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'out-size': node.gifo_x_weights_shape[0],
            'transpose_weights': True,
            'bias_term': True,
        }

        fc_layer_after_input = FullyConnected(
            graph, fc_layer_after_input_attrs).create_node()
        fc_layer_after_input.in_port(0).connect(input_out_port)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1,
                       'weights', node.gifo_x_weights)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2,
                       'biases', node.gifo_biases)

        init_value_prev_lstm_output = create_zero_value_with_batch_from_input(
            input_out_port, node.gifo_r_weights_shape[1])
        prev_lstm_output = ReadValue(graph, {
            'name': 'prev_memory_output',
            'variable_id': memory_pair_input
        }).create_node()
        prev_lstm_output.in_port(0).connect(
            init_value_prev_lstm_output.out_port(0))

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'out-size': node.gifo_r_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False,
        }

        fc_layer_from_prev_state = FullyConnected(
            graph, fc_layer_from_prev_state_attrs).create_node()
        fc_layer_from_prev_state.in_port(0).connect(
            prev_lstm_output.out_port(0))
        input_as_const(fc_layer_from_prev_state,
                       fc_layer_from_prev_state_attrs, 1, 'weights',
                       node.gifo_r_weights)

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise'
        }).create_node()
        join_input_prev_state_sum.in_port(0).connect(
            fc_layer_from_prev_state.out_port(0))
        join_input_prev_state_sum.in_port(1).connect(
            fc_layer_after_input.out_port(0))

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input_axis = Const(graph, {
            'value': np.int64(1)
        }).create_node()
        split_joined_input = Split(graph, {
            'name': 'join_input_split',
            'num_splits': 4,
            'out_ports_count': 4
        }).create_node()
        split_joined_input.in_port(0).connect(
            join_input_prev_state_sum.out_port(0))
        split_joined_input.in_port(1).connect(
            split_joined_input_axis.out_port(0))

        # prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
        #                                 'id': memory_pair_output,
        #                                 'index': 1,
        #                                 'size': 2,
        #                                 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
        #                                 }).create_node()
        init_value_prev_lstm_state = create_zero_value_with_batch_from_input(
            split_joined_input.out_port(0), node.input_gate_weights.shape[0])
        prev_lstm_state = ReadValue(graph, {
            'name': 'prev_memory_state',
            'variable_id': memory_pair_output
        }).create_node()
        prev_lstm_state.in_port(0).connect(
            init_value_prev_lstm_state.out_port(0))

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node()
        state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1,
                       'weights', node.input_gate_weights)

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        state_forget_scaleshift = ScaleShiftOp(
            graph, state_forget_scaleshift_attrs).create_node()
        state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs,
                       1, 'weights', node.forget_gate_weights)

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise'
            }).create_node()
        join_prev_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(1))
        join_prev_lstm_input_joined_input_sum.in_port(1).connect(
            state_input_scaleshift.out_port(0))
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node()
        join_prev_lstm_input_joined_forget_sum.in_port(0).connect(
            split_joined_input.out_port(2))
        join_prev_lstm_input_joined_forget_sum.in_port(1).connect(
            state_forget_scaleshift.out_port(0))

        # Split -> Tanh
        remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
        remember_tahn.in_port(0).connect(split_joined_input.out_port(0))

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node()
        remember_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_input_sum.out_port(0))

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node()
        forget_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_forget_sum.out_port(0))

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul'
        }).create_node()
        join_forget_prev_state_mul.in_port(0).connect(
            forget_sigmoid.out_port(0))
        join_forget_prev_state_mul.in_port(1).connect(
            prev_lstm_state.out_port(0))

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul'
            }).create_node()
        join_remember_candidates_mul.in_port(0).connect(
            remember_tahn.out_port(0))
        join_remember_candidates_mul.in_port(1).connect(
            remember_sigmoid.out_port(0))

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum'
        }).create_node()
        join_forget_remember_sum.in_port(0).connect(
            join_forget_prev_state_mul.out_port(0))
        join_forget_remember_sum.in_port(1).connect(
            join_remember_candidates_mul.out_port(0))

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = create_op_with_const_inputs(
            graph, Clamp, {
                1: np.array(-node.clip_value, dtype=np.float32),
                2: np.array(node.clip_value, dtype=np.float32)
            }, {'name': 'join_forget_clamp'}, join_forget_remember_sum)
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Assign(graph, {
            'name': 'next_lstm_state',
            'variable_id': memory_pair_output
        }).create_node()
        next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))

        res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
        res_node.in_port(0).connect(next_lstm_state.out_port(0))

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node()
        state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        clamp_scaleshift = ScaleShiftOp(graph,
                                        clamp_scaleshift_attrs).create_node()
        clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
        input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights',
                       node.output_gate_weights)

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node()
        join_next_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(3))
        join_next_lstm_input_joined_input_sum.in_port(1).connect(
            clamp_scaleshift.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node()
        output_sigmoid.in_port(0).connect(
            join_next_lstm_input_joined_input_sum.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node()
        joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
        joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'out-size': node.projection_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False
        }
        fc_output = FullyConnected(graph, fc_output_attrs).create_node()
        fc_output.in_port(0).connect(joined_output_mul.out_port(0))
        input_as_const(fc_output, fc_output_attrs, 1, 'weights',
                       node.projection_weights)

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Assign(graph, {
            'name': 'next_lstm_output',
            'variable_id': memory_pair_input
        }).create_node()
        next_lstm_output.in_port(0).connect(fc_output.out_port(0))

        res_node_lstm_output = Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node()
        res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))

        return [fc_output.id]
Example #5
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        # check if we have dropout
        input_port = node.in_port(0)
        if node.has_and_set('use_dropout'):
            split_dropout = AttributedVariadicSplit(
                graph, {
                    'name': node_name + '/split_dropout',
                    'size_splits': int64_array([-1, 1, 1, 1]),
                    'axis': int64_array(1)
                }).create_node()
            input_port.get_connection().set_destination(
                split_dropout.in_port(0))
            input_port = split_dropout.out_port(0)
            i_drop_scale = split_dropout.out_port(1)
            f_drop_scale = split_dropout.out_port(2)
            o_drop_scale = split_dropout.out_port(3)

        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node = create_op_with_const_inputs(
            graph, Split, {1: np.int64(1)}, {
                'name': node_name + '/split_lstm_input',
                'num_splits': 5
            })
        input_port.get_connection().set_destination(split_node.in_port(0))

        i_part = split_node.out_port(0)
        f_part = split_node.out_port(1)
        c_part = split_node.out_port(2)
        o_part = split_node.out_port(3)
        ct_1 = split_node.out_port(4)

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {
            'name': node_name + '/i_scaleshift',
            'bias_term': False
        }
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        ct_1.connect(i_scale.in_port(0))

        sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
        i_part.connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {
            'name': node_name + '/i_sigmoid'
        }).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_i = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_i'
            }).create_node()
            mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0))
            mul_dropout_i.in_port(1).connect(i_drop_scale)
            i_sigmoid = mul_dropout_i

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {
            'name': node_name + '/f_scaleshift',
            'bias_term': False
        }
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        ct_1.connect(f_scale.in_port(0))

        sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
        f_part.connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {
            'name': node_name + '/f_sigmoid'
        }).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_f = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_f'
            }).create_node()
            mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0))
            mul_dropout_f.in_port(1).connect(f_drop_scale)
            f_sigmoid = mul_dropout_f

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
        c_part.connect(c_tanh.in_port(0))

        prod_i_c_tanh = Mul(graph, {
            'name': node_name + '/prod_i_c_tanh_'
        }).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Mul(graph, {
            'name': node_name + '/prod_f_ct_1_'
        }).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        ct_1.connect(prod_f_ct_1.in_port(1))

        sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {
            'name': node_name + '/o_scaleshift',
            'bias_term': False
        }
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
        o_part.connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {
            'name': node_name + '/o_sigmoid'
        }).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_o = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_o'
            }).create_node()
            mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0))
            mul_dropout_o.in_port(1).connect(o_drop_scale)
            o_sigmoid = mul_dropout_o

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Mul(graph, {
            'name': node_name + '/prod_o_c_t_tanh_'
        }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {
            'name': node_name + '/concat_c_m'
        }).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]