def __init__(self, args, num_authors: int, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = 20, 768 self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True) # (m, d) self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True) # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim) self.tanh = nn.Tanh() self.softmax = nn.Softmax(dim=2) self.sigmoid = nn.Sigmoid() self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz) # self.loss = nn.CrossEntropyLoss() # loss related # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32 self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim) self.attention = nn.Parameter(torch.randn( self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True) # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim) self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.date_span = date_span self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32 self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768 # self.author_dim = self.sk_dim + self.time_dim self.author_dim = self.sk_dim self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.author_dim), requires_grad=True) # (m, d) # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim) self.temp_ctx_attention_ns = TempCtxAttentionNS( h=8, d_model=self.author_dim, d_query=self.sk_dim, d_time=self.time_dim) # temporal context self.time_encoder = TimeEncoder(self.time_dim, dropout=0.1, span=1, date_range=date_span) # layer_norm self.ctx_layer_norm = LayerNorm(self.author_dim) # loss related # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.coherence_func = CoherenceInnerProd()
def __init__(self, args, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim()) self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz) self.loss = nn.CrossEntropyLoss()
class BertClassifier(Model): def __init__(self, args, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim()) self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz) self.loss = nn.CrossEntropyLoss() def forward(self, tokens: Dict[str, torch.Tensor], id: Any, label: torch.Tensor) -> torch.Tensor: mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) state = self.encoder(embeddings, mask) class_logits = self.projection(state) output = {"class_logits": class_logits} output["loss"] = self.loss(class_logits, label) return output
def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any, num_shift: int, span: int): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.date_span = date_span self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = 20, 768 self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim) # layer_norm self.ctx_layer_norm = nn.LayerNorm(self.sk_dim) self.shift_temp_att = ShiftTempAttention(self.num_authors, self.sk_dim, date_span, num_shift, span) # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.weight_temp = 0.3
class TempXCtxModel(ModelBase): def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.date_span = date_span self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32 self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768 # self.author_dim = self.sk_dim + self.time_dim self.author_dim = self.sk_dim self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.author_dim), requires_grad=True) # (m, d) # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim) self.temp_ctx_attention_ns = TempCtxAttentionNS( h=8, d_model=self.author_dim, d_query=self.sk_dim, d_time=self.time_dim) # temporal context self.time_encoder = TimeEncoder(self.time_dim, dropout=0.1, span=1, date_range=date_span) # layer_norm self.ctx_layer_norm = LayerNorm(self.author_dim) # loss related # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.coherence_func = CoherenceInnerProd() def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any, date: Any, accept_usr: Any) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) # (n, l, d) token_hidden = self.encoder(embeddings, mask).transpose(-1, -2) # (n, l, d) token_embed = torch.mean(token_hidden, 1).squeeze(1) # (n, d) # token_embed = token_hidden[:, :, -1] # transfer the date into time embedding # TODO: use answer date for time embedding time_embed = gen_time_encoding(self.time_encoder, date) # token_temp_embed = torch.cat((token_embed, time_embed), 1) token_temp_embed = token_embed + time_embed author_tctx_embed = self.temp_ctx_attention_ns(token_embed, self.author_embeddings, self.author_embeddings, time_embed) # (n, m, d) # add layer norm for author context embedding author_tctx_embed = self.ctx_layer_norm(author_tctx_embed) # (n, m, d) # generate loss loss, coherence = self.rank_loss(token_temp_embed, author_tctx_embed, answerers, accept_usr) output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) truth = [[j[0] for j in i] for i in answerers] # self.rank_recall(predict, truth) # self.mrr(predict, truth) self.mrr(predict, accept_usr) return output
class BertNoCtxRanker(ModelBase): def __init__(self, args, num_authors: int, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = 20, 768 self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True) # (m, d) self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True) # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim) self.tanh = nn.Tanh() self.softmax = nn.Softmax(dim=2) self.sigmoid = nn.Sigmoid() self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz) # self.loss = nn.CrossEntropyLoss() # loss related # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) def build_coherence(self, token_hidden, author_embeds): # token_hidden (n, d, l) # author_embeds (m, d) n, _, l = token_hidden.shape m = author_embeds.shape[0] token_embed = torch.mean(token_hidden, 2) # (n, d) coherence = torch.einsum('nd,md->nm', [token_embed, author_embeds]) # (n, m) return coherence def forward(self, tokens: Dict[str, torch.Tensor], id: Any, label: Any, date: Any) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) token_hidden = self.encoder(embeddings, mask) # (n, d, l) token_embed = torch.mean(token_hidden, 2) # (n, d) # coherence = self.build_coherence(token_hidden, self.author_embeddings) # generate positive loss # all_labels = list(range(self.num_authors)) # loss = 0 # for i, pos_labels in enumerate(label): # # pos_labels = torch.tensor(pos_labels) # if torch.cuda.is_available(): pos_labels = pos_labels.cuda() # pos_coherence = coherence[i, pos_labels] # pos_loss = torch.sum(-torch.log(self.sigmoid(pos_coherence))) / len(pos_labels) # # neg_labels = torch.tensor([item for item in all_labels if item not in pos_labels]) # if torch.cuda.is_available(): neg_labels = neg_labels.cuda() # neg_coherence = coherence[i, neg_labels] # neg_loss = torch.sum(-torch.log(self.sigmoid(-neg_coherence))) / len(neg_labels) # # loss += (pos_loss + neg_loss) # pass # generate negative loss # # positive author embeddings # pos_author_embeds, pos_size = self.gen_pos_author_embeds(label) # (n, p, d, k) # # # negative author embeddings # neg_size = pos_size # choose negative samples the same as positive size # neg_author_embeds = self.gen_neg_author_embeds(label, neg_size) # # pos_coherence = self.build_coherence(token_hidden, pos_author_embeds) # neg_coherence = self.build_coherence(token_hidden, neg_author_embeds) # # pos_loss = torch.sum(torch.sum(torch.log(self.sigmoid(-pos_coherence)))) / pos_size # neg_loss = torch.sum(torch.sum(torch.log(self.sigmoid(neg_coherence)))) / neg_size # loss = pos_loss + neg_loss # loss, coherence = self.cohere_loss(token_embed, self.author_embeddings, label, no_ctx=True) loss, coherence = self.triplet_loss(token_embed, self.author_embeddings, label, no_ctx=True) output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) self.rank_recall(predict, label) return output
def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any, num_shift: int, spans: List, encoder: Any, max_vocab_size: int, ignore_time: bool, ns_mode: bool = False, num_sk: int = 20): super().__init__(vocab) self.date_span = date_span self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = num_sk, 768 self.ignore_time = ignore_time self.ns_mode = ns_mode if self.ns_mode: self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.sk_dim), requires_grad=True) # (m, d) else: self.author_embeddings = nn.Parameter( torch.randn(num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.encode_type = encoder if self.encode_type == "bert": # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) else: # prepare embeddings token_embedding = Embedding(num_embeddings=max_vocab_size + 2, embedding_dim=300, padding_index=0) self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder( {"tokens": token_embedding}) self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper( nn.LSTM(self.word_embeddings.get_output_dim(), hidden_size=int(self.sk_dim / 2), bidirectional=True, batch_first=True)) self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim) self.ctx_layer_norm = nn.LayerNorm(self.sk_dim) # layer_norm # shifted temporal attentions self.spans = spans self.span_temp_atts = nn.ModuleList() for span in self.spans: self.span_temp_atts.append( ShiftTempAttention(self.num_authors, self.sk_dim, date_span, num_shift, span, self.ignore_time)) self.span_projection = nn.Linear(len(spans), 1) self.num_shift = num_shift # temporal encoder: used only for adding temporal information into token embedding self.time_encoder = TimeEncoder(self.sk_dim, dropout=0.1, span=spans[0], date_range=date_span) # loss # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.weight_temp = 0.3 self.visual_id = 0
class MultiSpanTempModel(ModelBase): def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any, num_shift: int, spans: List, encoder: Any, max_vocab_size: int, ignore_time: bool, ns_mode: bool = False, num_sk: int = 20): super().__init__(vocab) self.date_span = date_span self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = num_sk, 768 self.ignore_time = ignore_time self.ns_mode = ns_mode if self.ns_mode: self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.sk_dim), requires_grad=True) # (m, d) else: self.author_embeddings = nn.Parameter( torch.randn(num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.encode_type = encoder if self.encode_type == "bert": # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) else: # prepare embeddings token_embedding = Embedding(num_embeddings=max_vocab_size + 2, embedding_dim=300, padding_index=0) self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder( {"tokens": token_embedding}) self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper( nn.LSTM(self.word_embeddings.get_output_dim(), hidden_size=int(self.sk_dim / 2), bidirectional=True, batch_first=True)) self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim) self.ctx_layer_norm = nn.LayerNorm(self.sk_dim) # layer_norm # shifted temporal attentions self.spans = spans self.span_temp_atts = nn.ModuleList() for span in self.spans: self.span_temp_atts.append( ShiftTempAttention(self.num_authors, self.sk_dim, date_span, num_shift, span, self.ignore_time)) self.span_projection = nn.Linear(len(spans), 1) self.num_shift = num_shift # temporal encoder: used only for adding temporal information into token embedding self.time_encoder = TimeEncoder(self.sk_dim, dropout=0.1, span=spans[0], date_range=date_span) # loss # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.weight_temp = 0.3 self.visual_id = 0 def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any, date: Any, accept_usr: Any) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) token_hidden2 = self.encoder(embeddings, mask) if self.encode_type == "bert": token_hidden = self.encoder(embeddings, mask).transpose(-1, -2) # (n, d, l) token_embed = torch.mean(token_hidden, 1).squeeze(1) # (n, d) else: token_embed = self.encoder(embeddings, mask) # (n, d) time_embed = gen_time_encoding(self.time_encoder, date) token_temp_embed = token_embed if self.ignore_time else token_embed + time_embed # if self.ignore_time: # token_temp_embed = token_embed # else: # token_temp_embed = token_embed + time_embed # add time embedding # generate the token_embed with temporal information # time_embed_zs = [self.time_encoder.get_time_encoding(d, num_shift=0) for d in date] # time_embed_zs = torch.stack(time_embed_zs, dim=0) # (n, d) # token_temp_embed = token_embed + time_embed_zs if self.ns_mode: author_ctx_embed = self.author_embeddings.unsqueeze(0).expand( token_embed.size(0), -1, -1) # (n, m, d) else: # token_embed = token_hidden[:, :, -1] author_ctx_embed = self.ctx_attention( token_temp_embed, self.author_embeddings, self.author_embeddings) # (n, m, d) # add layer norm for author context embedding author_ctx_embed = self.ctx_layer_norm(author_ctx_embed) # multi-span shifted time attention layer span_temp_ctx_embeds, history_embeds = [], [] for i in range(len(self.spans)): temp_ctx_embed, history_embed = self.span_temp_atts[i]( token_embed, author_ctx_embed, date) # (n, m, d) span_temp_ctx_embeds.append(temp_ctx_embed) history_embeds.append(history_embed) temp_ctx_embed_sp = torch.stack(span_temp_ctx_embeds, dim=-1) # temp_ctx_embed_sp = torch.transpose(torch.stack(temp_ctx_embed_splist), 0, -1) temp_ctx_embed = torch.squeeze(self.span_projection(temp_ctx_embed_sp), dim=-1) # print temporal context-aware embedding for visualization for i, answerer in enumerate(answerers): # generate the visualization embedding file if len(answerer) > 10: print("QID:", id[i], "Answerers:", len(answerer)) embed_pq = temp_ctx_embed[i].cpu().numpy() qid = id[i] answerer_set = set([j[0] for j in answerer]) with open("./exp_results/ve_" + str(qid), 'a') as f: for j in range(embed_pq.shape[0]): embed_pa = embed_pq[j] embed_dump = "\t".join([str(i) for i in embed_pa]) category = 1 if j in answerer_set else 0 f.write(str(category) + "\t" + embed_dump + "\n") self.visual_id += 1 # generate loss # loss, coherence = self.cohere_loss(token_embed, temp_ctx_embed, label) # triplet_loss, coherence = self.triplet_loss(token_embed, temp_ctx_embed, label) triplet_loss, coherence = self.rank_loss(token_embed, temp_ctx_embed, answerers, accept_usr) truth = [[j[0] for j in i] for i in answerers] if self.num_shift > 2: # no temporal loss between 1st and 2nd shifts temp_loss = sum([ self.temp_loss(token_embed, history_embed, truth) for history_embed in history_embeds ]) else: temp_loss = 0 loss = triplet_loss + temp_loss * self.weight_temp output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) #print("Truth:", accept_usr) self.mrr(predict, accept_usr) return output
class BertCtxRanker(ModelBase): def __init__(self, args, num_authors: int, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = 20, 768 self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.sk_dim, self.num_sk), requires_grad=True) # (m, d, k) self.attention = nn.Parameter(torch.randn( self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True) # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim) self.tanh = nn.Tanh() self.softmax = nn.Softmax(dim=2) self.sigmoid = nn.Sigmoid() self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz) # self.loss = nn.CrossEntropyLoss() # loss related # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) def build_author_ctx_embed(self, token_hidden, author_embeds): # token_hidden (n, d, l) # author_embeds (m, d, k) n, _, l = token_hidden.shape m = author_embeds.shape[0] F_sim = torch.einsum('ndl,de,mek->nmlk', [token_hidden, self.attention, author_embeds]) F_tanh = self.tanh(F_sim.contiguous().view( n * m, l, self.num_sk)) # (n * m, l, k) F_tanh = F_tanh.view(n, m, l, self.num_sk) # (n, m, l, k) g_u = torch.mean(F_tanh, 2) # (n, m, k) a_u = self.softmax(g_u) # (n, m, k) author_ctx_embed = torch.einsum('mdk,nmk->nmd', [author_embeds, a_u]) # (n, m, d) return author_ctx_embed def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any, date: Any, accept_usr: Any, att_l=False) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) token_hidden = self.encoder(embeddings, mask) # (n, d, l) author_ctx_embed = self.build_author_ctx_embed( token_hidden, self.author_embeddings) # (n, m, d) token_embed = torch.mean(token_hidden, 2) # (n, d) # coherence = torch.einsum('nd,nmd->nm', [token_embed, author_ctx_embed]) # (n, m) # loss, coherence = self.cohere_loss(token_embed, author_ctx_embed, label) # loss, coherence = self.triplet_loss(token_embed, author_ctx_embed, label) loss, coherence = self.rank_loss(token_embed, author_ctx_embed, answerers, accept_usr) # generate positive loss # all_labels = list(range(self.num_authors)) # loss = 0 # for i, pos_labels in enumerate(label): # # num_pos = len(pos_labels) # if num_pos == 0: # continue # # # BR-DEV relation # pos_labels = torch.tensor(pos_labels) # if torch.cuda.is_available(): pos_labels = pos_labels.cuda() # pos_coherence = coherence[i, pos_labels] # pos_loss = torch.sum(-torch.log(self.sigmoid(pos_coherence))) / num_pos # # neg_labels = torch.tensor([item for item in all_labels if item not in pos_labels]) # num_neg = len(neg_labels) # if torch.cuda.is_available(): neg_labels = neg_labels.cuda() # neg_coherence = coherence[i, neg_labels] # neg_loss = torch.sum(-torch.log(self.sigmoid(-neg_coherence))) / num_neg # # loss += (pos_loss + neg_loss) # # # DEV-DEV relation # pos_authors = author_ctx_embed[i, pos_labels] # (pos, d) # neg_authors = author_ctx_embed[i, neg_labels] # (neg, d) # # auth_pos_coherence = torch.einsum('pd,qd->pq', [pos_authors, pos_authors]) # (pos, pos) # auth_neg_coherence = torch.einsum('pd,nd->pn', [pos_authors, neg_authors]) # (pos, neg) # # log_sig_auth = -torch.log(self.sigmoid(auth_pos_coherence)) # auth_pos_loss = (torch.sum(log_sig_auth) - torch.sum(torch.diagonal(log_sig_auth, 0))) # if num_pos > 1: # auth_pos_loss /= (num_pos * num_pos - num_pos) # # auth_neg_loss = torch.sum(-torch.log(self.sigmoid(-auth_neg_coherence))) / (num_pos * num_neg) # # # loss += (auth_pos_loss + auth_neg_loss) # loss += (auth_pos_loss) # # if torch.isnan(loss): # raise ValueError("nan loss encountered") output = {"loss": loss, "coherence": coherence} # output = {"class_logits": class_logits} # output["loss"] = self.loss(class_logits, label) predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) truth = [[j[0] for j in i] for i in answerers] self.mrr(predict, accept_usr) return output
class MultiHeadCtxModel(ModelBase): def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32 self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim) self.attention = nn.Parameter(torch.randn( self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True) # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim) self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) def forward(self, tokens: Dict[str, torch.Tensor], id: Any, label: Any, date: Any, att_l=False) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) token_hidden = self.encoder(embeddings, mask).transpose(-1, -2) # (n, d, l) token_embed = torch.mean(token_hidden, 1).squeeze() # (n, d) # token_embed = token_hidden[:, :, -1] if att_l: author_ctx_embed = self.multihead_att(token_hidden, self.author_embeddings, self.author_embeddings) else: author_ctx_embed = self.multihead_att(token_embed, self.author_embeddings, self.author_embeddings) # generate loss loss, coherence = self.cohere_loss(token_embed, author_ctx_embed, label) output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) self.rank_recall(predict, label) return output
class TempCtxModel(ModelBase): def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.date_span = date_span self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = 20, 768 self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim) self.temp_ctx_attention = MHTempCtxAttention(h=8, d_model=self.sk_dim) self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True) # temporal context self.time_encoder = TimeEncoder(self.sk_dim, dropout=0.1, span=1, date_range=date_span) # layer_norm self.ctx_layer_norm = LayerNorm(self.sk_dim) # loss related # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.coherence_func = CoherenceInnerProd() def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any, date: Any, accept_usr: Any) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) # (n, l, d) token_hidden = self.encoder(embeddings, mask).transpose(-1, -2) # (n, l, d) token_embed = torch.mean(token_hidden, 1).squeeze(1) # (n, d) # token_embed = token_hidden[:, :, -1] author_ctx_embed = self.ctx_attention(token_embed, self.author_embeddings, self.author_embeddings) # (n, m, d) # add layer norm for author context embedding author_ctx_embed = self.ctx_layer_norm(author_ctx_embed) # (n, m, d) # transfer the date into time embedding # TODO: use answer date for time embedding time_embed = gen_time_encoding(self.time_encoder, answerers, date, embeddings.size(2), self.num_authors, train_mode=self.training) # time_embed = [self.time_encoder.get_time_encoding(i) for i in date] # time_embed = torch.stack(time_embed, dim=0) # (n, d) # time_embed = time_embed.unsqueeze(1).expand(-1, self.num_authors, -1) # (n, m, d) author_ctx_embed_te = author_ctx_embed + time_embed author_tctx_embed = self.temp_ctx_attention(time_embed, author_ctx_embed, author_ctx_embed) # (n, m, d) # author_tctx_embed = self.temp_ctx_attention(author_ctx_embed_te, author_ctx_embed_te, author_ctx_embed_te) # (n, m, d) # get horizontal temporal time embeddings # htemp_embeds = [] # truth = [[j[0] for j in i] for i in answerers] # for i, d in enumerate(date): # pos_labels = br_utils.to_cuda(torch.tensor(truth[i])) # post_time_embeds = self.time_encoder.get_post_encodings(d) # (t, d) # post_time_embeds = post_time_embeds.unsqueeze(1).expand(-1, pos_labels.size(0), -1) # (t, pos, d) # # pos_embed = author_ctx_embed[i, pos_labels, :] # (pos, d) # pos_embed = pos_embed.unsqueeze(0).expand(post_time_embeds.size(0), -1, -1) # (t, pos, d) # author_post_ctx_embed_te = pos_embed + post_time_embeds # # author_post_ctx_embed = self.temp_ctx_attention(author_post_ctx_embed_te, author_post_ctx_embed_te, author_post_ctx_embed_te) # (t, pos, d) # author_post_ctx_embed = self.temp_ctx_attention(post_time_embeds, pos_embed, pos_embed) # (t, pos, d) # htemp_embeds.append(author_post_ctx_embed) # htemp_loss = self.htemp_loss(token_embed, htemp_embeds) # generate loss # loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers) loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers, accept_usr) # loss += 0.5 * htemp_loss # coherence = self.coherence_func(token_embed, None, author_tctx_embed) output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) truth = [[j[0] for j in i] for i in answerers] # self.rank_recall(predict, truth) # self.mrr(predict, truth) self.mrr(predict, accept_usr) return output
class ShiftTempModel(ModelBase): def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary, date_span: Any, num_shift: int, span: int): super().__init__(vocab) # init word embedding bert_embedder = PretrainedBertEmbedder( pretrained_model="bert-base-uncased", top_layer_only=True, # conserve memory ) self.date_span = date_span self.word_embeddings = BasicTextFieldEmbedder( {"tokens": bert_embedder}, # we'll be ignoring masks so we'll need to set this to True allow_unmatched_keys=True) self.encoder = BertSentencePooler( vocab, self.word_embeddings.get_output_dim()) self.num_authors = num_authors # skills dim self.num_sk, self.sk_dim = 20, 768 self.author_embeddings = nn.Parameter(torch.randn( num_authors, self.num_sk, self.sk_dim), requires_grad=True) # (m, k, d) self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim) # layer_norm self.ctx_layer_norm = nn.LayerNorm(self.sk_dim) self.shift_temp_att = ShiftTempAttention(self.num_authors, self.sk_dim, date_span, num_shift, span) # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz) self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz) self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz) self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz) self.weight_temp = 0.3 # self.loss = nn.CrossEntropyLoss() def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any, date: Any, accept_usr: Any, att_l=False) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) token_hidden = self.encoder(embeddings, mask).transpose(-1, -2) # (n, d, l) token_embed = torch.mean(token_hidden, 1).squeeze(1) # (n, d) # token_embed = token_hidden[:, :, -1] if att_l: author_ctx_embed = self.ctx_attention(token_hidden, self.author_embeddings, self.author_embeddings, att_l=att_l) else: author_ctx_embed = self.ctx_attention(token_embed, self.author_embeddings, self.author_embeddings, att_l=att_l) # (n, m, d) # add layer norm for author context embedding author_ctx_embed = self.ctx_layer_norm(author_ctx_embed) # (n, m, d) temp_ctx_embed, history_temp_embeds = self.shift_temp_att( author_ctx_embed, date) # (n, m, d), ( # generate loss # loss, coherence = self.cohere_loss(token_embed, temp_ctx_embed, label) # triplet_loss, coherence = self.triplet_loss(token_embed, temp_ctx_embed, label) triplet_loss, coherence = self.rank_loss(token_embed, temp_ctx_embed, answerers, accept_usr) truth = [[j[0] for j in i] for i in answerers] temp_loss = self.temp_loss(token_embed, history_temp_embeds, truth) loss = triplet_loss + temp_loss * self.weight_temp output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) # print("Truth:", accept_usr) self.mrr(predict, accept_usr) return output