예제 #1
0
    def __init__(self,
                 text_processor: TextProcessor,
                 config: BertConfig = None,
                 encoder: BertModel = None,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(LM, self).__init__()
        self.text_processor: TextProcessor = text_processor

        if config is not None:
            self.config = config
        else:
            self.config = lm_config.get_config(
                vocab_size=text_processor.tokenizer.get_vocab_size(),
                pad_token_id=text_processor.pad_token_id(),
                bos_token_id=text_processor.bos_token_id(),
                eos_token_id=text_processor.sep_token_id(),
                enc_layer=enc_layer,
                embed_dim=embed_dim,
                intermediate_dim=intermediate_dim)

            self.config["type_vocab_size"] = len(text_processor.languages)
            self.config = BertConfig(**self.config)

        self.masked_lm = BertOutputLayer(self.config)
        if encoder is None:
            self.encoder: BertModel = BertModel(self.config)
            self.encoder.init_weights()
        else:
            self.encoder = encoder
        self.encoder._tie_or_clone_weights(
            self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
def create_batches(sentences,
                   src2dst_dict,
                   text_processor: TextProcessor,
                   resume_index=0,
                   end_index=-1):
    print(len(src2dst_dict))

    print("Getting batches...")
    index = 0

    for sid in src2dst_dict.keys():
        index += 1
        if index >= end_index and end_index > 0:
            break
        if index <= resume_index:
            continue
        tids = list(src2dst_dict[sid])
        source_tokenized = torch.LongTensor(tok_sen(sentences[sid]))
        trans_cands = list(
            map(lambda i: torch.LongTensor(tok_sen(sentences[i])), tids))
        candidates = pad_sequence(trans_cands,
                                  batch_first=True,
                                  padding_value=text_processor.pad_token_id())
        target_langs = list(
            map(
                lambda i: text_processor.lang_id(sentences[i].strip().split(
                    " ")[0]), tids))
        src_lang = torch.LongTensor(
            [text_processor.lang_id(sentences[sid].strip().split(" ")[0])])
        yield sid, source_tokenized, torch.LongTensor(
            tids), candidates, src_lang, torch.LongTensor(target_langs)
예제 #3
0
    def test_albert_seq2seq_init(self):
        path_dir_name = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(path_dir_name, "sample.txt")

        with tempfile.TemporaryDirectory() as tmpdirname:
            processor = TextProcessor()
            processor.train_tokenizer([data_path],
                                      vocab_size=1000,
                                      to_save_dir=tmpdirname,
                                      languages={
                                          "<en>": 0,
                                          "<fa>": 1
                                      })
            seq2seq = Seq2Seq(text_processor=processor)
            src_inputs = torch.tensor([[
                1, 2, 3, 4, 5,
                processor.pad_token_id(),
                processor.pad_token_id()
            ], [1, 2, 3, 4, 5, 6, processor.pad_token_id()]])
            tgt_inputs = torch.tensor(
                [[6, 8, 7,
                  processor.pad_token_id(),
                  processor.pad_token_id()],
                 [6, 8, 7, 8, processor.pad_token_id()]])
            src_mask = (src_inputs != processor.pad_token_id())
            tgt_mask = (tgt_inputs != processor.pad_token_id())
            src_langs = torch.tensor([[0], [0]]).squeeze()
            tgt_langs = torch.tensor([[1], [1]]).squeeze()
            seq_output = seq2seq(src_inputs,
                                 tgt_inputs,
                                 src_mask,
                                 tgt_mask,
                                 src_langs,
                                 tgt_langs,
                                 log_softmax=True)
            assert list(seq_output.size()) == [5, processor.vocab_size()]

            seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask,
                                 src_langs, tgt_langs)
            assert list(seq_output.size()) == [5, processor.vocab_size()]
예제 #4
0
def mass_mask(mask_prob, pad_indices, src_text,
              text_processor: TextProcessor) -> Dict:
    """
        20% of times, mask from start to middle
        20% of times, mask from middle to end
        60% of times, mask a random index
    """
    index_range = pad_indices - (1 - mask_prob) * pad_indices
    src_mask = torch.zeros(src_text.size(), dtype=torch.bool)
    to_recover = []
    to_recover_pos = []
    for i, irange in enumerate(index_range):
        range_size = int(pad_indices[i] / 2)
        r = random.random()
        last_idx = int(math.ceil(irange))
        if r > 0.8:
            start = 1
        elif r > 0.6:
            start = last_idx
        else:
            start = random.randint(2, last_idx) if last_idx >= 2 else 2

        end = start + range_size
        src_mask[i, start:end] = True
        to_recover.append(src_text[i, start - 1:end])
        to_recover_pos.append(torch.arange(start - 1, end))
    to_recover = pad_sequence(to_recover,
                              batch_first=True,
                              padding_value=text_processor.pad_token_id())
    to_recover_pos = pad_sequence(to_recover_pos,
                                  batch_first=True,
                                  padding_value=int(src_text.size(-1)) - 1)

    assert 0 < mask_prob < 1
    masked_ids = src_text[:, 1:][src_mask[:, 1:]]
    mask_idx = src_text[src_mask]
    random_index = lambda: random.randint(len(text_processor.special_tokens),
                                          text_processor.vocab_size() - 1)
    rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else (
        random_index() if r < 0.9 else int(mask_idx[c]))
    replacements = list(
        map(lambda i: rand_select(random.random(), i),
            range(mask_idx.size(0))))
    src_text[src_mask] = torch.LongTensor(replacements)
    return {
        "src_mask": src_mask,
        "targets": masked_ids,
        "src_text": src_text,
        "to_recover": to_recover,
        "positions": to_recover_pos,
        "mask_idx": mask_idx
    }
예제 #5
0
    def train(options):
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0
        num_processors = max(torch.cuda.device_count(), 1)

        mt_model = SenSim(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim,
                          intermediate_dim=options.intermediate_layer_dim)

        if options.pretrained_path is not None:
            pret = Seq2Seq.load(Seq2Seq, options.pretrained_path, tok_dir=options.tokenizer_path)
            mt_model.init_from_lm(pret)

        print("Model initialization done!")

        optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup)
        trainer = SenSimTrainer(model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip,
                                fp16=options.fp16)

        pin_memory = torch.cuda.is_available()

        mt_train_loader = SenSimTrainer.get_mt_train_data(mt_model, num_processors, options, pin_memory)
        src_neg_data = dataset.MassDataset(batch_pickle_dir=options.src_neg,
                                           max_batch_capacity=num_processors * options.total_capacity * 5,
                                           max_batch=num_processors * options.batch * 5,
                                           pad_idx=mt_model.text_processor.pad_token_id(), keep_pad_idx=False,
                                           max_seq_len=options.max_seq_len, keep_examples=False)
        dst_neg_data = dataset.MassDataset(batch_pickle_dir=options.dst_neg,
                                           max_batch_capacity=num_processors * options.total_capacity * 5,
                                           max_batch=num_processors * options.batch * 5,
                                           pad_idx=mt_model.text_processor.pad_token_id(), keep_pad_idx=False,
                                           max_seq_len=options.max_seq_len, keep_examples=False)

        src_neg_loader = data_utils.DataLoader(src_neg_data, batch_size=1, shuffle=True, pin_memory=pin_memory)
        dst_neg_loader = data_utils.DataLoader(dst_neg_data, batch_size=1, shuffle=True, pin_memory=pin_memory)

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = SenSimTrainer.get_mt_dev_data(mt_model, options, pin_memory, text_processor, trainer, )

        step, train_epoch = 0, 1
        trainer.best_loss = 1000000
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(mt_train_iter=mt_train_loader, max_step=options.step, mt_dev_iter=mt_dev_loader,
                                       saving_path=options.model_path, step=step, src_neg_iter=src_neg_loader,
                                       dst_neg_iter=dst_neg_loader)
            train_epoch += 1
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0

        image_captioner = Seq2Seq.load(ImageCaptioning, options.pretrained_path, tok_dir=options.tokenizer_path)
        txt2ImageModel = Caption2Image(text_processor=text_processor, enc_layer=options.encoder_layer,
                                       embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        optimizer = build_optimizer(txt2ImageModel, options.learning_rate, warump_steps=options.warmup)

        trainer = Caption2ImageTrainer(model=txt2ImageModel, caption_model=image_captioner, mask_prob=options.mask_prob,
                                       optimizer=optimizer,
                                       clip=options.clip,
                                       beam_width=options.beam_width, max_len_a=options.max_len_a,
                                       max_len_b=options.max_len_b, len_penalty_ratio=options.len_penalty_ratio,
                                       fp16=options.fp16, mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(collator, dataset.ImageCaptionDataset, options.train_path,
                                                         txt2ImageModel, num_batches, options, pin_memory,
                                                         lex_dict=lex_dict)

        img_dev_loader = ImageMTTrainer.get_img_loader(collator, dataset.ImageCaptionDataset, options.dev_path,
                                                       txt2ImageModel, num_batches, options, pin_memory,
                                                       lex_dict=lex_dict,
                                                       shuffle=False, denom=2)

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader, img_dev_data_iter=img_dev_loader,
                                       max_step=options.step, lex_dict=lex_dict,
                                       saving_path=options.model_path, step=step)
            train_epoch += 1
예제 #7
0
    def __init__(self,
                 text_processor: TextProcessor,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(SenSim, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(
            vocab_size=text_processor.tokenizer.get_vocab_size(),
            pad_token_id=text_processor.pad_token_id(),
            bos_token_id=text_processor.bos_token_id(),
            eos_token_id=text_processor.sep_token_id(),
            enc_layer=enc_layer,
            embed_dim=embed_dim,
            intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.input_attention = nn.Linear(self.config.hidden_size, 1)
    def __init__(self,
                 text_processor: TextProcessor,
                 config: ReformerConfig = None,
                 size: int = 1):
        """
        :param size: config size: 1 small, 2 medium, 3 base.
        """
        super(ReformerLM, self).__init__()
        self.text_processor: TextProcessor = text_processor

        if config is not None:
            self.config = config
        else:
            config_func = _small_config if size == 1 else (
                _base_config if size == 3 else _medium_config)
            self.config = config_func(
                vocab_size=text_processor.tokenizer.get_vocab_size(),
                pad_token_id=text_processor.pad_token_id(),
                eos_token_id=text_processor.sep_token_id())
            self.config = ReformerConfig(**self.config)

        reformer = ReformerModelWithLMHead(self.config)
        self.lm_head: ReformerOnlyLMHead = reformer.lm_head
        self.encoder: ReformerModel = reformer.reformer
예제 #9
0
    def __init__(self,
                 text_processor: TextProcessor,
                 lang_dec: bool = True,
                 use_proposals=False,
                 tie_embed=False,
                 enc_layer: int = 6,
                 dec_layer: int = 3,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072,
                 freeze_image: bool = False,
                 resnet_depth: int = 1):
        super(Seq2Seq, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(
            vocab_size=text_processor.tokenizer.get_vocab_size(),
            pad_token_id=text_processor.pad_token_id(),
            bos_token_id=text_processor.bos_token_id(),
            eos_token_id=text_processor.sep_token_id(),
            enc_layer=enc_layer,
            embed_dim=embed_dim,
            intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.dec_layer = dec_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        dec_config = copy.deepcopy(self.config)
        dec_config.num_hidden_layers = self.dec_layer

        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.lang_dec = lang_dec
        self.tie_embed = tie_embed
        if not lang_dec:
            self.decoder = BertDecoderModel(dec_config)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.position_embeddings,
                self.decoder.embeddings.position_embeddings)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.token_type_embeddings,
                self.decoder.embeddings.token_type_embeddings)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.word_embeddings,
                self.decoder.embeddings.word_embeddings)

            if tie_embed:
                self.output_layer = BertOutputLayer(dec_config)
                self.encoder._tie_or_clone_weights(
                    self.output_layer, self.encoder.embeddings.word_embeddings)
                self.encoder._tie_or_clone_weights(
                    self.encoder.embeddings.position_embeddings,
                    self.decoder.embeddings.position_embeddings)
                self.decoder._tie_or_clone_weights(
                    self.output_layer, self.decoder.embeddings.word_embeddings)
            else:
                self.output_layer = nn.ModuleList([
                    BertOutputLayer(dec_config)
                    for _ in text_processor.languages
                ])

            if len(self.encoder.encoder.layer) == len(
                    self.decoder.decoder.layer):
                for i in range(len(self.encoder.encoder.layer)):
                    self.decoder.decoder.layer[
                        i].attention = self.encoder.encoder.layer[i].attention

        else:
            dec = BertDecoderModel(dec_config)
            self.decoder = nn.ModuleList(
                [copy.deepcopy(dec) for _ in text_processor.languages])
            self.output_layer = nn.ModuleList([
                BertOutputLayer(dec_config) for _ in text_processor.languages
            ])
            for i, dec in enumerate(self.decoder):
                if tie_embed:
                    self.encoder._tie_or_clone_weights(
                        self.output_layer[i],
                        self.encoder.embeddings.word_embeddings)
                    dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings
                dec._tie_or_clone_weights(self.output_layer[i],
                                          dec.embeddings.word_embeddings)
                dec._tie_or_clone_weights(
                    self.encoder.embeddings.token_type_embeddings,
                    dec.embeddings.token_type_embeddings)

        self.use_proposals = use_proposals
        if self.use_proposals:
            self.proposal_embedding = self.encoder.embeddings.word_embeddings
            self.lexical_gate = nn.Parameter(torch.zeros(
                1, self.config.hidden_size).fill_(0.1),
                                             requires_grad=True)
            self.lexical_layer_norm = nn.LayerNorm(
                self.config.hidden_size, eps=self.config.layer_norm_eps)

        self.freeze_image = freeze_image
        self.resnet_depth = resnet_depth
예제 #10
0
                cur_capacity = 2 * (max(int(src_input.size(0)),
                                        int(tgt_inputs_all.size(1)))**3) * int(
                                            tgt_inputs_all.size(0))
                split_size = int(math.ceil(cur_capacity / max_capacity))
                split_size = max(1,
                                 int(math.floor(len(tids_all) / split_size)))

                tgt_inputs_spl = torch.split(tgt_inputs_all, split_size)
                tids_spl = torch.split(tids_all, split_size)
                dst_langs_spl = torch.split(dst_langs_all, split_size)

                trans_score = dict()
                for spl_i in range(len(tgt_inputs_spl)):
                    src_input = src_input.view(-1,
                                               src_input.size(0)).to(device)
                    src_mask = (src_input != text_processor.pad_token_id())
                    src_lang = src_lang.to(device)
                    encoder_states = model.encode(
                        src_input, src_mask,
                        src_lang.expand(src_input.size()))[0]

                    tgt_inputs, tids, dst_langs = tgt_inputs_spl[
                        spl_i], tids_spl[spl_i], dst_langs_spl[spl_i]
                    tgt_mask = (tgt_inputs !=
                                text_processor.pad_token_id()).to(device)

                    tgt_inputs = tgt_inputs.to(device)
                    dst_langs = dst_langs.to(device)
                    batch_lang = int(dst_langs[0])
                    subseq_mask = future_mask(tgt_mask[:, :-1])
                    if subseq_mask.device != tgt_inputs.device:
예제 #11
0
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0
        num_processors = max(torch.cuda.device_count(), 1)

        if options.pretrained_path is not None:
            print("Loading pretrained path", options.pretrained_path)
            mt_model = Seq2Seq.load(ImageMassSeq2Seq,
                                    options.pretrained_path,
                                    tok_dir=options.tokenizer_path)
        else:
            mt_model = ImageMassSeq2Seq(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim)

        if options.lm_path is not None:
            lm = LM(text_processor=text_processor,
                    enc_layer=options.encoder_layer,
                    embed_dim=options.embed_dim,
                    intermediate_dim=options.intermediate_layer_dim)
            mt_model.init_from_lm(lm)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(mt_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageMTTrainer(model=mt_model,
                                 mask_prob=options.mask_prob,
                                 optimizer=optimizer,
                                 clip=options.clip,
                                 beam_width=options.beam_width,
                                 max_len_a=options.max_len_a,
                                 max_len_b=options.max_len_b,
                                 len_penalty_ratio=options.len_penalty_ratio,
                                 fp16=options.fp16,
                                 mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()

        mt_train_loader = None
        if options.mt_train_path is not None:
            mt_train_loader = ImageMTTrainer.get_mt_train_data(
                mt_model,
                num_processors,
                options,
                pin_memory,
                lex_dict=lex_dict)

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = ImageMTTrainer.get_mt_dev_data(mt_model,
                                                           options,
                                                           pin_memory,
                                                           text_processor,
                                                           trainer,
                                                           lex_dict=lex_dict)

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step and train_epoch <= 10:
            print("train epoch", train_epoch, "step:", step)
            step = trainer.train_epoch(mt_train_iter=mt_train_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       mt_dev_iter=mt_dev_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       save_opt=False)
            train_epoch += 1
예제 #12
0
    def train(options):
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)

        lm_class = ReformerLM if options.reformer else LM
        if options.pretrained_path is None:
            lm = lm_class(text_processor=text_processor,
                          size=options.model_size)
        else:
            lm = lm_class.load(options.pretrained_path)

        if options.reformer:
            lm.config.hidden_dropout_prob = options.dropout
            lm.config.local_attention_probs_dropout_prob = options.dropout
            lm.config.lsh_attention_probs_dropout_prob = options.dropout
        else:
            LMTrainer.config_dropout(lm, options.dropout)

        train_data = dataset.TextDataset(save_cache_dir=options.train_path,
                                         max_cache_size=options.cache_size)
        dev_data = dataset.TextDataset(save_cache_dir=options.dev_path,
                                       max_cache_size=options.cache_size,
                                       load_all=True)

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(lm, options.learning_rate,
                                        options.warmup)

        trainer = LMTrainer(model=lm,
                            mask_prob=options.mask_prob,
                            optimizer=optimizer,
                            clip=options.clip)

        collator = dataset.TextCollator(pad_idx=text_processor.pad_token_id())
        train_sampler, dev_sampler = None, None

        pin_memory = torch.cuda.is_available()
        loader = data_utils.DataLoader(train_data,
                                       batch_size=options.batch,
                                       shuffle=False,
                                       pin_memory=pin_memory,
                                       collate_fn=collator,
                                       sampler=train_sampler)
        dev_loader = data_utils.DataLoader(dev_data,
                                           batch_size=options.batch,
                                           shuffle=False,
                                           pin_memory=pin_memory,
                                           collate_fn=collator,
                                           sampler=dev_sampler)

        step, train_epoch = 0, 1
        while step <= options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(data_iter=loader,
                                       dev_data_iter=dev_loader,
                                       saving_path=options.model_path,
                                       step=step)
예제 #13
0
    def __init__(self,
                 root_img_dir: str,
                 data_bin_file: str,
                 max_capacity: int,
                 text_processor: TextProcessor,
                 max_img_per_batch: int,
                 lex_dict=None,
                 ngpu=1):
        self.ngpu = ngpu
        self.lex_dict = lex_dict
        self.size_transform = transforms.Resize(256)
        self.crop = transforms.CenterCrop(224)
        self.to_tensor = transforms.ToTensor()
        self.img_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225])
        self.pad_idx = text_processor.pad_token_id()
        self.batches = []
        self.root_img_dir = root_img_dir
        max_capacity *= 1000000
        self.image_batches = []
        self.lang_ids = set()
        self.all_captions = []

        print("Start", datetime.datetime.now())
        cur_batch, cur_imgs, cur_lex_cand_batch = [], [], []
        cur_max_len = 0
        with open(data_bin_file, "rb") as fp:
            self.unique_images, captions = marshal.load(fp)
            lang_id = text_processor.id2token(captions[0][1][0])
            self.lang_ids.add(int(captions[0][1][0]))
            self.lang = text_processor.languages[
                lang_id] if lang_id in text_processor.languages else 0
            for caption_info in captions:
                image_id, caption = caption_info
                if self.unique_images[image_id].lower().endswith(".png"):
                    continue
                caption = torch.LongTensor(caption)
                cur_batch.append(caption)
                self.all_captions.append(caption)
                if self.lex_dict is not None:
                    lex_cands = get_lex_suggestions(
                        self.lex_dict, caption, text_processor.pad_token_id())
                    cur_lex_cand_batch.append(lex_cands)

                cur_imgs.append(image_id)
                cur_max_len = max(cur_max_len, len(caption))
                batch_capacity_size = 2 * (cur_max_len**3) * len(cur_batch)
                if (len(cur_imgs) > max_img_per_batch
                        or batch_capacity_size > max_capacity
                    ) and len(
                        cur_batch[:-1]) >= self.ngpu and len(cur_batch) > 1:
                    batch_tensor = pad_sequence(cur_batch[:-1],
                                                batch_first=True,
                                                padding_value=self.pad_idx)
                    lex_cand_batch = None
                    if self.lex_dict is not None:
                        lex_cand_batch = pad_sequence(
                            cur_lex_cand_batch[:-1],
                            batch_first=True,
                            padding_value=self.pad_idx)
                        cur_lex_cand_batch = [cur_lex_cand_batch[-1]]
                    pads = batch_tensor != self.pad_idx
                    pad_indices = [int(pads.size(1)) - 1] * int(pads.size(0))
                    pindices = torch.nonzero(~pads)
                    for (r, c) in pindices:
                        pad_indices[r] = min(pad_indices[r], int(c))

                    self.batches.append(
                        (batch_tensor, pads, torch.LongTensor(pad_indices),
                         lex_cand_batch))
                    self.image_batches.append(cur_imgs[:-1])

                    cur_batch = [cur_batch[-1]]
                    cur_imgs = [cur_imgs[-1]]
                    cur_max_len = len(cur_batch[0])

            if len(cur_batch) > 0:
                batch_tensor = pad_sequence(cur_batch,
                                            batch_first=True,
                                            padding_value=self.pad_idx)
                pads = batch_tensor != self.pad_idx
                pad_indices = [int(pads.size(1)) - 1] * int(pads.size(0))
                lex_cand_batch = None
                if self.lex_dict is not None:
                    lex_cand_batch = pad_sequence(cur_lex_cand_batch,
                                                  batch_first=True,
                                                  padding_value=self.pad_idx)

                pindices = torch.nonzero(~pads)
                for (r, c) in pindices:
                    pad_indices[r] = min(pad_indices[r], int(c))

                self.batches.append(
                    (batch_tensor, pads, torch.LongTensor(pad_indices),
                     lex_cand_batch))
                self.image_batches.append(cur_imgs)

        print(
            "Loaded %d image batches of %d unique images and %d all captions!"
            % (len(self.batches), len(
                self.unique_images), len(self.all_captions)))
        print("End", datetime.datetime.now())
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0

        if options.pretrained_path is not None:
            mt_model = Seq2Seq.load(ImageCaptioning,
                                    options.pretrained_path,
                                    tok_dir=options.tokenizer_path)
        else:
            mt_model = ImageCaptioning(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim)

        if options.lm_path is not None:
            lm = LM(text_processor=text_processor,
                    enc_layer=options.encoder_layer,
                    embed_dim=options.embed_dim,
                    intermediate_dim=options.intermediate_layer_dim)
            mt_model.init_from_lm(lm)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(mt_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageCaptionTrainer(
            model=mt_model,
            mask_prob=options.mask_prob,
            optimizer=optimizer,
            clip=options.clip,
            beam_width=options.beam_width,
            max_len_a=options.max_len_a,
            max_len_b=options.max_len_b,
            len_penalty_ratio=options.len_penalty_ratio,
            fp16=options.fp16,
            mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.train_path,
            mt_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict)

        img_dev_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.dev_path,
            mt_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=False,
            denom=2)

        trainer.reference = None
        if img_dev_loader is not None:
            trainer.reference = []
            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)
            for data in img_dev_loader:
                for batch in data:
                    captions = [b["captions"] for b in batch]
                    for caption in captions:
                        refs = get_outputs_until_eos(
                            text_processor.sep_token_id(),
                            caption,
                            remove_first_token=True)
                        ref = [
                            generator.seq2seq_model.text_processor.tokenizer.
                            decode(ref.numpy()) for ref in refs
                        ]
                        trainer.reference += ref

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       img_dev_data_iter=img_dev_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       saving_path=options.model_path,
                                       step=step)
            train_epoch += 1