def expand_dims_for_broadcast(low_tensor, high_tensor): """Expand the dimensions of a lower-rank tensor, so that its rank matches that of a higher-rank tensor. This makes it possible to perform broadcast operations between low_tensor and high_tensor. Args: low_tensor (Tensor): lower-rank Tensor with shape [s_0, ..., s_p] high_tensor (Tensor): higher-rank Tensor with shape [s_0, ..., s_p, ..., s_n] Note that the shape of low_tensor must be a prefix of the shape of high_tensor. Returns: Tensor: the lower-rank tensor, but with shape expanded to be [s_0, ..., s_p, 1, 1, ..., 1] """ low_size, high_size = low_tensor.size(), high_tensor.size() low_rank, high_rank = len(low_size), len(high_size) # verify that low_tensor shape is prefix of high_tensor shape assert low_size == high_size[:low_rank] new_tensor = low_tensor for _ in range(high_rank - low_rank): new_tensor = torch.unsqueeze(new_tensor, len(new_tensor.size())) return new_tensor
def forward(self, inputs, aspects, lengths, aspects_lengths): if torch.cuda.is_available(): torch.cuda.manual_seed(self.seed) inputs = self.embedding(inputs) inputs = self.noise_emb(inputs) inputs = self.drop_emb(inputs) aspects = self.embedding(aspects) aspects = self.drop_emb(aspects) mask = (aspects > 0).float() aspects = torch.sum(aspects * mask, dim=1) new_asp = aspects / aspects_lengths.unsqueeze(-1).float() new_asp = torch.unsqueeze(new_asp, 1) new_asp = new_asp.expand(inputs.size(0), inputs.size(1), inputs.size(2)) concat = torch.cat((inputs, new_asp), 2) inputs = concat.unsqueeze(1) inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.convs] inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] concatenated = torch.cat(inputs, 1) concatenated = self.dropout(concatenated) return self.fc(concatenated)
def forward(self, message, topic, lengths, topic_lengths): if torch.cuda.is_available(): torch.cuda.manual_seed(self.seed) ###MESSAGE MODEL### embeds = self.embedding(message) embeds = self.noise_emb(embeds) embeds = self.dropout_embeds(embeds) # pack the batch embeds_pckd = pack_padded_sequence(embeds, list(lengths.data), batch_first=True) mout_pckd, (hx1, cx1) = self.shared_lstm(embeds_pckd) # unpack output - no need if we are going to use only the last outputs mout_unpckd, _ = pad_packed_sequence( mout_pckd, batch_first=True) # [batch_size,seq_length,300] # Last timestep output is not used # message_output = self.last_timestep(self.shared_lstm, hx1) # message_output = self.dropout_rnn(message_output) ###TOPIC MODEL### topic_embeds = self.embedding(topic) topic_embeds = self.dropout_embeds(topic_embeds) tout, (hx2, cx2) = self.shared_lstm(topic_embeds) tout = self.dropout_rnn(tout) mask = (topic > 0).float().unsqueeze(-1) tout = torch.sum(tout * mask, dim=1) tout = tout / topic_lengths.unsqueeze(-1).float() tout = torch.unsqueeze(tout, 1) tout = tout.expand(mout_unpckd.size(0), mout_unpckd.size(1), mout_unpckd.size(2)) out = torch.cat((mout_unpckd, tout), 2) representations, attentions = self.attention(out, lengths) return self.linear(representations)