def __init__(self, text_processor: TextProcessor, config: BertConfig = None, encoder: BertModel = None, enc_layer: int = 6, embed_dim: int = 768, intermediate_dim: int = 3072): super(LM, self).__init__() self.text_processor: TextProcessor = text_processor if config is not None: self.config = config else: self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) self.masked_lm = BertOutputLayer(self.config) if encoder is None: self.encoder: BertModel = BertModel(self.config) self.encoder.init_weights() else: self.encoder = encoder self.encoder._tie_or_clone_weights( self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
def mask_text(mask_prob, pads, texts, text_processor: TextProcessor, mask_eos: bool = True): assert 0 < mask_prob < 1 mask = torch.empty(texts.size()).uniform_(0, 1) < mask_prob mask[~pads] = False # We should not mask pads. if not mask_eos: eos_idx = texts == text_processor.sep_token_id() mask[ eos_idx] = False # We should not mask end-of-sentence (usually in case of BART training). masked_ids = texts[mask] random_index = lambda: random.randint(len(text_processor.special_tokens), text_processor.vocab_size() - 1) rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else ( random_index() if r < 0.9 else int(masked_ids[c])) replacements = list( map(lambda i: rand_select(random.random(), i), range(masked_ids.size(0)))) texts[mask] = torch.LongTensor(replacements) return mask, masked_ids, texts
def __init__(self, text_processor: TextProcessor, enc_layer: int = 6, embed_dim: int = 768, intermediate_dim: int = 3072): super(SenSim, self).__init__() self.text_processor: TextProcessor = text_processor self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.enc_layer = enc_layer self.embed_dim = embed_dim self.intermediate_dim = intermediate_dim self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) self.encoder = BertEncoderModel(self.config) self.encoder.init_weights() self.input_attention = nn.Linear(self.config.hidden_size, 1)
def __init__(self, text_processor: TextProcessor, config: ReformerConfig = None, size: int = 1): """ :param size: config size: 1 small, 2 medium, 3 base. """ super(ReformerLM, self).__init__() self.text_processor: TextProcessor = text_processor if config is not None: self.config = config else: config_func = _small_config if size == 1 else ( _base_config if size == 3 else _medium_config) self.config = config_func( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), eos_token_id=text_processor.sep_token_id()) self.config = ReformerConfig(**self.config) reformer = ReformerModelWithLMHead(self.config) self.lm_head: ReformerOnlyLMHead = reformer.lm_head self.encoder: ReformerModel = reformer.reformer
def __init__(self, text_processor: TextProcessor, lang_dec: bool = True, use_proposals=False, tie_embed=False, enc_layer: int = 6, dec_layer: int = 3, embed_dim: int = 768, intermediate_dim: int = 3072, freeze_image: bool = False, resnet_depth: int = 1): super(Seq2Seq, self).__init__() self.text_processor: TextProcessor = text_processor self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.enc_layer = enc_layer self.dec_layer = dec_layer self.embed_dim = embed_dim self.intermediate_dim = intermediate_dim self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) dec_config = copy.deepcopy(self.config) dec_config.num_hidden_layers = self.dec_layer self.encoder = BertEncoderModel(self.config) self.encoder.init_weights() self.lang_dec = lang_dec self.tie_embed = tie_embed if not lang_dec: self.decoder = BertDecoderModel(dec_config) self.encoder._tie_or_clone_weights( self.encoder.embeddings.position_embeddings, self.decoder.embeddings.position_embeddings) self.encoder._tie_or_clone_weights( self.encoder.embeddings.token_type_embeddings, self.decoder.embeddings.token_type_embeddings) self.encoder._tie_or_clone_weights( self.encoder.embeddings.word_embeddings, self.decoder.embeddings.word_embeddings) if tie_embed: self.output_layer = BertOutputLayer(dec_config) self.encoder._tie_or_clone_weights( self.output_layer, self.encoder.embeddings.word_embeddings) self.encoder._tie_or_clone_weights( self.encoder.embeddings.position_embeddings, self.decoder.embeddings.position_embeddings) self.decoder._tie_or_clone_weights( self.output_layer, self.decoder.embeddings.word_embeddings) else: self.output_layer = nn.ModuleList([ BertOutputLayer(dec_config) for _ in text_processor.languages ]) if len(self.encoder.encoder.layer) == len( self.decoder.decoder.layer): for i in range(len(self.encoder.encoder.layer)): self.decoder.decoder.layer[ i].attention = self.encoder.encoder.layer[i].attention else: dec = BertDecoderModel(dec_config) self.decoder = nn.ModuleList( [copy.deepcopy(dec) for _ in text_processor.languages]) self.output_layer = nn.ModuleList([ BertOutputLayer(dec_config) for _ in text_processor.languages ]) for i, dec in enumerate(self.decoder): if tie_embed: self.encoder._tie_or_clone_weights( self.output_layer[i], self.encoder.embeddings.word_embeddings) dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings dec._tie_or_clone_weights(self.output_layer[i], dec.embeddings.word_embeddings) dec._tie_or_clone_weights( self.encoder.embeddings.token_type_embeddings, dec.embeddings.token_type_embeddings) self.use_proposals = use_proposals if self.use_proposals: self.proposal_embedding = self.encoder.embeddings.word_embeddings self.lexical_gate = nn.Parameter(torch.zeros( 1, self.config.hidden_size).fill_(0.1), requires_grad=True) self.lexical_layer_norm = nn.LayerNorm( self.config.hidden_size, eps=self.config.layer_norm_eps) self.freeze_image = freeze_image self.resnet_depth = resnet_depth
def train(options): lex_dict = None if options.dict_path is not None: lex_dict = get_lex_dict(options.dict_path) if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) assert text_processor.pad_token_id() == 0 if options.pretrained_path is not None: mt_model = Seq2Seq.load(ImageCaptioning, options.pretrained_path, tok_dir=options.tokenizer_path) else: mt_model = ImageCaptioning( use_proposals=lex_dict is not None, tie_embed=options.tie_embed, text_processor=text_processor, resnet_depth=options.resnet_depth, lang_dec=options.lang_decoder, enc_layer=options.encoder_layer, dec_layer=options.decoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) if options.lm_path is not None: lm = LM(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) mt_model.init_from_lm(lm) print("Model initialization done!") # We assume that the collator function returns a list with the size of number of gpus (in case of cpus, collator = dataset.ImageTextCollator() num_batches = max(1, torch.cuda.device_count()) if options.continue_train: with open(os.path.join(options.pretrained_path, "optim"), "rb") as fp: optimizer = pickle.load(fp) else: optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup) trainer = ImageCaptionTrainer( model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, beam_width=options.beam_width, max_len_a=options.max_len_a, max_len_b=options.max_len_b, len_penalty_ratio=options.len_penalty_ratio, fp16=options.fp16, mm_mode=options.mm_mode) pin_memory = torch.cuda.is_available() img_train_loader = ImageMTTrainer.get_img_loader( collator, dataset.ImageCaptionDataset, options.train_path, mt_model, num_batches, options, pin_memory, lex_dict=lex_dict) img_dev_loader = ImageMTTrainer.get_img_loader( collator, dataset.ImageCaptionDataset, options.dev_path, mt_model, num_batches, options, pin_memory, lex_dict=lex_dict, shuffle=False, denom=2) trainer.reference = None if img_dev_loader is not None: trainer.reference = [] generator = (trainer.generator.module if hasattr( trainer.generator, "module") else trainer.generator) for data in img_dev_loader: for batch in data: captions = [b["captions"] for b in batch] for caption in captions: refs = get_outputs_until_eos( text_processor.sep_token_id(), caption, remove_first_token=True) ref = [ generator.seq2seq_model.text_processor.tokenizer. decode(ref.numpy()) for ref in refs ] trainer.reference += ref step, train_epoch = 0, 1 while options.step > 0 and step < options.step: print("train epoch", train_epoch) step = trainer.train_epoch(img_data_iter=img_train_loader, img_dev_data_iter=img_dev_loader, max_step=options.step, lex_dict=lex_dict, saving_path=options.model_path, step=step) train_epoch += 1