示例#1
0
def train(model, dataloader, optimizer, dcmn_scheduler, loss_fun, seq2seq,
          seq_optimizer, seq_scheduler, seq_loss_fun, epoch, config):
    model.train()
    seq2seq.train()
    tr_dcmn_loss = 0
    tr_seq_loss = 0
    nb_steps = 0
    nb_dcmn_steps = 0
    for step, (seq_batches, dcmn_batches) in enumerate(
            tqdm(dataloader, desc="Iteration", ncols=200)):
        seq_srcs, seq_tars, k_cs = [[_[__] for _ in seq_batches]
                                    for __ in range(3)]

        outs = []
        if len(dcmn_batches) > 0:
            for p in range(0, len(dcmn_batches), config.batch_size):
                dcmn_batches_smaller = dcmn_batches[p:p + config.batch_size]
                input_ids, input_mask, segment_ids, doc_len, ques_len, option_len, labels = [
                    torch.LongTensor([_[__] for _ in dcmn_batches_smaller
                                      ]).to(config.dcmn_device)
                    for __ in range(7)
                ]
                if epoch >= config.num_dcmn_epochs:
                    model.eval()
                    with torch.no_grad():
                        outputs = model(input_ids, segment_ids, input_mask,
                                        doc_len, ques_len, option_len)
                        loss = loss_fun(outputs, labels)
                        tr_dcmn_loss += loss.item()
                else:
                    model.train()
                    outputs = model(input_ids, segment_ids, input_mask,
                                    doc_len, ques_len, option_len)
                    loss = loss_fun(outputs, labels)
                    tr_dcmn_loss += loss.item()
                    loss.backward()
                    optimizer.step()
                    dcmn_scheduler.step()
                    optimizer.zero_grad()
                nb_dcmn_steps += 1
                outs_smaller = np.argmax(outputs.detach().cpu().numpy(),
                                         axis=1)
                outs.extend(outs_smaller)

        if epoch < config.num_dcmn_epochs:
            continue
        # outs = [0 for _ in range(len(seq_batches))]

        seq_srcs = remove_unk(seq_srcs, outs, k_cs)
        src_ids, src_masks = seq_tokenize(seq_srcs, config)
        tar_ids, tar_masks = seq_tokenize(seq_srcs, config)
        decoder_outputs, decoder_hidden, ret_dict = seq2seq(
            [src_ids, src_masks], tar_ids, 0.5)
        target = tar_ids[:, 1:].reshape(-1)
        mask = tar_masks[:, 1:].reshape(-1).float()
        logit = torch.stack(decoder_outputs, 1).view(target.shape[0], -1)
        seq_loss = (seq_loss_fun(input=logit, target=target) *
                    mask).sum() / mask.sum()
        tr_seq_loss += seq_loss.item()
        seq_loss.backward()
        seq_optimizer.step()
        seq_scheduler.step()
        seq_optimizer.zero_grad()
        nb_steps += 1

        if step % 100 == 0:
            print('train loss:{},{}'.format(loss.item(), seq_loss.item()))
            # print('train loss:{}'.format(loss.item()))

    if nb_steps == 0:
        nb_steps = 1
    return tr_dcmn_loss / nb_dcmn_steps, tr_seq_loss / nb_steps, dcmn_scheduler.get_last_lr(
    ), seq_scheduler.get_last_lr()
示例#2
0
文件: test.py 项目: cscyuge/seq2seq
def main():
    config = DCMN_Config()
    eval_seq_dataset, eval_dcmn_dataset = build_dataset_eval(config)
    eval_dataloader = build_iterator(eval_seq_dataset, eval_dcmn_dataset,
                                     config)
    seq2seq, seq_optimizer, seq_scheduler, seq_loss_fun = build_seq2seq(
        config, 768, config.no_cuda)
    dcmn = BertForMultipleChoiceWithMatch.from_pretrained(
        config.bert_model, num_choices=config.num_choices)
    dcmn.to(config.dcmn_device)

    save_file_best = torch.load('./backup/bert/best_save.data',
                                map_location=torch.device('cuda:2'))
    # dcmn.load_state_dict(save_file_best['dcmn_para'])
    seq2seq.load_state_dict(save_file_best['para'])

    save_file_best = torch.load('./backup/dcmn/best_save.data',
                                map_location=torch.device('cuda:2'))
    dcmn.load_state_dict(save_file_best['dcmn_para'])

    dcmn.eval()
    seq2seq.eval()

    # src = '[CLS] [MASK] computed tomography [MASK] showed [MASK] patient [MASK] was [MASK] fine [MASK] . [SEP]'
    # src_ids, src_masks = seq_tokenize([src], config)
    # decoder_outputs, decoder_hidden, ret_dict = seq2seq([src_ids, src_masks], src_ids, 0.0, False)
    # symbols = ret_dict['sequence']
    # symbols = torch.cat(symbols, 1).data.cpu().numpy()
    # results = decode_sentence(symbols, config)
    # print(results)

    results = []
    seq_srcs_all = []
    for step, (seq_batches, dcmn_batches) in enumerate(
            tqdm(eval_dataloader, desc="Evaluating")):
        seq_srcs, seq_tars, cudics, k_cs = [[_[__] for _ in seq_batches]
                                            for __ in range(4)]
        outs = []

        if len(dcmn_batches) > 0:
            for p in range(0, len(dcmn_batches), config.batch_size):
                dcmn_batches_smaller = dcmn_batches[p:p + config.batch_size]
                input_ids, input_mask, segment_ids, doc_len, ques_len, option_len, labels = [
                    torch.LongTensor([_[__] for _ in dcmn_batches_smaller
                                      ]).to(config.dcmn_device)
                    for __ in range(7)
                ]

                with torch.no_grad():
                    logits = dcmn(input_ids, segment_ids, input_mask, doc_len,
                                  ques_len, option_len)
                    outs_smaller = np.argmax(logits.detach().cpu().numpy(),
                                             axis=1)
                    outs.extend(outs_smaller)

        seq_srcs = remove_unk(seq_srcs, outs, k_cs)
        seq_srcs_all.extend(seq_srcs)
        src_ids, src_masks = seq_tokenize(seq_srcs, config)
        decoder_outputs, decoder_hidden, ret_dict = seq2seq(
            [src_ids, src_masks], src_ids, 0.0, False)

        symbols = ret_dict['sequence']
        symbols = torch.cat(symbols, 1).data.cpu().numpy()
        results.extend(decode_sentence(symbols, config))
    with open('./outs/outs-new.pkl', 'wb') as f:
        pickle.dump(results, f)

    sentences = []
    for words in results:
        words = words.replace('[MASK] ', '')
        words = words.replace(' - ', '-').replace(' . ',
                                                  '.').replace(' / ', '/')
        sentences.append(words.strip())

    with open('./result/tmp.out.txt', 'w', encoding='utf-8') as f:
        f.writelines([x.lower() + '\n' for x in sentences])
    bleu, hit, com, ascore = get_score()
    print('bleu:{}, hit:{}, com:{}, ascore:{}'.format(bleu, hit, com, ascore))
示例#3
0
def valid(dcmn, dataloader, loss_fun, seq2seq, epoch, config, is_val=True):
    dcmn.eval()
    seq2seq.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    results = []
    seq_srcs_all = []
    for step, (seq_batches, dcmn_batches) in enumerate(
            tqdm(dataloader, desc="Evaluating", ncols=200)):
        seq_srcs, seq_tars, cudics, k_cs = [[_[__] for _ in seq_batches]
                                            for __ in range(4)]
        outs = []

        if len(dcmn_batches) > 0:
            for p in range(0, len(dcmn_batches), config.eval_batch_size):
                dcmn_batches_smaller = dcmn_batches[p:p +
                                                    config.eval_batch_size]
                input_ids, input_mask, segment_ids, doc_len, ques_len, option_len, labels = [
                    torch.LongTensor([_[__] for _ in dcmn_batches_smaller
                                      ]).to(config.dcmn_device)
                    for __ in range(7)
                ]

                with torch.no_grad():
                    logits = dcmn(input_ids, segment_ids, input_mask, doc_len,
                                  ques_len, option_len)
                    tmp_eval_loss = loss_fun(logits, labels)
                    labels = labels.to('cpu').numpy()
                    tmp_eval_accuracy = accuracy(logits.detach().cpu().numpy(),
                                                 labels)
                    outs_smaller = np.argmax(logits.detach().cpu().numpy(),
                                             axis=1)
                    outs.extend(outs_smaller)
                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy
                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

        seq_srcs = remove_unk(seq_srcs, outs, k_cs)
        seq_srcs_all.extend(seq_srcs)
        src_ids, src_masks = seq_tokenize(seq_srcs, config)
        decoder_outputs, decoder_hidden, ret_dict = seq2seq(
            [src_ids, src_masks], src_ids, 0.0, False)

        symbols = ret_dict['sequence']
        symbols = torch.cat(symbols, 1).data.cpu().numpy()
        results.extend(decode_sentence(symbols, config))

    tmp = []
    for u in results:
        u = u.replace('[MASK] ', '')
        u = u.replace('[MASK]', '')
        tmp.append(u)
    sentences = tmp

    with open('./result/tmp.out.txt', 'w', encoding='utf-8') as f:
        f.writelines([x.lower() + '\n' for x in sentences])
    bleu, hit, com, ascore = get_score(config, is_val=is_val)

    if nb_eval_steps == 0:
        nb_eval_steps = 1
    if nb_eval_examples == 0:
        nb_eval_examples = 1
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    return eval_loss, eval_accuracy, sentences, bleu, hit, com, ascore
示例#4
0
文件: train.py 项目: cscyuge/seq2seq
def valid(dcmn, dataloader, loss_fun, seq2seq, epoch, config):
    dcmn.eval()
    seq2seq.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    results = []
    seq_srcs_all = []
    for step, (seq_batches,
               dcmn_batches) in enumerate(tqdm(dataloader, desc="Evaluating")):
        seq_srcs, seq_tars, cudics, k_cs = [[_[__] for _ in seq_batches]
                                            for __ in range(4)]
        outs = []

        if len(dcmn_batches) > 0:
            for p in range(0, len(dcmn_batches), config.batch_size):
                dcmn_batches_smaller = dcmn_batches[p:p + config.batch_size]
                input_ids, input_mask, segment_ids, doc_len, ques_len, option_len, labels = [
                    torch.LongTensor([_[__] for _ in dcmn_batches_smaller
                                      ]).to(config.dcmn_device)
                    for __ in range(7)
                ]

                with torch.no_grad():
                    logits = dcmn(input_ids, segment_ids, input_mask, doc_len,
                                  ques_len, option_len)
                    tmp_eval_loss = loss_fun(logits, labels)
                    labels = labels.to('cpu').numpy()
                    tmp_eval_accuracy = accuracy(logits.detach().cpu().numpy(),
                                                 labels)
                    outs_smaller = np.argmax(logits.detach().cpu().numpy(),
                                             axis=1)
                    outs.extend(outs_smaller)
                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy
                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

        seq_srcs = remove_unk(seq_srcs, outs, k_cs)
        seq_srcs_all.extend(seq_srcs)
        src_ids, src_masks = seq_tokenize(seq_srcs, config)
        decoder_outputs, decoder_hidden, ret_dict = seq2seq(
            [src_ids, src_masks], None, 0.0)

        symbols = ret_dict['sequence']
        symbols = torch.cat(symbols, 1).data.cpu().numpy()
        results.extend(decode_sentence(symbols, config))

    with open('./outs/outs{}.pkl'.format(epoch), 'wb') as f:
        pickle.dump(results, f)

    sentences = []
    for src, words in zip(seq_srcs_all, results):
        src = src.split('[MASK]')
        words = words.split('[MASK]')
        sts = ''
        for i, u in enumerate(src):
            if i % 2 == 1 and i < len(words):
                sts += words[i].strip() + ' '
            else:
                sts += u.strip() + ' '
        sts = sts.split('[CLS]')[1]
        sts = sts.split('[SEP]')[0]
        sentences.append(sts.strip())

    with open('./result/tmp.out.txt', 'w', encoding='utf-8') as f:
        f.writelines([x.lower() + '\n' for x in sentences])
    bleu, hit, com, ascore = get_score()
    # bleu, hit, com, ascore = 0,0,0,0
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    return eval_loss, eval_accuracy, results, bleu, hit, com, ascore