def forward(self, hidden_input_states: PaddedSequence, query_v_for_attention, normalize=True): if not isinstance(hidden_input_states, PaddedSequence): raise TypeError("Expected an input of type PaddedSequence but got {}".format(type(hidden_input_states))) if self.condition_attention: # the code below concatenates the query_v_for_attention (for a unit in the batch to each of the hidden states in the encoder) # expand the query vector used for attention by making it |batch|x1x|query_vector_size| query_v_for_attention = query_v_for_attention.unsqueeze(dim=1) # duplicate it to be the same number of (max) tokens in the batch query_v_for_attention = torch.cat(hidden_input_states.data.size()[1] * [query_v_for_attention], dim=1) # finally, concatenate this vector to every "final" element of the input tensor attention_inputs = torch.cat([hidden_input_states.data, query_v_for_attention], dim=2) else: attention_inputs = hidden_input_states.data raw_word_scores = self.token_attention_F(attention_inputs) raw_word_scores = raw_word_scores * hidden_input_states.mask(on=1.0, off=0.0, size=raw_word_scores.size(), device=raw_word_scores.device) # TODO this should probably become a logsumexp depending on condition a = self.attn_sm(raw_word_scores) # since we need to handle masking, we have to kill any support out of the softmax masked_attention = a * hidden_input_states.mask(on=1.0, off=0.0, size=a.size(), device=a.device) if normalize: # divide by the batch length here so we reduce the variance of the input to the next layer. this is only necessary for the tokenwise attention because its sum isn't constrained # a = masked_attention / word_inputs.batch_sizes.unsqueeze(-1).unsqueeze(-1).float() weights = torch.sum(masked_attention, dim=1).unsqueeze(1) a = masked_attention / weights else: a = masked_attention return a
def forward(self, word_inputs: PaddedSequence, query_v_for_attention: torch.Tensor=None, normalize_attention_distribution=True): if isinstance(word_inputs, PaddedSequence): embedded = self.embedding(word_inputs.data) as_padded = PaddedSequence(embedded, word_inputs.batch_sizes, word_inputs.batch_first) else: raise ValueError("Got an unexpected type {} for word_inputs {}".format(type(word_inputs), word_inputs)) if self.use_attention: a = self.attention_mechanism(as_padded, query_v_for_attention, normalize=normalize_attention_distribution) output = torch.sum(a * embedded * as_padded.mask().unsqueeze(2).cuda(), dim=1) return embedded, output, a else: output = torch.sum(embedded, dim=1) / word_inputs.batch_sizes.unsqueeze(-1).to(torch.float) return embedded, output, None
def forward(self, word_inputs: PaddedSequence, mask=None, query_v_for_attention=None, normalize_attention_distribution=True): embedded = self.embedding(word_inputs.data) projected = self.projection_layer(embedded) mask = word_inputs.mask().to("cuda") # now to the star transformer. # the model will return a tuple comprising <batch, words, dims> and a second # tensor (the rely nodes) of <batch, dims> -- we take the latter # in the case where no attention is to be used token_vectors, a_v = self.st(projected, mask=mask) if self.use_attention: token_vectors = PaddedSequence(token_vectors, word_inputs.batch_sizes, batch_first=True) a = None if self.concat_relay: ### # need to concatenate a_v <batch x model_d> for all articles ### token_vectors_with_relay = self._concat_relay_to_tokens_in_batches( token_vectors, a_v, word_inputs.batch_sizes) a = self.attention_mechanism( token_vectors_with_relay, query_v_for_attention, normalize=normalize_attention_distribution) else: a = self.attention_mechanism( token_vectors, query_v_for_attention, normalize=normalize_attention_distribution) # note this is an element-wise multiplication, so each of the hidden states is weighted by the attention vector weighted_hidden = torch.sum(a * token_vectors.data, dim=1) return token_vectors, weighted_hidden, a return a_v