def __init__(self, size, tracker_size, lateral_tracking=True, tracking_ln=True): '''Args: size: input size (parser hidden state) = FLAGS.model_dim tracker_size: FLAGS.tracking_lstm_hidden_dim (see FLAGS for the rest)''' super(Tracker, self).__init__() # Initialize layers. if lateral_tracking: self.buf = Linear()(size, 4 * tracker_size, bias=True) self.stack1 = Linear()(size, 4 * tracker_size, bias=False) self.stack2 = Linear()(size, 4 * tracker_size, bias=False) self.lateral = Linear(initializer=HeKaimingInitializer)( tracker_size, 4 * tracker_size, bias=False) self.state_size = tracker_size else: self.state_size = size * 3 if tracking_ln: self.buf_ln = LayerNormalization(size) self.stack1_ln = LayerNormalization(size) self.stack2_ln = LayerNormalization(size) self.lateral_tracking = lateral_tracking self.tracking_ln = tracking_ln self.reset_state()
def __init__(self, hidden_dim, composition_ln=False): super(BinaryTreeLSTMLayer, self).__init__() self.hidden_dim = hidden_dim self.comp_linear = Linear(initializer=HeKaimingInitializer)( in_features=2 * hidden_dim, out_features=5 * hidden_dim) self.composition_ln = composition_ln if composition_ln: self.left_h_ln = LayerNormalization(hidden_dim) self.right_h_ln = LayerNormalization(hidden_dim) self.left_c_ln = LayerNormalization(hidden_dim) self.right_c_ln = LayerNormalization(hidden_dim)