Exemplo n.º 1
0
    def __init__(self, args, num_authors: int, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True)  # (m, d)

        self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
Exemplo n.º 2
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
Exemplo n.º 3
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768

        # self.author_dim = self.sk_dim + self.time_dim
        self.author_dim = self.sk_dim

        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.author_dim),
                                              requires_grad=True)  # (m, d)

        # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim)
        self.temp_ctx_attention_ns = TempCtxAttentionNS(
            h=8,
            d_model=self.author_dim,
            d_query=self.sk_dim,
            d_time=self.time_dim)

        # temporal context
        self.time_encoder = TimeEncoder(self.time_dim,
                                        dropout=0.1,
                                        span=1,
                                        date_range=date_span)

        # layer_norm
        self.ctx_layer_norm = LayerNorm(self.author_dim)

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.coherence_func = CoherenceInnerProd()
Exemplo n.º 4
0
    def __init__(self, args, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.CrossEntropyLoss()
Exemplo n.º 5
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any, num_shift: int, span: int):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        # layer_norm
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)

        self.shift_temp_att = ShiftTempAttention(self.num_authors, self.sk_dim,
                                                 date_span, num_shift, span)

        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
Exemplo n.º 6
0
    def __init__(self,
                 num_authors: int,
                 out_sz: int,
                 vocab: Vocabulary,
                 date_span: Any,
                 num_shift: int,
                 spans: List,
                 encoder: Any,
                 max_vocab_size: int,
                 ignore_time: bool,
                 ns_mode: bool = False,
                 num_sk: int = 20):
        super().__init__(vocab)

        self.date_span = date_span

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = num_sk, 768
        self.ignore_time = ignore_time
        self.ns_mode = ns_mode
        if self.ns_mode:
            self.author_embeddings = nn.Parameter(torch.randn(
                num_authors, self.sk_dim),
                                                  requires_grad=True)  # (m, d)
        else:
            self.author_embeddings = nn.Parameter(
                torch.randn(num_authors, self.num_sk, self.sk_dim),
                requires_grad=True)  # (m, k, d)
        self.encode_type = encoder
        if self.encode_type == "bert":
            # init word embedding
            bert_embedder = PretrainedBertEmbedder(
                pretrained_model="bert-base-uncased",
                top_layer_only=True,  # conserve memory
            )
            self.word_embeddings = BasicTextFieldEmbedder(
                {"tokens": bert_embedder},
                # we'll be ignoring masks so we'll need to set this to True
                allow_unmatched_keys=True)
            self.encoder = BertSentencePooler(
                vocab, self.word_embeddings.get_output_dim())
        else:
            # prepare embeddings
            token_embedding = Embedding(num_embeddings=max_vocab_size + 2,
                                        embedding_dim=300,
                                        padding_index=0)
            self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder(
                {"tokens": token_embedding})

            self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(
                nn.LSTM(self.word_embeddings.get_output_dim(),
                        hidden_size=int(self.sk_dim / 2),
                        bidirectional=True,
                        batch_first=True))

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)  # layer_norm

        # shifted temporal attentions
        self.spans = spans
        self.span_temp_atts = nn.ModuleList()
        for span in self.spans:
            self.span_temp_atts.append(
                ShiftTempAttention(self.num_authors, self.sk_dim, date_span,
                                   num_shift, span, self.ignore_time))
        self.span_projection = nn.Linear(len(spans), 1)
        self.num_shift = num_shift

        # temporal encoder: used only for adding temporal information into token embedding
        self.time_encoder = TimeEncoder(self.sk_dim,
                                        dropout=0.1,
                                        span=spans[0],
                                        date_range=date_span)

        # loss
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
        self.visual_id = 0