def criterion(logits, targets, sent_num, decoder):
     return sequence_loss(logits,
                          targets,
                          sent_num,
                          decoder,
                          ce,
                          pad_idx=-1)
 def criterion(logits, cls_pro, targets, word2id):
     return sequence_loss(logits,
                          cls_pro,
                          targets,
                          word2id,
                          True,
                          nll,
                          pad_idx=PAD)
 def criterion(logit1_sent, logit_en, target_sent, target_en):
     sent_loss = sequence_loss(logit1_sent, target_sent, ce, pad_idx=-1)
     #entity_loss = F.binary_cross_entropy_with_logits(logit_en, target_en)
     print('logit_en:', logit_en)
     print('target_en:', target_en)
     entity_loss = binary_sequence_loss(logit_en, target_en, bce, pad_idx=-1)
     print('entity loss: {:.4f}'.format(entity_loss.mean().item()), end=' ')
     loss = sent_loss.mean() + entity_loss.mean()
     del entity_loss, sent_loss
     return loss
Exemple #4
0
 def criterion(logits1, logits2, targets1, targets2):
     aux_loss = None
     for logit in logits2:
         if aux_loss is None:
             aux_loss = sequence_loss(logit,
                                      targets2,
                                      bce,
                                      pad_idx=-1,
                                      if_aux=True,
                                      fp16=False).mean()
         else:
             aux_loss += sequence_loss(logit,
                                       targets2,
                                       bce,
                                       pad_idx=-1,
                                       if_aux=True,
                                       fp16=False).mean()
     return (sequence_loss(logits1, targets1, nll,
                           pad_idx=PAD).mean(), aux_loss)
 def criterion(logits, targets):
     return sequence_loss(logits, targets, ce, pad_idx=-1)
Exemple #6
0
 def criterion(logits, targets):
     return sequence_loss(logits, targets, nll, pad_idx=PAD)
 def criterion(logits, targets):
     return sequence_loss(logits, targets, ce, pad_idx=-1)
Exemple #8
0
def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting wihch GPU to use
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_list
    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begain... | %s " % time_str)

    test_seen_dataset = KGDataset(args.test_seen_file, max_knowledge=999)
    test_unseen_dataset = KGDataset(args.test_unseen_file, max_knowledge=999)

    test_seen_loader = get_batch_loader(test_seen_dataset,
                                        collate_fn=collate_fn,
                                        batch_size=args.eval_batch_size,
                                        is_test=True)
    test_unseen_loader = get_batch_loader(test_unseen_dataset,
                                          collate_fn=collate_fn,
                                          batch_size=args.eval_batch_size,
                                          is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    # Batcher
    dis_batcher = DisBatcher(args.bert_truncate, args.bert_config, args.cuda)
    gen_batcher = GenBatcher(args.knowledge_truncate, args.text_truncate,
                             args.gpt2_truncate, args.gpt2_config, args.cuda)

    # Load model
    dis_model = load_dis_net(args.emb_dim, args.lstm_hidden, args.lstm_layer,
                             args.bert_config, args.dis_pretrain_file,
                             args.load_dis, args.cuda)
    gen_model = load_gen_net(gen_batcher.tokenizer, args.segment,
                             args.gpt2_config, args.gen_pretrain_file,
                             args.load_gen, args.cuda)

    ce = lambda logit, target: F.cross_entropy(logit, target, reduce=False)
    gen_criterion = lambda logits, targets: sequence_loss(
        logits, targets, ce, pad_idx=-1)

    def dev_step(split, global_step):

        if split == 'test_seen':
            test_loader = test_seen_loader
        elif split == 'test_unseen':
            test_loader = test_unseen_loader
        else:
            raise ValueError

        dis_model.eval()
        gen_model.eval()

        n_token, test_loss = 0, 0.0  # ppl
        test_hyp, test_ref = [], []
        count = 0

        with torch.no_grad():
            for knowledges, histories, users, responses, knowledge_lens in test_loader:
                knowledges = [know.split('\n\n') for know in knowledges]
                histories = [his.split('\n\n') for his in histories]

                dis_args = dis_batcher(knowledges, histories, knowledge_lens,
                                       args.n_sent)
                dis_out = dis_model(*dis_args)
                dis_knowledges = [[knowledges[bi][dis_out[0][bi].item()]]
                                  for bi in range(len(knowledges))]

                gen_args = gen_batcher(dis_knowledges, histories, users,
                                       responses, args.segment, True)
                loss = gen_criterion(
                    gen_model(gen_args[0], token_type_ids=gen_args[1])[0],
                    gen_args[2])
                n_token += loss.size(0)
                test_loss += loss.sum().item()

                for bi in range(len(dis_knowledges)):
                    dec_in = gen_batcher(dis_knowledges[bi:bi + 1],
                                         histories[bi:bi + 1],
                                         users[bi:bi + 1],
                                         segment=args.segment,
                                         training=False)
                    dec_out = gen_model.batch_decode(
                        dec_in, args.max_length, args.min_length,
                        args.early_stopping, args.beam_size,
                        args.repetition_penalty, gen_batcher.eos_id,
                        args.length_penalty, args.no_repeat_ngram_size)
                    dec_out = dec_out[0].tolist()[dec_in.size(1):]
                    _hyp = gen_batcher.tokenizer.decode(
                        dec_out,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=False)
                    _ref = responses[bi]
                    test_hyp.append(_hyp)
                    test_ref.append(_ref)

                    count += 1
                    if count % 1000 == 0:
                        print(count)

        with open(
                os.path.join(
                    out_dir,
                    '{}-decoded-iter-{}.txt'.format(split, global_step)),
                'w') as f:
            for _hyp, _ref in zip(test_hyp, test_ref):
                f.writelines('{} ||| {}\n'.format(_hyp, _ref))

        MeanLoss = test_loss / n_token
        b1, b2, b3, b4 = bleu_metric(test_hyp, test_ref)
        d1, d2 = distinct_metric(test_hyp)
        f1 = f1_metric(test_hyp, test_ref)

        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print('hypothesis: ', len(test_hyp))
        print("Step: %d \t| ppl: %.3f \t|  %s" %
              (global_step, math.exp(MeanLoss), time_str))
        print("BLEU-1/2/3/4: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(
            b1, b2, b3, b4))
        print("Distinct-1/2: {:.4f}/{:.4f}".format(d1, d2))
        print("F1: {:.4f}".format(f1))
        print("**********************************")

        return {
            'f1': f1,
            'loss': MeanLoss,
            'bleu1': b1,
            'bleu2': b2,
            'bleu3': b3,
            'bleu4': b4,
            'distinct1': d1,
            'distinct2': d2
        }

    dev_step("test_seen", 0)  # test_random_split
    dev_step("test_unseen", 0)  # test_topic_split
 def criterion(logits, targets):
     return sequence_loss(logits, targets, nll, pad_idx=PAD)