Beispiel #1
0
 def translate_beam_search(self, img):
     with torch.no_grad():
         memory = self.transformer(img)
         beam = Beam(beam_size=2,
                     min_length=0,
                     n_top=1,
                     ranker=None,
                     start_token_id=1,
                     end_token_id=2)
         for _ in range(128):
             tgt_inp = beam.get_current_state().transpose(0, 1).to(
                 self.device)  # TxN
             decoder_outputs = self.transformer.transformer.forward_decoder(
                 tgt_inp, memory)
             log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0),
                                    dim=-1)
             beam.advance(log_prob.cpu())
             if beam.done():
                 break
         scores, ks = beam.sort_finished(minimum=1)
         hypothesises = []
         for times, k in ks:
             hypothesis = beam.get_hypothesis(times, k)
             hypothesises.append(hypothesis)
         encode = [1] + [int(i) for i in hypothesises[0][:-1]]
         return self.vocab.decode(encode)
    def predict_one(self, source, num_candidates=5):
        source_preprocessed = self.preprocess(source)
        source_tensor = torch.tensor(source_preprocessed).unsqueeze(
            0)  # why unsqueeze?
        length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0)

        sources_mask = pad_masking(source_tensor, source_tensor.size(1))
        memory_mask = pad_masking(source_tensor, 1)
        memory = self.model.encoder(source_tensor, sources_mask)

        decoder_state = self.model.decoder.init_decoder_state()
        # print('decoder_state src', decoder_state.src.shape)
        # print('previous_input previous_input', decoder_state.previous_input)
        # print('previous_input previous_layer_inputs ', decoder_state.previous_layer_inputs)

        # Repeat beam_size times
        memory_beam = memory.detach().repeat(
            self.beam_size, 1, 1)  # (beam_size, seq_len, hidden_size)

        beam = Beam(beam_size=self.beam_size,
                    min_length=0,
                    n_top=num_candidates,
                    ranker=None)

        for _ in range(self.max_length):

            new_inputs = beam.get_current_state().unsqueeze(
                1)  # (beam_size, seq_len=1)
            decoder_outputs, decoder_state = self.model.decoder(
                new_inputs, memory_beam, memory_mask, state=decoder_state)
            # decoder_outputs: (beam_size, target_seq_len=1, vocabulary_size)
            # attentions['std']: (target_seq_len=1, beam_size, source_seq_len)

            attention = self.model.decoder.decoder_layers[
                -1].memory_attention_layer.sublayer.attention
            beam.advance(decoder_outputs.squeeze(1), attention)

            beam_current_origin = beam.get_current_origin()  # (beam_size, )
            decoder_state.beam_update(beam_current_origin)

            if beam.done():
                break

        scores, ks = beam.sort_finished(minimum=num_candidates)
        hypothesises, attentions = [], []
        for i, (times, k) in enumerate(ks[:num_candidates]):
            hypothesis, attention = beam.get_hypothesis(times, k)
            hypothesises.append(hypothesis)
            attentions.append(attention)

        self.attentions = attentions
        self.hypothesises = [[token.item() for token in h]
                             for h in hypothesises]
        hs = [self.postprocess(h) for h in self.hypothesises]
        return list(reversed(hs))
Beispiel #3
0
    def gen(self, event_ids, context_ids, prior, topk):
        preds, scores = [], []
        zero = torch.cuda.LongTensor(1).fill_(0)
        prob, topk_id = prior.topk(topk, -1)
        context_ids_all = context_ids
        for i in range(event_ids.shape[0]):
            beam = Beam(self.beam_size, self.sos_id, self.eos_id, prob[i])
            context_ids = context_ids_all[i:i + 1].repeat(topk, 1, 1)
            context_ids, _ = self.selecter(context_ids, topk_id[i])
            ############################################################################
            #concatenate evidence and event to obtain hidden states
            inputs_ids = torch.cat(
                (context_ids, event_ids[i:i + 1].repeat(topk, 1)), -1)
            transformer_outputs = self.decoder(inputs_ids)
            past_x = [
                x.repeat(1, self.beam_size, 1, 1, 1)
                for x in transformer_outputs[1]
            ]
            past_inputs = inputs_ids.repeat(self.beam_size, 1)
            ############################################################################

            #beam search
            input_ids = None
            for _ in range(self.max_length - 1):
                if beam.done():
                    break
                if input_ids is None:
                    input_ids = beam.getCurrentState()
                else:
                    input_ids = torch.cat((input_ids, beam.getCurrentState()),
                                          -1)

                target_ids = input_ids.unsqueeze(1).repeat(1, topk, 1).view(
                    -1, input_ids.shape[-1])
                transformer_outputs = self.decoder(target_ids, past=past_x)
                hidden_states = transformer_outputs[0]

                out = self.lsm(self.lm_head(hidden_states[:, -1, :])).data
                out = out.view(-1, topk, out.shape[-1])
                beam.advance(out)
                input_ids.data.copy_(
                    input_ids.data.index_select(0, beam.getCurrentOrigin()))
            hyp = beam.getHyp(beam.getFinal())
            pred = beam.buildTargetTokens(hyp)[:10]
            pred = [
                torch.cat([x.view(-1) for x in p] + [zero] *
                          (self.max_length - len(p))).view(1, -1) for p in pred
            ]
            preds.append(torch.cat(pred, 0).unsqueeze(0))
        preds = torch.cat(preds, 0)
        return preds
Beispiel #4
0
    def beam_decode(self, batch, max_len, oov_nums):

        bos_token = self.data_utils.bos
        beam_size = self.args.beam_size
        vocab_size = self.data_utils.vocab_size

        src = batch['src'].long()
        src_mask = batch['src_mask']
        src_extended = batch['src_extended'].long()
        memory = self.model.encode(src, src_mask)
        batch_size = src.size(0)

        beam = Beam(self.data_utils.pad, bos_token, self.data_utils.eos,
                    beam_size, batch_size, self.args.n_best, True, max_len)

        ys = torch.full((batch_size, 1), bos_token).type_as(src.data).cuda()
        log_prob = self.model.decode(
            memory, src_mask, Variable(ys),
            Variable(
                subsequent_mask(ys.size(1)).type_as(src.data).expand(
                    (ys.size(0), ys.size(1), ys.size(1)))), src_extended,
            oov_nums)

        # log_prob = [batch_size, 1, voc_size]
        top_prob, top_indices = torch.topk(input=log_prob, k=beam_size, dim=-1)
        # print(top_indices)
        top_prob = top_prob.view(-1, 1)
        top_indices = top_indices.view(-1, 1)
        beam.update_prob(top_prob.detach().cpu(), top_indices.detach().cpu())
        # [batch_size, 1, beam_size]
        ys = top_indices
        top_indices = None
        # print(ys.size())
        ####### repeat var #######
        src = torch.repeat_interleave(src, beam_size, dim=0)
        src_mask = torch.repeat_interleave(src_mask, beam_size, dim=0)
        #[batch_size, src_len, d_model] -> [batch_size*beam_size, src_len, d_model]
        memory = torch.repeat_interleave(memory, beam_size, dim=0)
        # print('max_len', max_len)
        for t in range(1, max_len):
            log_prob = self.model.decode(
                memory, src_mask, Variable(ys),
                Variable(
                    subsequent_mask(ys.size(1)).type_as(src.data).expand(
                        (ys.size(0), ys.size(1), ys.size(1)))), src)
            # print('log_prob', log_prob.size())
            log_prob = log_prob[:, -1].unsqueeze(1)
            # print(beam.seq)
            real_top = beam.advance(log_prob.detach().cpu())
            # print(real_top.size())
            # print(ys.size())
            # print(real_top.size())
            ys = torch.cat((ys, real_top.view(-1, 1).cuda()), dim=-1)
            # print(ys.size())

        # print(ys.size())
        # print(beam.top_prob)
        # print(len(beam.seq))

        return [beam.seq[0]]
Beispiel #5
0
def eval_bleu(args, model, tokenizer, file_type='test', num=99999999):
    dataset = CodeChangeDataset(tokenizer,
                                args,
                                logger,
                                file_type=file_type,
                                block_size=args.block_size,
                                mode='test')
    test_sampler = SequentialSampler(dataset)
    test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=1)
    model.to(args.device)
    model.zero_grad()
    model.eval()
    preds = []
    for step, (batch, token_labels) in enumerate(
            tqdm(test_dataloader, total=min(num, len(dataset)))):
        if step >= num:
            break
        inputs = batch.to(args.device)
        with torch.no_grad():
            beam_size = args.beam_size
            m = torch.nn.LogSoftmax(dim=-1)
            outputs = model(inputs)[1]
            p = []
            zero = torch.cuda.LongTensor(1).fill_(0)
            for i in range(inputs.shape[0]):
                past_hidden = []
                for x in outputs:
                    _p = x[:, i:i + 1]
                    _q = _p.expand(-1, beam_size, -1, -1, -1)
                    past_hidden.append(_q)
                # context_mask=source_mask[i:i+1,:].expand(beam_size,-1)
                beam = Beam(beam_size, tokenizer.bos_token_id,
                            tokenizer.eos_token_id)
                input_ids = None
                for _ in range(162):
                    if beam.done():
                        break
                    input_ids = beam.getCurrentState()
                    transformer_outputs = model(input_ids, past=past_hidden)
                    out = m(transformer_outputs[0][:, -1, :]).data
                    beam.advance(out)
                    past_hidden = [
                        x.data.index_select(1, beam.getCurrentOrigin())
                        for x in transformer_outputs[1]
                    ]
                hyp = beam.getHyp(beam.getFinal())
                pred = beam.buildTargetTokens(hyp)[:beam_size]

                pred = [
                    torch.cat([x.view(-1)
                               for x in p] + [zero] * (162 - len(p))).view(
                                   1, -1) for p in pred
                ]
                p.append(torch.cat(pred, 0).unsqueeze(0))
            p = torch.cat(p, 0)
            for pred in p:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[:t.index(0)]
                text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
                preds.append(text)
    golds = []
    datas = read_data(data_dir=args.data_dir, file_type=file_type)
    for (src, tgt) in datas[:num]:
        golds.append(tgt)

    assert len(preds) == len(golds), 'Pred %d\tGold %d' % (len(preds),
                                                           len(golds))

    EM = []
    with open(os.path.join(args.output_dir, f"{file_type}.output"),
              'w',
              encoding='utf-8') as f, open(os.path.join(
                  args.output_dir, f"{file_type}.gold"),
                                           'w',
                                           encoding='utf-8') as f1:
        for pred, gold in zip(preds, golds):
            f.write(pred + '\n')
            f1.write(gold + '\n')
            EM.append(pred.split() == gold.split())

    bleu_score = round(
        _bleu(os.path.join(args.output_dir, f"{file_type}.gold"),
              os.path.join(args.output_dir, f"{file_type}.output")), 2)
    EM = round(np.mean(EM) * 100, 2)
    return bleu_score, EM
Beispiel #6
0
def eval_bleu(args, model, tokenizer, file_type='test', num=20000):
    dataset = MethodDataset(tokenizer,
                            args,
                            file_type='test',
                            block_size=args.block_size,
                            mode='test')
    test_sampler = SequentialSampler(dataset)
    test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=1)
    model.to(args.device)
    model.zero_grad()
    model.eval()

    preds = []
    for step, (batch, token_labels) in enumerate(test_dataloader):
        if step >= num:
            break
        inputs = batch.to(args.device)
        max_gen_len = min(256, args.block_size - inputs.shape[1] - 1)
        try:
            with torch.no_grad():
                beam_size = 5
                m = torch.nn.LogSoftmax(dim=-1)
                outputs = model(inputs, return_dict=True).past_key_values
                p = []
                zero = torch.cuda.LongTensor(1).fill_(0)
                for i in range(inputs.shape[0]):
                    past_hidden = tuple(
                        tuple(xx[i:i + 1, :].expand(beam_size, -1, -1, -1)
                              for xx in x) for x in outputs)
                    # past_hidden = [x[:, i:i+1].expand(-1, beam_size, -1, -1, -1) for x in outputs]
                    beam = Beam(beam_size, tokenizer.bos_token_id,
                                [tokenizer.eos_token_id])
                    input_ids = None
                    for _ in range(max_gen_len):
                        if beam.done():
                            break
                        input_ids = beam.getCurrentState()
                        transformer_outputs = model(
                            input_ids,
                            past_key_values=past_hidden,
                            return_dict=True)
                        out = m(transformer_outputs.logits[:, -1, :]).data
                        beam.advance(out)
                        past_hidden = tuple(
                            tuple(
                                xx.data.index_select(
                                    0, beam.getCurrentOrigin()) for xx in x)
                            for x in transformer_outputs.past_key_values)
                        # past_hidden = [x.data.index_select(1, beam.getCurrentOrigin()) for x in transformer_outputs[1]]
                    hyp = beam.getHyp(beam.getFinal())
                    pred = beam.buildTargetTokens(hyp)[:beam_size]

                    pred = [
                        torch.cat([x.view(-1) for x in p] + [zero] *
                                  (max_gen_len - len(p))).view(1, -1)
                        for p in pred
                    ]
                    p.append(torch.cat(pred, 0).unsqueeze(0))
                p = torch.cat(p, 0)
                for pred in p:
                    t = pred[0].cpu().numpy()
                    t = list(t)
                    if 0 in t:
                        t = t[:t.index(0)]
                    text = tokenizer.decode(
                        t, clean_up_tokenization_spaces=False).rstrip("</s>")
                    # print(text)
                    preds.append(text)
        except Exception:
            preds.append("")

        if step % args.logging_steps == 0:
            logger.info(f"{step} are done!")

    golds = []
    datafile = os.path.join(args.data_dir, f"{file_type}.jsonl")
    datas = open(datafile).readlines()
    for x in datas[:num]:
        x = json.loads(x)
        golds.append(x["body"])

    # assert len(preds) == len(golds)

    def post_process(code):
        code = code.replace("<EOL>",
                            "\n").replace("<INDENT>",
                                          " ").replace("<DEDENT>", " ")
        code = code.replace("<NUM_LIT>",
                            "0").replace("<STR_LIT>",
                                         "").replace("<CHAR_LIT>", "")
        pattern = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S)
        lits = re.findall(pattern, code)
        for lit in lits:
            code = code.replace(f"<{lit[0]}_LIT:{lit[1]}>", lit[1])
        return " ".join(code.split())

    ES = []
    with open(os.path.join(args.output_dir, f"{file_type}.output"),
              'w') as f, open(
                  os.path.join(args.output_dir, f"{file_type}.gold"),
                  'w') as f1:
        for pred, gold in zip(preds, golds):
            pred = post_process(pred)
            gold = post_process(gold)
            f.write(pred + '\n')
            f1.write(gold + '\n')
            ES.append(fuzz.ratio(pred, gold))

    bleu_score = round(
        _bleu(os.path.join(args.output_dir, f"{file_type}.gold"),
              os.path.join(args.output_dir, f"{file_type}.output")), 2)
    ES = round(np.mean(ES), 2)
    print(bleu_score, ES)
Beispiel #7
0
def eval_line_completion(args, model, tokenizer, file_type='test'):
    """
    Evaluate line level code completion on exact match and edit similarity.

    It is recommanded to use single GPU because it could not be batched.
    """
    def DecodeIds(idxs):
        codes = ""
        for idx in idxs:
            to_add = tokenizer.convert_ids_to_tokens(idx)
            if tokenizer.convert_ids_to_tokens(idx)[0] == '\u0120':
                if not codes.endswith(" "):
                    codes += " " + to_add[1:]
                else:
                    codes += to_add[1:]
            elif (idx in [
                    tokenizer.bos_token_id, tokenizer.eos_token_id,
                    tokenizer.sep_token_id, tokenizer.pad_token_id
            ] or tokenizer.convert_ids_to_tokens(idx).startswith("<NUM_LIT")):
                codes += " " + to_add + " "
            else:
                codes += to_add
        return codes.strip(" ")

    dataset = lineDataset(tokenizer,
                          args,
                          logger,
                          file_type=file_type,
                          block_size=args.block_size - 100)
    test_sampler = SequentialSampler(dataset)
    test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=1)
    model.to(args.device)
    # model.zero_grad()
    model.eval()

    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(repackage_hidden(v) for v in h)

    if args.langs == "python":
        break_ids = [tokenizer.sep_token_id]
    else:
        break_ids = [
            tokenizer.convert_tokens_to_ids('Ġ;'),
            tokenizer.convert_tokens_to_ids('Ġ}'),
            tokenizer.convert_tokens_to_ids('Ġ{')
        ]
    preds = []
    gts = []
    edit_sim = 0.0
    em = 0.0
    for step, (inputs, gt) in enumerate(test_dataloader):
        inputs = inputs.to(args.device)
        with torch.no_grad():
            beam_size = 5
            m = torch.nn.LogSoftmax(dim=-1)
            outputs = model(inputs[:, :-1])[1]
            p = []
            zero = torch.cuda.LongTensor(1).fill_(0)
            for i in range(inputs.shape[0]):
                if args.model_type == "rnn":
                    past_hidden = tuple(
                        x[:, i:i + 1].expand(-1, beam_size, -1).contiguous()
                        for x in outputs)
                else:
                    past = [
                        torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)],
                                  dim=0) if type(x) == tuple else x
                        for x in outputs
                    ]
                    past_hidden = [
                        x[:, i:i + 1].expand(-1, beam_size, -1, -1, -1)
                        for x in past
                    ]
                beam = Beam(beam_size, inputs[i][-1].cpu().data, break_ids)
                input_ids = None
                for _ in range(100):
                    if beam.done():
                        break
                    input_ids = beam.getCurrentState()
                    if args.model_type == "rnn":
                        outputs = model(input_ids,
                                        hidden=repackage_hidden(past_hidden))
                    else:
                        outputs = model(input_ids, past_key_values=past_hidden)
                    out = m(outputs[0][:, -1, :]).data
                    beam.advance(out)
                    if args.model_type == "rnn":
                        past_hidden = tuple(
                            x.data.index_select(
                                1, beam.getCurrentOrigin()).contiguous()
                            for x in outputs[1])
                    else:
                        past = [
                            torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)],
                                      dim=0) if type(x) == tuple else x
                            for x in outputs[1]
                        ]
                        past_hidden = [
                            x.data.index_select(1, beam.getCurrentOrigin())
                            for x in past
                        ]
                hyp = beam.getHyp(beam.getFinal())
                pred = beam.buildTargetTokens(hyp)[:beam_size]

                pred = [
                    torch.cat([x.view(-1)
                               for x in p] + [zero] * (100 - len(p))).view(
                                   1, -1) for p in pred
                ]
                p.append(torch.cat(pred, 0).unsqueeze(0))
            p = torch.cat(p, 0)
            for pred in p:
                t = pred[0].cpu().numpy()
                t = t.tolist()
                if 0 in t:
                    t = t[:t.index(0)]
                if args.langs == "python":
                    text = DecodeIds(t).strip("<EOL>").strip()
                else:
                    text = DecodeIds(t).strip("{").strip()
                # print(text)
                # exit()
                preds.append(text)
                gts.append(gt[0])
                edit_sim += fuzz.ratio(text, gt[0])
                em += 1 if text == gt[0] else 0
        if step % args.logging_steps == 0:
            logger.info(f"{step} are done!")

    saved_file = os.path.join(args.output_dir, "predictions_line.txt")
    with open(saved_file, "w") as f:
        for pred_text in preds:
            f.write(pred_text + "\n")

    logger.info(f"Test {len(preds)} samples")
    logger.info(f"Edit sim: {edit_sim/len(preds)}, EM: {em/len(preds)}")
Beispiel #8
0
def eval_bleu(args, model, tokenizer, file_type='test', num=2000):
    dataset = concodeDataset(tokenizer,
                             args,
                             logger,
                             file_type=file_type,
                             block_size=args.block_size,
                             mode='test')
    test_sampler = SequentialSampler(dataset)
    test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=1)
    model.to(args.device)
    model.zero_grad()
    model.eval()

    preds = []
    max_gen_len = 100
    for step, (batch, token_labels) in enumerate(test_dataloader):
        if step >= num:
            break
        inputs = batch.to(args.device)
        # with torch.no_grad():
        #     outputs = model.generate(inputs, max_length=args.block_size, num_beams=10, temperature=0.7, early_stopping=False, top_k=70, \
        #               bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
        #     # outputs = model.generate(inputs, max_length=args.block_size, do_sample=True, temperature=0.7, top_k=70, top_p=0.95, \
        #     #         bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id)
        #     # outputs = model.generate(inputs, max_length=args.block_size, num_beams=10, temperature=0.7, early_stopping=False, top_k=70)
        #     # outputs = model.generate(inputs, max_length=args.block_size, do_sample=True, temperature=0.7, top_k=70, top_p=0.95)
        #     generation = tokenizer.decode(outputs[0])[len(tokenizer.decode(inputs[0])):]
        #     preds.append(generation.rstrip("<pad>"))

        with torch.no_grad():
            beam_size = 10
            m = torch.nn.LogSoftmax(dim=-1)
            outputs = model(inputs)[1]
            p = []
            zero = torch.cuda.LongTensor(1).fill_(0)
            for i in range(inputs.shape[0]):
                # Compatible with transformers version 3.3.0 and 4.13.0
                past = [
                    torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)], dim=0)
                    if type(x) == tuple else x for x in outputs
                ]
                past_hidden = [
                    x[:, i:i + 1].expand(-1, beam_size, -1, -1, -1)
                    for x in past
                ]
                # context_mask=source_mask[i:i+1,:].expand(beam_size,-1)
                beam = Beam(beam_size, tokenizer.bos_token_id,
                            tokenizer.eos_token_id)
                input_ids = None
                for _ in range(max_gen_len):
                    if beam.done():
                        break
                    input_ids = beam.getCurrentState()
                    # context_mask=torch.cat((context_mask,input_ids*0+1),-1)
                    # mask=context_mask.unsqueeze(0).unsqueeze(-2).unsqueeze(-2).expand(self.config.n_layer, -1, -1, -1, -1)
                    transformer_outputs = model(input_ids, past=past_hidden)
                    out = m(transformer_outputs[0][:, -1, :]).data
                    # out = self.lsm(self.lm_head(transformer_outputs[0][:,-1,:])).data
                    beam.advance(out)
                    past = [
                        torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)],
                                  dim=0) if type(x) == tuple else x
                        for x in transformer_outputs[1]
                    ]
                    past_hidden = [
                        x.data.index_select(1, beam.getCurrentOrigin())
                        for x in past
                    ]
                hyp = beam.getHyp(beam.getFinal())
                pred = beam.buildTargetTokens(hyp)[:beam_size]

                pred = [
                    torch.cat([x.view(-1) for x in p] + [zero] *
                              (max_gen_len - len(p))).view(1, -1) for p in pred
                ]
                p.append(torch.cat(pred, 0).unsqueeze(0))
            p = torch.cat(p, 0)
            for pred in p:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[:t.index(0)]
                text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
                # print(text)
                preds.append(text)

        if step % args.logging_steps == 0:
            logger.info(f"{step} are done!")

    golds = []
    datafile = os.path.join(args.data_dir, f"{file_type}.json")
    datas = open(datafile).readlines()
    for x in datas[:num]:
        x = json.loads(x)
        golds.append(x["code"])

    assert len(preds) == len(golds)

    EM = []
    with open(os.path.join(args.output_dir, f"{file_type}.output"),
              'w') as f, open(
                  os.path.join(args.output_dir, f"{file_type}.gold"),
                  'w') as f1:
        for pred, gold in zip(preds, golds):
            f.write(pred + '\n')
            f1.write(gold + '\n')
            EM.append(pred.split() == gold.split())

    if file_type == "test":
        return 0, 0

    bleu_score = round(
        _bleu(os.path.join(args.output_dir, f"{file_type}.gold"),
              os.path.join(args.output_dir, f"{file_type}.output")), 2)
    EM = round(np.mean(EM) * 100, 2)
    return bleu_score, EM