def __init__(self, input_size, output_size, hidden_size = 0, sentence_embedding_type = 'last', sentence_zero_inithidden = False, crf_decode_method = 'viterbi', loss_function  = 'likelihood', cross_attention = False, attention_function = 'dot', NTN_flag = False, batch_size = 1, num_layers = 1, dropout = 0, bidirectional = True, batch_first = True):
		super(BiLSTMCRFSplitImpExp, self).__init__()

		if hidden_size <= 0:
			hidden_size = input_size
		if hidden_size % 2 != 0 and bidirectional:
			hidden_size = hidden_size - 1
		self.batch_size = batch_size
		self.cross_attention = cross_attention
		self.NTN_flag = NTN_flag

		self.crf_decode_method = crf_decode_method
		self.loss_function = loss_function

		self.encoder = Encoder(input_size,hidden_size, batch_size=batch_size, num_layers=num_layers, dropout = dropout, bidirectional = bidirectional, batch_first = batch_first, sentence_embedding_type = sentence_embedding_type, sentence_zero_inithidden = sentence_zero_inithidden)

		if self.cross_attention:
			self.soft_attention = SoftAttention(hidden_size, attention_function = attention_function, nonlinear = False, temporal = False)
		
		if self.NTN_flag:
			self.NTN = NeuralTensorNetwork(hidden_size, r = output_size)
			self.explicit_out = nn.Linear(hidden_size*2, output_size)
			self.implicit_out = nn.Linear(hidden_size*6, output_size)
		else:
			self.explicit_out = nn.Linear(hidden_size*2, output_size)
			self.implicit_out = nn.Linear(hidden_size*2, output_size)

		self.Dropout = nn.Dropout(dropout)
		self.CRF = CRF(output_size)
    def __init__(self,
                 input_size,
                 output_size,
                 hidden_size=0,
                 sentence_embedding_type='last',
                 sentence_zero_inithidden=False,
                 attention=None,
                 crf_decode_method='viterbi',
                 loss_function='likelihood',
                 batch_size=1,
                 num_layers=1,
                 dropout=0,
                 bidirectional=True,
                 batch_first=True):
        super(BiLSTMCRF, self).__init__()

        if hidden_size <= 0:
            hidden_size = input_size
        if hidden_size % 2 != 0 and bidirectional:
            hidden_size = hidden_size - 1
        self.batch_size = batch_size
        self.attention = attention
        self.crf_decode_method = crf_decode_method
        self.loss_function = loss_function

        self.encoder = Encoder(
            input_size,
            hidden_size,
            batch_size=batch_size,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional,
            batch_first=batch_first,
            sentence_embedding_type=sentence_embedding_type,
            sentence_zero_inithidden=sentence_zero_inithidden)
        if self.attention:
            self.soft_attention = SoftAttention(
                hidden_size,
                attention_function=self.attention,
                nonlinear=False,
                temporal=False)
        self.out = nn.Linear(hidden_size, output_size)

        self.Dropout = nn.Dropout(dropout)
        self.CRF = CRF(output_size)
class BiLSTMCRF(nn.Module):
    def __init__(self,
                 input_size,
                 output_size,
                 hidden_size=0,
                 sentence_embedding_type='last',
                 sentence_zero_inithidden=False,
                 attention=None,
                 crf_decode_method='viterbi',
                 loss_function='likelihood',
                 batch_size=1,
                 num_layers=1,
                 dropout=0,
                 bidirectional=True,
                 batch_first=True):
        super(BiLSTMCRF, self).__init__()

        if hidden_size <= 0:
            hidden_size = input_size
        if hidden_size % 2 != 0 and bidirectional:
            hidden_size = hidden_size - 1
        self.batch_size = batch_size
        self.attention = attention
        self.crf_decode_method = crf_decode_method
        self.loss_function = loss_function

        self.encoder = Encoder(
            input_size,
            hidden_size,
            batch_size=batch_size,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional,
            batch_first=batch_first,
            sentence_embedding_type=sentence_embedding_type,
            sentence_zero_inithidden=sentence_zero_inithidden)
        if self.attention:
            self.soft_attention = SoftAttention(
                hidden_size,
                attention_function=self.attention,
                nonlinear=False,
                temporal=False)
        self.out = nn.Linear(hidden_size, output_size)

        self.Dropout = nn.Dropout(dropout)
        self.CRF = CRF(output_size)
        #self.softmax = nn.Softmax()

    def _get_lstm_features(self,
                           input,
                           eos_position_list,
                           crf_target,
                           connective_position_list=None):
        _, sentence_level_output, _, _ = self.encoder(
            input, eos_position_list, connective_position_list)

        if self.batch_size == 1:
            #process sample one by one
            output_list = []
            prev_eos = 0
            target_seq = []

            for i in range(len(eos_position_list)):
                if self.attention:
                    output, _ = self.soft_attention(
                        sentence_level_output[:, i, :],
                        word_level_output[:, prev_eos:eos_position_list[i], :])
                else:
                    output = sentence_level_output[:, i, :].view(1, -1)
                prev_eos = eos_position_list[i]

                if crf_target[i] >= 0:
                    output = self.Dropout(output)
                    output = self.out(output)
                    output_list.append(output)
                    target_seq.append(crf_target[i])
        else:
            # To do later: process a batch of samples
            print "Don't support larger batch size now!"
            sys.exit()

        return torch.cat(output_list), torch.LongTensor(target_seq)

    def get_loss(self,
                 input,
                 eos_position_list,
                 target,
                 connective_position_list=None):
        # Get the emission scores from the BiLSTM
        feats, target_seq = self._get_lstm_features(
            input, eos_position_list, self.prepare_sequence(target),
            connective_position_list)

        if self.loss_function == 'likelihood':
            return self.CRF._get_neg_log_likilihood_loss(feats, target_seq)

    def forward(self,
                input,
                eos_position_list,
                target,
                connective_position_list=None):
        # Get the emission scores from the BiLSTM
        feats, _ = self._get_lstm_features(input, eos_position_list,
                                           self.prepare_sequence(target),
                                           connective_position_list)

        # Find the best path, given the features.
        if self.crf_decode_method == 'marginal':
            score, tag_seq = self.CRF._marginal_decode(feats)
        elif self.crf_decode_method == 'viterbi':
            score, tag_seq = self.CRF._viterbi_decode(feats)

        target = target.abs()
        predict = torch.zeros(target.size())
        j = 0
        for i in range(predict.size(0)):
            if torch.max(target[i, :]).data[0] > 0:
                predict[i, tag_seq[j]] = 1
                j += 1

        return Variable(predict)

    def prepare_sequence(self, target):
        target = target.abs()
        max_value, indexs = torch.max(target, 1)
        tensor = []

        for i in range(indexs.size(0)):
            if max_value[i].data[0] > 0:
                tensor.append(indexs[i].data[0])
            else:
                tensor.append(-1)
        return tensor