Esempio n. 1
0
    def action(self, sent, index, output, train=True):
        if train:
            action = sent.gold[index]
        else:
            max_score, max_index = torch.max(output, dim=1)
            # print(max_index)
            # print(max_score)
            action = self.id2gold[utils.to_scalar(max_index)]

        sent.action.append(action)
        pos_id = action.find('#')
        if pos_id == -1:
            last_word_record = sent.words_record[-1] + sent.chars[index]
            sent.words_record[-1] = last_word_record
        else:
            sent.words_record.append(sent.chars[index])
            pos_record = action[(pos_id + 1):]
            sent.pos_record.append(pos_record)
            sent.pos_index.append(self.pos2id[pos_record])
    def forward(self, fea_v, length, target_start, target_end):
        if self.add_char:
            word_v = fea_v[0]
            char_v = fea_v[1]
        else:
            word_v = fea_v
        batch_size = word_v.size(0)
        seq_length = word_v.size(1)

        word_emb = self.embedding(word_v)
        word_emb = self.dropout_emb(word_emb)
        if self.static:
            word_static = self.embedding_static(word_v)
            word_static = self.dropout_emb(word_static)
            word_emb = torch.cat([word_emb, word_static], 2)

        x = torch.transpose(word_emb, 0, 1)
        packed_words = pack_padded_sequence(x, length)
        lstm_out, self.hidden = self.lstm(packed_words, self.hidden)
        lstm_out, _ = pad_packed_sequence(lstm_out)
        ##### lstm_out: (seq_len, batch_size, hidden_size)
        lstm_out = self.dropout_lstm(lstm_out)
        x = lstm_out

        ##### batch version
        # x = torch.squeeze(lstm_out, 1)
        # x: variable (seq_len, batch_size, hidden_size)
        # target_start: variable (batch_size)
        _, start = torch.max(target_start.unsqueeze(0), dim=1)
        max_start = utils.to_scalar(target_start[start])
        _, end = torch.min(target_end.unsqueeze(0), dim=1)
        min_end = utils.to_scalar(target_end[end])

        x = x.transpose(0, 1)
        left_save = []
        mask_left_save = []
        right_save = []
        mask_right_save = []
        target_save = []
        for idx in range(batch_size):
            x_len_cur = x[idx].size(0)
            start_cur = utils.to_scalar(target_start[idx])
            left_len_cur = start_cur
            left_len_max = max_start
            if start_cur != 0:
                x_cur_left = x[idx][:start_cur]
                left_len_sub = left_len_max - left_len_cur
                mask_cur_left = [1 for _ in range(left_len_cur)]
            else:
                x_cur_left = x[idx][0].unsqueeze(0)
                left_len_sub = left_len_max - 1
                # mask_cur_left = [-1e+20]
                mask_cur_left = [0]
            # x_cur_left: variable (start_cur, two_hidden_size)
            # mask_cur_left = [1 for _ in range(start_cur)]
            # mask_cur_left: list (start_cur)
            if start_cur < max_start:
                add = Variable(torch.zeros(left_len_sub, self.lstm_hiddens))
                if self.use_cuda: add = add.cuda()
                x_cur_left = torch.cat([x_cur_left, add], 0)
                # x_cur_left: variable (max_start, two_hidden_size)
                left_save.append(x_cur_left.unsqueeze(0))
                # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)])
                mask_cur_left.extend([0 for _ in range(left_len_sub)])
                # mask_cur_left: list (max_start)
                mask_left_save.append(mask_cur_left)
            else:
                left_save.append(x_cur_left.unsqueeze(0))
                mask_left_save.append(mask_cur_left)

            end_cur = utils.to_scalar(target_end[idx])
            right_len_cur = x_len_cur - end_cur - 1
            right_len_max = x_len_cur - min_end - 1
            if (end_cur + 1) != x_len_cur:
                x_cur_right = x[idx][(end_cur + 1):]
                right_len_sub = right_len_max - right_len_cur
                mask_cur_right = [1 for _ in range(right_len_cur)]
            else:
                x_cur_right = x[idx][end_cur].unsqueeze(0)
                right_len_sub = right_len_max - right_len_cur - 1
                # mask_cur_right = [-1e+20]
                mask_cur_right = [0]
            # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size)
            # mask_cur_right = [1 for _ in range(right_len_cur)]
            # mask_cur_right: list (x_len_cur-end_cur-1==right_len)
            if end_cur > min_end:
                add = Variable(torch.zeros(right_len_sub, self.lstm_hiddens))
                if self.use_cuda: add = add.cuda()
                x_cur_right = torch.cat([x_cur_right, add], 0)
                right_save.append(x_cur_right.unsqueeze(0))
                # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)])
                mask_cur_right.extend([0 for _ in range(right_len_sub)])
                mask_right_save.append(mask_cur_right)
            else:
                right_save.append(x_cur_right.unsqueeze(0))
                mask_right_save.append(mask_cur_right)

            # target_sub = end_cur-start_cur
            x_target = x[idx][start_cur:(end_cur + 1)]
            x_average_target = torch.mean(x_target, 0)
            target_save.append(x_average_target.unsqueeze(0))
        mask_left_save = Variable(torch.ByteTensor(mask_left_save))
        # mask_left_save: variable (batch_size, left_len_max)
        mask_right_save = Variable(torch.ByteTensor(mask_right_save))
        # mask_right_save: variable (batch_size, right_len_max)
        left_save = torch.cat(left_save, 0)
        right_save = torch.cat(right_save, 0)
        target_save = torch.cat(target_save, 0)
        # left_save: variable (batch_size, left_len_max, two_hidden_size)
        # right_save: variable (batch_size, right_len_max, two_hidden_size)
        # target_save: variable (batch_size, two_hidden_size)
        if self.use_cuda:
            mask_right_save = mask_right_save.cuda()
            mask_left_save = mask_left_save.cuda()
            left_save = left_save.cuda()
            right_save = right_save.cuda()
            target_save = target_save.cuda()

        s = self.attention(x, target_save, None)
        s_l = self.attention_l(left_save, target_save, mask_left_save)
        s_r = self.attention_r(right_save, target_save, mask_right_save)

        result = self.linear(s)  # result: variable (1, label_num)
        # result = self.linear_l(s_l)
        result = torch.add(result, self.linear_l(s_l))
        result = torch.add(result, self.linear_r(s_r))
        # result: variable (batch_size, label_num)
        # print(result)
        return result
Esempio n. 3
0
    def forward(self, fea_v, length, target_start, target_end, label_v):
        word_v = fea_v
        batch_size = word_v.size(0)
        seq_length = word_v.size(1)

        word_emb = self.embedding(word_v)
        word_emb = self.dropout_emb(word_emb)
        label_emb = self.embedding_label(label_v)
        label_emb = self.dropout_emb(label_emb)

        # x = torch.cat([word_emb, label_emb], 2)
        x = torch.transpose(word_emb, 0, 1)
        # x = torch.transpose(x, 0, 1)
        # lstm_out, _ = self.lstm(x)

        packed_words = pack_padded_sequence(x, length)
        lstm_out, self.hidden = self.lstm(packed_words, self.hidden)
        lstm_out, _ = pad_packed_sequence(lstm_out)
        ##### lstm_out: (seq_len, batch_size, hidden_size)
        lstm_out = self.dropout_lstm(lstm_out)
        x = lstm_out
        # print(x)
        x = x.transpose(0, 1)

        ##### batch version
        # x = torch.squeeze(lstm_out, 1)
        # x: variable (seq_len, batch_size, hidden_size)
        # target_start: variable (batch_size)
        # _, start = torch.max(target_start.unsqueeze(0), dim=1)
        # max_start = utils.to_scalar(target_start[start])
        # _, end = torch.min(target_end.unsqueeze(0), dim=1)
        # min_end = utils.to_scalar(target_end[end])

        max_length = 0
        for index in range(batch_size):
            x_len = x[index].size(0)
            start = utils.to_scalar(target_start[index])
            end = utils.to_scalar(target_end[index])
            none_t = x_len - (end - start + 1)
            if none_t > max_length: max_length = none_t

        none_target = []
        mask_none_target = []
        target_save = []
        for idx in range(batch_size):
            mask_none_t = []
            none_t = None
            x_len_cur = x[idx].size(0)
            start_cur = utils.to_scalar(target_start[idx])
            end_cur = utils.to_scalar(target_end[idx])

            x_target = x[idx][start_cur:(end_cur + 1)]
            x_average_target = torch.mean(x_target, 0)
            target_save.append(x_average_target.unsqueeze(0))
            if start_cur != 0:
                left = x[idx][:start_cur]
                none_t = left
                mask_none_t.extend([1] * start_cur)
            if end_cur != (x_len_cur - 1):
                right = x[idx][(end_cur + 1):]
                if none_t is not None: none_t = torch.cat([none_t, right], 0)
                else: none_t = right
                mask_none_t.extend([1] * (x_len_cur - end_cur - 1))
            if len(mask_none_t) != max_length:
                add_t = Variable(
                    torch.zeros((max_length - len(mask_none_t)),
                                self.lstm_hiddens))
                if self.use_cuda: add_t = add_t.cuda()
                mask_none_t.extend([0] * (max_length - len(mask_none_t)))
                # print(add_t)
                none_t = torch.cat([none_t, add_t], 0)
            mask_none_target.append(mask_none_t)
            none_target.append(none_t.unsqueeze(0))

        target_save = torch.cat(target_save, 0)
        # print(none_target)
        none_target = torch.cat(none_target, 0)
        mask_none_target = Variable(torch.ByteTensor(mask_none_target))
        # target_save: variable (batch_size, two_hidden_size)
        if self.use_cuda:
            target_save = target_save.cuda()
            mask_none_target = mask_none_target.cuda()
            none_target = none_target.cuda()

        # squence = torch.cat(none_target, 1)
        # s = self.attention(x, target_save, None)
        s = self.attention(none_target, target_save, mask_none_target)
        # print(s)
        ##### s: variable (batch_size, lstm_hiddens)

        result = self.linear(s)  # result: variable (1, label_num)
        # result: variable (batch_size, label_num)

        return result
Esempio n. 4
0
    def forward(self, fea_v, length, target_start, target_end, label_v):
        word_v = fea_v
        batch_size = word_v.size(0)
        seq_length = word_v.size(1)

        word_emb = self.embedding(word_v)
        word_emb = self.dropout_emb(word_emb)

        # label_emb = self.embedding_label(label_v)
        # label_emb = self.dropout_emb(label_emb)

        # x = torch.cat([word_emb, label_emb], 2)
        # x = torch.transpose(x, 0, 1)

        x = torch.transpose(word_emb, 0, 1)
        packed_words = pack_padded_sequence(x, length)
        lstm_out, self.hidden = self.lstm(packed_words, self.hidden)
        lstm_out, _ = pad_packed_sequence(lstm_out)
        ##### lstm_out: (seq_len, batch_size, hidden_size)
        lstm_out = self.dropout_lstm(lstm_out)
        x = lstm_out
        x = x.transpose(0, 1)
        ##### batch version
        # x = torch.squeeze(lstm_out, 1)
        # x: variable (seq_len, batch_size, hidden_size)
        # target_start: variable (batch_size)
        _, start = torch.max(target_start.unsqueeze(0), dim=1)
        max_start = utils.to_scalar(target_start[start])
        _, end = torch.min(target_end.unsqueeze(0), dim=1)
        min_end = utils.to_scalar(target_end[end])

        max_length = 0
        for index in range(batch_size):
            x_len = x[index].size(0)
            start = utils.to_scalar(target_start[index])
            end = utils.to_scalar(target_end[index])
            none_t = x_len-(end-start+1)
            if none_t > max_length: max_length = none_t

        left_save = []
        mask_left_save = []
        right_save = []
        mask_right_save = []
        target_save = []
        none_target = []
        mask_none_target = []
        target_save_mean = []
        for idx in range(batch_size):
            mask_none_t = []
            none_t = None
            x_len_cur = x[idx].size(0)
            start_cur = utils.to_scalar(target_start[idx])
            end_cur = utils.to_scalar(target_end[idx])
            # left_len_cur = start_cur
            # left_len_max = max_start
            x_target = x[idx][start_cur:(end_cur + 1)]
            x_target_mean = torch.mean(x_target, 0)
            target_save_mean.append(x_target_mean.unsqueeze(0))

            x_target_fc = F.tanh(self.linear_try(x_target))
            beta = torch.mm(x_target_fc, self.u_try)
            beta = torch.t(beta)
            if self.use_cuda:
                alpha = F.softmax(beta, dim=1)  # alpha: variable (1, len)
            else:
                alpha = F.softmax(beta)  # alpha: variable (1, len)
            x_average_target = torch.mm(alpha, x_target)
            x_average_target = x_average_target.squeeze(0)
            # print(x_average_target)
            target_save.append(x_average_target.unsqueeze(0))

            if start_cur != 0:
                left = x[idx][:start_cur]
                none_t = left
                mask_none_t.extend([1] * start_cur)
            if end_cur != (x_len_cur - 1):
                right = x[idx][(end_cur + 1):]
                if none_t is not None:
                    none_t = torch.cat([none_t, right], 0)
                else:
                    none_t = right
                mask_none_t.extend([1] * (x_len_cur - end_cur - 1))
            if len(mask_none_t) != max_length:
                add_t = Variable(torch.zeros((max_length - len(mask_none_t)), self.lstm_hiddens))
                if self.use_cuda: add_t = add_t.cuda()
                mask_none_t.extend([0] * (max_length - len(mask_none_t)))
                # print(add_t)
                none_t = torch.cat([none_t, add_t], 0)
            mask_none_target.append(mask_none_t)
            none_target.append(none_t.unsqueeze(0))

            x_len_cur = x[idx].size(0)
            start_cur = utils.to_scalar(target_start[idx])
            left_len_cur = start_cur
            left_len_max = max_start
            if start_cur != 0:
                x_cur_left = x[idx][:start_cur]
                left_len_sub = left_len_max - left_len_cur
                mask_cur_left = [1 for _ in range(left_len_cur)]
            else:
                x_cur_left = x[idx][0].unsqueeze(0)
                left_len_sub = left_len_max - 1
                # mask_cur_left = [-1e+20]
                mask_cur_left = [0]
            # x_cur_left: variable (start_cur, two_hidden_size)
            # mask_cur_left = [1 for _ in range(start_cur)]
            # mask_cur_left: list (start_cur)
            if start_cur < max_start:
                add = Variable(torch.zeros(left_len_sub, self.lstm_hiddens))
                if self.use_cuda: add = add.cuda()
                x_cur_left = torch.cat([x_cur_left, add], 0)
                # x_cur_left: variable (max_start, two_hidden_size)
                left_save.append(x_cur_left.unsqueeze(0))
                # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)])
                mask_cur_left.extend([0 for _ in range(left_len_sub)])
                # mask_cur_left: list (max_start)
                mask_left_save.append(mask_cur_left)
            else:
                left_save.append(x_cur_left.unsqueeze(0))
                mask_left_save.append(mask_cur_left)

            end_cur = utils.to_scalar(target_end[idx])
            right_len_cur = x_len_cur-end_cur-1
            right_len_max = x_len_cur-min_end-1
            if (end_cur+1) != x_len_cur:
                x_cur_right = x[idx][(end_cur+1):]
                right_len_sub = right_len_max - right_len_cur
                mask_cur_right = [1 for _ in range(right_len_cur)]
            else:
                x_cur_right = x[idx][end_cur].unsqueeze(0)
                right_len_sub = right_len_max - right_len_cur - 1
                # mask_cur_right = [-1e+20]
                mask_cur_right = [0]
            # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size)
            # mask_cur_right = [1 for _ in range(right_len_cur)]
            # mask_cur_right: list (x_len_cur-end_cur-1==right_len)
            if end_cur > min_end:
                add = Variable(torch.zeros(right_len_sub, self.lstm_hiddens))
                if self.use_cuda: add = add.cuda()
                x_cur_right = torch.cat([x_cur_right, add], 0)
                right_save.append(x_cur_right.unsqueeze(0))
                # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)])
                mask_cur_right.extend([0 for _ in range(right_len_sub)])
                mask_right_save.append(mask_cur_right)
            else:
                right_save.append(x_cur_right.unsqueeze(0))
                mask_right_save.append(mask_cur_right)

            # target_sub = end_cur-start_cur
            # x_target = x[idx][start_cur:(end_cur+1)]
            # x_average_target = torch.mean(x_target, 0)
            # target_save.append(x_average_target.unsqueeze(0))
        mask_left_save = Variable(torch.ByteTensor(mask_left_save))
        # mask_left_save: variable (batch_size, left_len_max)
        mask_right_save = Variable(torch.ByteTensor(mask_right_save))
        # mask_right_save: variable (batch_size, right_len_max)
        left_save = torch.cat(left_save, 0)
        right_save = torch.cat(right_save, 0)
        target_save = torch.cat(target_save, 0)
        # left_save: variable (batch_size, left_len_max, two_hidden_size)
        # right_save: variable (batch_size, right_len_max, two_hidden_size)
        # target_save: variable (batch_size, two_hidden_size)
        none_target = torch.cat(none_target, 0)
        mask_none_target = Variable(torch.ByteTensor(mask_none_target))
        target_save_mean = torch.cat(target_save_mean, 0)
        if self.use_cuda:
            mask_right_save = mask_right_save.cuda()
            mask_left_save = mask_left_save.cuda()
            left_save = left_save.cuda()
            right_save = right_save.cuda()
            target_save = target_save.cuda()
            mask_none_target = mask_none_target.cuda()
            none_target = none_target.cuda()
            target_save_mean = target_save_mean.cuda()

        # s = self.attention(x, target_save, None)
        s = self.attention(none_target, target_save_mean, mask_none_target)
        s_l = self.attention_l(left_save, target_save_mean, mask_left_save)
        s_r = self.attention_r(right_save, target_save_mean, mask_right_save)

        # s = torch.cat([s, target_save], 1)
        result = self.linear(s)  # result: variable (1, label_num)
        # result = self.linear_l(s_l)
        result = torch.add(result, self.linear_l(s_l))
        result = torch.add(result, self.linear_r(s_r))
        # result: variable (batch_size, label_num)
        # print(result)
        return result
def train_segpos(train_insts, dev_insts, test_insts, encode, decode, config,
                 params):
    print('training...')
    parameters_en = filter(lambda p: p.requires_grad, encode.parameters())
    # optimizer_en = torch.optim.SGD(params=parameters_en, lr=config.learning_rate, momentum=0.9, weight_decay=config.decay)
    optimizer_en = torch.optim.Adam(params=parameters_en,
                                    lr=config.learning_rate,
                                    weight_decay=config.decay)
    parameters_de = filter(lambda p: p.requires_grad, decode.parameters())
    # optimizer_de = torch.optim.SGD(params= parameters_de, lr=config.learning_rate, momentum=0.9, weight_decay=config.decay)
    optimizer_de = torch.optim.Adam(params=parameters_de,
                                    lr=config.learning_rate,
                                    weight_decay=config.decay)

    best_dev_f1_seg = float('-inf')
    best_dev_f1_pos = float('-inf')
    best_test_f1_seg = float('-inf')
    best_test_f1_pos = float('-inf')

    dev_eval_seg = Eval()
    dev_eval_pos = Eval()
    test_eval_seg = Eval()
    test_eval_pos = Eval()
    for epoch in range(config.maxIters):
        start_time = time.time()
        encode.train()
        decode.train()
        train_insts = utils.random_instances(train_insts)
        epoch_loss = 0
        train_buckets = params.generate_batch_buckets(config.train_batch_size,
                                                      train_insts)

        for index in range(len(train_buckets)):
            batch_length = np.array(
                [np.sum(mask) for mask in train_buckets[index][-1]])

            var_b, list_b, mask_v, length_v, gold_v = utils.patch_var(
                train_buckets[index], batch_length.tolist(), params)
            encode.zero_grad()
            decode.zero_grad()
            if mask_v.size(0) != config.train_batch_size:
                encode.hidden = encode.init_hidden(mask_v.size(0))
            else:
                encode.hidden = encode.init_hidden(config.train_batch_size)
            lstm_out = encode.forward(var_b, list_b, mask_v,
                                      batch_length.tolist())
            output, state = decode.forward(lstm_out,
                                           var_b,
                                           list_b,
                                           mask_v,
                                           batch_length.tolist(),
                                           is_train=True)
            #### output: variable (batch_size, max_length, segpos_num)

            # num_total = output.size(0)*output.size(1)
            # output = output.contiguous().view(num_total, output.size(2))
            # print(output)
            # gold_v = gold_v.view(num_total)
            # print(output)
            gold_v = gold_v.view(output.size(0))
            # print(gold_v)
            loss = F.cross_entropy(output, gold_v)
            loss.backward()

            nn.utils.clip_grad_norm(parameters_en, max_norm=config.clip_grad)
            nn.utils.clip_grad_norm(parameters_de, max_norm=config.clip_grad)

            optimizer_en.step()
            optimizer_de.step()
            epoch_loss += utils.to_scalar(loss)
        print('\nepoch is {}, average loss is {} '.format(
            epoch,
            (epoch_loss / (config.train_batch_size * len(train_buckets)))))
        # update lr
        # adjust_learning_rate(optimizer, config.learning_rate / (1 + (epoch + 1) * config.decay))
        # acc = float(correct_num) / float(gold_num)
        # print('\nepoch is {}, accuracy is {}'.format(epoch, acc))
        print('the {} epoch training costs time: {} s '.format(
            epoch,
            time.time() - start_time))

        print('\nDev...')
        dev_eval_seg.clear()
        dev_eval_pos.clear()
        test_eval_seg.clear()
        test_eval_pos.clear()

        start_time = time.time()
        dev_f1_seg, dev_f1_pos = eval_batch(dev_insts, encode, decode, config,
                                            params, test_eval_seg,
                                            test_eval_pos)
        print('the {} epoch dev costs time: {} s'.format(
            epoch,
            time.time() - start_time))

        if dev_f1_seg > best_dev_f1_seg:
            best_dev_f1_seg = dev_f1_seg
            if dev_f1_pos > best_dev_f1_pos: best_dev_f1_pos = dev_f1_pos
            print('\nTest...')
            start_time = time.time()
            test_f1_seg, test_f1_pos = eval_batch(test_insts, encode, decode,
                                                  config, params,
                                                  test_eval_seg, test_eval_pos)
            print('the {} epoch testing costs time: {} s'.format(
                epoch,
                time.time() - start_time))

            if test_f1_seg > best_test_f1_seg:
                best_test_f1_seg = test_f1_seg
            if test_f1_pos > best_test_f1_pos:
                best_test_f1_pos = test_f1_pos
            print(
                'now, test fscore of seg is {}, test fscore of pos is {}, best test fscore of seg is {}, best fscore of pos is {} '
                .format(test_f1_seg, test_f1_pos, best_test_f1_seg,
                        best_test_f1_pos))
            torch.save(encode.state_dict(), config.save_encode_path)
            torch.save(decode.state_dict(), config.save_decode_path)
        else:
            if dev_f1_pos > best_dev_f1_pos:
                best_dev_f1_pos = dev_f1_pos
                print('\nTest...')
                start_time = time.time()
                test_f1_seg, test_f1_pos = eval_batch(test_insts, encode,
                                                      decode, config, params,
                                                      test_eval_seg,
                                                      test_eval_pos)
                print('the {} epoch testing costs time: {} s'.format(
                    epoch,
                    time.time() - start_time))

                if test_f1_seg > best_test_f1_seg:
                    best_test_f1_seg = test_f1_seg
                if test_f1_pos > best_test_f1_pos:
                    best_test_f1_pos = test_f1_pos
                print(
                    'now, test fscore of seg is {}, test fscore of pos is {}, best test fscore of seg is {}, best fscore of pos is {} '
                    .format(test_f1_seg, test_f1_pos, best_test_f1_seg,
                            best_test_f1_pos))
                torch.save(encode.state_dict(), config.save_encode_path)
                torch.save(decode.state_dict(), config.save_decode_path)
        print(
            'now, dev fscore of seg is {}, dev fscore of pos is {}, best dev fscore of seg is {}, best dev fscore of pos is {}, best test fscore of seg is {}, best test fscore of pos is {}'
            .format(dev_f1_seg, dev_f1_pos, best_dev_f1_seg, best_test_f1_pos,
                    best_test_f1_seg, best_test_f1_pos))
Esempio n. 6
0
 def action_batch(self, index, states, output, length):
     # if train:
     #     action = states.golds[index]        # ['SEP#P', 'SEP#CD', 'SEP#DT', 'SEP#NT', 'SEP#NN', 'SEP#P', 'SEP#NR', 'SEP#AD', 'SEP#NR', 'SEP#NN']
     #     if index == 0:
     #         for id, ele in enumerate(states.chars[index]):
     #             states.words_record[id].append(ele)
     #         # pos_record_cur = []
     #         # for ele in states.golds[index]:
     #         #     # states.pos_record[index].append(ele.split('#')[1])
     #         #     pos_record_cur.append(ele.split('#')[1])
     #         # states.pos_record.append(pos_record_cur)
     #         # # states.pos_record[index].append(states.gold[index].split('#')[1])
     #         for id, ele in enumerate(states.golds[index]):
     #             states.pos_record[id].append(ele.split('#')[1])
     #         # pos_index_cur = []
     #         # for ele in states.golds[index]:
     #         #     # states.pos_index[index].append(self.pos2id[ele.split('#')[1]])
     #         #     pos_index_cur.append(self.pos2id[ele.split('#')[1]])
     #         # states.pos_index.append(pos_index_cur)
     #         # # sent.pos_index.append(self.pos2id[sent.gold[index].split('#')[1]])
     #         for id, ele in enumerate(states.golds[index]):
     #             states.pos_index[id].append(self.pos2id[ele.split('#')[1]])
     # else:
     max_score, max_index = torch.max(output, dim=1)
     # print(max_index)
     # action = [self.id2gold[utils.to_scalar(ele)] for ele in max_index]
     action = []
     for id, ele in enumerate(max_index):
         action_cur = self.id2gold[utils.to_scalar(ele)]
         if index >= length[id]:
             action_cur = '<pad>'
         action.append(action_cur)
     # print(action)
     if index == 0:
         for id, ele in enumerate(action):
             pos_id = ele.find('#')
             if pos_id == -1:
                 print('action at the first index is error.')
             else:
                 states.words_record[id].append(states.chars[index][id])
                 pos_record = action[id][(pos_id + 1):]
                 states.pos_record[id].append(pos_record)
                 pos_index = self.pos2id[pos_record]
                 states.pos_index[id].append(pos_index)
     # states.action.append(action)
     for id, ele in enumerate(action):
         states.action[id].append(ele)
     if index != 0:
         # last_words_record = states.words_record[index-1]
         # words_record_cur = []
         # pos_record_cur = []
         # pos_index_cur = []
         for id, ele in enumerate(action):
             if ele == '<pad>':
                 states.words_record[id].append('<pad>')
                 states.pos_record[id].append('<pad>')
                 states.pos_index[id].append(self.posPadID)
             else:
                 pos_id = ele.find('#')
                 if pos_id == -1:
                     last_cur = states.words_record[id][-1] + states.chars[
                         index][id]  # chars和golds不一样
                     # words_record_cur.append(last_cur)
                     states.words_record[id][-1] = last_cur
                 else:
                     # words_record_cur.append(states.chars[index][id])
                     states.words_record[id].append(states.chars[index][id])
                     pos_record = action[id][(pos_id + 1):]
                     # pos_record_cur.append(pos_record)
                     states.pos_record[id].append(pos_record)
                     pos_index = self.pos2id[pos_record]
                     # pos_index_cur.append(pos_index)
                     states.pos_index[id].append(pos_index)
    def forward(self, fea_v, length, target_start, target_end, label_v):
        if self.add_char:
            word_v = fea_v[0]
            char_v = fea_v[1]
        else:
            word_v = fea_v
        batch_size = word_v.size(0)
        seq_length = word_v.size(1)

        word_emb = self.embedding(word_v)
        word_emb = self.dropout_emb(word_emb)
        label_emb = self.embedding_label(label_v)
        label_emb = self.dropout_emb(label_emb)
        x = torch.cat([word_emb, label_emb], 2)

        x = torch.transpose(x, 0, 1)
        # x = torch.transpose(word_emb, 0, 1)
        packed_words = pack_padded_sequence(x, length)
        lstm_out, self.hidden = self.lstm(packed_words, self.hidden)
        lstm_out, _ = pad_packed_sequence(lstm_out)
        ##### lstm_out: (seq_len, batch_size, hidden_size)
        lstm_out = self.dropout_lstm(lstm_out)
        x = lstm_out

        ##### batch version
        # x: variable (seq_len, batch_size, hidden_size)
        # target_start: variable (batch_size)
        _, start = torch.max(target_start.unsqueeze(0), dim=1)
        max_start = utils.to_scalar(target_start[start])
        _, end = torch.min(target_end.unsqueeze(0), dim=1)
        min_end = utils.to_scalar(target_end[end])

        x = x.transpose(0, 1)

        max_length = 0
        for index in range(batch_size):
            x_len = x[index].size(0)
            start = utils.to_scalar(target_start[index])
            end = utils.to_scalar(target_end[index])
            none_t = x_len - (end - start + 1)
            if none_t > max_length: max_length = none_t

        left_save = []
        mask_left_save = []
        right_save = []
        mask_right_save = []
        target_save_mean = []
        none_target = []
        mask_none_target = []
        for idx in range(batch_size):
            mask_none_t = []
            none_t = None
            x_len_cur = x[idx].size(0)
            start_cur = utils.to_scalar(target_start[idx])
            end_cur = utils.to_scalar(target_end[idx])
            left_len_cur = start_cur
            left_len_max = max_start
            if start_cur != 0:
                x_cur_left = x[idx][:start_cur]
                left_len_sub = left_len_max - left_len_cur
                mask_cur_left = [1 for _ in range(left_len_cur)]
            else:
                x_cur_left = x[idx][0].unsqueeze(0)
                left_len_sub = left_len_max - 1
                # mask_cur_left = [-1e+20]
                mask_cur_left = [0]
            # x_cur_left: variable (start_cur, two_hidden_size)
            # mask_cur_left = [1 for _ in range(start_cur)]
            # mask_cur_left: list (start_cur)
            if start_cur < max_start:
                if left_len_sub == 0: print('error')
                add = Variable(torch.rand(left_len_sub, self.lstm_hiddens))
                if self.use_cuda: add = add.cuda()
                x_cur_left = torch.cat([x_cur_left, add], dim=0)
                # x_cur_left: variable (max_start, two_hidden_size)
                left_save.append(x_cur_left.unsqueeze(0))
                # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)])
                mask_cur_left.extend([0 for _ in range(left_len_sub)])
                # mask_cur_left: list (max_start)
                mask_left_save.append(mask_cur_left)
            else:
                left_save.append(x_cur_left.unsqueeze(0))
                mask_left_save.append(mask_cur_left)

            right_len_cur = x_len_cur - end_cur - 1
            right_len_max = x_len_cur - min_end - 1
            if (end_cur + 1) != x_len_cur:
                x_cur_right = x[idx][(end_cur + 1):]
                right_len_sub = right_len_max - right_len_cur
                mask_cur_right = [1 for _ in range(right_len_cur)]
            else:
                x_cur_right = x[idx][end_cur].unsqueeze(0)
                right_len_sub = right_len_max - right_len_cur - 1
                # mask_cur_right = [-1e+20]
                mask_cur_right = [0]
            # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size)
            # mask_cur_right = [1 for _ in range(right_len_cur)]
            # mask_cur_right: list (x_len_cur-end_cur-1==right_len)
            if end_cur > min_end:
                if right_len_sub == 0: print('error2')
                add = Variable(torch.rand(right_len_sub, self.lstm_hiddens))
                if self.use_cuda: add = add.cuda()
                x_cur_right = torch.cat([x_cur_right, add], dim=0)
                right_save.append(x_cur_right.unsqueeze(0))
                # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)])
                mask_cur_right.extend([0 for _ in range(right_len_sub)])
                mask_right_save.append(mask_cur_right)
            else:
                right_save.append(x_cur_right.unsqueeze(0))
                mask_right_save.append(mask_cur_right)

            if start_cur != 0:
                left = x[idx][:start_cur]
                none_t = left
                mask_none_t.extend([1] * start_cur)
            if end_cur != (x_len_cur - 1):
                right = x[idx][(end_cur + 1):]
                if none_t is not None: none_t = torch.cat([none_t, right], 0)
                else: none_t = right
                mask_none_t.extend([1] * (x_len_cur - end_cur - 1))
            if len(mask_none_t) != max_length:
                add_t = Variable(
                    torch.zeros((max_length - len(mask_none_t)),
                                self.lstm_hiddens))
                if self.use_cuda: add_t = add_t.cuda()
                mask_none_t.extend([0] * (max_length - len(mask_none_t)))
                # print(add_t)
                none_t = torch.cat([none_t, add_t], 0)
            mask_none_target.append(mask_none_t)
            none_target.append(none_t.unsqueeze(0))

            x_target = x[idx][start_cur:(end_cur + 1)]
            x_average_target = torch.mean(x_target, 0)
            target_save_mean.append(x_average_target.unsqueeze(0))
        mask_left_save = Variable(torch.ByteTensor(mask_left_save))
        # mask_left_save: variable (batch_size, left_len_max)
        mask_right_save = Variable(torch.ByteTensor(mask_right_save))
        # mask_right_save: variable (batch_size, right_len_max)
        left_save = torch.cat(left_save, dim=0)
        right_save = torch.cat(right_save, dim=0)
        target_save_mean = torch.cat(target_save_mean, dim=0)
        none_target = torch.cat(none_target, 0)
        mask_none_target = Variable(torch.ByteTensor(mask_none_target))
        # left_save: variable (batch_size, left_len_max, two_hidden_size)
        # right_save: variable (batch_size, right_len_max, two_hidden_size)
        # target_save_mean: variable (batch_size, two_hidden_size)
        if self.use_cuda:
            mask_right_save = mask_right_save.cuda()
            mask_left_save = mask_left_save.cuda()
            left_save = left_save.cuda()
            right_save = right_save.cuda()
            target_save_mean = target_save_mean.cuda()
            none_target = none_target.cuda()
            mask_none_target = mask_none_target.cuda()

        # s, s_alpha = self.attention(x, target_save, None)
        # s_l, s_l_alpha = self.attention_l(left_save, target_save, mask_left_save)
        # s_r, s_r_alpha = self.attention_r(right_save, target_save, mask_right_save)
        # s = self.attention(x, target_save, None)
        s = self.attention(none_target, target_save_mean, mask_none_target)
        s_l = self.attention_l(left_save, target_save_mean, mask_left_save)
        s_r = self.attention_r(right_save, target_save_mean, mask_right_save)

        w1s = torch.mm(self.w1, torch.t(s))
        u1t = torch.mm(self.u1, torch.t(target_save_mean))
        if self.use_cuda:
            w1s = w1s.cuda()
            u1t = u1t.cuda()

        if batch_size == self.batch_size:
            z = torch.exp(w1s + u1t + self.b1)
        else:
            z = torch.exp(w1s + u1t)

        z_all = z
        # z_all: variable (two_hidden_size, batch_size)
        z_all = z_all.unsqueeze(2)

        w2s = torch.mm(self.w2, torch.t(s_l))
        u2t = torch.mm(self.u2, torch.t(target_save_mean))
        if self.use_cuda:
            w2s = w2s.cuda()
            u2t = u2t.cuda()
        if batch_size == self.batch_size:
            z_l = torch.exp(w2s + u2t + self.b2)
        else:
            z_l = torch.exp(w2s + u2t)
        # print(z_all)
        # print(z_l)
        z_all = torch.cat([z_all, z_l.unsqueeze(2)], dim=2)

        w3s = torch.mm(self.w3, torch.t(s_r))
        u3t = torch.mm(self.u3, torch.t(target_save_mean))
        if self.use_cuda:
            w3s = w3s.cuda()
            u3t = u3t.cuda()
        if batch_size == self.batch_size:
            z_r = torch.exp(w3s + u3t + self.b3)
        else:
            z_r = torch.exp(w3s + u3t)
        z_all = torch.cat([z_all, z_r.unsqueeze(2)], dim=2)

        # z_all: variable (two_hidden_size, batch_size, 3)
        if self.use_cuda:
            z_all = F.softmax(z_all, dim=2)
        else:
            z_all = F.softmax(z_all)
        # z_all = torch.t(z_all)
        z_all = z_all.permute(2, 1, 0)
        # z = torch.unsqueeze(z_all[:batch_size], 0)
        # z_l = torch.unsqueeze(z_all[batch_size:(2*batch_size)], 0)
        # z_r = torch.unsqueeze(z_all[(2*batch_size):], 0)
        # z = z_all[:batch_size]
        # z_l = z_all[batch_size:(2*batch_size)]
        # z_r = z_all[(2*batch_size):]
        z = z_all[0]
        z_l = z_all[1]
        z_r = z_all[2]

        ss = torch.mul(z, s)
        ss = torch.add(ss, torch.mul(z_l, s_l))
        ss = torch.add(ss, torch.mul(z_r, s_r))

        logit = self.linear_2(ss)
        # print(logit)
        # alpha = [s_alpha, s_l_alpha, s_r_alpha]
        return logit
Esempio n. 8
0
    def forward(self, fea_v, length, target_start, target_end):
        if self.add_char:
            word_v = fea_v[0]
            char_v = fea_v[1]
        else:
            word_v = fea_v
        batch_size = word_v.size(0)
        seq_length = word_v.size(1)

        word_emb = self.embedding(word_v)
        word_emb = self.dropout_emb(word_emb)
        if self.static:
            word_static = self.embedding_static(word_v)
            word_static = self.dropout_emb(word_static)
            word_emb = torch.cat([word_emb, word_static], 2)

        x = torch.transpose(word_emb, 0, 1)
        packed_words = pack_padded_sequence(x, length)
        lstm_out, self.hidden = self.lstm(packed_words, self.hidden)
        lstm_out, _ = pad_packed_sequence(lstm_out)
        ##### lstm_out: (seq_len, batch_size, hidden_size)
        lstm_out = self.dropout_lstm(lstm_out)
        x = lstm_out
        x = x.transpose(0, 1)
        ##### batch version
        # x = torch.squeeze(lstm_out, 1)
        # x: variable (seq_len, batch_size, hidden_size)
        # target_start: variable (batch_size)
        # _, start = torch.max(target_start.unsqueeze(0), dim=1)
        # max_start = utils.to_scalar(target_start[start])
        # _, end = torch.min(target_end.unsqueeze(0), dim=1)
        # min_end = utils.to_scalar(target_end[end])
        max_length = 0
        for index in range(batch_size):
            x_len = x[index].size(0)
            start = utils.to_scalar(target_start[index])
            end = utils.to_scalar(target_end[index])
            none_t = x_len - (end - start + 1)
            if none_t > max_length: max_length = none_t

        # left_save = []
        # mask_left_save = []
        # right_save = []
        # mask_right_save = []
        none_target = []
        mask_none_target = []
        target_save = []
        for idx in range(batch_size):
            mask_none_t = []
            none_t = None
            x_len_cur = x[idx].size(0)
            start_cur = utils.to_scalar(target_start[idx])
            end_cur = utils.to_scalar(target_end[idx])
            # left_len_cur = start_cur
            # left_len_max = max_start
            x_target = x[idx][start_cur:(end_cur + 1)]
            x_average_target = torch.mean(x_target, 0)
            target_save.append(x_average_target.unsqueeze(0))
            if start_cur != 0:
                left = x[idx][:start_cur]
                none_t = left
                mask_none_t.extend([1] * start_cur)
            if end_cur != (x_len_cur - 1):
                right = x[idx][(end_cur + 1):]
                if none_t is not None: none_t = torch.cat([none_t, right], 0)
                else: none_t = right
                mask_none_t.extend([1] * (x_len_cur - end_cur - 1))
            if len(mask_none_t) != max_length:
                add_t = Variable(
                    torch.zeros((max_length - len(mask_none_t)),
                                self.lstm_hiddens))
                if self.use_cuda: add_t = add_t.cuda()
                mask_none_t.extend([0] * (max_length - len(mask_none_t)))
                # print(add_t)
                none_t = torch.cat([none_t, add_t], 0)
            mask_none_target.append(mask_none_t)
            none_target.append(none_t.unsqueeze(0))
            # if start_cur != 0:
            #     x_cur_left = x[idx][:start_cur]
            #     left_len_sub = left_len_max - left_len_cur
            #     mask_cur_left = [1 for _ in range(left_len_cur)]
            # else:
            #     x_cur_left = x[idx][0].unsqueeze(0)
            #     left_len_sub = left_len_max - 1
            #     # mask_cur_left = [-1e+20]
            #     mask_cur_left = [0]
            # # x_cur_left: variable (start_cur, two_hidden_size)
            # # mask_cur_left = [1 for _ in range(start_cur)]
            # # mask_cur_left: list (start_cur)
            # if start_cur < max_start:
            #     add = Variable(torch.zeros(left_len_sub, self.lstm_hiddens))
            #     if self.use_cuda: add = add.cuda()
            #     x_cur_left = torch.cat([x_cur_left, add], 0)
            #     # x_cur_left: variable (max_start, two_hidden_size)
            #     left_save.append(x_cur_left.unsqueeze(0))
            #     # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)])
            #     mask_cur_left.extend([0 for _ in range(left_len_sub)])
            #     # mask_cur_left: list (max_start)
            #     mask_left_save.append(mask_cur_left)
            # else:
            #     left_save.append(x_cur_left.unsqueeze(0))
            #     mask_left_save.append(mask_cur_left)
            #
            # end_cur = utils.to_scalar(target_end[idx])
            # right_len_cur = x_len_cur - end_cur - 1
            # right_len_max = x_len_cur - min_end - 1
            # if (end_cur + 1) != x_len_cur:
            #     x_cur_right = x[idx][(end_cur + 1):]
            #     right_len_sub = right_len_max - right_len_cur
            #     mask_cur_right = [1 for _ in range(right_len_cur)]
            # else:
            #     x_cur_right = x[idx][end_cur].unsqueeze(0)
            #     right_len_sub = right_len_max - right_len_cur - 1
            #     # mask_cur_right = [-1e+20]
            #     mask_cur_right = [0]
            # # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size)
            # # mask_cur_right = [1 for _ in range(right_len_cur)]
            # # mask_cur_right: list (x_len_cur-end_cur-1==right_len)
            # if end_cur > min_end:
            #     add = Variable(torch.zeros(right_len_sub, self.lstm_hiddens))
            #     if self.use_cuda: add = add.cuda()
            #     x_cur_right = torch.cat([x_cur_right, add], 0)
            #     right_save.append(x_cur_right.unsqueeze(0))
            #     # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)])
            #     mask_cur_right.extend([0 for _ in range(right_len_sub)])
            #     mask_right_save.append(mask_cur_right)
            # else:
            #     right_save.append(x_cur_right.unsqueeze(0))
            #     mask_right_save.append(mask_cur_right)

        # mask_left_save = Variable(torch.ByteTensor(mask_left_save))
        # # mask_left_save: variable (batch_size, left_len_max)
        # mask_right_save = Variable(torch.ByteTensor(mask_right_save))
        # # mask_right_save: variable (batch_size, right_len_max)
        # left_save = torch.cat(left_save, 0)
        # right_save = torch.cat(right_save, 0)
        target_save = torch.cat(target_save, 0)
        # print(none_target)
        none_target = torch.cat(none_target, 0)
        mask_none_target = Variable(torch.ByteTensor(mask_none_target))
        # left_save: variable (batch_size, left_len_max, two_hidden_size)
        # right_save: variable (batch_size, right_len_max, two_hidden_size)
        # target_save: variable (batch_size, two_hidden_size)
        if self.use_cuda:
            # mask_right_save = mask_right_save.cuda()
            # mask_left_save = mask_left_save.cuda()
            # left_save = left_save.cuda()
            # right_save = right_save.cuda()
            target_save = target_save.cuda()
            mask_none_target = mask_none_target.cuda()
            none_target = none_target.cuda()

        # squence = torch.cat(none_target, 1)
        s = self.attention(none_target, target_save, mask_none_target)
        # s = self.attention(x, target_save, None)
        # s_l = self.attention_l(left_save, target_save, mask_left_save)
        # s_r = self.attention_r(right_save, target_save, mask_right_save)

        result = self.linear(s)  # result: variable (1, label_num)
        # result = self.linear_l(s_l)
        # result = torch.add(result, self.linear_l(s_l))
        # result = torch.add(result, self.linear_r(s_r))
        # result: variable (batch_size, label_num)
        # print(result)
        return result