Esempio n. 1
0
def evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test"):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    if subset == 'test':
        eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
    elif subset == 'train':
        eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
    logger.info("***** Running evaluation on {} dataset *****".format(subset))

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    args.per_gpu_eval_batch_size = 1
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)


    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    
    model_vae.eval()

    model_vae =  model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    mi = calc_mi(model_vae, eval_dataloader, args)
    au = calc_au(model_vae, eval_dataloader, delta=0.01, args=args)[0]
    ppl, elbo, nll, kl = calc_iwnll(model_vae, eval_dataloader, args, ns=100)

    result = {
        "perplexity": ppl, "elbo": elbo, "kl": kl, "nll": nll, "au": au, "mi": mi
    }

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))


    row = {
            'PartitionKey': 'MILU_Rule_Rule_Template',
            'RowKey': str(datetime.now()),
            'ExpName' : args.ExpName, 
            'test_perplexity': str( ppl ),
            'test_elbo': str( elbo ),
            'test_nll': str(nll),
            'test_au': str(au),
            'test_mi': str(mi)
        }
    # pdb.set_trace()
    ts.insert_entity(table_name, row)


    return result
Esempio n. 2
0
def test(model, test_data_batch, mode, args, verbose=True):
    global logging

    report_kl_loss = report_rec_loss = report_loss = 0
    report_num_words = report_num_sents = 0
    for i in np.random.permutation(len(test_data_batch)):
        batch_data = test_data_batch[i]
        batch_size, sent_len = batch_data.size()

        # not predict start symbol
        report_num_words += (sent_len - 1) * batch_size
        report_num_sents += batch_size
        #loss, loss_rc, loss_kl = model.loss(batch_data, args.beta, nsamples=args.nsamples)

        if args.iw_train_nsamples < 0:
            loss, loss_rc, loss_kl = model.loss(batch_data,
                                                args.beta,
                                                nsamples=args.nsamples)
        else:
            loss, loss_rc, loss_kl = model.loss_iw(
                batch_data, args.beta, nsamples=args.iw_train_nsamples, ns=ns)

        assert (not loss_rc.requires_grad)

        loss_rc = loss_rc.sum()
        loss_kl = loss_kl.sum()
        loss = loss.sum()

        report_rec_loss += loss_rc.item()
        report_kl_loss += loss_kl.item()
        report_loss += loss.item()

    mutual_info = calc_mi(model, test_data_batch)

    test_loss = report_loss / report_num_sents

    nll = (report_kl_loss + report_rec_loss) / report_num_sents
    kl = report_kl_loss / report_num_sents
    ppl = np.exp(nll * report_num_sents / report_num_words)
    if verbose:
        logging('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \
               (mode, test_loss, report_kl_loss / report_num_sents, mutual_info,
                report_rec_loss / report_num_sents, nll, ppl))
        #sys.stdout.flush()

    return test_loss, nll, kl, ppl, mutual_info