def forward(self, word_inputs: PaddedSequence, query_v_for_attention: torch.Tensor=None, normalize_attention_distribution=True): if isinstance(word_inputs, PaddedSequence): embedded = self.embedding(word_inputs.data) as_padded = PaddedSequence(embedded, word_inputs.batch_sizes, word_inputs.batch_first) else: raise ValueError("Got an unexpected type {} for word_inputs {}".format(type(word_inputs), word_inputs)) if self.use_attention: a = self.attention_mechanism(as_padded, query_v_for_attention, normalize=normalize_attention_distribution) output = torch.sum(a * embedded * as_padded.mask().unsqueeze(2).cuda(), dim=1) return embedded, output, a else: output = torch.sum(embedded, dim=1) / word_inputs.batch_sizes.unsqueeze(-1).to(torch.float) return embedded, output, None
def forward(self, word_inputs: PaddedSequence, mask=None, query_v_for_attention=None, normalize_attention_distribution=True): embedded = self.embedding(word_inputs.data) projected = self.projection_layer(embedded) mask = word_inputs.mask().to("cuda") # now to the star transformer. # the model will return a tuple comprising <batch, words, dims> and a second # tensor (the rely nodes) of <batch, dims> -- we take the latter # in the case where no attention is to be used token_vectors, a_v = self.st(projected, mask=mask) if self.use_attention: token_vectors = PaddedSequence(token_vectors, word_inputs.batch_sizes, batch_first=True) a = None if self.concat_relay: ### # need to concatenate a_v <batch x model_d> for all articles ### token_vectors_with_relay = self._concat_relay_to_tokens_in_batches( token_vectors, a_v, word_inputs.batch_sizes) a = self.attention_mechanism( token_vectors_with_relay, query_v_for_attention, normalize=normalize_attention_distribution) else: a = self.attention_mechanism( token_vectors, query_v_for_attention, normalize=normalize_attention_distribution) # note this is an element-wise multiplication, so each of the hidden states is weighted by the attention vector weighted_hidden = torch.sum(a * token_vectors.data, dim=1) return token_vectors, weighted_hidden, a return a_v
def _concat_relay_to_tokens_in_batches(self, article_token_batches, relay_batches, batch_sizes): ''' Takes <batch x doc_len x embedding> tensor (article_token_batches) and builds and returns a version <batch x doc_len x [embedding + relay_embedding]> which concatenates repeated copies of the relay embedding associated with each batch. ''' # create an empty <batch x (token emedding + relay_embedding)> article_tokens_with_relays = torch.zeros( article_token_batches.data.shape[0], article_token_batches.data.shape[1], article_token_batches.data.shape[2] + relay_batches.shape[1]) for b in range(article_token_batches.data.shape[0]): batch_relay = relay_batches[b].repeat( article_tokens_with_relays.shape[1], 1) article_tokens_with_relays[b] = torch.cat( (article_token_batches.data[b], batch_relay), 1) return PaddedSequence(article_tokens_with_relays.to("cuda"), batch_sizes, batch_first=True)