class BertPosattnForSequenceClassification(PreTrainedBertModel): def __init__(self, config, num_labels=2, max_offset=10, offset_emb=30): """ :param config: :param num_labels: :param max_offset: :param offset_emb: size of pos embedding, 0 to disable """ print('model_post attention') print('max_offset:', max_offset) print('offset_emb:', offset_emb) super(BertPosattnForSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) if offset_emb > 0: self.offset1_emb = nn.Embedding(2 * max_offset + 1, offset_emb) self.offset2_emb = nn.Embedding(2 * max_offset + 1, offset_emb) self.attn_layer_1 = nn.Linear((config.hidden_size + offset_emb) * 2, config.hidden_size) self.attn_tanh = nn.Tanh() self.attn_layer_2 = nn.Linear(config.hidden_size, 1) self.attn_softmax = nn.Softmax(dim=1) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, offset1, offset2, token_type_ids=None, attention_mask=None, labels=None): encoded_layers, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) # Batch x Tok_pos x Hidden_dim batch_n, tok_n, hid_n = encoded_layers.size() # Batch x Hidden_dim -> Batch x Tok_pos x Hidden_dim global_vec = pooled_output.unsqueeze(1).repeat(1, tok_n, 1) tensors_to_cat = [encoded_layers, global_vec] if hasattr(self, 'offset1_emb') and hasattr(self, 'offset2_emb'): tensors_to_cat += [ self.offset1_emb(offset1), self.offset2_emb(offset2) ] # (Tok_pos*Batch) x (Hidden_dim*2+offset_emb*2) attn_input = torch.cat(tensors_to_cat, 2).view(batch_n * tok_n, -1) # (Tok_pos*Batch) x Hidden_dim attn_1 = self.attn_layer_1(attn_input) attn_1 = self.attn_tanh(attn_1) # (Tok_pos*Batch) x 1 -> Batch x Tok attn_2 = self.attn_layer_2(attn_1).view(batch_n, tok_n) attn_weight = self.attn_softmax(attn_2) # Batch x Tok_pos x Hidden_dim -> Batch x Hiddem_dim weighted_layers = torch.sum(attn_weight.unsqueeze(2) * encoded_layers, dim=1, keepdim=True) pooled_output = self.bert.pooler(weighted_layers) pooled_output = self.dropout(pooled_output) # Batch x label_num logits = self.classifier(pooled_output) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss, logits else: return logits
class BertKnrm(BertPreTrainedModel): """Implementation of K-NRM on top of BERT for Ad-hoc ranking. See [1] that creates such a model and [2] and [3] for K-NRM and Convolutional K-NRM. References ---------- [1] - MacAvaney, S., Yates, A., Cohan, A., & Goharian, N. (2019). CEDR: Contextualized Embeddings for Document Ranking. CoRR. (https://arxiv.org/pdf/1904.07094.pdf) [2] - Xiong, C., Dai, Z., Callan, J., Liu, Z., & Power, R. (2017, August). End-to-end neural ad-hoc ranking with kernel pooling. In Proceedings of the 40th International ACM SIGIR conference on research and development in information retrieval (pp. 55-64). ACM. (http://www.cs.cmu.edu/~zhuyund/papers/end-end-neural.pdf) [3] - Dai, Z., Xiong, C., Callan, J., & Liu, Z. (2018, February). Convolutional neural networks for soft-matching n-grams in ad-hoc search. In Proceedings of the eleventh ACM international conference on web search and data mining (pp. 126-134). ACM. (http://www.cs.cmu.edu/~./callan/Papers/wsdm18-zhuyun-dai.pdf) """ def __init__(self, config, use_knrm=False, K=11, lamb=0.5, use_exact=True, last_layer_only=True, N=None, method="mean", weights=None, mu_sigma_learnable=False): super(BertKnrm, self).__init__(config) self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) # bert encoding from different layers options if not last_layer_only: # if N is None, will consider last 5 layers if N is None: N = 5 self.N = N if method not in ("avg", "wavg", "sum", "wsum", "max", "selfattn"): method = "avg" self.method = method if method.startswith("w"): if weights is None or len(weights) != self.N: # better weights setting? maybe make them learnable if used? # fix me: hard coded 12 layers self.weights = torch.linspace(0.01, 1.0, self.N if self.N else 12) else: self.weights = torch.tensor(weights, dtype=torch.float) self.last_layer_only = last_layer_only if use_knrm: # kernels options self.K = K # make mu and sigma learnable, otherwise use values from paper if not mu_sigma_learnable: self.mus = torch.tensor( self.kernal_mus(K, use_exact), dtype=torch.float ) self.sigmas = torch.tensor( self.kernel_sigmas(K, lamb, use_exact), dtype=torch.float ) else: self.mus = nn.Parameter(torch.randn(K).float()) self.sigmas = nn.Parameter(torch.randn(K).float()) self.mu_sigma_learnable = mu_sigma_learnable # output layers for final score self.linear = nn.Linear(K, 1) else: self.linear = nn.Linear(768, 1) self.use_knrm = use_knrm self.activation = nn.Tanh() self.apply(self.init_bert_weights) def to_device(self, device): if not self.mu_sigma_learnable: self.mus = self.mus.to(device) self.sigmas = self.sigmas.to(device) if not self.last_layer_only and self.method.startswith("w"): self.weights = self.weights.to(device) def kernal_mus(self, n_kernels, use_exact): """Get mu value for each Gaussian kernel. Mu is the middle of each bin. Parameters ---------- n_kernels : int Number of kernel (including exact match), first one is exact match. use_exact : bool Whether to use exact match kernel. Returns ------- l_mu : list of float List of mu values. References ---------- Taken from K-NRM source: https://github.com/AdeDZY/K-NRM/blob/master/knrm/model/model_base.py """ if use_exact: l_mu = [1] else: l_mu = [2] if n_kernels == 1: return l_mu bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1] l_mu.append(1 - bin_size / 2) # mu: middle of the bin for i in range(1, n_kernels - 1): l_mu.append(l_mu[i] - bin_size) return l_mu def kernel_sigmas(self, n_kernels, lamb, use_exact): """Get sigma value for each Gaussian kernel. Parameters ---------- n_kernels : int Number of kernels (including exact match). lamb : float Defines the gaussian kernels' sigma value. use_exact : bool Whether to use exact match kernel. Returns ------- l_sigma : list of float List of sigma values. References ---------- Taken from K-NRM source: https://github.com/AdeDZY/K-NRM/blob/master/knrm/model/model_base.py """ bin_size = 2.0 / (n_kernels - 1) l_sigma = [0.00001] # for exact match. small variance -> exact match if n_kernels == 1: return l_sigma l_sigma += [bin_size * lamb] * (n_kernels - 1) return l_sigma def encoded_layers_transform(self, encoded_layers): """Utility function to play with BERT layers.""" if self.last_layer_only: return encoded_layers[-1] # list of N layers of shape B x L x H if self.N: encoded_layers = encoded_layers[-self.N:] else: self.N = len(encoded_layers) N = self.N B, L, H = encoded_layers[0].shape # > [(B x L x H), ..., (B x L x H)] -> B x L x NH encoded_layers = torch.cat(encoded_layers, dim=2) # > B x L x N x H encoded_layers = encoded_layers.reshape(B, L, N, H) # > B x N x L x H encoded_layers = encoded_layers.permute(0, 2, 1, 3) if self.method == "selfattn": # > B x N x LH encoded_layers = encoded_layers.contiguous().view(B, N, L*H) # > B x N x LH * B x LH x N --bmm--> B x N x N attention_scores = encoded_layers.bmm(encoded_layers.transpose(-1, -2)) attention_scores = torch.softmax(attention_scores, dim=-1) # soft (attended) layers # > B x N x LH soft_encoded_layers = attention_scores.bmm(encoded_layers) # add with residual encoded_layers = encoded_layers + soft_encoded_layers # average over all layers # > B x LH encoded_layers = encoded_layers.mean(dim=1) # > B x L x H output = encoded_layers.view(B, L, H) else: if self.method.startswith("w"): # > N --unsqueeze--> N x 1 --expand--> N x LH # --unsqueeze--> 1 x N x LH --expand--> B x N x LH weights = self.weights.unsqueeze(-1).expand(-1, L*H).unsqueeze(0).expand(2, -1, -1) # > B x N x L x H weights = weights.reshape(B, N, L, H) output = weights * encoded_layers # > B x L x H output = output.sum(1) if "sum" in self.method else output.mean(1) else: # > B x L x H if self.method == "sum": output = encoded_layers.sum(1) elif self.method == "avg": output = encoded_layers.mean(1) else: output, _ = encoded_layers.max(1) return output def knrm(self, embedded, segment_ids, input_mask): # # input_mask : B x L # segment_ids : B x L # # * note: BERT default LRs [2|3|5]e-5 did not # worked for KNRM. Loss does not change because # LR is too small, changing to 1e-4 works. # # Original K-NRM implementation: # https://github.com/AdeDZY/K-NRM/blob/master/knrm/model/model_knrm.py#L74 # document_ids_mask = segment_ids * input_mask query_ids_mask = (1 - segment_ids) * input_mask # batch wise outer product to get query-doc masks query_doc_mask = torch.bmm( query_ids_mask.unsqueeze(2).float(), document_ids_mask.unsqueeze(1).float() ) embedded_normalized = F.normalize(embedded, p=2, dim=2) # B x L x H * B x H x L --> B x L x L M = embedded_normalized.bmm(embedded_normalized.transpose(1, 2)) # B x L x L x 1 phi_M = M.unsqueeze(-1) # eq. 4 numerator and denominator # > (B x L x L x 1 - K) --broadcasted--> B x L x L x K numerator = phi_M - self.mus numerator = -(numerator ** 2) # denominator is K sized vector denominator = 2 * (self.sigmas ** 2) # eq. 4 without summation # > B x L x L x K phi_M = torch.exp(numerator/denominator) # apply masks query_doc_mask = query_doc_mask.unsqueeze(-1).float() phi_M = phi_M * query_doc_mask # sum along document dimension (eq. 4 with summation) # > B x L x K phi_M = phi_M.sum(2) # clip small values phi_M[phi_M < 1e-10] = 1e-10 phi_M = torch.log(phi_M) * 0.01 # sum over query features to get TF-soft features # > B x K phi_M = phi_M.sum(1) return phi_M def forward(self, input_ids, segment_ids, input_mask): # # embedded : B x L x H # cls_embed : B x H # where, L is joint length i.e. "query len + doc len" # embedded, cls_embed = self.bert( input_ids, segment_ids, input_mask, output_all_encoded_layers=not self.last_layer_only ) if not self.last_layer_only: embedded = embedded[-self.N:] embedded = self.encoded_layers_transform(embedded) cls_embed = self.bert.pooler(embedded) if self.use_knrm: phi_M = self.knrm(embedded, segment_ids, input_mask) output = self.linear(phi_M).squeeze(-1) else: output = self.linear(cls_embed).squeeze(-1) output = self.activation(output) return output