Ejemplo n.º 1
0
    def __init__(self, config, vocab, max_len_token):
        super(LDTW, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size + 1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim,
                         config.bidirectional)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1)).cuda()
        self.loss = BCEWithLogitsLoss()
Ejemplo n.º 2
0
    def __init__(self, config, vocab, max_len_token):
        """
        param config: config object
        param vocab: vocab object
        param max_len_token: max number of tokens 
        """
        super(Stance, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size+1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim, config.bidirectional)
        self.CNN = CNN(config.increasing, config.cnn_num_layers, config.filter_counts, max_len_token)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1)).cuda()
        self.loss = BCEWithLogitsLoss()
Ejemplo n.º 3
0
    def __init__(self, config, vocab, max_len_token):
        """
        param config: config object
        param vocab: vocab object
        param max_len_token: max number of tokens 
        """
        super(AlignDot, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size + 1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim,
                         config.bidirectional)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1))
        self.loss = BCEWithLogitsLoss()
Ejemplo n.º 4
0
class Stance(torch.nn.Module):
    def __init__(self, config, vocab, max_len_token):
        """
        param config: config object
        param vocab: vocab object
        param max_len_token: max number of tokens 
        """
        super(Stance, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size+1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim, config.bidirectional)
        self.CNN = CNN(config.increasing, config.cnn_num_layers, config.filter_counts, max_len_token)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1)).cuda()
        self.loss = BCEWithLogitsLoss()

    def compute_loss(self, qry_tok, pos_tok, neg_tok):
        """
        Computes loss for batch of query positive negative triplets

        param qry: query tokens (batch size of list of tokens)
        param pos: positive mention lookup (batch size of list of tokens)
        param neg: negative mention lookup (batch size of list of tokens)
        return: loss (batch_size)
        """
        qry_lkup, pos_lkup, neg_lkup = get_qry_pos_neg_tok_lookup(self.vocab, qry_tok, pos_tok, neg_tok)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        pos_emb, pos_mask = self.LSTM(torch.from_numpy(pos_lkup).cuda())
        neg_emb, neg_mask = self.LSTM(torch.from_numpy(neg_lkup).cuda())

        loss = self.loss(self.score_pair_train(qry_emb, pos_emb, qry_mask, pos_mask)
                            - self.score_pair_train(qry_emb, neg_emb, qry_mask, neg_mask),  self.ones)

        return loss


    def score_pair_train(self, qry_emb, cnd_emb, qry_msk, cnd_msk):
        """ 
        Scores the batch of query candidate pair
        Take the dot product of all pairs of embeddings (with bmm) to get similarity matrix
        Uses optimal transport to align the weights
        Then runs CNN over the similarity matrix

        param qry: query mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param cnd: candidate mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param qry_msk: query mention mask (batch_size * max_len_token)
        param cnd_mask: candidate mention mask (batch_size * max_len_token)
        return: score for query candidate pairs (batch_size * 1)
        """
        qry_cnd_sim = torch.bmm(qry_emb, torch.transpose(cnd_emb, 2, 1))

        qry_mask = qry_msk.unsqueeze(dim=2)
        cnd_msk = cnd_msk.unsqueeze(dim=1)
        qry_cnd_mask = torch.bmm(qry_mask, cnd_msk)

        qry_cnd_dist = torch.cuda.FloatTensor(qry_cnd_sim.size()).fill_(torch.max(qry_cnd_sim)) - qry_cnd_sim + 1e-6
        qry_cnd_pi = batch_sinkhorn_loss(qry_cnd_dist, qry_cnd_mask)
        qry_cnd_sim_aligned = torch.mul(qry_cnd_sim, qry_cnd_pi)
        qry_cnd_sim_aligned = torch.mul(qry_cnd_sim_aligned, qry_cnd_mask)

        return self.CNN(qry_cnd_sim_aligned)

    def score_dev_test_batch(self, qry_tk, cnd_tk):
        """ 
        Returns the score for query candidate pair 

        param qry: query mention lookup (batch size of list of tokens)
        param cnd: candidate mention lookup (batch size of list of tokens)
        return: score (batch_size)
        """
        qry_lkup, cnd_lkup = get_qry_cnd_tok_lookup(self.vocab, qry_tk, cnd_tk)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        cnd_embed, cnd_mask = self.LSTM(torch.from_numpy(cnd_lkup).cuda())

        scores = self.score_pair_train(qry_emb, cnd_embed, qry_mask, cnd_mask)
        return scores

    def flatten_parameters(self):
        self.LSTM.flatten_parameters()
Ejemplo n.º 5
0
class AlignDot(torch.nn.Module):
    def __init__(self, config, vocab, max_len_token):
        """
        param config: config object
        param vocab: vocab object
        param max_len_token: max number of tokens 
        """
        super(AlignDot, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size + 1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim,
                         config.bidirectional)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1))
        self.loss = BCEWithLogitsLoss()

    def compute_loss(self, qry_tk, pos_tk, neg_tk):
        """
        Computes loss for batch of query positive negative triplets

        param qry: query mention lookup (batch_size of list of token)
        param pos: positive mention lookup (batch_size of list of token)
        param neg: negative mention lookup (batch_size of list of token)
        return: loss (batch_size)
        """
        qry_lkup, pos_lkup, neg_lkup = get_qry_pos_neg_tok_lookup(
            self.vocab, qry_tk, pos_tk, neg_tk)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        pos_emb, pos_mask = self.LSTM(torch.from_numpy(pos_lkup).cuda())
        neg_emb, neg_mask = self.LSTM(torch.from_numpy(neg_lkup).cuda())

        output_dim = qry_emb.shape[2]

        qry_len = torch.sum(qry_mask, dim=1).view(-1, 1).unsqueeze(2).repeat(
            1, 1, output_dim).long()
        pos_len = torch.sum(pos_mask, dim=1).view(-1, 1).unsqueeze(2).repeat(
            1, 1, output_dim).long()
        neg_len = torch.sum(neg_mask, dim=1).view(-1, 1).unsqueeze(2).repeat(
            1, 1, output_dim).long()

        qry_emb = torch.gather(input=qry_emb, dim=1, index=qry_len)
        pos_emb = torch.gather(input=pos_emb, dim=1, index=pos_len)
        neg_emb = torch.gather(input=neg_emb, dim=1, index=neg_len)

        loss = self.loss(
            self.score_pair_train(qry_emb, pos_emb, qry_mask, pos_mask) -
            self.score_pair_train(qry_emb, neg_emb, qry_mask, neg_mask),
            self.ones)

        return loss

    def score_pair_train(self, qry_emb, cnd_emb, qry_mask, cnd_mask):
        """ 
        Scores the batch of query candidate pair
        Take the dot product of all pairs of embeddings (with bmm) to get similarity matrix
        Then multiply by weight matrix and sum across row and column of 

        param qry: query mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param cnd: candidate mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param qry_msk: query mention mask (batch_size * max_len_token)
        param cnd_mask: candidate mention mask (batch_size * max_len_token)
        return: score for query candidate pairs (batch_size)
        """

        return torch.sum(qry_emb * cnd_emb, dim=2)

    def score_dev_test_batch(self, qry_tk, cnd_tk):
        """ 
        Returns the score for query candidate pair 

        param qry: query mention lookup (batch_size of list of tokens)
        param cnd: candidate mention lookup (batch_size of list of tokens)
        return: scores (batch_size)
        """
        qry_lkup, cnd_lkup = get_qry_cnd_tok_lookup(self.vocab, qry_tk, cnd_tk)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        cnd_emb, cnd_mask = self.LSTM(torch.from_numpy(cnd_lkup).cuda())

        output_dim = qry_emb.shape[2]

        qry_len = torch.sum(qry_mask, dim=1).view(-1, 1).unsqueeze(2).repeat(
            1, 1, output_dim).long()
        cnd_len = torch.sum(cnd_mask, dim=1).view(-1, 1).unsqueeze(2).repeat(
            1, 1, output_dim).long()

        qry_emb = torch.gather(input=qry_emb, dim=1, index=qry_len)
        cnd_emb = torch.gather(input=cnd_emb, dim=1, index=cnd_len)

        scores = self.score_pair_train(qry_emb, cnd_emb, qry_mask, cnd_mask)

        return scores

    def flatten_parameters(self):
        self.LSTM.flatten_parameters()
Ejemplo n.º 6
0
class LDTW(torch.nn.Module):
    def __init__(self, config, vocab, max_len_token):
        super(LDTW, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size + 1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim,
                         config.bidirectional)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1)).cuda()
        self.loss = BCEWithLogitsLoss()

    def compute_loss(self, qry_tok, pos_tok, neg_tok):
        """ 
        Computes loss for batch of query positive negative triplets

        param qry: query tokens (batch size of list of tokens)
        param pos: positive mention lookup (batch size of list of tokens)
        param neg: negative mention lookup (batch size of list of tokens)
        return: loss (batch_size)
        """
        qry_lkup, pos_lkup, neg_lkup = get_qry_pos_neg_tok_lookup(
            self.vocab, qry_tok, pos_tok, neg_tok)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        pos_emb, pos_mask = self.LSTM(torch.from_numpy(pos_lkup).cuda())
        neg_emb, neg_mask = self.LSTM(torch.from_numpy(neg_lkup).cuda())

        loss = self.loss(self.score_pair_train(qry_emb, pos_emb, qry_mask, pos_mask) - \
                            self.score_pair_train(qry_emb, neg_emb, qry_mask, neg_mask),  self.ones)

        return loss

    def score_pair_train(self, qry_emb, cnd_emb, qry_msk, cnd_msk):
        """ 
        param qry: query mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param cnd: candidate mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param qry_msk: query mention mask (batch_size * max_len_token)
        param cnd_mask: candidate mention mask (batch_size * max_len_token)
        return: score for query candidate pairs (batch_size * 1)
        """
        qry_cnd_sim = torch.bmm(qry_emb, torch.transpose(cnd_emb, 2, 1))

        qry_mask = qry_msk.unsqueeze(dim=2)
        cnd_msk = cnd_msk.unsqueeze(dim=1)
        qry_cnd_mask = torch.bmm(qry_mask, cnd_msk)

        qry_cnd_sim = torch.mul(qry_cnd_sim, qry_cnd_mask)
        qry_cnd_dist = -qry_cnd_sim

        return MySoftDTW()(qry_cnd_dist)

    def score_dev_test_batch(self, qry_tk, cnd_tk):
        """ 
        Returns the score for query candidate pair 

        param qry: query mention lookup (batch size of list of tokens)
        param cnd: candidate mention lookup (batch size of list of tokens)
        return: score (batch_size)
        """
        qry_lkup, cnd_lkup = get_qry_cnd_tok_lookup(self.vocab, qry_tk, cnd_tk)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        cnd_embed, cnd_mask = self.LSTM(torch.from_numpy(cnd_lkup).cuda())

        scores = self.score_pair_train(qry_emb, cnd_embed, qry_mask, cnd_mask)
        return scores

    def flatten_parameters(self):
        self.LSTM.flatten_parameters()
Ejemplo n.º 7
0
class AlignLinear(torch.nn.Module):
    def __init__(self, config, vocab, max_len_token):
        """
        param config: config object
        param vocab: vocab object
        param max_len_token: max number of tokens 
        """
        super(AlignLinear, self).__init__()
        self.config = config
        self.vocab = vocab
        self.max_len_token = max_len_token

        self.need_flatten = True

        self.EMB = EMB(vocab.size + 1, config.embedding_dim)
        self.LSTM = LSTM(self.EMB, config.embedding_dim, config.rnn_hidden_dim,
                         config.bidirectional)
        self.align_weights = nn.Parameter(torch.randn(max_len_token,
                                                      max_len_token),
                                          requires_grad=True)

        # Vector of ones (used for loss)
        self.ones = Variable(torch.ones(config.train_batch_size, 1))
        self.loss = BCEWithLogitsLoss()

    def compute_loss(self, qry_tk, pos_tk, neg_tk):
        """
        Computes loss for batch of query positive negative triplets

        param qry: query mention lookup (batch_size of list of token)
        param pos: positive mention lookup (batch_size of list of token)
        param neg: negative mention lookup (batch_size of list of token)
        return: loss (batch_size)
        """
        qry_lkup, pos_lkup, neg_lkup = get_qry_pos_neg_tok_lookup(
            self.vocab, qry_tk, pos_tk, neg_tk)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        pos_emb, pos_mask = self.LSTM(torch.from_numpy(pos_lkup).cuda())
        neg_emb, neg_mask = self.LSTM(torch.from_numpy(neg_lkup).cuda())

        loss = self.loss(
            self.score_pair_train(qry_emb, pos_emb, qry_mask, pos_mask) -
            self.score_pair_train(qry_emb, neg_emb, qry_mask, neg_mask),
            self.ones)

        return loss

    def score_pair_train(self, qry_emb, cnd_emb, qry_mask, cnd_mask):
        """ 
        Scores the batch of query candidate pair
        Take the dot product of all pairs of embeddings (with bmm) to get similarity matrix
        Then multiply by weight matrix and sum across row and column of 

        param qry: query mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param cnd: candidate mention embedding (batch_size * max_len_token * hidden_state_output_size)
        param qry_msk: query mention mask (batch_size * max_len_token)
        param cnd_mask: candidate mention mask (batch_size * max_len_token)
        return: score for query candidate pairs (batch_size)
        """
        qry_cnd_sim = torch.bmm(qry_emb, torch.transpose(cnd_emb, 2, 1))

        qry_mask = qry_mask.unsqueeze(dim=2)
        cnd_mask = cnd_mask.unsqueeze(dim=1)
        qry_cnd_mask = torch.bmm(qry_mask, cnd_mask)
        qry_cnd_sim = torch.mul(qry_cnd_sim, qry_cnd_mask)

        output = torch.sum(self.align_weights.expand_as(qry_cnd_sim) *
                           qry_cnd_sim,
                           dim=1,
                           keepdim=True)
        output = torch.sum(output, dim=2, keepdim=True)
        output = torch.squeeze(output, dim=2)
        return output

    def score_dev_test_batch(self, qry_tk, cnd_tk):
        """ 
        Returns the score for query candidate pair 

        param qry: query mention lookup (batch_size of list of tokens)
        param cnd: candidate mention lookup (batch_size of list of tokens)
        return: scores (batch_size)
        """
        qry_lkup, cnd_lkup = get_qry_cnd_tok_lookup(self.vocab, qry_tk, cnd_tk)

        qry_emb, qry_mask = self.LSTM(torch.from_numpy(qry_lkup).cuda())
        cnd_embed, cnd_mask = self.LSTM(torch.from_numpy(cnd_lkup).cuda())

        scores = self.score_pair_train(qry_emb, cnd_embed, qry_mask, cnd_mask)
        return scores

    def flatten_parameters(self):
        self.LSTM.flatten_parameters()