Beispiel #1
0
    def __init__(self,
                 text_processor: TextProcessor,
                 config: BertConfig = None,
                 encoder: BertModel = None,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(LM, self).__init__()
        self.text_processor: TextProcessor = text_processor

        if config is not None:
            self.config = config
        else:
            self.config = lm_config.get_config(
                vocab_size=text_processor.tokenizer.get_vocab_size(),
                pad_token_id=text_processor.pad_token_id(),
                bos_token_id=text_processor.bos_token_id(),
                eos_token_id=text_processor.sep_token_id(),
                enc_layer=enc_layer,
                embed_dim=embed_dim,
                intermediate_dim=intermediate_dim)

            self.config["type_vocab_size"] = len(text_processor.languages)
            self.config = BertConfig(**self.config)

        self.masked_lm = BertOutputLayer(self.config)
        if encoder is None:
            self.encoder: BertModel = BertModel(self.config)
            self.encoder.init_weights()
        else:
            self.encoder = encoder
        self.encoder._tie_or_clone_weights(
            self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
    def __init__(self,
                 text_processor: TextProcessor,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(SenSim, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(
            vocab_size=text_processor.tokenizer.get_vocab_size(),
            pad_token_id=text_processor.pad_token_id(),
            bos_token_id=text_processor.bos_token_id(),
            eos_token_id=text_processor.sep_token_id(),
            enc_layer=enc_layer,
            embed_dim=embed_dim,
            intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.input_attention = nn.Linear(self.config.hidden_size, 1)
Beispiel #3
0
    def __init__(self,
                 text_processor: TextProcessor,
                 lang_dec: bool = True,
                 use_proposals=False,
                 tie_embed=False,
                 enc_layer: int = 6,
                 dec_layer: int = 3,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072,
                 freeze_image: bool = False,
                 resnet_depth: int = 1):
        super(Seq2Seq, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(
            vocab_size=text_processor.tokenizer.get_vocab_size(),
            pad_token_id=text_processor.pad_token_id(),
            bos_token_id=text_processor.bos_token_id(),
            eos_token_id=text_processor.sep_token_id(),
            enc_layer=enc_layer,
            embed_dim=embed_dim,
            intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.dec_layer = dec_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        dec_config = copy.deepcopy(self.config)
        dec_config.num_hidden_layers = self.dec_layer

        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.lang_dec = lang_dec
        self.tie_embed = tie_embed
        if not lang_dec:
            self.decoder = BertDecoderModel(dec_config)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.position_embeddings,
                self.decoder.embeddings.position_embeddings)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.token_type_embeddings,
                self.decoder.embeddings.token_type_embeddings)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.word_embeddings,
                self.decoder.embeddings.word_embeddings)

            if tie_embed:
                self.output_layer = BertOutputLayer(dec_config)
                self.encoder._tie_or_clone_weights(
                    self.output_layer, self.encoder.embeddings.word_embeddings)
                self.encoder._tie_or_clone_weights(
                    self.encoder.embeddings.position_embeddings,
                    self.decoder.embeddings.position_embeddings)
                self.decoder._tie_or_clone_weights(
                    self.output_layer, self.decoder.embeddings.word_embeddings)
            else:
                self.output_layer = nn.ModuleList([
                    BertOutputLayer(dec_config)
                    for _ in text_processor.languages
                ])

            if len(self.encoder.encoder.layer) == len(
                    self.decoder.decoder.layer):
                for i in range(len(self.encoder.encoder.layer)):
                    self.decoder.decoder.layer[
                        i].attention = self.encoder.encoder.layer[i].attention

        else:
            dec = BertDecoderModel(dec_config)
            self.decoder = nn.ModuleList(
                [copy.deepcopy(dec) for _ in text_processor.languages])
            self.output_layer = nn.ModuleList([
                BertOutputLayer(dec_config) for _ in text_processor.languages
            ])
            for i, dec in enumerate(self.decoder):
                if tie_embed:
                    self.encoder._tie_or_clone_weights(
                        self.output_layer[i],
                        self.encoder.embeddings.word_embeddings)
                    dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings
                dec._tie_or_clone_weights(self.output_layer[i],
                                          dec.embeddings.word_embeddings)
                dec._tie_or_clone_weights(
                    self.encoder.embeddings.token_type_embeddings,
                    dec.embeddings.token_type_embeddings)

        self.use_proposals = use_proposals
        if self.use_proposals:
            self.proposal_embedding = self.encoder.embeddings.word_embeddings
            self.lexical_gate = nn.Parameter(torch.zeros(
                1, self.config.hidden_size).fill_(0.1),
                                             requires_grad=True)
            self.lexical_layer_norm = nn.LayerNorm(
                self.config.hidden_size, eps=self.config.layer_norm_eps)

        self.freeze_image = freeze_image
        self.resnet_depth = resnet_depth