def __init__(self, text_processor: TextProcessor, config: BertConfig = None, encoder: BertModel = None, enc_layer: int = 6, embed_dim: int = 768, intermediate_dim: int = 3072): super(LM, self).__init__() self.text_processor: TextProcessor = text_processor if config is not None: self.config = config else: self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) self.masked_lm = BertOutputLayer(self.config) if encoder is None: self.encoder: BertModel = BertModel(self.config) self.encoder.init_weights() else: self.encoder = encoder self.encoder._tie_or_clone_weights( self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
def create_batches(sentences, src2dst_dict, text_processor: TextProcessor, resume_index=0, end_index=-1): print(len(src2dst_dict)) print("Getting batches...") index = 0 for sid in src2dst_dict.keys(): index += 1 if index >= end_index and end_index > 0: break if index <= resume_index: continue tids = list(src2dst_dict[sid]) source_tokenized = torch.LongTensor(tok_sen(sentences[sid])) trans_cands = list( map(lambda i: torch.LongTensor(tok_sen(sentences[i])), tids)) candidates = pad_sequence(trans_cands, batch_first=True, padding_value=text_processor.pad_token_id()) target_langs = list( map( lambda i: text_processor.lang_id(sentences[i].strip().split( " ")[0]), tids)) src_lang = torch.LongTensor( [text_processor.lang_id(sentences[sid].strip().split(" ")[0])]) yield sid, source_tokenized, torch.LongTensor( tids), candidates, src_lang, torch.LongTensor(target_langs)
def test_albert_seq2seq_init(self): path_dir_name = os.path.dirname(os.path.realpath(__file__)) data_path = os.path.join(path_dir_name, "sample.txt") with tempfile.TemporaryDirectory() as tmpdirname: processor = TextProcessor() processor.train_tokenizer([data_path], vocab_size=1000, to_save_dir=tmpdirname, languages={ "<en>": 0, "<fa>": 1 }) seq2seq = Seq2Seq(text_processor=processor) src_inputs = torch.tensor([[ 1, 2, 3, 4, 5, processor.pad_token_id(), processor.pad_token_id() ], [1, 2, 3, 4, 5, 6, processor.pad_token_id()]]) tgt_inputs = torch.tensor( [[6, 8, 7, processor.pad_token_id(), processor.pad_token_id()], [6, 8, 7, 8, processor.pad_token_id()]]) src_mask = (src_inputs != processor.pad_token_id()) tgt_mask = (tgt_inputs != processor.pad_token_id()) src_langs = torch.tensor([[0], [0]]).squeeze() tgt_langs = torch.tensor([[1], [1]]).squeeze() seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask, src_langs, tgt_langs, log_softmax=True) assert list(seq_output.size()) == [5, processor.vocab_size()] seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask, src_langs, tgt_langs) assert list(seq_output.size()) == [5, processor.vocab_size()]
def mass_mask(mask_prob, pad_indices, src_text, text_processor: TextProcessor) -> Dict: """ 20% of times, mask from start to middle 20% of times, mask from middle to end 60% of times, mask a random index """ index_range = pad_indices - (1 - mask_prob) * pad_indices src_mask = torch.zeros(src_text.size(), dtype=torch.bool) to_recover = [] to_recover_pos = [] for i, irange in enumerate(index_range): range_size = int(pad_indices[i] / 2) r = random.random() last_idx = int(math.ceil(irange)) if r > 0.8: start = 1 elif r > 0.6: start = last_idx else: start = random.randint(2, last_idx) if last_idx >= 2 else 2 end = start + range_size src_mask[i, start:end] = True to_recover.append(src_text[i, start - 1:end]) to_recover_pos.append(torch.arange(start - 1, end)) to_recover = pad_sequence(to_recover, batch_first=True, padding_value=text_processor.pad_token_id()) to_recover_pos = pad_sequence(to_recover_pos, batch_first=True, padding_value=int(src_text.size(-1)) - 1) assert 0 < mask_prob < 1 masked_ids = src_text[:, 1:][src_mask[:, 1:]] mask_idx = src_text[src_mask] random_index = lambda: random.randint(len(text_processor.special_tokens), text_processor.vocab_size() - 1) rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else ( random_index() if r < 0.9 else int(mask_idx[c])) replacements = list( map(lambda i: rand_select(random.random(), i), range(mask_idx.size(0)))) src_text[src_mask] = torch.LongTensor(replacements) return { "src_mask": src_mask, "targets": masked_ids, "src_text": src_text, "to_recover": to_recover, "positions": to_recover_pos, "mask_idx": mask_idx }
def train(options): if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) assert text_processor.pad_token_id() == 0 num_processors = max(torch.cuda.device_count(), 1) mt_model = SenSim(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) if options.pretrained_path is not None: pret = Seq2Seq.load(Seq2Seq, options.pretrained_path, tok_dir=options.tokenizer_path) mt_model.init_from_lm(pret) print("Model initialization done!") optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup) trainer = SenSimTrainer(model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, fp16=options.fp16) pin_memory = torch.cuda.is_available() mt_train_loader = SenSimTrainer.get_mt_train_data(mt_model, num_processors, options, pin_memory) src_neg_data = dataset.MassDataset(batch_pickle_dir=options.src_neg, max_batch_capacity=num_processors * options.total_capacity * 5, max_batch=num_processors * options.batch * 5, pad_idx=mt_model.text_processor.pad_token_id(), keep_pad_idx=False, max_seq_len=options.max_seq_len, keep_examples=False) dst_neg_data = dataset.MassDataset(batch_pickle_dir=options.dst_neg, max_batch_capacity=num_processors * options.total_capacity * 5, max_batch=num_processors * options.batch * 5, pad_idx=mt_model.text_processor.pad_token_id(), keep_pad_idx=False, max_seq_len=options.max_seq_len, keep_examples=False) src_neg_loader = data_utils.DataLoader(src_neg_data, batch_size=1, shuffle=True, pin_memory=pin_memory) dst_neg_loader = data_utils.DataLoader(dst_neg_data, batch_size=1, shuffle=True, pin_memory=pin_memory) mt_dev_loader = None if options.mt_dev_path is not None: mt_dev_loader = SenSimTrainer.get_mt_dev_data(mt_model, options, pin_memory, text_processor, trainer, ) step, train_epoch = 0, 1 trainer.best_loss = 1000000 while options.step > 0 and step < options.step: print("train epoch", train_epoch) step = trainer.train_epoch(mt_train_iter=mt_train_loader, max_step=options.step, mt_dev_iter=mt_dev_loader, saving_path=options.model_path, step=step, src_neg_iter=src_neg_loader, dst_neg_iter=dst_neg_loader) train_epoch += 1
def train(options): lex_dict = None if options.dict_path is not None: lex_dict = get_lex_dict(options.dict_path) if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) assert text_processor.pad_token_id() == 0 image_captioner = Seq2Seq.load(ImageCaptioning, options.pretrained_path, tok_dir=options.tokenizer_path) txt2ImageModel = Caption2Image(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) print("Model initialization done!") # We assume that the collator function returns a list with the size of number of gpus (in case of cpus, collator = dataset.ImageTextCollator() num_batches = max(1, torch.cuda.device_count()) optimizer = build_optimizer(txt2ImageModel, options.learning_rate, warump_steps=options.warmup) trainer = Caption2ImageTrainer(model=txt2ImageModel, caption_model=image_captioner, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, beam_width=options.beam_width, max_len_a=options.max_len_a, max_len_b=options.max_len_b, len_penalty_ratio=options.len_penalty_ratio, fp16=options.fp16, mm_mode=options.mm_mode) pin_memory = torch.cuda.is_available() img_train_loader = ImageMTTrainer.get_img_loader(collator, dataset.ImageCaptionDataset, options.train_path, txt2ImageModel, num_batches, options, pin_memory, lex_dict=lex_dict) img_dev_loader = ImageMTTrainer.get_img_loader(collator, dataset.ImageCaptionDataset, options.dev_path, txt2ImageModel, num_batches, options, pin_memory, lex_dict=lex_dict, shuffle=False, denom=2) step, train_epoch = 0, 1 while options.step > 0 and step < options.step: print("train epoch", train_epoch) step = trainer.train_epoch(img_data_iter=img_train_loader, img_dev_data_iter=img_dev_loader, max_step=options.step, lex_dict=lex_dict, saving_path=options.model_path, step=step) train_epoch += 1
def __init__(self, text_processor: TextProcessor, enc_layer: int = 6, embed_dim: int = 768, intermediate_dim: int = 3072): super(SenSim, self).__init__() self.text_processor: TextProcessor = text_processor self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.enc_layer = enc_layer self.embed_dim = embed_dim self.intermediate_dim = intermediate_dim self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) self.encoder = BertEncoderModel(self.config) self.encoder.init_weights() self.input_attention = nn.Linear(self.config.hidden_size, 1)
def __init__(self, text_processor: TextProcessor, config: ReformerConfig = None, size: int = 1): """ :param size: config size: 1 small, 2 medium, 3 base. """ super(ReformerLM, self).__init__() self.text_processor: TextProcessor = text_processor if config is not None: self.config = config else: config_func = _small_config if size == 1 else ( _base_config if size == 3 else _medium_config) self.config = config_func( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), eos_token_id=text_processor.sep_token_id()) self.config = ReformerConfig(**self.config) reformer = ReformerModelWithLMHead(self.config) self.lm_head: ReformerOnlyLMHead = reformer.lm_head self.encoder: ReformerModel = reformer.reformer
def __init__(self, text_processor: TextProcessor, lang_dec: bool = True, use_proposals=False, tie_embed=False, enc_layer: int = 6, dec_layer: int = 3, embed_dim: int = 768, intermediate_dim: int = 3072, freeze_image: bool = False, resnet_depth: int = 1): super(Seq2Seq, self).__init__() self.text_processor: TextProcessor = text_processor self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.enc_layer = enc_layer self.dec_layer = dec_layer self.embed_dim = embed_dim self.intermediate_dim = intermediate_dim self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) dec_config = copy.deepcopy(self.config) dec_config.num_hidden_layers = self.dec_layer self.encoder = BertEncoderModel(self.config) self.encoder.init_weights() self.lang_dec = lang_dec self.tie_embed = tie_embed if not lang_dec: self.decoder = BertDecoderModel(dec_config) self.encoder._tie_or_clone_weights( self.encoder.embeddings.position_embeddings, self.decoder.embeddings.position_embeddings) self.encoder._tie_or_clone_weights( self.encoder.embeddings.token_type_embeddings, self.decoder.embeddings.token_type_embeddings) self.encoder._tie_or_clone_weights( self.encoder.embeddings.word_embeddings, self.decoder.embeddings.word_embeddings) if tie_embed: self.output_layer = BertOutputLayer(dec_config) self.encoder._tie_or_clone_weights( self.output_layer, self.encoder.embeddings.word_embeddings) self.encoder._tie_or_clone_weights( self.encoder.embeddings.position_embeddings, self.decoder.embeddings.position_embeddings) self.decoder._tie_or_clone_weights( self.output_layer, self.decoder.embeddings.word_embeddings) else: self.output_layer = nn.ModuleList([ BertOutputLayer(dec_config) for _ in text_processor.languages ]) if len(self.encoder.encoder.layer) == len( self.decoder.decoder.layer): for i in range(len(self.encoder.encoder.layer)): self.decoder.decoder.layer[ i].attention = self.encoder.encoder.layer[i].attention else: dec = BertDecoderModel(dec_config) self.decoder = nn.ModuleList( [copy.deepcopy(dec) for _ in text_processor.languages]) self.output_layer = nn.ModuleList([ BertOutputLayer(dec_config) for _ in text_processor.languages ]) for i, dec in enumerate(self.decoder): if tie_embed: self.encoder._tie_or_clone_weights( self.output_layer[i], self.encoder.embeddings.word_embeddings) dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings dec._tie_or_clone_weights(self.output_layer[i], dec.embeddings.word_embeddings) dec._tie_or_clone_weights( self.encoder.embeddings.token_type_embeddings, dec.embeddings.token_type_embeddings) self.use_proposals = use_proposals if self.use_proposals: self.proposal_embedding = self.encoder.embeddings.word_embeddings self.lexical_gate = nn.Parameter(torch.zeros( 1, self.config.hidden_size).fill_(0.1), requires_grad=True) self.lexical_layer_norm = nn.LayerNorm( self.config.hidden_size, eps=self.config.layer_norm_eps) self.freeze_image = freeze_image self.resnet_depth = resnet_depth
cur_capacity = 2 * (max(int(src_input.size(0)), int(tgt_inputs_all.size(1)))**3) * int( tgt_inputs_all.size(0)) split_size = int(math.ceil(cur_capacity / max_capacity)) split_size = max(1, int(math.floor(len(tids_all) / split_size))) tgt_inputs_spl = torch.split(tgt_inputs_all, split_size) tids_spl = torch.split(tids_all, split_size) dst_langs_spl = torch.split(dst_langs_all, split_size) trans_score = dict() for spl_i in range(len(tgt_inputs_spl)): src_input = src_input.view(-1, src_input.size(0)).to(device) src_mask = (src_input != text_processor.pad_token_id()) src_lang = src_lang.to(device) encoder_states = model.encode( src_input, src_mask, src_lang.expand(src_input.size()))[0] tgt_inputs, tids, dst_langs = tgt_inputs_spl[ spl_i], tids_spl[spl_i], dst_langs_spl[spl_i] tgt_mask = (tgt_inputs != text_processor.pad_token_id()).to(device) tgt_inputs = tgt_inputs.to(device) dst_langs = dst_langs.to(device) batch_lang = int(dst_langs[0]) subseq_mask = future_mask(tgt_mask[:, :-1]) if subseq_mask.device != tgt_inputs.device:
def train(options): lex_dict = None if options.dict_path is not None: lex_dict = get_lex_dict(options.dict_path) if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) assert text_processor.pad_token_id() == 0 num_processors = max(torch.cuda.device_count(), 1) if options.pretrained_path is not None: print("Loading pretrained path", options.pretrained_path) mt_model = Seq2Seq.load(ImageMassSeq2Seq, options.pretrained_path, tok_dir=options.tokenizer_path) else: mt_model = ImageMassSeq2Seq( use_proposals=lex_dict is not None, tie_embed=options.tie_embed, text_processor=text_processor, resnet_depth=options.resnet_depth, lang_dec=options.lang_decoder, enc_layer=options.encoder_layer, dec_layer=options.decoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) if options.lm_path is not None: lm = LM(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) mt_model.init_from_lm(lm) print("Model initialization done!") # We assume that the collator function returns a list with the size of number of gpus (in case of cpus, collator = dataset.ImageTextCollator() num_batches = max(1, torch.cuda.device_count()) if options.continue_train: with open(os.path.join(options.pretrained_path, "optim"), "rb") as fp: optimizer = pickle.load(fp) else: optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup) trainer = ImageMTTrainer(model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, beam_width=options.beam_width, max_len_a=options.max_len_a, max_len_b=options.max_len_b, len_penalty_ratio=options.len_penalty_ratio, fp16=options.fp16, mm_mode=options.mm_mode) pin_memory = torch.cuda.is_available() mt_train_loader = None if options.mt_train_path is not None: mt_train_loader = ImageMTTrainer.get_mt_train_data( mt_model, num_processors, options, pin_memory, lex_dict=lex_dict) mt_dev_loader = None if options.mt_dev_path is not None: mt_dev_loader = ImageMTTrainer.get_mt_dev_data(mt_model, options, pin_memory, text_processor, trainer, lex_dict=lex_dict) step, train_epoch = 0, 1 while options.step > 0 and step < options.step and train_epoch <= 10: print("train epoch", train_epoch, "step:", step) step = trainer.train_epoch(mt_train_iter=mt_train_loader, max_step=options.step, lex_dict=lex_dict, mt_dev_iter=mt_dev_loader, saving_path=options.model_path, step=step, save_opt=False) train_epoch += 1
def train(options): if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) lm_class = ReformerLM if options.reformer else LM if options.pretrained_path is None: lm = lm_class(text_processor=text_processor, size=options.model_size) else: lm = lm_class.load(options.pretrained_path) if options.reformer: lm.config.hidden_dropout_prob = options.dropout lm.config.local_attention_probs_dropout_prob = options.dropout lm.config.lsh_attention_probs_dropout_prob = options.dropout else: LMTrainer.config_dropout(lm, options.dropout) train_data = dataset.TextDataset(save_cache_dir=options.train_path, max_cache_size=options.cache_size) dev_data = dataset.TextDataset(save_cache_dir=options.dev_path, max_cache_size=options.cache_size, load_all=True) if options.continue_train: with open(os.path.join(options.pretrained_path, "optim"), "rb") as fp: optimizer = pickle.load(fp) else: optimizer = build_optimizer(lm, options.learning_rate, options.warmup) trainer = LMTrainer(model=lm, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip) collator = dataset.TextCollator(pad_idx=text_processor.pad_token_id()) train_sampler, dev_sampler = None, None pin_memory = torch.cuda.is_available() loader = data_utils.DataLoader(train_data, batch_size=options.batch, shuffle=False, pin_memory=pin_memory, collate_fn=collator, sampler=train_sampler) dev_loader = data_utils.DataLoader(dev_data, batch_size=options.batch, shuffle=False, pin_memory=pin_memory, collate_fn=collator, sampler=dev_sampler) step, train_epoch = 0, 1 while step <= options.step: print("train epoch", train_epoch) step = trainer.train_epoch(data_iter=loader, dev_data_iter=dev_loader, saving_path=options.model_path, step=step)
def __init__(self, root_img_dir: str, data_bin_file: str, max_capacity: int, text_processor: TextProcessor, max_img_per_batch: int, lex_dict=None, ngpu=1): self.ngpu = ngpu self.lex_dict = lex_dict self.size_transform = transforms.Resize(256) self.crop = transforms.CenterCrop(224) self.to_tensor = transforms.ToTensor() self.img_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.pad_idx = text_processor.pad_token_id() self.batches = [] self.root_img_dir = root_img_dir max_capacity *= 1000000 self.image_batches = [] self.lang_ids = set() self.all_captions = [] print("Start", datetime.datetime.now()) cur_batch, cur_imgs, cur_lex_cand_batch = [], [], [] cur_max_len = 0 with open(data_bin_file, "rb") as fp: self.unique_images, captions = marshal.load(fp) lang_id = text_processor.id2token(captions[0][1][0]) self.lang_ids.add(int(captions[0][1][0])) self.lang = text_processor.languages[ lang_id] if lang_id in text_processor.languages else 0 for caption_info in captions: image_id, caption = caption_info if self.unique_images[image_id].lower().endswith(".png"): continue caption = torch.LongTensor(caption) cur_batch.append(caption) self.all_captions.append(caption) if self.lex_dict is not None: lex_cands = get_lex_suggestions( self.lex_dict, caption, text_processor.pad_token_id()) cur_lex_cand_batch.append(lex_cands) cur_imgs.append(image_id) cur_max_len = max(cur_max_len, len(caption)) batch_capacity_size = 2 * (cur_max_len**3) * len(cur_batch) if (len(cur_imgs) > max_img_per_batch or batch_capacity_size > max_capacity ) and len( cur_batch[:-1]) >= self.ngpu and len(cur_batch) > 1: batch_tensor = pad_sequence(cur_batch[:-1], batch_first=True, padding_value=self.pad_idx) lex_cand_batch = None if self.lex_dict is not None: lex_cand_batch = pad_sequence( cur_lex_cand_batch[:-1], batch_first=True, padding_value=self.pad_idx) cur_lex_cand_batch = [cur_lex_cand_batch[-1]] pads = batch_tensor != self.pad_idx pad_indices = [int(pads.size(1)) - 1] * int(pads.size(0)) pindices = torch.nonzero(~pads) for (r, c) in pindices: pad_indices[r] = min(pad_indices[r], int(c)) self.batches.append( (batch_tensor, pads, torch.LongTensor(pad_indices), lex_cand_batch)) self.image_batches.append(cur_imgs[:-1]) cur_batch = [cur_batch[-1]] cur_imgs = [cur_imgs[-1]] cur_max_len = len(cur_batch[0]) if len(cur_batch) > 0: batch_tensor = pad_sequence(cur_batch, batch_first=True, padding_value=self.pad_idx) pads = batch_tensor != self.pad_idx pad_indices = [int(pads.size(1)) - 1] * int(pads.size(0)) lex_cand_batch = None if self.lex_dict is not None: lex_cand_batch = pad_sequence(cur_lex_cand_batch, batch_first=True, padding_value=self.pad_idx) pindices = torch.nonzero(~pads) for (r, c) in pindices: pad_indices[r] = min(pad_indices[r], int(c)) self.batches.append( (batch_tensor, pads, torch.LongTensor(pad_indices), lex_cand_batch)) self.image_batches.append(cur_imgs) print( "Loaded %d image batches of %d unique images and %d all captions!" % (len(self.batches), len( self.unique_images), len(self.all_captions))) print("End", datetime.datetime.now())
def train(options): lex_dict = None if options.dict_path is not None: lex_dict = get_lex_dict(options.dict_path) if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) assert text_processor.pad_token_id() == 0 if options.pretrained_path is not None: mt_model = Seq2Seq.load(ImageCaptioning, options.pretrained_path, tok_dir=options.tokenizer_path) else: mt_model = ImageCaptioning( use_proposals=lex_dict is not None, tie_embed=options.tie_embed, text_processor=text_processor, resnet_depth=options.resnet_depth, lang_dec=options.lang_decoder, enc_layer=options.encoder_layer, dec_layer=options.decoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) if options.lm_path is not None: lm = LM(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) mt_model.init_from_lm(lm) print("Model initialization done!") # We assume that the collator function returns a list with the size of number of gpus (in case of cpus, collator = dataset.ImageTextCollator() num_batches = max(1, torch.cuda.device_count()) if options.continue_train: with open(os.path.join(options.pretrained_path, "optim"), "rb") as fp: optimizer = pickle.load(fp) else: optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup) trainer = ImageCaptionTrainer( model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, beam_width=options.beam_width, max_len_a=options.max_len_a, max_len_b=options.max_len_b, len_penalty_ratio=options.len_penalty_ratio, fp16=options.fp16, mm_mode=options.mm_mode) pin_memory = torch.cuda.is_available() img_train_loader = ImageMTTrainer.get_img_loader( collator, dataset.ImageCaptionDataset, options.train_path, mt_model, num_batches, options, pin_memory, lex_dict=lex_dict) img_dev_loader = ImageMTTrainer.get_img_loader( collator, dataset.ImageCaptionDataset, options.dev_path, mt_model, num_batches, options, pin_memory, lex_dict=lex_dict, shuffle=False, denom=2) trainer.reference = None if img_dev_loader is not None: trainer.reference = [] generator = (trainer.generator.module if hasattr( trainer.generator, "module") else trainer.generator) for data in img_dev_loader: for batch in data: captions = [b["captions"] for b in batch] for caption in captions: refs = get_outputs_until_eos( text_processor.sep_token_id(), caption, remove_first_token=True) ref = [ generator.seq2seq_model.text_processor.tokenizer. decode(ref.numpy()) for ref in refs ] trainer.reference += ref step, train_epoch = 0, 1 while options.step > 0 and step < options.step: print("train epoch", train_epoch) step = trainer.train_epoch(img_data_iter=img_train_loader, img_dev_data_iter=img_dev_loader, max_step=options.step, lex_dict=lex_dict, saving_path=options.model_path, step=step) train_epoch += 1