Exemple #1
0
    def __init__(self, config, language_pretrained_model_path=None):
        super(VisualLinguisticBert, self).__init__(config)

        self.config = config

        # embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.end_embedding = nn.Embedding(1, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # for compatibility of roberta
        self.position_padding_idx = config.position_padding_idx

        # visual transform
        self.visual_1x1_text = None
        self.visual_1x1_object = None
        if config.visual_size != config.hidden_size:
            self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size)
            self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size)
        if config.visual_ln:
            self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12)
            self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12)
        else:
            visual_scale_text = nn.Parameter(torch.as_tensor(self.config.visual_scale_text_init, dtype=torch.float),
                                             requires_grad=True)
            self.register_parameter('visual_scale_text', visual_scale_text)
            visual_scale_object = nn.Parameter(torch.as_tensor(self.config.visual_scale_object_init, dtype=torch.float),
                                               requires_grad=True)
            self.register_parameter('visual_scale_object', visual_scale_object)

        self.encoder = BertEncoder(config)

        if self.config.with_pooler:
            self.pooler = BertPooler(config)

        # init weights
        self.apply(self.init_weights)
        if config.visual_ln:
            self.visual_ln_text.weight.data.fill_(self.config.visual_scale_text_init)
            self.visual_ln_object.weight.data.fill_(self.config.visual_scale_object_init)

        # load language pretrained model
        if language_pretrained_model_path is not None:
            self.load_language_pretrained_model(language_pretrained_model_path)

        if config.word_embedding_frozen:
            for p in self.word_embeddings.parameters():
                p.requires_grad = False
            self.special_word_embeddings = nn.Embedding(NUM_SPECIAL_WORDS, config.hidden_size)
            self.special_word_embeddings.weight.data.copy_(self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])
Exemple #2
0
    def __init__(self, config):
        super(RobertaLMHead, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = BertLayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size,
                                 config.vocab_size,
                                 bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
    def __init__(self,
                 bert_config,
                 input_size,
                 output_all_encoded_layers=False):
        super(BertEncoderWrapper, self).__init__()
        self.bert_config = bert_config
        self.output_all_encoded_layers = output_all_encoded_layers
        self.input_transform = nn.Linear(input_size, bert_config.hidden_size)
        self.with_position_embeddings = False if 'with_position_embeddings' not in bert_config \
            else bert_config.with_position_embeddings
        if self.with_position_embeddings:
            self.position_embedding = nn.Embedding(
                bert_config.max_position_embeddings, bert_config.hidden_size)
            self.LayerNorm = BertLayerNorm(bert_config.hidden_size, eps=1e-12)
            self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
        self.bert_encoder = BertEncoder(bert_config)

        self.apply(self.init_bert_weights)
Exemple #4
0
    def __init__(self, dummy_config):
        super(LXMERT, self).__init__(dummy_config)
        
        frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
        # self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg)
        self.backbone, self.roi_heads = build_image_encoder()
        self.lxmert_vqa = LxmertForPreTraining.from_pretrained("unc-nlp/lxmert-base-uncased")
        # self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased")
        self.tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
        self.image_preprocess = Preprocess(frcnn_cfg)
        
        hid_dim = self.lxmert_vqa.config.hidden_size
        # transform = BertPredictionHeadTransform(self.config.NETWORK.VLBERT)

        self.logit_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim),
            GELU(),
            BertLayerNorm(hid_dim),
            nn.Dropout(self.config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            nn.Linear(hid_dim, self.config.NETWORK.CLASSIFIER_CLASS),
        )
Exemple #5
0
class VisualLinguisticBert(BaseModel):
    def __init__(self,
                 config,
                 language_pretrained_model_path=None,
                 finetune_strategy='standard',
                 is_policy_net=False):
        super(VisualLinguisticBert, self).__init__(config)

        self.config = config

        # embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size)
        self.end_embedding = nn.Embedding(1, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # for compatibility of roberta
        self.position_padding_idx = config.position_padding_idx

        # visual transform
        self.visual_1x1_text = None
        self.visual_1x1_object = None
        if config.visual_size != config.hidden_size:
            self.visual_1x1_text = nn.Linear(config.visual_size,
                                             config.hidden_size)
            self.visual_1x1_object = nn.Linear(config.visual_size,
                                               config.hidden_size)
        if config.visual_ln:
            self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12)
            self.visual_ln_object = BertLayerNorm(config.hidden_size,
                                                  eps=1e-12)
        else:
            visual_scale_text = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_text_init, dtype=torch.float),
                                             requires_grad=True)
            self.register_parameter('visual_scale_text', visual_scale_text)
            visual_scale_object = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_object_init, dtype=torch.float),
                                               requires_grad=True)
            self.register_parameter('visual_scale_object', visual_scale_object)

        self.encoder = BertEncoder(config, finetune_strategy=finetune_strategy)

        if self.config.with_pooler:
            self.pooler = BertPooler(config)

        # init weights
        self.apply(self.init_weights)
        if config.visual_ln:
            self.visual_ln_text.weight.data.fill_(
                self.config.visual_scale_text_init)
            self.visual_ln_object.weight.data.fill_(
                self.config.visual_scale_object_init)

        # self.is_policy_net
        self.is_policy_net = is_policy_net

        # load language pretrained model
        if language_pretrained_model_path is not None and not is_policy_net:
            self.load_language_pretrained_model(language_pretrained_model_path)

        if config.word_embedding_frozen:
            for p in self.word_embeddings.parameters():
                p.requires_grad = False
            self.special_word_embeddings = nn.Embedding(
                NUM_SPECIAL_WORDS, config.hidden_size)
            self.special_word_embeddings.weight.data.copy_(
                self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])

    def word_embeddings_wrapper(self, input_ids):
        if self.config.word_embedding_frozen:
            word_embeddings = self.word_embeddings(input_ids)
            word_embeddings[input_ids < NUM_SPECIAL_WORDS] \
                = self.special_word_embeddings(input_ids[input_ids < NUM_SPECIAL_WORDS])
            return word_embeddings
        else:
            return self.word_embeddings(input_ids)

    def forward(self,
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                object_mask,
                output_all_encoded_layers=True,
                output_text_and_object_separately=False,
                output_attention_probs=False,
                policy=None):

        # get seamless concatenate embeddings and mask
        embedding_output, attention_mask, text_mask_new, object_mask_new = self.embedding(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, object_mask)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        # extended_attention_mask = 1.0 - extended_attention_mask
        # extended_attention_mask[extended_attention_mask != 0] = float('-inf')

        if output_attention_probs:
            encoded_layers, attention_probs = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers=output_all_encoded_layers,
                output_attention_probs=output_attention_probs,
                policy=policy)
        else:
            encoded_layers = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers=output_all_encoded_layers,
                output_attention_probs=output_attention_probs,
                policy=policy)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(
            sequence_output) if self.config.with_pooler else None
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]

        if output_text_and_object_separately:
            if not output_all_encoded_layers:
                encoded_layers = [encoded_layers]
            encoded_layers_text = []
            encoded_layers_object = []
            for encoded_layer in encoded_layers:
                max_text_len = text_input_ids.shape[1]
                max_object_len = object_vl_embeddings.shape[1]
                encoded_layer_text = encoded_layer[:, :max_text_len]
                encoded_layer_object = encoded_layer.new_zeros(
                    (encoded_layer.shape[0], max_object_len,
                     encoded_layer.shape[2]))
                encoded_layer_object[object_mask] = encoded_layer[
                    object_mask_new]
                encoded_layers_text.append(encoded_layer_text)
                encoded_layers_object.append(encoded_layer_object)
            if not output_all_encoded_layers:
                encoded_layers_text = encoded_layers_text[0]
                encoded_layers_object = encoded_layers_object[0]
            if output_attention_probs:
                return encoded_layers_text, encoded_layers_object, pooled_output, attention_probs
            else:
                return encoded_layers_text, encoded_layers_object, pooled_output
        else:
            if output_attention_probs:
                return encoded_layers, pooled_output, attention_probs
            else:
                return encoded_layers, pooled_output

    def embedding(self, text_input_ids, text_token_type_ids,
                  text_visual_embeddings, text_mask, object_vl_embeddings,
                  object_mask):

        text_linguistic_embedding = self.word_embeddings_wrapper(
            text_input_ids)
        if self.visual_1x1_text is not None:
            text_visual_embeddings = self.visual_1x1_text(
                text_visual_embeddings)
        if self.config.visual_ln:
            text_visual_embeddings = self.visual_ln_text(
                text_visual_embeddings)
        else:
            text_visual_embeddings *= self.visual_scale_text
        text_vl_embeddings = text_linguistic_embedding + text_visual_embeddings

        object_visual_embeddings = object_vl_embeddings[:, :, :self.config.
                                                        visual_size]
        if self.visual_1x1_object is not None:
            object_visual_embeddings = self.visual_1x1_object(
                object_visual_embeddings)
        if self.config.visual_ln:
            object_visual_embeddings = self.visual_ln_object(
                object_visual_embeddings)
        else:
            object_visual_embeddings *= self.visual_scale_object
        object_linguistic_embeddings = object_vl_embeddings[:, :, self.config.
                                                            visual_size:]
        object_vl_embeddings = object_linguistic_embeddings + object_visual_embeddings

        bs = text_vl_embeddings.size(0)
        vl_embed_size = text_vl_embeddings.size(-1)
        max_length = (text_mask.sum(1) + object_mask.sum(1)).max() + 1
        grid_ind, grid_pos = torch.meshgrid(
            torch.arange(bs,
                         dtype=torch.long,
                         device=text_vl_embeddings.device),
            torch.arange(max_length,
                         dtype=torch.long,
                         device=text_vl_embeddings.device))
        text_end = text_mask.sum(1, keepdim=True)
        object_end = text_end + object_mask.sum(1, keepdim=True)

        # seamlessly concatenate visual linguistic embeddings of text and object
        _zero_id = torch.zeros((bs, ),
                               dtype=torch.long,
                               device=text_vl_embeddings.device)
        vl_embeddings = text_vl_embeddings.new_zeros(
            (bs, max_length, vl_embed_size))
        vl_embeddings[grid_pos < text_end] = text_vl_embeddings[text_mask]
        vl_embeddings[(grid_pos >= text_end) & (
            grid_pos < object_end)] = object_vl_embeddings[object_mask]
        vl_embeddings[grid_pos == object_end] = self.end_embedding(_zero_id)

        # token type embeddings/ segment embeddings
        token_type_ids = text_token_type_ids.new_zeros((bs, max_length))
        token_type_ids[grid_pos < text_end] = text_token_type_ids[text_mask]
        token_type_ids[(grid_pos >= text_end) & (grid_pos <= object_end)] = 2
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # position embeddings
        position_ids = grid_pos + self.position_padding_idx + 1
        if self.config.obj_pos_id_relative:
            position_ids[(grid_pos >= text_end) & (grid_pos < object_end)] \
                = text_end.expand((bs, max_length))[(grid_pos >= text_end) & (grid_pos < object_end)] \
                + self.position_padding_idx + 1
            position_ids[grid_pos == object_end] = (
                text_end + 1).squeeze(1) + self.position_padding_idx + 1
        else:
            assert False, "Don't use position id 510/511 for objects and [END]!!!"
            position_ids[(grid_pos >= text_end)
                         & (grid_pos < object_end
                            )] = self.config.max_position_embeddings - 2
            position_ids[grid_pos ==
                         object_end] = self.config.max_position_embeddings - 1

        position_embeddings = self.position_embeddings(position_ids)
        mask = text_mask.new_zeros((bs, max_length))
        mask[grid_pos <= object_end] = 1

        embeddings = vl_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.embedding_LayerNorm(embeddings)
        embeddings = self.embedding_dropout(embeddings)

        return embeddings, mask, grid_pos < text_end, (
            grid_pos >= text_end) & (grid_pos < object_end)

    def load_language_pretrained_model(self, language_pretrained_model_path):
        pretrained_state_dict = torch.load(
            language_pretrained_model_path,
            map_location=lambda storage, loc: storage)
        encoder_pretrained_state_dict = {}
        pooler_pretrained_state_dict = {}
        embedding_ln_pretrained_state_dict = {}
        unexpected_keys = []
        for k, v in pretrained_state_dict.items():
            if k.startswith('bert.'):
                k = k[len('bert.'):]
            elif k.startswith('roberta.'):
                k = k[len('roberta.'):]
            else:
                unexpected_keys.append(k)
                continue
            if 'gamma' in k:
                k = k.replace('gamma', 'weight')
            if 'beta' in k:
                k = k.replace('beta', 'bias')
            if k.startswith('encoder.'):
                k_ = k[len('encoder.'):]
                if k_ in self.encoder.state_dict():
                    encoder_pretrained_state_dict[k_] = v
                else:
                    unexpected_keys.append(k)
            elif k.startswith('embeddings.'):
                k_ = k[len('embeddings.'):]
                if k_ == 'word_embeddings.weight':
                    self.word_embeddings.weight.data = v.to(
                        dtype=self.word_embeddings.weight.data.dtype,
                        device=self.word_embeddings.weight.data.device)
                elif k_ == 'position_embeddings.weight':
                    self.position_embeddings.weight.data = v.to(
                        dtype=self.position_embeddings.weight.data.dtype,
                        device=self.position_embeddings.weight.data.device)
                elif k_ == 'token_type_embeddings.weight':
                    self.token_type_embeddings.weight.data[:v.size(0)] = v.to(
                        dtype=self.token_type_embeddings.weight.data.dtype,
                        device=self.token_type_embeddings.weight.data.device)
                    if v.size(0) == 1:
                        # Todo: roberta token type embedding
                        self.token_type_embeddings.weight.data[1] = v[0].clone(
                        ).to(
                            dtype=self.token_type_embeddings.weight.data.dtype,
                            device=self.token_type_embeddings.weight.data.
                            device)
                        self.token_type_embeddings.weight.data[2] = v[0].clone(
                        ).to(
                            dtype=self.token_type_embeddings.weight.data.dtype,
                            device=self.token_type_embeddings.weight.data.
                            device)

                elif k_.startswith('LayerNorm.'):
                    k__ = k_[len('LayerNorm.'):]
                    if k__ in self.embedding_LayerNorm.state_dict():
                        embedding_ln_pretrained_state_dict[k__] = v
                    else:
                        unexpected_keys.append(k)
                else:
                    unexpected_keys.append(k)
            elif self.config.with_pooler and k.startswith('pooler.'):
                k_ = k[len('pooler.'):]
                if k_ in self.pooler.state_dict():
                    pooler_pretrained_state_dict[k_] = v
                else:
                    unexpected_keys.append(k)
            else:
                unexpected_keys.append(k)
        if len(unexpected_keys) > 0:
            print("Warnings: Unexpected keys: {}.".format(unexpected_keys))
        self.embedding_LayerNorm.load_state_dict(
            embedding_ln_pretrained_state_dict)
        # preprocess encoder state dict for parallel blocks
        if not self.is_policy_net:
            for k in self.encoder.state_dict():
                if 'parallel_' in str(k):
                    encoder_pretrained_state_dict[
                        k] = encoder_pretrained_state_dict[k.replace(
                            'parallel_', '')]
        self.encoder.load_state_dict(encoder_pretrained_state_dict)
        if self.config.with_pooler and len(pooler_pretrained_state_dict) > 0:
            self.pooler.load_state_dict(pooler_pretrained_state_dict)
Exemple #6
0
    def __init__(self, config, language_pretrained_model_path=None):
        super(VisualLinguisticBert, self).__init__(config)

        self.config = config

        # embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size)
        self.end_embedding = nn.Embedding(1, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # for compatibility of roberta
        self.position_padding_idx = config.position_padding_idx

        # visual transform
        self.visual_1x1_text = None
        self.visual_1x1_object = None
        if config.visual_size != config.hidden_size:  # Always False
            self.visual_1x1_text = nn.Linear(config.visual_size,
                                             config.hidden_size)
            self.visual_1x1_object = nn.Linear(config.visual_size,
                                               config.hidden_size)
        if config.visual_ln:  # Always True
            self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12)
            self.visual_ln_object = BertLayerNorm(config.hidden_size,
                                                  eps=1e-12)
        else:
            visual_scale_text = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_text_init, dtype=torch.float),
                                             requires_grad=True)
            self.register_parameter('visual_scale_text', visual_scale_text)
            visual_scale_object = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_object_init, dtype=torch.float),
                                               requires_grad=True)
            self.register_parameter('visual_scale_object', visual_scale_object)

        self.encoder = BertEncoder(config)

        if self.config.with_pooler:
            self.pooler = BertPooler(config)

        # init weights
        self.apply(self.init_weights)
        if config.visual_ln:
            self.visual_ln_text.weight.data.fill_(
                self.config.visual_scale_text_init)
            self.visual_ln_object.weight.data.fill_(
                self.config.visual_scale_object_init)

        # load language pretrained model
        if language_pretrained_model_path is not None:
            self.load_language_pretrained_model(language_pretrained_model_path)

        if config.word_embedding_frozen:  # False by default
            for p in self.word_embeddings.parameters():
                p.requires_grad = False
            self.special_word_embeddings = nn.Embedding(
                NUM_SPECIAL_WORDS, config.hidden_size)
            self.special_word_embeddings.weight.data.copy_(
                self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])

        self.enhanced_img_feature = False
        self.no_predicate = False
        if config.ENHANCED_IMG_FEATURE:
            self.enhanced_img_feature = True
            if config.NO_PREDICATE:  # VRD
                self.no_predicate = True
                self.lan_img_conv3 = nn.Conv2d(768, 1, kernel_size=(1, 1))
            else:  # SpatialSense
                self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1))
                self.lan_img_conv4 = nn.Conv2d(768, 1, kernel_size=(1, 1))

            self.obj_feat_downsample = nn.Conv2d(2048, 768, kernel_size=(1, 1))
            self.obj_feat_batchnorm = nn.BatchNorm2d(768)
            self.lan_img_conv1 = nn.Conv2d(768, 768, kernel_size=(1, 1))
            self.lan_img_conv2 = nn.Conv2d(768, 768, kernel_size=(1, 1))
            # self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1))
            # self.lan_img_bn1 = nn.BatchNorm2d(768)
            self.lan_img_avgpool = nn.AvgPool2d(14, stride=1)
Exemple #7
0
class VisualLinguisticBert(BaseModel):
    def __init__(self, config, language_pretrained_model_path=None):
        super(VisualLinguisticBert, self).__init__(config)

        self.config = config

        # embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size)
        self.end_embedding = nn.Embedding(1, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # for compatibility of roberta
        self.position_padding_idx = config.position_padding_idx

        # visual transform
        self.visual_1x1_text = None
        self.visual_1x1_object = None
        if config.visual_size != config.hidden_size:  # Always False
            self.visual_1x1_text = nn.Linear(config.visual_size,
                                             config.hidden_size)
            self.visual_1x1_object = nn.Linear(config.visual_size,
                                               config.hidden_size)
        if config.visual_ln:  # Always True
            self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12)
            self.visual_ln_object = BertLayerNorm(config.hidden_size,
                                                  eps=1e-12)
        else:
            visual_scale_text = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_text_init, dtype=torch.float),
                                             requires_grad=True)
            self.register_parameter('visual_scale_text', visual_scale_text)
            visual_scale_object = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_object_init, dtype=torch.float),
                                               requires_grad=True)
            self.register_parameter('visual_scale_object', visual_scale_object)

        self.encoder = BertEncoder(config)

        if self.config.with_pooler:
            self.pooler = BertPooler(config)

        # init weights
        self.apply(self.init_weights)
        if config.visual_ln:
            self.visual_ln_text.weight.data.fill_(
                self.config.visual_scale_text_init)
            self.visual_ln_object.weight.data.fill_(
                self.config.visual_scale_object_init)

        # load language pretrained model
        if language_pretrained_model_path is not None:
            self.load_language_pretrained_model(language_pretrained_model_path)

        if config.word_embedding_frozen:  # False by default
            for p in self.word_embeddings.parameters():
                p.requires_grad = False
            self.special_word_embeddings = nn.Embedding(
                NUM_SPECIAL_WORDS, config.hidden_size)
            self.special_word_embeddings.weight.data.copy_(
                self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])

        self.enhanced_img_feature = False
        self.no_predicate = False
        if config.ENHANCED_IMG_FEATURE:
            self.enhanced_img_feature = True
            if config.NO_PREDICATE:  # VRD
                self.no_predicate = True
                self.lan_img_conv3 = nn.Conv2d(768, 1, kernel_size=(1, 1))
            else:  # SpatialSense
                self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1))
                self.lan_img_conv4 = nn.Conv2d(768, 1, kernel_size=(1, 1))

            self.obj_feat_downsample = nn.Conv2d(2048, 768, kernel_size=(1, 1))
            self.obj_feat_batchnorm = nn.BatchNorm2d(768)
            self.lan_img_conv1 = nn.Conv2d(768, 768, kernel_size=(1, 1))
            self.lan_img_conv2 = nn.Conv2d(768, 768, kernel_size=(1, 1))
            # self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1))
            # self.lan_img_bn1 = nn.BatchNorm2d(768)
            self.lan_img_avgpool = nn.AvgPool2d(14, stride=1)
            # TODO: make these layers trainable during finetuning!

    def word_embeddings_wrapper(self, input_ids):
        if self.config.word_embedding_frozen:  # False by default
            word_embeddings = self.word_embeddings(input_ids)
            word_embeddings[input_ids < NUM_SPECIAL_WORDS] \
                = self.special_word_embeddings(input_ids[input_ids < NUM_SPECIAL_WORDS])
            return word_embeddings
        else:
            return self.word_embeddings(input_ids)

    def forward(self,
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                object_mask,
                object_visual_feat=None,
                spo_len=None,
                output_all_encoded_layers=True,
                output_text_and_object_separately=False,
                output_attention_probs=False):

        # get seamless concatenate embeddings and mask
        embedding_output, attention_mask, text_mask_new, object_mask_new, spo_fused_masks = self.embedding(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, object_mask, object_visual_feat,
            spo_len)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        # extended_attention_mask = 1.0 - extended_attention_mask
        # extended_attention_mask[extended_attention_mask != 0] = float('-inf')

        if output_attention_probs:
            encoded_layers, attention_probs = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers=output_all_encoded_layers,
                output_attention_probs=output_attention_probs)
        else:
            encoded_layers = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers=output_all_encoded_layers,
                output_attention_probs=output_attention_probs)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(
            sequence_output) if self.config.with_pooler else None
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]

        if output_text_and_object_separately:  # False
            if not output_all_encoded_layers:
                encoded_layers = [encoded_layers]
            encoded_layers_text = []
            encoded_layers_object = []
            for encoded_layer in encoded_layers:
                max_text_len = text_input_ids.shape[1]
                max_object_len = object_vl_embeddings.shape[1]
                encoded_layer_text = encoded_layer[:, :max_text_len]
                encoded_layer_object = encoded_layer.new_zeros(
                    (encoded_layer.shape[0], max_object_len,
                     encoded_layer.shape[2]))
                encoded_layer_object[object_mask] = encoded_layer[
                    object_mask_new]
                encoded_layers_text.append(encoded_layer_text)
                encoded_layers_object.append(encoded_layer_object)
            if not output_all_encoded_layers:
                encoded_layers_text = encoded_layers_text[0]
                encoded_layers_object = encoded_layers_object[0]
            if output_attention_probs:
                return encoded_layers_text, encoded_layers_object, pooled_output, attention_probs
            else:
                return encoded_layers_text, encoded_layers_object, pooled_output
        else:
            if output_attention_probs:  # False
                return encoded_layers, pooled_output, attention_probs
            else:
                if spo_fused_masks is not None:
                    return encoded_layers, pooled_output, spo_fused_masks
                else:
                    return encoded_layers, pooled_output, None

    def embedding(self,
                  text_input_ids,
                  text_token_type_ids,
                  text_visual_embeddings,
                  text_mask,
                  object_vl_embeddings,
                  object_mask,
                  object_visual_feat=None,
                  spo_len=None):

        # (Text) Token Embedding + Visual Feature Embedding
        text_linguistic_embedding = self.word_embeddings_wrapper(
            text_input_ids)
        # if object_visual_feat is not None and spo_len is not None:
        if self.enhanced_img_feature and object_visual_feat is not None and spo_len is not None:
            # memo the shape
            object_visual_feat_shape = object_visual_feat.shape
            # reshape to [batch_size * nb_imgs, 2048, 14, 14] & downsample from 2048 to 768
            object_visual_feat = object_visual_feat.view(
                -1, object_visual_feat.shape[2], object_visual_feat.shape[3],
                object_visual_feat.shape[4])
            object_visual_feat = self.obj_feat_downsample(object_visual_feat)
            object_visual_feat = self.obj_feat_batchnorm(object_visual_feat)
            # restore to [batch_size, nb_imgs, 768, 14, 14]
            object_visual_feat = object_visual_feat.view(
                (object_visual_feat_shape[0], object_visual_feat_shape[1], 768,
                 object_visual_feat_shape[3], object_visual_feat_shape[4]))

            spo_fused_masks = []
            if self.no_predicate:  # VRD
                for i in range(text_input_ids.shape[0]):  # for each img sample
                    subj_end = 1 + spo_len[i, 0]
                    obj_end = subj_end + spo_len[i, 1]

                    sub_text_emb = text_linguistic_embedding[
                        i, 1:subj_end].view(-1, 1, 1) if spo_len[
                            i, 0] == 1 else text_linguistic_embedding[
                                i, 1:subj_end].sum(dim=0).view(-1, 1, 1)
                    sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=1)
                    sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=2)
                    fused_sub = object_visual_feat[i, 0] + sub_text_emb

                    obj_text_emb = text_linguistic_embedding[
                        i, subj_end:obj_end].view(-1, 1, 1) if spo_len[
                            i, 1] == 1 else text_linguistic_embedding[
                                i, subj_end:obj_end].sum(dim=0).view(-1, 1, 1)
                    obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=1)
                    obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=2)
                    fused_obj = object_visual_feat[i, 0] + obj_text_emb

                    # spo_fused = torch.cat((fused_sub.unsqueeze(0), fused_obj.unsqueeze(0)))
                    fused_sub = self.lan_img_conv1(fused_sub.unsqueeze(0))
                    fused_sub = nn.functional.relu(fused_sub)
                    fused_obj = self.lan_img_conv2(fused_obj.unsqueeze(0))
                    fused_obj = nn.functional.relu(fused_obj)
                    spo_fused = torch.cat((fused_sub, fused_obj))

                    spo_fused = self.lan_img_conv3(spo_fused).squeeze()
                    # spo_fused[spo_fused > 1] = 1
                    # spo_fused[spo_fused < 0] = 0
                    # Save the mask for computing mask loss
                    # spo_fused_masks.append(spo_fused)

                    spo_fused_norm = torch.zeros_like(spo_fused)
                    for j in range(2):
                        max1, min1 = spo_fused[j].max(), spo_fused[j].min()
                        denominator = max1 - min1
                        if not self.config.mask_loss_sum and not self.config.mask_loss_mse:  # For BCE loss
                            # import pdb; pdb.set_trace()
                            denominator += 1e-10
                        spo_fused_norm[j] = (spo_fused[j] - min1) / denominator
                        # spo_fused[j] = torch.sigmoid(spo_fused[j])

                    spo_fused_masks.append(spo_fused_norm)

                    text_visual_embeddings[
                        i, 1:subj_end] = self.lan_img_avgpool(
                            object_visual_feat[i, 0] *
                            spo_fused_norm[0]).squeeze()  # spo_fused[0]
                    text_visual_embeddings[
                        i, subj_end:obj_end] = self.lan_img_avgpool(
                            object_visual_feat[i, 0] *
                            spo_fused_norm[1]).squeeze()  # spo_fused[1]
            else:  # SpatialSense
                for i in range(text_input_ids.shape[0]):  # for each img sample
                    subj_end = 1 + spo_len[i, 0]
                    pred_end = subj_end + spo_len[i, 1]
                    obj_end = pred_end + spo_len[i, 2]

                    sub_text_emb = text_linguistic_embedding[
                        i, 1:subj_end].view(-1, 1, 1) if spo_len[
                            i, 0] == 1 else text_linguistic_embedding[
                                i, 1:subj_end].sum(dim=0).view(-1, 1, 1)
                    sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=1)
                    sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=2)
                    fused_sub = object_visual_feat[i, 0] + sub_text_emb

                    pred_text_emb = text_linguistic_embedding[
                        i, subj_end:pred_end].view(-1, 1, 1) if spo_len[
                            i, 1] == 1 else text_linguistic_embedding[
                                i,
                                subj_end:pred_end].sum(dim=0).view(-1, 1, 1)
                    pred_text_emb = torch.cat(([pred_text_emb] * 14), dim=1)
                    pred_text_emb = torch.cat(([pred_text_emb] * 14), dim=2)
                    fused_pred = object_visual_feat[i, 0] + pred_text_emb

                    obj_text_emb = text_linguistic_embedding[
                        i, pred_end:obj_end].view(-1, 1, 1) if spo_len[
                            i, 2] == 1 else text_linguistic_embedding[
                                i, pred_end:obj_end].sum(dim=0).view(-1, 1, 1)
                    obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=1)
                    obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=2)
                    fused_obj = object_visual_feat[i, 0] + obj_text_emb

                    fused_sub = self.lan_img_conv1(fused_sub.unsqueeze(0))
                    fused_sub = nn.functional.relu(fused_sub)
                    fused_pred = self.lan_img_conv2(fused_pred.unsqueeze(0))
                    fused_pred = nn.functional.relu(fused_pred)
                    fused_obj = self.lan_img_conv3(fused_obj.unsqueeze(0))
                    fused_obj = nn.functional.relu(fused_obj)
                    spo_fused = torch.cat((fused_sub, fused_pred, fused_obj))

                    spo_fused = self.lan_img_conv4(spo_fused).squeeze()
                    # spo_fused[spo_fused > 1] = 1
                    # spo_fused[spo_fused < 0] = 0
                    # Save the mask for computing mask loss
                    # spo_fused_masks.append(spo_fused)

                    # spo_fused += object_visual_feat[i,0].unsqueeze(0)
                    # spo_fused = self.lan_img_avgpool(spo_fused).squeeze()

                    # import pdb; pdb.set_trace()
                    spo_fused_norm = torch.zeros_like(spo_fused)
                    for j in range(3):
                        max1, min1 = spo_fused[j].max(), spo_fused[j].min()
                        denominator = max1 - min1
                        # if not self.config.mask_loss_sum and not self.config.mask_loss_mse: # For BCE loss
                        #     # import pdb; pdb.set_trace()
                        #     denominator += 1e-10
                        spo_fused_norm[j] = (spo_fused[j] - min1) / denominator
                        # spo_fused[j] = torch.sigmoid(spo_fused[j])

                    spo_fused_masks.append(spo_fused_norm)

                    text_visual_embeddings[i,
                                           1:subj_end] = self.lan_img_avgpool(
                                               object_visual_feat[i, 0] *
                                               spo_fused_norm[0]).squeeze()
                    text_visual_embeddings[
                        i, subj_end:pred_end] = self.lan_img_avgpool(
                            object_visual_feat[i, 0] *
                            spo_fused_norm[1]).squeeze()  # spo_fused[1]
                    text_visual_embeddings[
                        i, pred_end:obj_end] = self.lan_img_avgpool(
                            object_visual_feat[i, 0] *
                            spo_fused_norm[2]).squeeze()  # spo_fused[1]
            spo_fused_masks = torch.cat(spo_fused_masks)
        if self.visual_1x1_text is not None:  # always False
            text_visual_embeddings = self.visual_1x1_text(
                text_visual_embeddings)
        if self.config.visual_ln:  # always True
            text_visual_embeddings = self.visual_ln_text(
                text_visual_embeddings)
        else:
            text_visual_embeddings *= self.visual_scale_text
        text_vl_embeddings = text_linguistic_embedding + text_visual_embeddings

        # (Object) Token Embedding + Visual Feature Embedding
        object_visual_embeddings = object_vl_embeddings[:, :, :self.config.
                                                        visual_size]
        if self.visual_1x1_object is not None:  # always False
            object_visual_embeddings = self.visual_1x1_object(
                object_visual_embeddings)
        if self.config.visual_ln:  # always True
            object_visual_embeddings = self.visual_ln_object(
                object_visual_embeddings)
        else:
            object_visual_embeddings *= self.visual_scale_object
        object_linguistic_embeddings = object_vl_embeddings[:, :, self.config.
                                                            visual_size:]
        # import pdb; pdb.set_trace()
        object_vl_embeddings = object_linguistic_embeddings + object_visual_embeddings

        # Some indices setup for following process
        bs = text_vl_embeddings.size(0)
        vl_embed_size = text_vl_embeddings.size(-1)
        max_length = (text_mask.sum(1) + object_mask.sum(1)).max() + 1
        grid_ind, grid_pos = torch.meshgrid(
            torch.arange(bs,
                         dtype=torch.long,
                         device=text_vl_embeddings.device),
            torch.arange(max_length,
                         dtype=torch.long,
                         device=text_vl_embeddings.device))
        text_end = text_mask.sum(1, keepdim=True)
        object_end = text_end + object_mask.sum(1, keepdim=True)

        # seamlessly concatenate visual linguistic embeddings of text and object
        _zero_id = torch.zeros((bs, ),
                               dtype=torch.long,
                               device=text_vl_embeddings.device)
        vl_embeddings = text_vl_embeddings.new_zeros(
            (bs, max_length, vl_embed_size))
        vl_embeddings[grid_pos < text_end] = text_vl_embeddings[text_mask]
        vl_embeddings[(grid_pos >= text_end) & (
            grid_pos < object_end)] = object_vl_embeddings[object_mask]
        vl_embeddings[grid_pos == object_end] = self.end_embedding(
            _zero_id)  # '[END]'

        # segment embeddings / token type embeddings
        # import pdb; pdb.set_trace()
        token_type_ids = text_token_type_ids.new_zeros((bs, max_length))
        token_type_ids[grid_pos < text_end] = text_token_type_ids[text_mask]
        token_type_ids[(grid_pos >= text_end) & (grid_pos <= object_end)] = 2
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # position embeddings
        position_ids = grid_pos + self.position_padding_idx + 1
        if self.config.use_img_region_order:
            pass
        elif self.config.obj_pos_id_relative:  # always True!
            position_ids[(grid_pos >= text_end) & (grid_pos < object_end)] \
                = text_end.expand((bs, max_length))[(grid_pos >= text_end) & (grid_pos < object_end)] \
                + self.position_padding_idx + 1
            position_ids[grid_pos == object_end] = (
                text_end + 1).squeeze(1) + self.position_padding_idx + 1
        else:
            assert False, "Don't use position id 510/511 for objects and [END]!!!"
            position_ids[(grid_pos >= text_end)
                         & (grid_pos < object_end
                            )] = self.config.max_position_embeddings - 2
            position_ids[grid_pos ==
                         object_end] = self.config.max_position_embeddings - 1
        position_embeddings = self.position_embeddings(position_ids)
        # import pdb; pdb.set_trace()

        mask = text_mask.new_zeros((bs, max_length))
        mask[grid_pos <= object_end] = 1

        embeddings = vl_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.embedding_LayerNorm(embeddings)
        embeddings = self.embedding_dropout(embeddings)

        if self.enhanced_img_feature and object_visual_feat is not None and spo_len is not None:
            return embeddings, mask, grid_pos < text_end, (
                grid_pos >= text_end) & (grid_pos <
                                         object_end), spo_fused_masks
        else:
            return embeddings, mask, grid_pos < text_end, (
                grid_pos >= text_end) & (grid_pos < object_end), None

    def load_language_pretrained_model(self, language_pretrained_model_path):
        pretrained_state_dict = torch.load(
            language_pretrained_model_path,
            map_location=lambda storage, loc: storage)
        encoder_pretrained_state_dict = {}
        pooler_pretrained_state_dict = {}
        embedding_ln_pretrained_state_dict = {}
        unexpected_keys = []
        for k, v in pretrained_state_dict.items():
            if k.startswith('bert.'):
                k = k[len('bert.'):]
            elif k.startswith('roberta.'):
                k = k[len('roberta.'):]
            else:
                unexpected_keys.append(k)
                continue
            if 'gamma' in k:
                k = k.replace('gamma', 'weight')
            if 'beta' in k:
                k = k.replace('beta', 'bias')
            if k.startswith('encoder.'):
                k_ = k[len('encoder.'):]
                if k_ in self.encoder.state_dict():
                    encoder_pretrained_state_dict[k_] = v
                else:
                    unexpected_keys.append(k)
            elif k.startswith('embeddings.'):
                k_ = k[len('embeddings.'):]
                if k_ == 'word_embeddings.weight':
                    self.word_embeddings.weight.data = v.to(
                        dtype=self.word_embeddings.weight.data.dtype,
                        device=self.word_embeddings.weight.data.device)
                elif k_ == 'position_embeddings.weight':
                    self.position_embeddings.weight.data = v.to(
                        dtype=self.position_embeddings.weight.data.dtype,
                        device=self.position_embeddings.weight.data.device)
                elif k_ == 'token_type_embeddings.weight':
                    self.token_type_embeddings.weight.data[:v.size(0)] = v.to(
                        dtype=self.token_type_embeddings.weight.data.dtype,
                        device=self.token_type_embeddings.weight.data.device)
                    if v.size(0) == 1:
                        # Todo: roberta token type embedding
                        self.token_type_embeddings.weight.data[1] = v[0].clone(
                        ).to(
                            dtype=self.token_type_embeddings.weight.data.dtype,
                            device=self.token_type_embeddings.weight.data.
                            device)
                        self.token_type_embeddings.weight.data[2] = v[0].clone(
                        ).to(
                            dtype=self.token_type_embeddings.weight.data.dtype,
                            device=self.token_type_embeddings.weight.data.
                            device)

                elif k_.startswith('LayerNorm.'):
                    k__ = k_[len('LayerNorm.'):]
                    if k__ in self.embedding_LayerNorm.state_dict():
                        embedding_ln_pretrained_state_dict[k__] = v
                    else:
                        unexpected_keys.append(k)
                else:
                    unexpected_keys.append(k)
            elif self.config.with_pooler and k.startswith('pooler.'):
                k_ = k[len('pooler.'):]
                if k_ in self.pooler.state_dict():
                    pooler_pretrained_state_dict[k_] = v
                else:
                    unexpected_keys.append(k)
            else:
                unexpected_keys.append(k)
        if len(unexpected_keys) > 0:
            print("Warnings: Unexpected keys: {}.".format(unexpected_keys))
        self.embedding_LayerNorm.load_state_dict(
            embedding_ln_pretrained_state_dict)
        self.encoder.load_state_dict(encoder_pretrained_state_dict)
        if self.config.with_pooler and len(pooler_pretrained_state_dict) > 0:
            self.pooler.load_state_dict(pooler_pretrained_state_dict)
    def __init__(self, config, language_pretrained_model_path=None):
        super(VisualLinguisticBertDecoder, self).__init__(config)

        self.config = config

        # embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size)
        self.end_embedding = nn.Embedding(1, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # for compatibility of roberta
        self.position_padding_idx = config.position_padding_idx

        # visual transform
        self.visual_1x1_text = None
        self.visual_1x1_object = None
        if config.visual_size != config.hidden_size:
            self.visual_1x1_text = nn.Linear(config.visual_size,
                                             config.hidden_size)
            self.visual_1x1_object = nn.Linear(config.visual_size,
                                               config.hidden_size)
        if config.visual_ln:
            self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12)
            self.visual_ln_object = BertLayerNorm(config.hidden_size,
                                                  eps=1e-12)
        else:
            visual_scale_text = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_text_init, dtype=torch.float),
                                             requires_grad=True)
            self.register_parameter('visual_scale_text', visual_scale_text)
            visual_scale_object = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_object_init, dtype=torch.float),
                                               requires_grad=True)
            self.register_parameter('visual_scale_object', visual_scale_object)

        # *********************************************
        # FM addition - Set-up decoder layer for MT
        #  Initializing a BERT bert-base-uncased style configuration
        configuration = BertConfig()
        configuration.vocab_size = config.vocab_size
        # FM edit: reduce size - 12 layers doesn't fit in single 12GB GPU
        configuration.num_hidden_layers = 6
        configuration.is_decoder = True
        # Initializing a model from the bert-base-uncased style configuration
        self.decoder = BertModel(configuration)
        # *********************************************

        if self.config.with_pooler:
            self.pooler = BertPooler(config)

        # init weights
        self.apply(self.init_weights)
        if config.visual_ln:
            self.visual_ln_text.weight.data.fill_(
                self.config.visual_scale_text_init)
            self.visual_ln_object.weight.data.fill_(
                self.config.visual_scale_object_init)

        # load language pretrained model
        if language_pretrained_model_path is not None:
            self.load_language_pretrained_model(language_pretrained_model_path)

        if config.word_embedding_frozen:
            for p in self.word_embeddings.parameters():
                p.requires_grad = False
            self.special_word_embeddings = nn.Embedding(
                NUM_SPECIAL_WORDS, config.hidden_size)
            self.special_word_embeddings.weight.data.copy_(
                self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])