Ejemplo n.º 1
0
def load_and_cache_examples(args, tokenizer, evaluate=False):
    dataset = concodeDataset(tokenizer,
                             args,
                             logger,
                             file_type='dev' if evaluate else 'train',
                             block_size=args.block_size)
    return dataset
Ejemplo n.º 2
0
def eval_bleu(args, model, tokenizer, file_type='test', num=2000):
    dataset = concodeDataset(tokenizer,
                             args,
                             logger,
                             file_type=file_type,
                             block_size=args.block_size,
                             mode='test')
    test_sampler = SequentialSampler(dataset)
    test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=1)
    model.to(args.device)
    model.zero_grad()
    model.eval()

    preds = []
    max_gen_len = 100
    for step, (batch, token_labels) in enumerate(test_dataloader):
        if step >= num:
            break
        inputs = batch.to(args.device)
        # with torch.no_grad():
        #     outputs = model.generate(inputs, max_length=args.block_size, num_beams=10, temperature=0.7, early_stopping=False, top_k=70, \
        #               bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
        #     # outputs = model.generate(inputs, max_length=args.block_size, do_sample=True, temperature=0.7, top_k=70, top_p=0.95, \
        #     #         bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id)
        #     # outputs = model.generate(inputs, max_length=args.block_size, num_beams=10, temperature=0.7, early_stopping=False, top_k=70)
        #     # outputs = model.generate(inputs, max_length=args.block_size, do_sample=True, temperature=0.7, top_k=70, top_p=0.95)
        #     generation = tokenizer.decode(outputs[0])[len(tokenizer.decode(inputs[0])):]
        #     preds.append(generation.rstrip("<pad>"))

        with torch.no_grad():
            beam_size = 10
            m = torch.nn.LogSoftmax(dim=-1)
            outputs = model(inputs)[1]
            p = []
            zero = torch.cuda.LongTensor(1).fill_(0)
            for i in range(inputs.shape[0]):
                # Compatible with transformers version 3.3.0 and 4.13.0
                past = [
                    torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)], dim=0)
                    if type(x) == tuple else x for x in outputs
                ]
                past_hidden = [
                    x[:, i:i + 1].expand(-1, beam_size, -1, -1, -1)
                    for x in past
                ]
                # context_mask=source_mask[i:i+1,:].expand(beam_size,-1)
                beam = Beam(beam_size, tokenizer.bos_token_id,
                            tokenizer.eos_token_id)
                input_ids = None
                for _ in range(max_gen_len):
                    if beam.done():
                        break
                    input_ids = beam.getCurrentState()
                    # context_mask=torch.cat((context_mask,input_ids*0+1),-1)
                    # mask=context_mask.unsqueeze(0).unsqueeze(-2).unsqueeze(-2).expand(self.config.n_layer, -1, -1, -1, -1)
                    transformer_outputs = model(input_ids, past=past_hidden)
                    out = m(transformer_outputs[0][:, -1, :]).data
                    # out = self.lsm(self.lm_head(transformer_outputs[0][:,-1,:])).data
                    beam.advance(out)
                    past = [
                        torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)],
                                  dim=0) if type(x) == tuple else x
                        for x in transformer_outputs[1]
                    ]
                    past_hidden = [
                        x.data.index_select(1, beam.getCurrentOrigin())
                        for x in past
                    ]
                hyp = beam.getHyp(beam.getFinal())
                pred = beam.buildTargetTokens(hyp)[:beam_size]

                pred = [
                    torch.cat([x.view(-1) for x in p] + [zero] *
                              (max_gen_len - len(p))).view(1, -1) for p in pred
                ]
                p.append(torch.cat(pred, 0).unsqueeze(0))
            p = torch.cat(p, 0)
            for pred in p:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[:t.index(0)]
                text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
                # print(text)
                preds.append(text)

        if step % args.logging_steps == 0:
            logger.info(f"{step} are done!")

    golds = []
    datafile = os.path.join(args.data_dir, f"{file_type}.json")
    datas = open(datafile).readlines()
    for x in datas[:num]:
        x = json.loads(x)
        golds.append(x["code"])

    assert len(preds) == len(golds)

    EM = []
    with open(os.path.join(args.output_dir, f"{file_type}.output"),
              'w') as f, open(
                  os.path.join(args.output_dir, f"{file_type}.gold"),
                  'w') as f1:
        for pred, gold in zip(preds, golds):
            f.write(pred + '\n')
            f1.write(gold + '\n')
            EM.append(pred.split() == gold.split())

    if file_type == "test":
        return 0, 0

    bleu_score = round(
        _bleu(os.path.join(args.output_dir, f"{file_type}.gold"),
              os.path.join(args.output_dir, f"{file_type}.output")), 2)
    EM = round(np.mean(EM) * 100, 2)
    return bleu_score, EM