Exemplo n.º 1
0
    def __init__(self,
                 size,
                 tracker_size,
                 lateral_tracking=True,
                 tracking_ln=True):
        '''Args:
            size: input size (parser hidden state) = FLAGS.model_dim
            tracker_size: FLAGS.tracking_lstm_hidden_dim
            (see FLAGS for the rest)'''
        super(Tracker, self).__init__()

        # Initialize layers.
        if lateral_tracking:
            self.buf = Linear()(size, 4 * tracker_size, bias=True)
            self.stack1 = Linear()(size, 4 * tracker_size, bias=False)
            self.stack2 = Linear()(size, 4 * tracker_size, bias=False)
            self.lateral = Linear(initializer=HeKaimingInitializer)(
                tracker_size, 4 * tracker_size, bias=False)
            self.state_size = tracker_size
        else:
            self.state_size = size * 3

        if tracking_ln:
            self.buf_ln = LayerNormalization(size)
            self.stack1_ln = LayerNormalization(size)
            self.stack2_ln = LayerNormalization(size)

        self.lateral_tracking = lateral_tracking
        self.tracking_ln = tracking_ln

        self.reset_state()
Exemplo n.º 2
0
 def __init__(self, hidden_dim, composition_ln=False):
     super(BinaryTreeLSTMLayer, self).__init__()
     self.hidden_dim = hidden_dim
     self.comp_linear = Linear(initializer=HeKaimingInitializer)(
         in_features=2 * hidden_dim, out_features=5 * hidden_dim)
     self.composition_ln = composition_ln
     if composition_ln:
         self.left_h_ln = LayerNormalization(hidden_dim)
         self.right_h_ln = LayerNormalization(hidden_dim)
         self.left_c_ln = LayerNormalization(hidden_dim)
         self.right_c_ln = LayerNormalization(hidden_dim)