예제 #1
0
    def __init__(self, config):
        super(BertEncoder, self).__init__()

        # in the bert encoder, we need to extract three things here.
        # text bert layer: BertLayer
        # vision bert layer: BertImageLayer
        # Bi-Attention: Given the output of two bertlayer, perform bi-directional
        # attention and add on two layers.
        t_config = BertConfig.from_dict(config.t_config)
        v_config = BertConfig.from_dict(config.v_config)

        self.FAST_MODE = config.fast_mode
        self.with_coattention = config.with_coattention
        self.v_biattention_id = v_config.biattention_id
        self.t_biattention_id = t_config.biattention_id
        self.in_batch_pairs = config.in_batch_pairs
        self.fixed_t_layer = config.fixed_t_layer
        self.fixed_v_layer = config.fixed_v_layer

        # layer = BertLayer(config)
        layer = BertLayer(t_config)
        v_layer = BertLayer(v_config)
        connect_layer = BertConnectionLayer(config)

        self.layer = nn.ModuleList(
            [copy.deepcopy(layer) for _ in range(t_config.num_hidden_layers)])
        self.v_layer = nn.ModuleList([
            copy.deepcopy(v_layer) for _ in range(v_config.num_hidden_layers)
        ])
        self.c_layer = nn.ModuleList([
            copy.deepcopy(connect_layer)
            for _ in range(len(v_config.biattention_id))
        ])
        def init_data(self, use_cuda: bool) -> None:
            test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(1)

            torch.set_grad_enabled(False)
            self.cfg = BertConfig(attention_probs_dropout_prob=0.0,
                                  hidden_dropout_prob=0.0)

            self.torch_bert_layer = BertLayer(self.cfg)
            self.torch_bert_layer.eval()
            if use_cuda:
                self.torch_bert_layer.to(test_device)

            self.hidden_size = self.cfg.hidden_size
            self.input_tensor = torch.rand(size=(batch_size, seq_length,
                                                 self.hidden_size),
                                           dtype=torch.float32,
                                           device=test_device)

            self.attention_mask = torch.ones((batch_size, seq_length),
                                             dtype=torch.float32,
                                             device=test_device)
            self.attention_mask = self.attention_mask[:, None, None, :]
            self.attention_mask = (1.0 - self.attention_mask) * -10000.0

            self.turbo_bert_layer = turbo_transformers.BertLayer.from_torch(
                self.torch_bert_layer)
예제 #3
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

        assert config.num_hidden_layers % 2 == 0, "num_hidden_layers must be even in Trelm!"
        self.tlayer = BertLayer(config) #
        self.tlayer_position = int(config.num_hidden_layers / 2)
예제 #4
0
    def __init__(self, config):
        super(BertForSequenceClassificationNq, self).__init__(config)
        self.num_labels = config.num_labels
        # config.output_hidden_states = True
        bert_later_dropout = 0.3
        self.dropout = nn.Dropout(bert_later_dropout)
        self.later_model_type = config.later_model_type

        if self.later_model_type == 'linear':
            self.bert = BertModel(config)
            self.projection = nn.Linear(config.hidden_size * 3,
                                        config.hidden_size)
            self.projection_dropout = nn.Dropout(0.1)
            self.projection_activation = nn.Tanh()
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        elif self.later_model_type == '1bert_layer':
            config.num_hidden_layers = 1
            self.bert = BertModel(config)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        elif self.later_model_type == 'bilinear':
            self.bert = BertModel(config)
            lstm_layers = 2
            self.qemb_match = SeqAttnMatch(config.hidden_size)
            doc_input_size = 2 * config.hidden_size
            # RNN document encoder
            self.doc_rnn = StackedBRNN(
                input_size=doc_input_size,
                hidden_size=config.hidden_size,
                num_layers=lstm_layers,
                dropout_rate=bert_later_dropout,
                dropout_output=bert_later_dropout,
                concat_layers=True,
                rnn_type=nn.LSTM,
                padding=False,
            )

            self.bilinear_dropout = nn.Dropout(bert_later_dropout)
            self.bilinear_size = 128
            self.doc_proj = nn.Linear(lstm_layers * 2 * config.hidden_size,
                                      self.bilinear_size)
            self.qs_proj = nn.Linear(config.hidden_size, self.bilinear_size)
            self.bilinear = nn.Bilinear(self.bilinear_size, self.bilinear_size,
                                        self.bilinear_size)
            self.classifier = nn.Linear(self.bilinear_size, config.num_labels)
        elif self.later_model_type == 'transformer':
            self.copy_from_bert_layer_num = 11
            self.bert = BertModel(config)
            self.bert_position_emb = nn.Embedding(
                config.max_position_embeddings, config.hidden_size)
            self.bert_type_id_emb = nn.Embedding(config.type_vocab_size,
                                                 config.hidden_size)

            self.bert_layer = BertLayer(config)
            self.bert_pooler_qd = BertPoolerQD(config)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
예제 #5
0
    def __init__(self, extractor, config, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.extractor = extractor
        self.config = config

        if config["pretrained"] == "electra-base-msmarco":
            self.bert = ElectraModel.from_pretrained(
                "Capreolus/electra-base-msmarco")
        elif config["pretrained"] == "bert-base-msmarco":
            self.bert = BertModel.from_pretrained(
                "Capreolus/bert-base-msmarco")
        elif config["pretrained"] == "bert-base-uncased":
            self.bert = BertModel.from_pretrained("bert-base-uncased")
        else:
            raise ValueError(
                f"unsupported model: {config['pretrained']}; need to ensure correct tokenizers will be used before arbitrary hgf models are supported"
            )

        self.transformer_layer_1 = BertLayer(self.bert.config)
        self.transformer_layer_2 = BertLayer(self.bert.config)
        self.num_passages = extractor.config["numpassages"]
        self.maxseqlen = extractor.config["maxseqlen"]
        self.linear = nn.Linear(self.bert.config.hidden_size, 1)

        if config["aggregation"] == "max":
            raise NotImplementedError()
        elif config["aggregation"] == "avg":
            raise NotImplementedError()
        elif config["aggregation"] == "attn":
            raise NotImplementedError()
        elif config["aggregation"] == "transformer":
            self.aggregation = self.aggregate_using_transformer
            input_embeddings = self.bert.get_input_embeddings()
            # TODO hardcoded CLS token id
            cls_token_id = torch.tensor([[101]])
            self.initial_cls_embedding = input_embeddings(cls_token_id).view(
                1, self.bert.config.hidden_size)
            self.full_position_embeddings = torch.zeros(
                (1, self.num_passages + 1, self.bert.config.hidden_size),
                requires_grad=True,
                dtype=torch.float)
            torch.nn.init.normal_(self.full_position_embeddings,
                                  mean=0.0,
                                  std=0.02)

            self.initial_cls_embedding = nn.Parameter(
                self.initial_cls_embedding, requires_grad=True)
            self.full_position_embeddings = nn.Parameter(
                self.full_position_embeddings, requires_grad=True)
        else:
            raise ValueError(
                f"unknown aggregation type: {self.config['aggregation']}")
예제 #6
0
 def __init__(self, config, scc_n_layer=6):
     super(BertEncoder, self).__init__()
     self.prd_n_layer = config.num_hidden_layers
     self.scc_n_layer = scc_n_layer
     assert self.prd_n_layer % self.scc_n_layer == 0
     self.compress_ratio = self.prd_n_layer // self.scc_n_layer
     self.bernoulli = None
     self.output_attentions = config.output_attentions
     self.output_hidden_states = config.output_hidden_states
     self.layer = nn.ModuleList(
         [BertLayer(config) for _ in range(self.prd_n_layer)])
     self.scc_layer = nn.ModuleList(
         [BertLayer(config) for _ in range(self.scc_n_layer)])
예제 #7
0
 def __init__(self, config):
     super(BertEncoderATM, self).__init__()
     self.output_attentions = config.output_attentions
     self.output_hidden_states = config.output_hidden_states
     self.map_linear = nn.Linear(config.hidden_size, config.hidden_size) 
     self.layer = nn.ModuleList([BertLayer(config)
                                 for _ in range(config.num_hidden_layers)])
예제 #8
0
 def __init__(self, config, num_shared_layers, num_layers):
     super(transformer_block, self).__init__()
     self.num_layers = num_layers
     self.num_shared_layers = num_shared_layers
     self.bert_layers = nn.ModuleList(
         [BertLayer(config) for _ in range(num_layers)])
     self.pooler = BertPooler(config)
예제 #9
0
파일: visual_bert.py 프로젝트: zpppy/mmf
    def __init__(
        self,
        config,
        visual_embedding_dim=512,
        embedding_strategy="plain",
        bypass_transformer=False,
        output_attentions=False,
        output_hidden_states=False,
    ):
        super().__init__(config)
        self.config = config

        config.visual_embedding_dim = visual_embedding_dim
        config.embedding_strategy = embedding_strategy
        config.bypass_transformer = bypass_transformer
        config.output_attentions = output_attentions
        config.output_hidden_states = output_hidden_states

        self.embeddings = BertVisioLinguisticEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.bypass_transformer = config.bypass_transformer

        if self.bypass_transformer:
            self.additional_layer = BertLayer(config)

        self.output_attentions = self.config.output_attentions
        self.output_hidden_states = self.config.output_hidden_states
        self.fixed_head_masks = [None for _ in range(len(self.encoder.layer))]
        self.init_weights()
예제 #10
0
 def __init__(self, config, visual_start_layer):
     super().__init__()
     self.output_attentions = False # config.output_attentions
     self.output_hidden_states = True # config.output_hidden_states
     self.visual_start_layer = visual_start_layer
     self.config = config
     self.layer = torch.nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
예제 #11
0
파일: lxmert.py 프로젝트: slbinilkumar/mmf
    def __init__(self, config):
        super().__init__()

        # Obj-level image embedding layer
        self.visn_fc = VisualFeatEncoder(config)

        # Number of layers
        self.num_l_layers = config.l_layers
        self.num_x_layers = config.x_layers
        self.num_r_layers = config.r_layers
        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_l_layers)])
        self.x_layers = nn.ModuleList(
            [LXMERTXLayer(config) for _ in range(self.num_x_layers)])
        self.r_layers = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_r_layers)])
예제 #12
0
 def __init__(self, config):
     super(BertEncoder4TokenMix, self).__init__()
     # self.output_attentions = config.output_attentions
     # self.output_hidden_states = config.output_hidden_states
     self.output_attentions = False
     self.output_hidden_states = True 
     self.layer = nn.ModuleList([BertLayer(config)
                                 for _ in range(config.num_hidden_layers)])
예제 #13
0
    def __init__(self, config):
        super().__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)])

        self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
예제 #14
0
 def __init__(self, config):
     super(Stage0, self).__init__()
     self.embedding_layer = BertEmbeddings(config)
     self.layers = []
     for i in range(config.num_hidden_layers // 24):
         self.layers.append(BertLayer(config))
     self.layers = torch.nn.ModuleList(self.layers)
     self.config = config
     self.apply(self.init_bert_weights)
예제 #15
0
 def __init__(self, config):
     super(Stage1, self).__init__()
     self.layers = []
     for i in range(12):  #config.num_hidden_layers):
         self.layers.append(BertLayer(config))
     self.layers = torch.nn.ModuleList(self.layers)
     self.pooling_layer = BertPooler(config)
     self.pre_training_heads_layer = BertPreTrainingHeads(config)
     self.config = config
     self.apply(self.init_bert_weights)
예제 #16
0
    def __init__(self, config, embedding_dim, num_groups):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config, embedding_dim)
        msg = 'Amount of encoder blocks should be divisible by number of groups.'
        assert config.num_hidden_layers % num_groups == 0, msg
        self.encoder = nn.ModuleList([BertLayer(config) for _ in range(num_groups)])
        self.group_size = config.num_hidden_layers // num_groups
        self.init_weights()
예제 #17
0
    def __init__(self,
                 config,
                 num_layers=1,
                 num_langs=1,
                 struct="transformer",
                 add_weights=False,
                 tied=True,
                 bottle_size=768):
        super().__init__()

        self.nets = []
        self.num_layers = num_layers
        self.num_langs = num_langs
        self.struct = struct
        self.add_weights = add_weights
        self.tied = tied
        for i in range(num_langs):
            for j in range(num_layers):
                if struct == "transformer":
                    self.nets.append(BertLayer(config))
                elif struct == "perceptron":
                    hidden_size = config.hidden_size
                    if add_weights:
                        if tied:
                            self.nets.append(
                                nn.Sequential(
                                    nn.Linear(hidden_size, bottle_size),
                                    nn.ReLU(),
                                    nn.Linear(bottle_size, hidden_size + 1)))
                        else:
                            self.nets.append(
                                nn.Sequential(
                                    nn.Linear(hidden_size, bottle_size),
                                    nn.ReLU(),
                                    nn.Linear(bottle_size, hidden_size)))
                            self.weight_net = nn.Sequential(
                                nn.Linear(hidden_size, bottle_size), nn.ReLU(),
                                nn.Linear(bottle_size, 1))
                    else:
                        self.nets.append(
                            nn.Sequential(
                                nn.Linear(hidden_size, hidden_size // 4),
                                nn.ReLU(),
                                nn.Linear(hidden_size // 4, hidden_size)))
                else:
                    print("The specified structure is not implemented.")
                    sys.exit(0)

        self.nets = nn.ModuleList(self.nets)
        self.alpha = nn.Parameter(torch.zeros(num_langs, num_layers))

        if struct == "perceptron":
            self.init_weights()
예제 #18
0
    def __init__(self, hidden_size, args):
        super().__init__()
        self.extract_num_layers = args.extract_num_layers
        self.pos_embeddings = PositionalEncoding(hidden_size, args.extract_dropout_prob)

        config = BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size*4, layer_norm_eps=args.extract_layer_norm_eps,
                            hidden_dropout_prob=args.extract_dropout_prob, attention_probs_dropout_prob=args.extract_dropout_prob)
        self.encoder_stack = nn.ModuleList([BertLayer(config) for _ in range(args.extract_num_layers)])

        self.dropout = nn.Dropout(args.extract_dropout_prob)
        self.layer_norm = nn.LayerNorm(hidden_size, args.extract_layer_norm_eps)
        self.linear = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
    def __init__(self,
                 num_labels,
                 pretrained_model_name_or_path=None,
                 cat_num=0,
                 token_size=None,
                 MAX_SEQUENCE_LENGTH=512):
        super(BertModelForBinaryMultiLabelClassifier, self).__init__()
        if pretrained_model_name_or_path:
            self.model = BertModel.from_pretrained(
                pretrained_model_name_or_path)
        else:
            raise NotImplementedError
        self.num_labels = num_labels
        if cat_num > 0:
            self.catembedding = nn.Embedding(cat_num, 768)
            self.catdropout = nn.Dropout(0.2)
            self.catactivate = nn.ReLU()

            self.catembeddingOut = nn.Embedding(cat_num, cat_num // 2 + 1)
            self.catactivateOut = nn.ReLU()
            self.dropout = nn.Dropout(0.2)
            self.classifier = nn.Linear(
                self.model.pooler.dense.out_features + cat_num // 2 + 1,
                num_labels)
        else:
            self.catembedding = None
            self.catdropout = None
            self.catactivate = None
            self.catembeddingOut = None
            self.catactivateOut = None
            self.dropout = nn.Dropout(0.2)
            self.classifier = nn.Linear(self.model.pooler.dense.out_features,
                                        num_labels)

        # resize
        if token_size:
            self.model.resize_token_embeddings(token_size)

        # define input embedding and transformers
        input_model_config = BertConfig(
            vocab_size=token_size, max_position_embeddings=MAX_SEQUENCE_LENGTH)
        self.input_embeddings = BertEmbeddings(input_model_config)
        self.input_bert_layer = BertLayer(input_model_config)

        # use bertmodel as decoder
        # self.model.config.is_decoder = True

        # add modules
        self.add_module('my_input_embeddings', self.input_embeddings)
        self.add_module('my_input_bert_layer', self.input_bert_layer)
        self.add_module('fc_output', self.classifier)
        def init_bertlayer_models(self, use_cuda: bool) -> None:
            self.test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(1)

            torch.set_grad_enabled(False)
            self.cfg = BertConfig(attention_probs_dropout_prob=0.0,
                                  hidden_dropout_prob=0.0)

            self.torch_model = BertLayer(self.cfg)
            self.torch_model.eval()
            if use_cuda:
                self.torch_model.to(self.test_device)

            self.hidden_size = self.cfg.hidden_size

            self.turbo_model = turbo_transformers.BertLayerSmartBatch.from_torch(
                self.torch_model)
예제 #21
0
    def __init__(self, config):
        super(GLTTokenGrounding, self).__init__()
        self.initial_img_project = nn.Linear(config.input_img_dim,
                                             config.hidden_size)
        self.text_project = nn.Linear(config.hidden_size, config.hidden_size)

        self.b_bias = nn.Parameter(torch.zeros((1)))

        self.lnorm = BertLayerNorm(config.hidden_size,
                                   eps=config.layer_norm_eps)

        self.pos_project = nn.Linear(6, config.hidden_size)
        self.lnorm_pos = BertLayerNorm(config.hidden_size,
                                       eps=config.layer_norm_eps)

        self.coreference_jump_gate = nn.Linear(config.hidden_size, 2)

        self.layer = nn.ModuleList([
            BertLayer(config)
            for _ in range(config.visual_self_attention_layers)
        ])
예제 #22
0
    def __init__(self, config):
        super(GLTEncoder, self).__init__()
        self.ground = GLTTokenGrounding(config)

        modules = [
            GLTLayer(config) for _ in range(config.max_sentence_length - 1)
        ]

        layers_to_tie = config.layers_to_tie
        if layers_to_tie:
            tie_layers(modules[0], modules[1:], config.layers_to_tie)

        self.layer = nn.ModuleList(modules)
        self.answer_nn = GLTAnswerVisualTextComp(config)

        self._contextualize_inputs_n_layers = config.contextualize_inputs_layers
        if self._contextualize_inputs_n_layers:
            self.self_att_layers = nn.ModuleList([
                BertLayer(config)
                for _ in range(self._contextualize_inputs_n_layers)
            ])
        else:
            self.self_att_layers = None
    class TestBertLayer(unittest.TestCase):
        def init_data(self, use_cuda: bool) -> None:
            test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(1)

            torch.set_grad_enabled(False)
            self.cfg = BertConfig(attention_probs_dropout_prob=0.0,
                                  hidden_dropout_prob=0.0)

            self.torch_bert_layer = BertLayer(self.cfg)
            self.torch_bert_layer.eval()
            if use_cuda:
                self.torch_bert_layer.to(test_device)

            self.hidden_size = self.cfg.hidden_size
            self.input_tensor = torch.rand(size=(batch_size, seq_length,
                                                 self.hidden_size),
                                           dtype=torch.float32,
                                           device=test_device)

            self.attention_mask = torch.ones((batch_size, seq_length),
                                             dtype=torch.float32,
                                             device=test_device)
            self.attention_mask = self.attention_mask[:, None, None, :]
            self.attention_mask = (1.0 - self.attention_mask) * -10000.0

            self.turbo_bert_layer = turbo_transformers.BertLayer.from_torch(
                self.torch_bert_layer)

        def check_torch_and_turbo(self, use_cuda):
            self.init_data(use_cuda)
            num_iter = 2
            device = "GPU" if use_cuda else "CPU"
            torch_model = lambda: self.torch_bert_layer(
                self.input_tensor, self.attention_mask, output_attentions=True)
            torch_bert_layer_result, torch_qps, torch_time = \
                test_helper.run_model(torch_model, use_cuda, num_iter)
            print(f"BertLayer \"({batch_size},{seq_length:03})\" ",
                  f"{device} Torch QPS,  {torch_qps}, time, {torch_time}")

            turbo_model = lambda: self.turbo_bert_layer(
                self.input_tensor, self.attention_mask, output_attentions=True)
            turbo_bert_layer_result, turbo_qps, turbo_time = \
                test_helper.run_model(turbo_model, use_cuda, num_iter)
            print(
                f"BertLayer \"({batch_size},{seq_length:03})\"  ",
                f"{device} TurboTransform QPS, {turbo_qps}, time, {turbo_time}"
            )

            # Tensor core will introduce more errors
            tolerate_error = 1e-2 if use_cuda else 1e-3
            self.assertTrue(
                torch.max(
                    torch.abs(torch_bert_layer_result[0] -
                              turbo_bert_layer_result[0])) < tolerate_error)

            self.assertTrue(
                torch.max(
                    torch.abs(torch_bert_layer_result[1] -
                              turbo_bert_layer_result[1])) < tolerate_error)

            with open(fname, "a") as fh:
                fh.write(
                    f"\"({batch_size},{seq_length:03})\", {torch_qps}, {turbo_qps}\n"
                )

        def test_bert_layer(self):
            self.check_torch_and_turbo(use_cuda=False)
            if torch.cuda.is_available() and \
                turbo_transformers.config.is_compiled_with_cuda():
                self.check_torch_and_turbo(use_cuda=True)
예제 #24
0
파일: berts.py 프로젝트: OerneMand/K-BERT
 def __init__(self, config):
     super().__init__(config)
     self.output_attentions = config.output_attentions
     self.output_hidden_states = config.output_hidden_states
     self.layer = nn.ModuleList(
         [BertLayer(config) for _ in range(config.num_hidden_layers)])
예제 #25
0
class BertForSequenceClassificationNq(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForSequenceClassificationNq, self).__init__(config)
        self.num_labels = config.num_labels
        # config.output_hidden_states = True
        bert_later_dropout = 0.3
        self.dropout = nn.Dropout(bert_later_dropout)
        self.later_model_type = config.later_model_type

        if self.later_model_type == 'linear':
            self.bert = BertModel(config)
            self.projection = nn.Linear(config.hidden_size * 3,
                                        config.hidden_size)
            self.projection_dropout = nn.Dropout(0.1)
            self.projection_activation = nn.Tanh()
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        elif self.later_model_type == '1bert_layer':
            config.num_hidden_layers = 1
            self.bert = BertModel(config)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        elif self.later_model_type == 'bilinear':
            self.bert = BertModel(config)
            lstm_layers = 2
            self.qemb_match = SeqAttnMatch(config.hidden_size)
            doc_input_size = 2 * config.hidden_size
            # RNN document encoder
            self.doc_rnn = StackedBRNN(
                input_size=doc_input_size,
                hidden_size=config.hidden_size,
                num_layers=lstm_layers,
                dropout_rate=bert_later_dropout,
                dropout_output=bert_later_dropout,
                concat_layers=True,
                rnn_type=nn.LSTM,
                padding=False,
            )

            self.bilinear_dropout = nn.Dropout(bert_later_dropout)
            self.bilinear_size = 128
            self.doc_proj = nn.Linear(lstm_layers * 2 * config.hidden_size,
                                      self.bilinear_size)
            self.qs_proj = nn.Linear(config.hidden_size, self.bilinear_size)
            self.bilinear = nn.Bilinear(self.bilinear_size, self.bilinear_size,
                                        self.bilinear_size)
            self.classifier = nn.Linear(self.bilinear_size, config.num_labels)
        elif self.later_model_type == 'transformer':
            self.copy_from_bert_layer_num = 11
            self.bert = BertModel(config)
            self.bert_position_emb = nn.Embedding(
                config.max_position_embeddings, config.hidden_size)
            self.bert_type_id_emb = nn.Embedding(config.type_vocab_size,
                                                 config.hidden_size)

            self.bert_layer = BertLayer(config)
            self.bert_pooler_qd = BertPoolerQD(config)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def init_top_layer_from_bert(self):
        if self.later_model_type == 'transformer':
            # directly load from bert
            copy_dict = copy.deepcopy(self.bert.encoder.layer[
                self.copy_from_bert_layer_num].state_dict())
            self.bert_layer.load_state_dict(copy_dict)
            copy_dict = copy.deepcopy(
                self.bert.embeddings.position_embeddings.state_dict())
            self.bert_position_emb.load_state_dict(copy_dict)
            copy_dict = copy.deepcopy(
                self.bert.embeddings.token_type_embeddings.state_dict())
            self.bert_type_id_emb.load_state_dict(copy_dict)

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                position_ids=None,
                head_mask=None,
                input_ids_a=None,
                token_type_ids_a=None,
                attention_mask_a=None,
                input_ids_b=None,
                token_type_ids_b=None,
                attention_mask_b=None):

        # outputs_original = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
        #                     attention_mask=attention_mask, head_mask=head_mask)
        # pooled_output = self.dropout(pooled_output)
        outputs_a = self.bert(input_ids_a,
                              position_ids=None,
                              token_type_ids=token_type_ids_a,
                              attention_mask=attention_mask_a)
        outputs_b = self.bert(input_ids_b,
                              position_ids=None,
                              token_type_ids=token_type_ids_b,
                              attention_mask=attention_mask_b)
        if self.later_model_type == 'linear':
            pooled_output_a = outputs_a[1]
            pooled_output_a = self.dropout(pooled_output_a)
            pooled_output_b = outputs_b[1]
            pooled_output_b = self.dropout(pooled_output_b)
            pooled_output = torch.cat((pooled_output_a, pooled_output_b,
                                       pooled_output_a - pooled_output_b),
                                      dim=1)
            pooled_output = self.projection(pooled_output)
            pooled_output = self.projection_activation(pooled_output)
            pooled_output = self.projection_dropout(pooled_output)
        elif self.later_model_type == '1bert_layer':
            encoder_outputs = self.bert(input_ids,
                                        attention_mask=attention_mask,
                                        token_type_ids=token_type_ids,
                                        position_ids=position_ids,
                                        head_mask=head_mask)

            bert_1stlayer = encoder_outputs[1]
            # pooled_output = self.bert_pooler(bert_1stlayer)
            pooled_output = self.dropout(bert_1stlayer)
        elif self.later_model_type == 'bilinear':
            question_hiddens = outputs_a[0]
            question_hiddens = self.dropout(question_hiddens)
            doc_hiddens = outputs_b[0]
            doc_hiddens = self.dropout(doc_hiddens)
            question_mask = (1 - attention_mask_a).to(torch.bool)
            doc_mask = (1 - attention_mask_b).to(torch.bool)
            x2_weighted_emb = self.qemb_match(doc_hiddens, question_hiddens,
                                              question_mask)
            doc_hiddens = torch.cat((doc_hiddens, x2_weighted_emb), 2)
            doc_hiddens = self.doc_rnn(doc_hiddens, doc_mask)

            question_hidden = outputs_a[1]
            question_hidden = self.qs_proj(question_hidden)
            question_hidden = self.bilinear_dropout(question_hidden)
            doc_hiddens = self.doc_proj(doc_hiddens)
            doc_hiddens = self.bilinear_dropout(doc_hiddens)

            question_hidden = question_hidden.unsqueeze(1).expand_as(
                doc_hiddens).contiguous()
            doc_hiddens = self.bilinear(doc_hiddens, question_hidden)
            pooled_output = doc_hiddens.max(dim=1)[0]
        elif self.later_model_type == 'transformer':
            input_ids = torch.cat((input_ids_a, input_ids_b), dim=1)
            bert_embeddings_a = outputs_a[0]
            bert_embeddings_b = outputs_b[0]
            embeddings_cat = torch.cat((bert_embeddings_a, bert_embeddings_b),
                                       dim=1)
            attention_mask = torch.cat((attention_mask_a, attention_mask_b),
                                       dim=1)
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_attention_mask = extended_attention_mask.to(
                dtype=next(self.parameters()).dtype)
            extended_attention_mask = (1.0 -
                                       extended_attention_mask) * -10000.0
            token_type_ids_a = torch.zeros_like(token_type_ids_a)
            token_type_ids_b = torch.ones_like(token_type_ids_b)
            token_type_ids = torch.cat((token_type_ids_a, token_type_ids_b),
                                       dim=1)
            token_type_ids_emb = self.bert_type_id_emb(token_type_ids)
            seq_length = embeddings_cat.size(1)
            if position_ids is None:
                position_ids = torch.arange(seq_length,
                                            dtype=torch.long,
                                            device=input_ids_a.device)
                position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
            embeddings_cat_position_emb = self.bert_position_emb(position_ids)
            transformer_input = embeddings_cat + embeddings_cat_position_emb + token_type_ids_emb
            transformer_outputs = self.bert_layer(transformer_input,
                                                  extended_attention_mask)
            pooled_output = self.bert_pooler_qd(
                transformer_outputs[0], question_size=input_ids_a.size(1))

        logits = self.classifier(pooled_output)

        outputs = (logits, ) + outputs_a[2:] + outputs_b[
            2:]  # add hidden states and attention if they are here

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))
            outputs = (loss, ) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)