def __init__(self, config): super(BertEncoder, self).__init__() # in the bert encoder, we need to extract three things here. # text bert layer: BertLayer # vision bert layer: BertImageLayer # Bi-Attention: Given the output of two bertlayer, perform bi-directional # attention and add on two layers. t_config = BertConfig.from_dict(config.t_config) v_config = BertConfig.from_dict(config.v_config) self.FAST_MODE = config.fast_mode self.with_coattention = config.with_coattention self.v_biattention_id = v_config.biattention_id self.t_biattention_id = t_config.biattention_id self.in_batch_pairs = config.in_batch_pairs self.fixed_t_layer = config.fixed_t_layer self.fixed_v_layer = config.fixed_v_layer # layer = BertLayer(config) layer = BertLayer(t_config) v_layer = BertLayer(v_config) connect_layer = BertConnectionLayer(config) self.layer = nn.ModuleList( [copy.deepcopy(layer) for _ in range(t_config.num_hidden_layers)]) self.v_layer = nn.ModuleList([ copy.deepcopy(v_layer) for _ in range(v_config.num_hidden_layers) ]) self.c_layer = nn.ModuleList([ copy.deepcopy(connect_layer) for _ in range(len(v_config.biattention_id)) ])
def init_data(self, use_cuda: bool) -> None: test_device = torch.device('cuda:0') if use_cuda else \ torch.device('cpu:0') if not use_cuda: torch.set_num_threads(1) torch.set_grad_enabled(False) self.cfg = BertConfig(attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0) self.torch_bert_layer = BertLayer(self.cfg) self.torch_bert_layer.eval() if use_cuda: self.torch_bert_layer.to(test_device) self.hidden_size = self.cfg.hidden_size self.input_tensor = torch.rand(size=(batch_size, seq_length, self.hidden_size), dtype=torch.float32, device=test_device) self.attention_mask = torch.ones((batch_size, seq_length), dtype=torch.float32, device=test_device) self.attention_mask = self.attention_mask[:, None, None, :] self.attention_mask = (1.0 - self.attention_mask) * -10000.0 self.turbo_bert_layer = turbo_transformers.BertLayer.from_torch( self.torch_bert_layer)
def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) assert config.num_hidden_layers % 2 == 0, "num_hidden_layers must be even in Trelm!" self.tlayer = BertLayer(config) # self.tlayer_position = int(config.num_hidden_layers / 2)
def __init__(self, config): super(BertForSequenceClassificationNq, self).__init__(config) self.num_labels = config.num_labels # config.output_hidden_states = True bert_later_dropout = 0.3 self.dropout = nn.Dropout(bert_later_dropout) self.later_model_type = config.later_model_type if self.later_model_type == 'linear': self.bert = BertModel(config) self.projection = nn.Linear(config.hidden_size * 3, config.hidden_size) self.projection_dropout = nn.Dropout(0.1) self.projection_activation = nn.Tanh() self.classifier = nn.Linear(config.hidden_size, config.num_labels) elif self.later_model_type == '1bert_layer': config.num_hidden_layers = 1 self.bert = BertModel(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) elif self.later_model_type == 'bilinear': self.bert = BertModel(config) lstm_layers = 2 self.qemb_match = SeqAttnMatch(config.hidden_size) doc_input_size = 2 * config.hidden_size # RNN document encoder self.doc_rnn = StackedBRNN( input_size=doc_input_size, hidden_size=config.hidden_size, num_layers=lstm_layers, dropout_rate=bert_later_dropout, dropout_output=bert_later_dropout, concat_layers=True, rnn_type=nn.LSTM, padding=False, ) self.bilinear_dropout = nn.Dropout(bert_later_dropout) self.bilinear_size = 128 self.doc_proj = nn.Linear(lstm_layers * 2 * config.hidden_size, self.bilinear_size) self.qs_proj = nn.Linear(config.hidden_size, self.bilinear_size) self.bilinear = nn.Bilinear(self.bilinear_size, self.bilinear_size, self.bilinear_size) self.classifier = nn.Linear(self.bilinear_size, config.num_labels) elif self.later_model_type == 'transformer': self.copy_from_bert_layer_num = 11 self.bert = BertModel(config) self.bert_position_emb = nn.Embedding( config.max_position_embeddings, config.hidden_size) self.bert_type_id_emb = nn.Embedding(config.type_vocab_size, config.hidden_size) self.bert_layer = BertLayer(config) self.bert_pooler_qd = BertPoolerQD(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights()
def __init__(self, extractor, config, *args, **kwargs): super().__init__(*args, **kwargs) self.extractor = extractor self.config = config if config["pretrained"] == "electra-base-msmarco": self.bert = ElectraModel.from_pretrained( "Capreolus/electra-base-msmarco") elif config["pretrained"] == "bert-base-msmarco": self.bert = BertModel.from_pretrained( "Capreolus/bert-base-msmarco") elif config["pretrained"] == "bert-base-uncased": self.bert = BertModel.from_pretrained("bert-base-uncased") else: raise ValueError( f"unsupported model: {config['pretrained']}; need to ensure correct tokenizers will be used before arbitrary hgf models are supported" ) self.transformer_layer_1 = BertLayer(self.bert.config) self.transformer_layer_2 = BertLayer(self.bert.config) self.num_passages = extractor.config["numpassages"] self.maxseqlen = extractor.config["maxseqlen"] self.linear = nn.Linear(self.bert.config.hidden_size, 1) if config["aggregation"] == "max": raise NotImplementedError() elif config["aggregation"] == "avg": raise NotImplementedError() elif config["aggregation"] == "attn": raise NotImplementedError() elif config["aggregation"] == "transformer": self.aggregation = self.aggregate_using_transformer input_embeddings = self.bert.get_input_embeddings() # TODO hardcoded CLS token id cls_token_id = torch.tensor([[101]]) self.initial_cls_embedding = input_embeddings(cls_token_id).view( 1, self.bert.config.hidden_size) self.full_position_embeddings = torch.zeros( (1, self.num_passages + 1, self.bert.config.hidden_size), requires_grad=True, dtype=torch.float) torch.nn.init.normal_(self.full_position_embeddings, mean=0.0, std=0.02) self.initial_cls_embedding = nn.Parameter( self.initial_cls_embedding, requires_grad=True) self.full_position_embeddings = nn.Parameter( self.full_position_embeddings, requires_grad=True) else: raise ValueError( f"unknown aggregation type: {self.config['aggregation']}")
def __init__(self, config, scc_n_layer=6): super(BertEncoder, self).__init__() self.prd_n_layer = config.num_hidden_layers self.scc_n_layer = scc_n_layer assert self.prd_n_layer % self.scc_n_layer == 0 self.compress_ratio = self.prd_n_layer // self.scc_n_layer self.bernoulli = None self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList( [BertLayer(config) for _ in range(self.prd_n_layer)]) self.scc_layer = nn.ModuleList( [BertLayer(config) for _ in range(self.scc_n_layer)])
def __init__(self, config): super(BertEncoderATM, self).__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.map_linear = nn.Linear(config.hidden_size, config.hidden_size) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def __init__(self, config, num_shared_layers, num_layers): super(transformer_block, self).__init__() self.num_layers = num_layers self.num_shared_layers = num_shared_layers self.bert_layers = nn.ModuleList( [BertLayer(config) for _ in range(num_layers)]) self.pooler = BertPooler(config)
def __init__( self, config, visual_embedding_dim=512, embedding_strategy="plain", bypass_transformer=False, output_attentions=False, output_hidden_states=False, ): super().__init__(config) self.config = config config.visual_embedding_dim = visual_embedding_dim config.embedding_strategy = embedding_strategy config.bypass_transformer = bypass_transformer config.output_attentions = output_attentions config.output_hidden_states = output_hidden_states self.embeddings = BertVisioLinguisticEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) self.bypass_transformer = config.bypass_transformer if self.bypass_transformer: self.additional_layer = BertLayer(config) self.output_attentions = self.config.output_attentions self.output_hidden_states = self.config.output_hidden_states self.fixed_head_masks = [None for _ in range(len(self.encoder.layer))] self.init_weights()
def __init__(self, config, visual_start_layer): super().__init__() self.output_attentions = False # config.output_attentions self.output_hidden_states = True # config.output_hidden_states self.visual_start_layer = visual_start_layer self.config = config self.layer = torch.nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def __init__(self, config): super().__init__() # Obj-level image embedding layer self.visn_fc = VisualFeatEncoder(config) # Number of layers self.num_l_layers = config.l_layers self.num_x_layers = config.x_layers self.num_r_layers = config.r_layers self.layer = nn.ModuleList( [BertLayer(config) for _ in range(self.num_l_layers)]) self.x_layers = nn.ModuleList( [LXMERTXLayer(config) for _ in range(self.num_x_layers)]) self.r_layers = nn.ModuleList( [BertLayer(config) for _ in range(self.num_r_layers)])
def __init__(self, config): super(BertEncoder4TokenMix, self).__init__() # self.output_attentions = config.output_attentions # self.output_hidden_states = config.output_hidden_states self.output_attentions = False self.output_hidden_states = True self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def __init__(self, config): super().__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)]) self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
def __init__(self, config): super(Stage0, self).__init__() self.embedding_layer = BertEmbeddings(config) self.layers = [] for i in range(config.num_hidden_layers // 24): self.layers.append(BertLayer(config)) self.layers = torch.nn.ModuleList(self.layers) self.config = config self.apply(self.init_bert_weights)
def __init__(self, config): super(Stage1, self).__init__() self.layers = [] for i in range(12): #config.num_hidden_layers): self.layers.append(BertLayer(config)) self.layers = torch.nn.ModuleList(self.layers) self.pooling_layer = BertPooler(config) self.pre_training_heads_layer = BertPreTrainingHeads(config) self.config = config self.apply(self.init_bert_weights)
def __init__(self, config, embedding_dim, num_groups): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config, embedding_dim) msg = 'Amount of encoder blocks should be divisible by number of groups.' assert config.num_hidden_layers % num_groups == 0, msg self.encoder = nn.ModuleList([BertLayer(config) for _ in range(num_groups)]) self.group_size = config.num_hidden_layers // num_groups self.init_weights()
def __init__(self, config, num_layers=1, num_langs=1, struct="transformer", add_weights=False, tied=True, bottle_size=768): super().__init__() self.nets = [] self.num_layers = num_layers self.num_langs = num_langs self.struct = struct self.add_weights = add_weights self.tied = tied for i in range(num_langs): for j in range(num_layers): if struct == "transformer": self.nets.append(BertLayer(config)) elif struct == "perceptron": hidden_size = config.hidden_size if add_weights: if tied: self.nets.append( nn.Sequential( nn.Linear(hidden_size, bottle_size), nn.ReLU(), nn.Linear(bottle_size, hidden_size + 1))) else: self.nets.append( nn.Sequential( nn.Linear(hidden_size, bottle_size), nn.ReLU(), nn.Linear(bottle_size, hidden_size))) self.weight_net = nn.Sequential( nn.Linear(hidden_size, bottle_size), nn.ReLU(), nn.Linear(bottle_size, 1)) else: self.nets.append( nn.Sequential( nn.Linear(hidden_size, hidden_size // 4), nn.ReLU(), nn.Linear(hidden_size // 4, hidden_size))) else: print("The specified structure is not implemented.") sys.exit(0) self.nets = nn.ModuleList(self.nets) self.alpha = nn.Parameter(torch.zeros(num_langs, num_layers)) if struct == "perceptron": self.init_weights()
def __init__(self, hidden_size, args): super().__init__() self.extract_num_layers = args.extract_num_layers self.pos_embeddings = PositionalEncoding(hidden_size, args.extract_dropout_prob) config = BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size*4, layer_norm_eps=args.extract_layer_norm_eps, hidden_dropout_prob=args.extract_dropout_prob, attention_probs_dropout_prob=args.extract_dropout_prob) self.encoder_stack = nn.ModuleList([BertLayer(config) for _ in range(args.extract_num_layers)]) self.dropout = nn.Dropout(args.extract_dropout_prob) self.layer_norm = nn.LayerNorm(hidden_size, args.extract_layer_norm_eps) self.linear = nn.Linear(hidden_size, 1) self.sigmoid = nn.Sigmoid()
def __init__(self, num_labels, pretrained_model_name_or_path=None, cat_num=0, token_size=None, MAX_SEQUENCE_LENGTH=512): super(BertModelForBinaryMultiLabelClassifier, self).__init__() if pretrained_model_name_or_path: self.model = BertModel.from_pretrained( pretrained_model_name_or_path) else: raise NotImplementedError self.num_labels = num_labels if cat_num > 0: self.catembedding = nn.Embedding(cat_num, 768) self.catdropout = nn.Dropout(0.2) self.catactivate = nn.ReLU() self.catembeddingOut = nn.Embedding(cat_num, cat_num // 2 + 1) self.catactivateOut = nn.ReLU() self.dropout = nn.Dropout(0.2) self.classifier = nn.Linear( self.model.pooler.dense.out_features + cat_num // 2 + 1, num_labels) else: self.catembedding = None self.catdropout = None self.catactivate = None self.catembeddingOut = None self.catactivateOut = None self.dropout = nn.Dropout(0.2) self.classifier = nn.Linear(self.model.pooler.dense.out_features, num_labels) # resize if token_size: self.model.resize_token_embeddings(token_size) # define input embedding and transformers input_model_config = BertConfig( vocab_size=token_size, max_position_embeddings=MAX_SEQUENCE_LENGTH) self.input_embeddings = BertEmbeddings(input_model_config) self.input_bert_layer = BertLayer(input_model_config) # use bertmodel as decoder # self.model.config.is_decoder = True # add modules self.add_module('my_input_embeddings', self.input_embeddings) self.add_module('my_input_bert_layer', self.input_bert_layer) self.add_module('fc_output', self.classifier)
def init_bertlayer_models(self, use_cuda: bool) -> None: self.test_device = torch.device('cuda:0') if use_cuda else \ torch.device('cpu:0') if not use_cuda: torch.set_num_threads(1) torch.set_grad_enabled(False) self.cfg = BertConfig(attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0) self.torch_model = BertLayer(self.cfg) self.torch_model.eval() if use_cuda: self.torch_model.to(self.test_device) self.hidden_size = self.cfg.hidden_size self.turbo_model = turbo_transformers.BertLayerSmartBatch.from_torch( self.torch_model)
def __init__(self, config): super(GLTTokenGrounding, self).__init__() self.initial_img_project = nn.Linear(config.input_img_dim, config.hidden_size) self.text_project = nn.Linear(config.hidden_size, config.hidden_size) self.b_bias = nn.Parameter(torch.zeros((1))) self.lnorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pos_project = nn.Linear(6, config.hidden_size) self.lnorm_pos = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.coreference_jump_gate = nn.Linear(config.hidden_size, 2) self.layer = nn.ModuleList([ BertLayer(config) for _ in range(config.visual_self_attention_layers) ])
def __init__(self, config): super(GLTEncoder, self).__init__() self.ground = GLTTokenGrounding(config) modules = [ GLTLayer(config) for _ in range(config.max_sentence_length - 1) ] layers_to_tie = config.layers_to_tie if layers_to_tie: tie_layers(modules[0], modules[1:], config.layers_to_tie) self.layer = nn.ModuleList(modules) self.answer_nn = GLTAnswerVisualTextComp(config) self._contextualize_inputs_n_layers = config.contextualize_inputs_layers if self._contextualize_inputs_n_layers: self.self_att_layers = nn.ModuleList([ BertLayer(config) for _ in range(self._contextualize_inputs_n_layers) ]) else: self.self_att_layers = None
class TestBertLayer(unittest.TestCase): def init_data(self, use_cuda: bool) -> None: test_device = torch.device('cuda:0') if use_cuda else \ torch.device('cpu:0') if not use_cuda: torch.set_num_threads(1) torch.set_grad_enabled(False) self.cfg = BertConfig(attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0) self.torch_bert_layer = BertLayer(self.cfg) self.torch_bert_layer.eval() if use_cuda: self.torch_bert_layer.to(test_device) self.hidden_size = self.cfg.hidden_size self.input_tensor = torch.rand(size=(batch_size, seq_length, self.hidden_size), dtype=torch.float32, device=test_device) self.attention_mask = torch.ones((batch_size, seq_length), dtype=torch.float32, device=test_device) self.attention_mask = self.attention_mask[:, None, None, :] self.attention_mask = (1.0 - self.attention_mask) * -10000.0 self.turbo_bert_layer = turbo_transformers.BertLayer.from_torch( self.torch_bert_layer) def check_torch_and_turbo(self, use_cuda): self.init_data(use_cuda) num_iter = 2 device = "GPU" if use_cuda else "CPU" torch_model = lambda: self.torch_bert_layer( self.input_tensor, self.attention_mask, output_attentions=True) torch_bert_layer_result, torch_qps, torch_time = \ test_helper.run_model(torch_model, use_cuda, num_iter) print(f"BertLayer \"({batch_size},{seq_length:03})\" ", f"{device} Torch QPS, {torch_qps}, time, {torch_time}") turbo_model = lambda: self.turbo_bert_layer( self.input_tensor, self.attention_mask, output_attentions=True) turbo_bert_layer_result, turbo_qps, turbo_time = \ test_helper.run_model(turbo_model, use_cuda, num_iter) print( f"BertLayer \"({batch_size},{seq_length:03})\" ", f"{device} TurboTransform QPS, {turbo_qps}, time, {turbo_time}" ) # Tensor core will introduce more errors tolerate_error = 1e-2 if use_cuda else 1e-3 self.assertTrue( torch.max( torch.abs(torch_bert_layer_result[0] - turbo_bert_layer_result[0])) < tolerate_error) self.assertTrue( torch.max( torch.abs(torch_bert_layer_result[1] - turbo_bert_layer_result[1])) < tolerate_error) with open(fname, "a") as fh: fh.write( f"\"({batch_size},{seq_length:03})\", {torch_qps}, {turbo_qps}\n" ) def test_bert_layer(self): self.check_torch_and_turbo(use_cuda=False) if torch.cuda.is_available() and \ turbo_transformers.config.is_compiled_with_cuda(): self.check_torch_and_turbo(use_cuda=True)
def __init__(self, config): super().__init__(config) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList( [BertLayer(config) for _ in range(config.num_hidden_layers)])
class BertForSequenceClassificationNq(BertPreTrainedModel): def __init__(self, config): super(BertForSequenceClassificationNq, self).__init__(config) self.num_labels = config.num_labels # config.output_hidden_states = True bert_later_dropout = 0.3 self.dropout = nn.Dropout(bert_later_dropout) self.later_model_type = config.later_model_type if self.later_model_type == 'linear': self.bert = BertModel(config) self.projection = nn.Linear(config.hidden_size * 3, config.hidden_size) self.projection_dropout = nn.Dropout(0.1) self.projection_activation = nn.Tanh() self.classifier = nn.Linear(config.hidden_size, config.num_labels) elif self.later_model_type == '1bert_layer': config.num_hidden_layers = 1 self.bert = BertModel(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) elif self.later_model_type == 'bilinear': self.bert = BertModel(config) lstm_layers = 2 self.qemb_match = SeqAttnMatch(config.hidden_size) doc_input_size = 2 * config.hidden_size # RNN document encoder self.doc_rnn = StackedBRNN( input_size=doc_input_size, hidden_size=config.hidden_size, num_layers=lstm_layers, dropout_rate=bert_later_dropout, dropout_output=bert_later_dropout, concat_layers=True, rnn_type=nn.LSTM, padding=False, ) self.bilinear_dropout = nn.Dropout(bert_later_dropout) self.bilinear_size = 128 self.doc_proj = nn.Linear(lstm_layers * 2 * config.hidden_size, self.bilinear_size) self.qs_proj = nn.Linear(config.hidden_size, self.bilinear_size) self.bilinear = nn.Bilinear(self.bilinear_size, self.bilinear_size, self.bilinear_size) self.classifier = nn.Linear(self.bilinear_size, config.num_labels) elif self.later_model_type == 'transformer': self.copy_from_bert_layer_num = 11 self.bert = BertModel(config) self.bert_position_emb = nn.Embedding( config.max_position_embeddings, config.hidden_size) self.bert_type_id_emb = nn.Embedding(config.type_vocab_size, config.hidden_size) self.bert_layer = BertLayer(config) self.bert_pooler_qd = BertPoolerQD(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def init_top_layer_from_bert(self): if self.later_model_type == 'transformer': # directly load from bert copy_dict = copy.deepcopy(self.bert.encoder.layer[ self.copy_from_bert_layer_num].state_dict()) self.bert_layer.load_state_dict(copy_dict) copy_dict = copy.deepcopy( self.bert.embeddings.position_embeddings.state_dict()) self.bert_position_emb.load_state_dict(copy_dict) copy_dict = copy.deepcopy( self.bert.embeddings.token_type_embeddings.state_dict()) self.bert_type_id_emb.load_state_dict(copy_dict) def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None, input_ids_a=None, token_type_ids_a=None, attention_mask_a=None, input_ids_b=None, token_type_ids_b=None, attention_mask_b=None): # outputs_original = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, # attention_mask=attention_mask, head_mask=head_mask) # pooled_output = self.dropout(pooled_output) outputs_a = self.bert(input_ids_a, position_ids=None, token_type_ids=token_type_ids_a, attention_mask=attention_mask_a) outputs_b = self.bert(input_ids_b, position_ids=None, token_type_ids=token_type_ids_b, attention_mask=attention_mask_b) if self.later_model_type == 'linear': pooled_output_a = outputs_a[1] pooled_output_a = self.dropout(pooled_output_a) pooled_output_b = outputs_b[1] pooled_output_b = self.dropout(pooled_output_b) pooled_output = torch.cat((pooled_output_a, pooled_output_b, pooled_output_a - pooled_output_b), dim=1) pooled_output = self.projection(pooled_output) pooled_output = self.projection_activation(pooled_output) pooled_output = self.projection_dropout(pooled_output) elif self.later_model_type == '1bert_layer': encoder_outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) bert_1stlayer = encoder_outputs[1] # pooled_output = self.bert_pooler(bert_1stlayer) pooled_output = self.dropout(bert_1stlayer) elif self.later_model_type == 'bilinear': question_hiddens = outputs_a[0] question_hiddens = self.dropout(question_hiddens) doc_hiddens = outputs_b[0] doc_hiddens = self.dropout(doc_hiddens) question_mask = (1 - attention_mask_a).to(torch.bool) doc_mask = (1 - attention_mask_b).to(torch.bool) x2_weighted_emb = self.qemb_match(doc_hiddens, question_hiddens, question_mask) doc_hiddens = torch.cat((doc_hiddens, x2_weighted_emb), 2) doc_hiddens = self.doc_rnn(doc_hiddens, doc_mask) question_hidden = outputs_a[1] question_hidden = self.qs_proj(question_hidden) question_hidden = self.bilinear_dropout(question_hidden) doc_hiddens = self.doc_proj(doc_hiddens) doc_hiddens = self.bilinear_dropout(doc_hiddens) question_hidden = question_hidden.unsqueeze(1).expand_as( doc_hiddens).contiguous() doc_hiddens = self.bilinear(doc_hiddens, question_hidden) pooled_output = doc_hiddens.max(dim=1)[0] elif self.later_model_type == 'transformer': input_ids = torch.cat((input_ids_a, input_ids_b), dim=1) bert_embeddings_a = outputs_a[0] bert_embeddings_b = outputs_b[0] embeddings_cat = torch.cat((bert_embeddings_a, bert_embeddings_b), dim=1) attention_mask = torch.cat((attention_mask_a, attention_mask_b), dim=1) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 token_type_ids_a = torch.zeros_like(token_type_ids_a) token_type_ids_b = torch.ones_like(token_type_ids_b) token_type_ids = torch.cat((token_type_ids_a, token_type_ids_b), dim=1) token_type_ids_emb = self.bert_type_id_emb(token_type_ids) seq_length = embeddings_cat.size(1) if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids_a.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) embeddings_cat_position_emb = self.bert_position_emb(position_ids) transformer_input = embeddings_cat + embeddings_cat_position_emb + token_type_ids_emb transformer_outputs = self.bert_layer(transformer_input, extended_attention_mask) pooled_output = self.bert_pooler_qd( transformer_outputs[0], question_size=input_ids_a.size(1)) logits = self.classifier(pooled_output) outputs = (logits, ) + outputs_a[2:] + outputs_b[ 2:] # add hidden states and attention if they are here if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss, ) + outputs return outputs # (loss), logits, (hidden_states), (attentions)