예제 #1
0
    def __init__(self, args, num_authors: int, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True)  # (m, d)

        self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
예제 #2
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
예제 #3
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768

        # self.author_dim = self.sk_dim + self.time_dim
        self.author_dim = self.sk_dim

        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.author_dim),
                                              requires_grad=True)  # (m, d)

        # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim)
        self.temp_ctx_attention_ns = TempCtxAttentionNS(
            h=8,
            d_model=self.author_dim,
            d_query=self.sk_dim,
            d_time=self.time_dim)

        # temporal context
        self.time_encoder = TimeEncoder(self.time_dim,
                                        dropout=0.1,
                                        span=1,
                                        date_range=date_span)

        # layer_norm
        self.ctx_layer_norm = LayerNorm(self.author_dim)

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.coherence_func = CoherenceInnerProd()
예제 #4
0
    def __init__(self, args, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.CrossEntropyLoss()
예제 #5
0
class BertClassifier(Model):
    def __init__(self, args, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, label: torch.Tensor) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)

        output = {"class_logits": class_logits}
        output["loss"] = self.loss(class_logits, label)

        return output
예제 #6
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any, num_shift: int, span: int):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        # layer_norm
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)

        self.shift_temp_att = ShiftTempAttention(self.num_authors, self.sk_dim,
                                                 date_span, num_shift, span)

        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
예제 #7
0
class TempXCtxModel(ModelBase):
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768

        # self.author_dim = self.sk_dim + self.time_dim
        self.author_dim = self.sk_dim

        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.author_dim),
                                              requires_grad=True)  # (m, d)

        # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim)
        self.temp_ctx_attention_ns = TempCtxAttentionNS(
            h=8,
            d_model=self.author_dim,
            d_query=self.sk_dim,
            d_time=self.time_dim)

        # temporal context
        self.time_encoder = TimeEncoder(self.time_dim,
                                        dropout=0.1,
                                        span=1,
                                        date_range=date_span)

        # layer_norm
        self.ctx_layer_norm = LayerNorm(self.author_dim)

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.coherence_func = CoherenceInnerProd()

    def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any,
                date: Any, accept_usr: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)  # (n, l, d)
        token_hidden = self.encoder(embeddings,
                                    mask).transpose(-1, -2)  # (n, l, d)

        token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        # token_embed = token_hidden[:, :, -1]

        # transfer the date into time embedding
        # TODO: use answer date for time embedding
        time_embed = gen_time_encoding(self.time_encoder, date)

        # token_temp_embed = torch.cat((token_embed, time_embed), 1)
        token_temp_embed = token_embed + time_embed
        author_tctx_embed = self.temp_ctx_attention_ns(token_embed,
                                                       self.author_embeddings,
                                                       self.author_embeddings,
                                                       time_embed)  # (n, m, d)

        # add layer norm for author context embedding
        author_tctx_embed = self.ctx_layer_norm(author_tctx_embed)  # (n, m, d)

        # generate loss
        loss, coherence = self.rank_loss(token_temp_embed, author_tctx_embed,
                                         answerers, accept_usr)

        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        truth = [[j[0] for j in i] for i in answerers]

        # self.rank_recall(predict, truth)
        # self.mrr(predict, truth)
        self.mrr(predict, accept_usr)

        return output
예제 #8
0
class BertNoCtxRanker(ModelBase):
    def __init__(self, args, num_authors: int, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True)  # (m, d)

        self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)


    def build_coherence(self, token_hidden, author_embeds):

        # token_hidden (n, d, l)
        # author_embeds (m, d)

        n, _, l = token_hidden.shape
        m = author_embeds.shape[0]

        token_embed = torch.mean(token_hidden, 2)  # (n, d)

        coherence = torch.einsum('nd,md->nm', [token_embed, author_embeds])  # (n, m)

        return coherence


    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, label: Any, date: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch

        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings, mask)  # (n, d, l)
        token_embed = torch.mean(token_hidden, 2)  # (n, d)

        # coherence = self.build_coherence(token_hidden, self.author_embeddings)

        # generate positive loss
        # all_labels = list(range(self.num_authors))
        # loss = 0
        # for i, pos_labels in enumerate(label):
        #
        #     pos_labels = torch.tensor(pos_labels)
        #     if torch.cuda.is_available(): pos_labels = pos_labels.cuda()
        #     pos_coherence = coherence[i, pos_labels]
        #     pos_loss = torch.sum(-torch.log(self.sigmoid(pos_coherence))) / len(pos_labels)
        #
        #     neg_labels = torch.tensor([item for item in all_labels if item not in pos_labels])
        #     if torch.cuda.is_available(): neg_labels = neg_labels.cuda()
        #     neg_coherence = coherence[i, neg_labels]
        #     neg_loss = torch.sum(-torch.log(self.sigmoid(-neg_coherence))) / len(neg_labels)
        #
        #     loss += (pos_loss + neg_loss)
        #     pass

        # generate negative loss

        # # positive author embeddings
        # pos_author_embeds, pos_size = self.gen_pos_author_embeds(label)  # (n, p, d, k)
        #
        # # negative author embeddings
        # neg_size = pos_size  # choose negative samples the same as positive size
        # neg_author_embeds = self.gen_neg_author_embeds(label, neg_size)
        #
        # pos_coherence = self.build_coherence(token_hidden, pos_author_embeds)
        # neg_coherence = self.build_coherence(token_hidden, neg_author_embeds)
        #
        # pos_loss = torch.sum(torch.sum(torch.log(self.sigmoid(-pos_coherence)))) / pos_size
        # neg_loss = torch.sum(torch.sum(torch.log(self.sigmoid(neg_coherence)))) / neg_size

        # loss = pos_loss + neg_loss

        # loss, coherence = self.cohere_loss(token_embed, self.author_embeddings, label, no_ctx=True)
        loss, coherence = self.triplet_loss(token_embed, self.author_embeddings, label, no_ctx=True)

        output = {"loss": loss, "coherence": coherence}
        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        self.rank_recall(predict, label)

        return output
예제 #9
0
    def __init__(self,
                 num_authors: int,
                 out_sz: int,
                 vocab: Vocabulary,
                 date_span: Any,
                 num_shift: int,
                 spans: List,
                 encoder: Any,
                 max_vocab_size: int,
                 ignore_time: bool,
                 ns_mode: bool = False,
                 num_sk: int = 20):
        super().__init__(vocab)

        self.date_span = date_span

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = num_sk, 768
        self.ignore_time = ignore_time
        self.ns_mode = ns_mode
        if self.ns_mode:
            self.author_embeddings = nn.Parameter(torch.randn(
                num_authors, self.sk_dim),
                                                  requires_grad=True)  # (m, d)
        else:
            self.author_embeddings = nn.Parameter(
                torch.randn(num_authors, self.num_sk, self.sk_dim),
                requires_grad=True)  # (m, k, d)
        self.encode_type = encoder
        if self.encode_type == "bert":
            # init word embedding
            bert_embedder = PretrainedBertEmbedder(
                pretrained_model="bert-base-uncased",
                top_layer_only=True,  # conserve memory
            )
            self.word_embeddings = BasicTextFieldEmbedder(
                {"tokens": bert_embedder},
                # we'll be ignoring masks so we'll need to set this to True
                allow_unmatched_keys=True)
            self.encoder = BertSentencePooler(
                vocab, self.word_embeddings.get_output_dim())
        else:
            # prepare embeddings
            token_embedding = Embedding(num_embeddings=max_vocab_size + 2,
                                        embedding_dim=300,
                                        padding_index=0)
            self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder(
                {"tokens": token_embedding})

            self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(
                nn.LSTM(self.word_embeddings.get_output_dim(),
                        hidden_size=int(self.sk_dim / 2),
                        bidirectional=True,
                        batch_first=True))

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)  # layer_norm

        # shifted temporal attentions
        self.spans = spans
        self.span_temp_atts = nn.ModuleList()
        for span in self.spans:
            self.span_temp_atts.append(
                ShiftTempAttention(self.num_authors, self.sk_dim, date_span,
                                   num_shift, span, self.ignore_time))
        self.span_projection = nn.Linear(len(spans), 1)
        self.num_shift = num_shift

        # temporal encoder: used only for adding temporal information into token embedding
        self.time_encoder = TimeEncoder(self.sk_dim,
                                        dropout=0.1,
                                        span=spans[0],
                                        date_range=date_span)

        # loss
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
        self.visual_id = 0
예제 #10
0
class MultiSpanTempModel(ModelBase):
    def __init__(self,
                 num_authors: int,
                 out_sz: int,
                 vocab: Vocabulary,
                 date_span: Any,
                 num_shift: int,
                 spans: List,
                 encoder: Any,
                 max_vocab_size: int,
                 ignore_time: bool,
                 ns_mode: bool = False,
                 num_sk: int = 20):
        super().__init__(vocab)

        self.date_span = date_span

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = num_sk, 768
        self.ignore_time = ignore_time
        self.ns_mode = ns_mode
        if self.ns_mode:
            self.author_embeddings = nn.Parameter(torch.randn(
                num_authors, self.sk_dim),
                                                  requires_grad=True)  # (m, d)
        else:
            self.author_embeddings = nn.Parameter(
                torch.randn(num_authors, self.num_sk, self.sk_dim),
                requires_grad=True)  # (m, k, d)
        self.encode_type = encoder
        if self.encode_type == "bert":
            # init word embedding
            bert_embedder = PretrainedBertEmbedder(
                pretrained_model="bert-base-uncased",
                top_layer_only=True,  # conserve memory
            )
            self.word_embeddings = BasicTextFieldEmbedder(
                {"tokens": bert_embedder},
                # we'll be ignoring masks so we'll need to set this to True
                allow_unmatched_keys=True)
            self.encoder = BertSentencePooler(
                vocab, self.word_embeddings.get_output_dim())
        else:
            # prepare embeddings
            token_embedding = Embedding(num_embeddings=max_vocab_size + 2,
                                        embedding_dim=300,
                                        padding_index=0)
            self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder(
                {"tokens": token_embedding})

            self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(
                nn.LSTM(self.word_embeddings.get_output_dim(),
                        hidden_size=int(self.sk_dim / 2),
                        bidirectional=True,
                        batch_first=True))

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)  # layer_norm

        # shifted temporal attentions
        self.spans = spans
        self.span_temp_atts = nn.ModuleList()
        for span in self.spans:
            self.span_temp_atts.append(
                ShiftTempAttention(self.num_authors, self.sk_dim, date_span,
                                   num_shift, span, self.ignore_time))
        self.span_projection = nn.Linear(len(spans), 1)
        self.num_shift = num_shift

        # temporal encoder: used only for adding temporal information into token embedding
        self.time_encoder = TimeEncoder(self.sk_dim,
                                        dropout=0.1,
                                        span=spans[0],
                                        date_range=date_span)

        # loss
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
        self.visual_id = 0

    def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any,
                date: Any, accept_usr: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden2 = self.encoder(embeddings, mask)

        if self.encode_type == "bert":
            token_hidden = self.encoder(embeddings,
                                        mask).transpose(-1, -2)  # (n, d, l)
            token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        else:
            token_embed = self.encoder(embeddings, mask)  # (n, d)

        time_embed = gen_time_encoding(self.time_encoder, date)

        token_temp_embed = token_embed if self.ignore_time else token_embed + time_embed
        # if self.ignore_time:
        #     token_temp_embed = token_embed
        # else:
        #     token_temp_embed = token_embed + time_embed  # add time embedding

        # generate the token_embed with temporal information
        # time_embed_zs = [self.time_encoder.get_time_encoding(d, num_shift=0) for d in date]
        # time_embed_zs = torch.stack(time_embed_zs, dim=0)  # (n, d)
        # token_temp_embed = token_embed + time_embed_zs

        if self.ns_mode:
            author_ctx_embed = self.author_embeddings.unsqueeze(0).expand(
                token_embed.size(0), -1, -1)  # (n, m, d)
        else:
            # token_embed = token_hidden[:, :, -1]
            author_ctx_embed = self.ctx_attention(
                token_temp_embed, self.author_embeddings,
                self.author_embeddings)  # (n, m, d)

            # add layer norm for author context embedding
            author_ctx_embed = self.ctx_layer_norm(author_ctx_embed)

        # multi-span shifted time attention layer
        span_temp_ctx_embeds, history_embeds = [], []
        for i in range(len(self.spans)):
            temp_ctx_embed, history_embed = self.span_temp_atts[i](
                token_embed, author_ctx_embed, date)  # (n, m, d)
            span_temp_ctx_embeds.append(temp_ctx_embed)
            history_embeds.append(history_embed)
        temp_ctx_embed_sp = torch.stack(span_temp_ctx_embeds, dim=-1)
        # temp_ctx_embed_sp = torch.transpose(torch.stack(temp_ctx_embed_splist), 0, -1)
        temp_ctx_embed = torch.squeeze(self.span_projection(temp_ctx_embed_sp),
                                       dim=-1)

        # print temporal context-aware embedding for visualization
        for i, answerer in enumerate(answerers):

            # generate the visualization embedding file
            if len(answerer) > 10:
                print("QID:", id[i], "Answerers:", len(answerer))
                embed_pq = temp_ctx_embed[i].cpu().numpy()
                qid = id[i]
                answerer_set = set([j[0] for j in answerer])

                with open("./exp_results/ve_" + str(qid), 'a') as f:
                    for j in range(embed_pq.shape[0]):
                        embed_pa = embed_pq[j]
                        embed_dump = "\t".join([str(i) for i in embed_pa])
                        category = 1 if j in answerer_set else 0
                        f.write(str(category) + "\t" + embed_dump + "\n")
                self.visual_id += 1

        # generate loss
        # loss, coherence = self.cohere_loss(token_embed, temp_ctx_embed, label)
        # triplet_loss, coherence = self.triplet_loss(token_embed, temp_ctx_embed, label)
        triplet_loss, coherence = self.rank_loss(token_embed, temp_ctx_embed,
                                                 answerers, accept_usr)

        truth = [[j[0] for j in i] for i in answerers]
        if self.num_shift > 2:  # no temporal loss between 1st and 2nd shifts
            temp_loss = sum([
                self.temp_loss(token_embed, history_embed, truth)
                for history_embed in history_embeds
            ])
        else:
            temp_loss = 0
        loss = triplet_loss + temp_loss * self.weight_temp
        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)

        #print("Truth:", accept_usr)
        self.mrr(predict, accept_usr)

        return output
예제 #11
0
class BertCtxRanker(ModelBase):
    def __init__(self, args, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.sk_dim, self.num_sk),
                                              requires_grad=True)  # (m, d, k)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

    def build_author_ctx_embed(self, token_hidden, author_embeds):

        # token_hidden (n, d, l)
        # author_embeds (m, d, k)

        n, _, l = token_hidden.shape
        m = author_embeds.shape[0]

        F_sim = torch.einsum('ndl,de,mek->nmlk',
                             [token_hidden, self.attention, author_embeds])
        F_tanh = self.tanh(F_sim.contiguous().view(
            n * m, l, self.num_sk))  # (n * m, l, k)
        F_tanh = F_tanh.view(n, m, l, self.num_sk)  # (n, m, l, k)
        g_u = torch.mean(F_tanh, 2)  # (n, m, k)
        a_u = self.softmax(g_u)  # (n, m, k)

        author_ctx_embed = torch.einsum('mdk,nmk->nmd',
                                        [author_embeds, a_u])  # (n, m, d)

        return author_ctx_embed

    def forward(self,
                tokens: Dict[str, torch.Tensor],
                id: Any,
                answerers: Any,
                date: Any,
                accept_usr: Any,
                att_l=False) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch

        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings, mask)  # (n, d, l)

        author_ctx_embed = self.build_author_ctx_embed(
            token_hidden, self.author_embeddings)  # (n, m, d)
        token_embed = torch.mean(token_hidden, 2)  # (n, d)

        # coherence = torch.einsum('nd,nmd->nm', [token_embed, author_ctx_embed])  # (n, m)
        # loss, coherence = self.cohere_loss(token_embed, author_ctx_embed, label)
        # loss, coherence = self.triplet_loss(token_embed, author_ctx_embed, label)
        loss, coherence = self.rank_loss(token_embed, author_ctx_embed,
                                         answerers, accept_usr)

        # generate positive loss
        # all_labels = list(range(self.num_authors))
        # loss = 0
        # for i, pos_labels in enumerate(label):
        #
        #     num_pos = len(pos_labels)
        #     if num_pos == 0:
        #         continue
        #
        #     # BR-DEV relation
        #     pos_labels = torch.tensor(pos_labels)
        #     if torch.cuda.is_available(): pos_labels = pos_labels.cuda()
        #     pos_coherence = coherence[i, pos_labels]
        #     pos_loss = torch.sum(-torch.log(self.sigmoid(pos_coherence))) / num_pos
        #
        #     neg_labels = torch.tensor([item for item in all_labels if item not in pos_labels])
        #     num_neg = len(neg_labels)
        #     if torch.cuda.is_available(): neg_labels = neg_labels.cuda()
        #     neg_coherence = coherence[i, neg_labels]
        #     neg_loss = torch.sum(-torch.log(self.sigmoid(-neg_coherence))) / num_neg
        #
        #     loss += (pos_loss + neg_loss)
        #
        #     # DEV-DEV relation
        #     pos_authors = author_ctx_embed[i, pos_labels]  # (pos, d)
        #     neg_authors = author_ctx_embed[i, neg_labels]  # (neg, d)
        #
        #     auth_pos_coherence = torch.einsum('pd,qd->pq', [pos_authors, pos_authors])  # (pos, pos)
        #     auth_neg_coherence = torch.einsum('pd,nd->pn', [pos_authors, neg_authors])  # (pos, neg)
        #
        #     log_sig_auth = -torch.log(self.sigmoid(auth_pos_coherence))
        #     auth_pos_loss = (torch.sum(log_sig_auth) - torch.sum(torch.diagonal(log_sig_auth, 0)))
        #     if num_pos > 1:
        #         auth_pos_loss /= (num_pos * num_pos - num_pos)
        #
        #     auth_neg_loss = torch.sum(-torch.log(self.sigmoid(-auth_neg_coherence))) / (num_pos * num_neg)
        #
        #     # loss += (auth_pos_loss + auth_neg_loss)
        #     loss += (auth_pos_loss)
        #
        #     if torch.isnan(loss):
        #         raise ValueError("nan loss encountered")

        output = {"loss": loss, "coherence": coherence}
        # output = {"class_logits": class_logits}
        # output["loss"] = self.loss(class_logits, label)

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        truth = [[j[0] for j in i] for i in answerers]

        self.mrr(predict, accept_usr)
        return output
예제 #12
0
class MultiHeadCtxModel(ModelBase):
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)

    def forward(self,
                tokens: Dict[str, torch.Tensor],
                id: Any,
                label: Any,
                date: Any,
                att_l=False) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch

        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings,
                                    mask).transpose(-1, -2)  # (n, d, l)

        token_embed = torch.mean(token_hidden, 1).squeeze()  # (n, d)
        # token_embed = token_hidden[:, :, -1]
        if att_l:
            author_ctx_embed = self.multihead_att(token_hidden,
                                                  self.author_embeddings,
                                                  self.author_embeddings)
        else:
            author_ctx_embed = self.multihead_att(token_embed,
                                                  self.author_embeddings,
                                                  self.author_embeddings)

        # generate loss
        loss, coherence = self.cohere_loss(token_embed, author_ctx_embed,
                                           label)
        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        self.rank_recall(predict, label)

        return output
예제 #13
0
class TempCtxModel(ModelBase):
    def __init__(self, num_authors: int, out_sz: int,
                 vocab: Vocabulary, date_span: Any):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.num_sk, self.sk_dim), requires_grad=True)  # (m, k, d)

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        self.temp_ctx_attention = MHTempCtxAttention(h=8, d_model=self.sk_dim)

        self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True)

        # temporal context
        self.time_encoder = TimeEncoder(self.sk_dim, dropout=0.1, span=1, date_range=date_span)

        # layer_norm
        self.ctx_layer_norm = LayerNorm(self.sk_dim)

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.coherence_func = CoherenceInnerProd()

    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, answerers: Any, date: Any, accept_usr: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)  # (n, l, d)
        token_hidden = self.encoder(embeddings, mask).transpose(-1, -2)  # (n, l, d)

        token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        # token_embed = token_hidden[:, :, -1]
        author_ctx_embed = self.ctx_attention(token_embed, self.author_embeddings, self.author_embeddings)  # (n, m, d)

        # add layer norm for author context embedding
        author_ctx_embed = self.ctx_layer_norm(author_ctx_embed)  # (n, m, d)

        # transfer the date into time embedding
        # TODO: use answer date for time embedding
        time_embed = gen_time_encoding(self.time_encoder, answerers, date, embeddings.size(2), self.num_authors, train_mode=self.training)
        # time_embed = [self.time_encoder.get_time_encoding(i) for i in date]
        # time_embed = torch.stack(time_embed, dim=0)  # (n, d)
        # time_embed = time_embed.unsqueeze(1).expand(-1, self.num_authors, -1)  # (n, m, d)

        author_ctx_embed_te = author_ctx_embed + time_embed
        author_tctx_embed = self.temp_ctx_attention(time_embed, author_ctx_embed, author_ctx_embed)  # (n, m, d)
        # author_tctx_embed = self.temp_ctx_attention(author_ctx_embed_te, author_ctx_embed_te, author_ctx_embed_te)  # (n, m, d)

        # get horizontal temporal time embeddings
        # htemp_embeds = []
        # truth = [[j[0] for j in i] for i in answerers]
        # for i, d in enumerate(date):
        #     pos_labels = br_utils.to_cuda(torch.tensor(truth[i]))
        #     post_time_embeds = self.time_encoder.get_post_encodings(d)  # (t, d)
        #     post_time_embeds = post_time_embeds.unsqueeze(1).expand(-1, pos_labels.size(0), -1)  # (t, pos, d)
        #
        #     pos_embed = author_ctx_embed[i, pos_labels, :]  # (pos, d)
        #     pos_embed = pos_embed.unsqueeze(0).expand(post_time_embeds.size(0), -1, -1)  # (t, pos, d)
        #     author_post_ctx_embed_te = pos_embed + post_time_embeds
        #     # author_post_ctx_embed = self.temp_ctx_attention(author_post_ctx_embed_te, author_post_ctx_embed_te, author_post_ctx_embed_te)  # (t, pos, d)
        #     author_post_ctx_embed = self.temp_ctx_attention(post_time_embeds, pos_embed, pos_embed)  # (t, pos, d)
        #     htemp_embeds.append(author_post_ctx_embed)
        # htemp_loss = self.htemp_loss(token_embed, htemp_embeds)

        # generate loss
        # loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers)
        loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers, accept_usr)
        # loss += 0.5 * htemp_loss

        # coherence = self.coherence_func(token_embed, None, author_tctx_embed)
        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        truth = [[j[0] for j in i] for i in answerers]

        # self.rank_recall(predict, truth)
        # self.mrr(predict, truth)
        self.mrr(predict, accept_usr)


        return output
예제 #14
0
class ShiftTempModel(ModelBase):
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any, num_shift: int, span: int):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        # layer_norm
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)

        self.shift_temp_att = ShiftTempAttention(self.num_authors, self.sk_dim,
                                                 date_span, num_shift, span)

        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
        # self.loss = nn.CrossEntropyLoss()

    def forward(self,
                tokens: Dict[str, torch.Tensor],
                id: Any,
                answerers: Any,
                date: Any,
                accept_usr: Any,
                att_l=False) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings,
                                    mask).transpose(-1, -2)  # (n, d, l)

        token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        # token_embed = token_hidden[:, :, -1]
        if att_l:
            author_ctx_embed = self.ctx_attention(token_hidden,
                                                  self.author_embeddings,
                                                  self.author_embeddings,
                                                  att_l=att_l)
        else:
            author_ctx_embed = self.ctx_attention(token_embed,
                                                  self.author_embeddings,
                                                  self.author_embeddings,
                                                  att_l=att_l)  # (n, m, d)

        # add layer norm for author context embedding
        author_ctx_embed = self.ctx_layer_norm(author_ctx_embed)  # (n, m, d)

        temp_ctx_embed, history_temp_embeds = self.shift_temp_att(
            author_ctx_embed, date)  # (n, m, d), (

        # generate loss
        # loss, coherence = self.cohere_loss(token_embed, temp_ctx_embed, label)
        # triplet_loss, coherence = self.triplet_loss(token_embed, temp_ctx_embed, label)
        triplet_loss, coherence = self.rank_loss(token_embed, temp_ctx_embed,
                                                 answerers, accept_usr)

        truth = [[j[0] for j in i] for i in answerers]
        temp_loss = self.temp_loss(token_embed, history_temp_embeds, truth)
        loss = triplet_loss + temp_loss * self.weight_temp

        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)

        # print("Truth:", accept_usr)
        self.mrr(predict, accept_usr)

        return output