def __init__(self, text_processor: TextProcessor, config: BertConfig = None, encoder: BertModel = None, enc_layer: int = 6, embed_dim: int = 768, intermediate_dim: int = 3072): super(LM, self).__init__() self.text_processor: TextProcessor = text_processor if config is not None: self.config = config else: self.config = lm_config.get_config( vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) self.masked_lm = BertOutputLayer(self.config) if encoder is None: self.encoder: BertModel = BertModel(self.config) self.encoder.init_weights() else: self.encoder = encoder self.encoder._tie_or_clone_weights( self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
def __init__(self, text_processor: TextProcessor, enc_layer: int = 6, embed_dim: int = 768, intermediate_dim: int = 3072): super(Caption2Image, self).__init__() self.text_processor: TextProcessor = text_processor self.config = lm_config.get_config(vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.enc_layer = enc_layer self.embed_dim = embed_dim self.intermediate_dim = intermediate_dim self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) self.encoder = BertEncoderModel(self.config) self.encoder.init_weights() self.input_attention = nn.Linear(self.config.hidden_size, 1) self.decoder = nn.Linear(self.config.hidden_size, 49 * self.config.hidden_size)
def __init__(self, text_processor: TextProcessor, lang_dec: bool = True, use_proposals=False, tie_embed=False, enc_layer: int = 6, dec_layer: int = 3, embed_dim: int = 768, intermediate_dim: int = 3072, freeze_image: bool = False, resnet_depth: int = 1, use_obj: bool=False): super(Seq2Seq, self).__init__() self.text_processor: TextProcessor = text_processor self.config = lm_config.get_config(vocab_size=text_processor.tokenizer.get_vocab_size(), pad_token_id=text_processor.pad_token_id(), bos_token_id=text_processor.bos_token_id(), eos_token_id=text_processor.sep_token_id(), enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim) self.enc_layer = enc_layer self.dec_layer = dec_layer self.embed_dim = embed_dim self.intermediate_dim = intermediate_dim self.config["type_vocab_size"] = len(text_processor.languages) self.config = BertConfig(**self.config) dec_config = copy.deepcopy(self.config) dec_config.num_hidden_layers = self.dec_layer self.encoder = BertEncoderModel(self.config) self.encoder.init_weights() self.lang_dec = lang_dec self.tie_embed = tie_embed if not lang_dec: self.decoder = BertDecoderModel(dec_config) self.encoder._tie_or_clone_weights(self.encoder.embeddings.position_embeddings, self.decoder.embeddings.position_embeddings) self.encoder._tie_or_clone_weights(self.encoder.embeddings.token_type_embeddings, self.decoder.embeddings.token_type_embeddings) self.encoder._tie_or_clone_weights(self.encoder.embeddings.word_embeddings, self.decoder.embeddings.word_embeddings) if tie_embed: self.output_layer = BertOutputLayer(dec_config) self.encoder._tie_or_clone_weights(self.output_layer, self.encoder.embeddings.word_embeddings) self.encoder._tie_or_clone_weights(self.encoder.embeddings.position_embeddings, self.decoder.embeddings.position_embeddings) self.decoder._tie_or_clone_weights(self.output_layer, self.decoder.embeddings.word_embeddings) else: self.output_layer = nn.ModuleList([BertOutputLayer(dec_config) for _ in text_processor.languages]) if len(self.encoder.encoder.layer) == len(self.decoder.decoder.layer): for i in range(len(self.encoder.encoder.layer)): self.decoder.decoder.layer[i].attention = self.encoder.encoder.layer[i].attention else: dec = BertDecoderModel(dec_config) self.decoder = nn.ModuleList([copy.deepcopy(dec) for _ in text_processor.languages]) self.output_layer = nn.ModuleList([BertOutputLayer(dec_config) for _ in text_processor.languages]) for i, dec in enumerate(self.decoder): if tie_embed: self.encoder._tie_or_clone_weights(self.output_layer[i], self.encoder.embeddings.word_embeddings) dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings dec._tie_or_clone_weights(self.output_layer[i], dec.embeddings.word_embeddings) dec._tie_or_clone_weights(self.encoder.embeddings.token_type_embeddings, dec.embeddings.token_type_embeddings) self.use_proposals = use_proposals if self.use_proposals: self.proposal_embedding = self.encoder.embeddings.word_embeddings self.lexical_gate = nn.Parameter(torch.zeros(1, self.config.hidden_size).fill_(0.1), requires_grad=True) self.lexical_layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) self.freeze_image = freeze_image self.resnet_depth = resnet_depth