예제 #1
0
def expand_dims_for_broadcast(low_tensor, high_tensor):
    """Expand the dimensions of a lower-rank tensor, so that its rank matches that of a higher-rank tensor.

    This makes it possible to perform broadcast operations between low_tensor and high_tensor.

    Args:
        low_tensor (Tensor): lower-rank Tensor with shape [s_0, ..., s_p]
        high_tensor (Tensor): higher-rank Tensor with shape [s_0, ..., s_p, ..., s_n]

    Note that the shape of low_tensor must be a prefix of the shape of high_tensor.

    Returns:
        Tensor: the lower-rank tensor, but with shape expanded to be [s_0, ..., s_p, 1, 1, ..., 1]
    """
    low_size, high_size = low_tensor.size(), high_tensor.size()
    low_rank, high_rank = len(low_size), len(high_size)

    # verify that low_tensor shape is prefix of high_tensor shape
    assert low_size == high_size[:low_rank]

    new_tensor = low_tensor
    for _ in range(high_rank - low_rank):
        new_tensor = torch.unsqueeze(new_tensor, len(new_tensor.size()))

    return new_tensor
예제 #2
0
    def forward(self, inputs, aspects, lengths, aspects_lengths):

        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)

        inputs = self.embedding(inputs)
        inputs = self.noise_emb(inputs)
        inputs = self.drop_emb(inputs)

        aspects = self.embedding(aspects)
        aspects = self.drop_emb(aspects)

        mask = (aspects > 0).float()
        aspects = torch.sum(aspects * mask, dim=1)
        new_asp = aspects / aspects_lengths.unsqueeze(-1).float()
        new_asp = torch.unsqueeze(new_asp, 1)
        new_asp = new_asp.expand(inputs.size(0), inputs.size(1),
                                 inputs.size(2))

        concat = torch.cat((inputs, new_asp), 2)
        inputs = concat.unsqueeze(1)

        inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.convs]
        inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs]
        concatenated = torch.cat(inputs, 1)

        concatenated = self.dropout(concatenated)

        return self.fc(concatenated)
예제 #3
0
    def forward(self, message, topic, lengths, topic_lengths):

        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)

        ###MESSAGE MODEL###

        embeds = self.embedding(message)
        embeds = self.noise_emb(embeds)
        embeds = self.dropout_embeds(embeds)

        # pack the batch
        embeds_pckd = pack_padded_sequence(embeds,
                                           list(lengths.data),
                                           batch_first=True)

        mout_pckd, (hx1, cx1) = self.shared_lstm(embeds_pckd)

        # unpack output - no need if we are going to use only the last outputs
        mout_unpckd, _ = pad_packed_sequence(
            mout_pckd, batch_first=True)  # [batch_size,seq_length,300]

        # Last timestep output is not used
        # message_output = self.last_timestep(self.shared_lstm, hx1)
        # message_output = self.dropout_rnn(message_output)

        ###TOPIC MODEL###

        topic_embeds = self.embedding(topic)
        topic_embeds = self.dropout_embeds(topic_embeds)

        tout, (hx2, cx2) = self.shared_lstm(topic_embeds)
        tout = self.dropout_rnn(tout)

        mask = (topic > 0).float().unsqueeze(-1)
        tout = torch.sum(tout * mask, dim=1)
        tout = tout / topic_lengths.unsqueeze(-1).float()
        tout = torch.unsqueeze(tout, 1)
        tout = tout.expand(mout_unpckd.size(0), mout_unpckd.size(1),
                           mout_unpckd.size(2))

        out = torch.cat((mout_unpckd, tout), 2)
        representations, attentions = self.attention(out, lengths)

        return self.linear(representations)