示例#1
0
    def __init__(self,
                 text_processor: TextProcessor,
                 config: BertConfig = None,
                 encoder: BertModel = None,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(LM, self).__init__()
        self.text_processor: TextProcessor = text_processor

        if config is not None:
            self.config = config
        else:
            self.config = lm_config.get_config(
                vocab_size=text_processor.tokenizer.get_vocab_size(),
                pad_token_id=text_processor.pad_token_id(),
                bos_token_id=text_processor.bos_token_id(),
                eos_token_id=text_processor.sep_token_id(),
                enc_layer=enc_layer,
                embed_dim=embed_dim,
                intermediate_dim=intermediate_dim)

            self.config["type_vocab_size"] = len(text_processor.languages)
            self.config = BertConfig(**self.config)

        self.masked_lm = BertOutputLayer(self.config)
        if encoder is None:
            self.encoder: BertModel = BertModel(self.config)
            self.encoder.init_weights()
        else:
            self.encoder = encoder
        self.encoder._tie_or_clone_weights(
            self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
示例#2
0
    def __init__(self, text_processor: TextProcessor, enc_layer: int = 6, embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(Caption2Image, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(vocab_size=text_processor.tokenizer.get_vocab_size(),
                                           pad_token_id=text_processor.pad_token_id(),
                                           bos_token_id=text_processor.bos_token_id(),
                                           eos_token_id=text_processor.sep_token_id(),
                                           enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)

        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()

        self.input_attention = nn.Linear(self.config.hidden_size, 1)
        self.decoder = nn.Linear(self.config.hidden_size, 49 * self.config.hidden_size)
示例#3
0
    def __init__(self, text_processor: TextProcessor, lang_dec: bool = True, use_proposals=False, tie_embed=False,
                 enc_layer: int = 6, dec_layer: int = 3, embed_dim: int = 768, intermediate_dim: int = 3072,
                 freeze_image: bool = False, resnet_depth: int = 1, use_obj: bool=False):
        super(Seq2Seq, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(vocab_size=text_processor.tokenizer.get_vocab_size(),
                                           pad_token_id=text_processor.pad_token_id(),
                                           bos_token_id=text_processor.bos_token_id(),
                                           eos_token_id=text_processor.sep_token_id(),
                                           enc_layer=enc_layer, embed_dim=embed_dim, intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.dec_layer = dec_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        dec_config = copy.deepcopy(self.config)
        dec_config.num_hidden_layers = self.dec_layer

        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.lang_dec = lang_dec
        self.tie_embed = tie_embed
        if not lang_dec:
            self.decoder = BertDecoderModel(dec_config)
            self.encoder._tie_or_clone_weights(self.encoder.embeddings.position_embeddings,
                                               self.decoder.embeddings.position_embeddings)
            self.encoder._tie_or_clone_weights(self.encoder.embeddings.token_type_embeddings,
                                               self.decoder.embeddings.token_type_embeddings)
            self.encoder._tie_or_clone_weights(self.encoder.embeddings.word_embeddings,
                                               self.decoder.embeddings.word_embeddings)

            if tie_embed:
                self.output_layer = BertOutputLayer(dec_config)
                self.encoder._tie_or_clone_weights(self.output_layer, self.encoder.embeddings.word_embeddings)
                self.encoder._tie_or_clone_weights(self.encoder.embeddings.position_embeddings,
                                                   self.decoder.embeddings.position_embeddings)
                self.decoder._tie_or_clone_weights(self.output_layer, self.decoder.embeddings.word_embeddings)
            else:
                self.output_layer = nn.ModuleList([BertOutputLayer(dec_config) for _ in text_processor.languages])

            if len(self.encoder.encoder.layer) == len(self.decoder.decoder.layer):
                for i in range(len(self.encoder.encoder.layer)):
                    self.decoder.decoder.layer[i].attention = self.encoder.encoder.layer[i].attention

        else:
            dec = BertDecoderModel(dec_config)
            self.decoder = nn.ModuleList([copy.deepcopy(dec) for _ in text_processor.languages])
            self.output_layer = nn.ModuleList([BertOutputLayer(dec_config) for _ in text_processor.languages])
            for i, dec in enumerate(self.decoder):
                if tie_embed:
                    self.encoder._tie_or_clone_weights(self.output_layer[i], self.encoder.embeddings.word_embeddings)
                    dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings
                dec._tie_or_clone_weights(self.output_layer[i], dec.embeddings.word_embeddings)
                dec._tie_or_clone_weights(self.encoder.embeddings.token_type_embeddings,
                                          dec.embeddings.token_type_embeddings)

        self.use_proposals = use_proposals
        if self.use_proposals:
            self.proposal_embedding = self.encoder.embeddings.word_embeddings
            self.lexical_gate = nn.Parameter(torch.zeros(1, self.config.hidden_size).fill_(0.1), requires_grad=True)
            self.lexical_layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)

        self.freeze_image = freeze_image
        self.resnet_depth = resnet_depth