def eval(self, saving_path: str = None, mt_dev_iter: data_utils.DataLoader = None):
        model = (
            self.model.module if hasattr(self.model, "module") else self.model
        )
        with open(saving_path, "w") as w:
            for i, batch in enumerate(mt_dev_iter):
                try:
                    with torch.no_grad():
                        src_inputs = batch["src_texts"].squeeze(0)
                        src_mask = batch["src_pad_mask"].squeeze(0)
                        tgt_inputs = batch["dst_texts"].squeeze(0)
                        tgt_mask = batch["dst_pad_mask"].squeeze(0)
                        src_langs = batch["src_langs"].squeeze(0)
                        dst_langs = batch["dst_langs"].squeeze(0)
                        if src_inputs.size(0) < self.num_gpu:
                            continue
                        sims = self.model(src_inputs=src_inputs, tgt_inputs=tgt_inputs,
                                          src_mask=src_mask, tgt_mask=tgt_mask, src_langs=src_langs,
                                          tgt_langs=dst_langs, normalize=False)
                        srcs = get_outputs_until_eos(model.text_processor.sep_token_id(), src_inputs,
                                                     remove_first_token=True)
                        targets = get_outputs_until_eos(model.text_processor.sep_token_id(), tgt_inputs,
                                                        remove_first_token=True)
                        src_txts = list(map(lambda src: model.text_processor.tokenizer.decode(src.numpy()), srcs))
                        target_txts = list(
                            map(lambda tgt: model.text_processor.tokenizer.decode(tgt.numpy()), targets))
                        for s in range(len(sims)):
                            w.write(src_txts[s] + "\t" + target_txts[s] + "\t" + str(float(sims[s])) + "\n")
                        print(i, "/", len(mt_dev_iter), end="\r")


                except RuntimeError as err:
                    print(repr(err))
                    torch.cuda.empty_cache()
            print("\n")
Beispiel #2
0
def translate_batch(batch, generator, text_processor, verbose=False):
    src_inputs = batch["src_texts"].squeeze(0)
    src_mask = batch["src_pad_mask"].squeeze(0)
    tgt_inputs = batch["dst_texts"].squeeze(0)
    src_langs = batch["src_langs"].squeeze(0)
    dst_langs = batch["dst_langs"].squeeze(0)
    src_pad_idx = batch["pad_idx"].squeeze(0)
    src_text = None
    if verbose:
        src_ids = get_outputs_until_eos(text_processor.sep_token_id(),
                                        src_inputs,
                                        remove_first_token=True)
        src_text = list(
            map(lambda src: text_processor.tokenizer.decode(src.numpy()),
                src_ids))

    outputs = generator(src_inputs=src_inputs,
                        src_sizes=src_pad_idx,
                        first_tokens=tgt_inputs[:, 0],
                        src_mask=src_mask,
                        src_langs=src_langs,
                        tgt_langs=dst_langs,
                        pad_idx=text_processor.pad_token_id())
    if torch.cuda.device_count() > 1:
        new_outputs = []
        for output in outputs:
            new_outputs += output
        outputs = new_outputs
    mt_output = list(
        map(lambda x: text_processor.tokenizer.decode(x[1:].numpy()), outputs))
    return mt_output, src_text
    def get_mt_dev_data(mt_model,
                        options,
                        pin_memory,
                        text_processor,
                        trainer,
                        lex_dict=None):
        mt_dev_loader = []
        dev_paths = options.mt_dev_path.split(",")
        trainer.reference = []
        for dev_path in dev_paths:
            mt_dev_data = dataset.MTDataset(
                batch_pickle_dir=dev_path,
                max_batch_capacity=options.total_capacity,
                keep_pad_idx=True,
                max_batch=int(options.batch / (options.beam_width * 2)),
                pad_idx=mt_model.text_processor.pad_token_id(),
                lex_dict=lex_dict)
            dl = data_utils.DataLoader(mt_dev_data,
                                       batch_size=1,
                                       shuffle=False,
                                       pin_memory=pin_memory)
            mt_dev_loader.append(dl)

            print(options.local_rank, "creating reference")

            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)

            for batch in dl:
                tgt_inputs = batch["dst_texts"].squeeze()
                refs = get_outputs_until_eos(text_processor.sep_token_id(),
                                             tgt_inputs,
                                             remove_first_token=True)
                ref = [
                    generator.seq2seq_model.text_processor.tokenizer.decode(
                        ref.numpy()) for ref in refs
                ]
                trainer.reference += ref
        return mt_dev_loader
    def eval_bleu(self, dev_data_iter, saving_path, save_opt: bool = False):
        mt_output = []
        src_text = []
        model = (self.model.module
                 if hasattr(self.model, "module") else self.model)
        model.eval()

        with torch.no_grad():
            for iter in dev_data_iter:
                for batch in iter:
                    src_inputs = batch["src_texts"].squeeze(0)
                    src_mask = batch["src_pad_mask"].squeeze(0)
                    tgt_inputs = batch["dst_texts"].squeeze(0)
                    src_langs = batch["src_langs"].squeeze(0)
                    dst_langs = batch["dst_langs"].squeeze(0)
                    src_pad_idx = batch["pad_idx"].squeeze(0)
                    proposal = batch["proposal"].squeeze(
                        0) if batch["proposal"] is not None else None

                    src_ids = get_outputs_until_eos(
                        model.text_processor.sep_token_id(),
                        src_inputs,
                        remove_first_token=True)
                    src_text += list(
                        map(
                            lambda src: model.text_processor.tokenizer.decode(
                                src.numpy()), src_ids))

                    outputs = self.generator(
                        src_inputs=src_inputs,
                        src_sizes=src_pad_idx,
                        first_tokens=tgt_inputs[:, 0],
                        src_mask=src_mask,
                        src_langs=src_langs,
                        tgt_langs=dst_langs,
                        pad_idx=model.text_processor.pad_token_id(),
                        proposals=proposal)
                    if self.num_gpu > 1 and self.rank < 0:
                        new_outputs = []
                        for output in outputs:
                            new_outputs += output
                        outputs = new_outputs

                    mt_output += list(
                        map(
                            lambda x: model.text_processor.tokenizer.decode(x[
                                1:].numpy()), outputs))

            model.train()
        bleu = sacrebleu.corpus_bleu(mt_output,
                                     [self.reference[:len(mt_output)]],
                                     lowercase=True,
                                     tokenize="intl")

        with open(os.path.join(saving_path, "bleu.output"), "w") as writer:
            writer.write("\n".join([
                src + "\n" + ref + "\n" + o + "\n\n***************\n"
                for src, ref, o in zip(src_text, mt_output,
                                       self.reference[:len(mt_output)])
            ]))

        if bleu.score > self.best_bleu:
            self.best_bleu = bleu.score
            print("Saving best BLEU", self.best_bleu)
            with open(os.path.join(saving_path, "bleu.best.output"),
                      "w") as writer:
                writer.write("\n".join([
                    src + "\n" + ref + "\n" + o + "\n\n***************\n"
                    for src, ref, o in zip(src_text, mt_output,
                                           self.reference[:len(mt_output)])
                ]))
            if self.rank < 0:
                model.cpu().save(saving_path)
                model = model.to(self.device)
            elif self.rank == 0:
                model.save(saving_path)

            if save_opt:
                with open(os.path.join(saving_path, "optim"), "wb") as fp:
                    pickle.dump(self.optimizer, fp)

        return bleu.score
Beispiel #5
0
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0

        if options.pretrained_path is not None:
            caption_model = Seq2Seq.load(ImageCaptioning,
                                         options.pretrained_path,
                                         tok_dir=options.tokenizer_path)
        else:
            caption_model = ImageCaptioning(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim,
                use_obj=not options.no_obj)

        if options.lm_path is not None:  # In our case, this is an MT model.
            mt_pret_model = Seq2Seq.load(ImageMassSeq2Seq,
                                         options.lm_path,
                                         tok_dir=options.tokenizer_path)
            assert len(caption_model.encoder.encoder.layer) == len(
                mt_pret_model.encoder.encoder.layer)
            assert len(caption_model.decoder.decoder.layer) == len(
                mt_pret_model.decoder.decoder.layer)
            caption_model.encoder = mt_pret_model.encoder
            caption_model.decoder = mt_pret_model.decoder
            caption_model.output_layer = mt_pret_model.output_layer

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(caption_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageCaptionTrainer(
            model=caption_model,
            mask_prob=options.mask_prob,
            optimizer=optimizer,
            clip=options.clip,
            beam_width=options.beam_width,
            max_len_a=options.max_len_a,
            max_len_b=options.max_len_b,
            len_penalty_ratio=options.len_penalty_ratio,
            fp16=options.fp16,
            mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.train_path,
            caption_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=(options.local_rank < 0))
        num_processors = max(torch.cuda.device_count(),
                             1) if options.local_rank < 0 else 1
        mt_train_loader = None
        if options.mt_train_path is not None:
            mt_train_loader = ImageMTTrainer.get_mt_train_data(
                caption_model,
                num_processors,
                options,
                pin_memory,
                lex_dict=lex_dict)

        img_dev_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionTestDataset,
            options.dev_path,
            caption_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=False,
            denom=2)

        trainer.caption_reference = None
        if img_dev_loader is not None:
            trainer.caption_reference = defaultdict(list)
            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)
            for data in img_dev_loader:
                for batch in data:
                    for b in batch:
                        captions = b["captions"]
                        for id in captions:
                            for caption in captions[id]:
                                refs = get_outputs_until_eos(
                                    text_processor.sep_token_id(),
                                    caption,
                                    remove_first_token=True)
                                ref = [
                                    generator.seq2seq_model.text_processor.
                                    tokenizer.decode(ref.numpy())
                                    for ref in refs
                                ]
                                trainer.caption_reference[id] += ref
            print("Number of dev image/captions",
                  len(trainer.caption_reference))

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = ImageMTTrainer.get_mt_dev_data(caption_model,
                                                           options,
                                                           pin_memory,
                                                           text_processor,
                                                           trainer,
                                                           lex_dict=lex_dict)
            print("Number of dev sentences", len(trainer.reference))

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       img_dev_data_iter=img_dev_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       mt_train_iter=mt_train_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       accum=options.accum,
                                       mt_dev_iter=mt_dev_loader,
                                       mtl_weight=options.mtl_weight)
            train_epoch += 1