def forward( self, encoded_paths: torch.Tensor, contexts_per_label: List[int], ) -> torch.Tensor: """Classify given paths :param encoded_paths: [n paths; classifier size] :param contexts_per_label: [n1, n2, ..., nk] sum = n paths :return: """ # [batch size; max context size; classifier input size], [batch size; max context size] batched_context, attention_mask = cut_encoded_contexts( encoded_paths, contexts_per_label, self._negative_value) # [batch size; max context size; 1] attn_weights = self.attention(batched_context, attention_mask) # [batch size; classifier input size] context = torch.bmm(attn_weights.transpose(1, 2), batched_context).squeeze(1) # [batch size; hidden size] hidden = self.hidden_layers(context) # [batch size; num classes] output = self.classification_layer(hidden) return output
def forward( self, encoded_paths: torch.Tensor, contexts_per_label: List[int], output_length: int, target_sequence: torch.Tensor = None, ) -> torch.Tensor: """Decode given paths into sequence :param encoded_paths: [n paths; decoder size] :param contexts_per_label: [n1, n2, ..., nk] sum = n paths :param output_length: length of output sequence :param target_sequence: [sequence length; batch size] :return: """ batch_size = len(contexts_per_label) # [batch size; max context size; decoder size], [batch size; max context size] batched_context, attention_mask = cut_encoded_contexts( encoded_paths, contexts_per_label, self._negative_value) # [n layers; batch size; decoder size] initial_state = (torch.cat([ ctx_batch.mean(0).unsqueeze(0) for ctx_batch in encoded_paths.split(contexts_per_label) ]).unsqueeze(0).repeat(self.num_decoder_layers, 1, 1)) h_prev, c_prev = initial_state, initial_state # [target len; batch size; vocab size] output = encoded_paths.new_zeros( (output_length, batch_size, self.out_size)) # [batch size] current_input = encoded_paths.new_full((batch_size, ), self.sos_token, dtype=torch.long) for step in range(output_length): current_output, (h_prev, c_prev) = self.decoder_step( current_input, h_prev, c_prev, batched_context, attention_mask) output[step] = current_output if target_sequence is not None and torch.rand( 1) < self.teacher_forcing: current_input = target_sequence[step] else: current_input = output[step].argmax(dim=-1) return output
def test_cut_encoded_contexts(self): units = 10 mask_value = -1 batch_size = 5 contexts_per_label = list(range(1, batch_size + 1)) max_context_len = max(contexts_per_label) encoded_contexts = torch.cat([ torch.full((i, units), i, dtype=torch.float) for i in contexts_per_label ]) def create_true_batch(fill_value: int, counts: int, size: int) -> torch.tensor: return torch.cat( [ torch.full( (1, counts, units), fill_value, dtype=torch.float), torch.zeros((1, size - counts, units)) ], dim=1, ) def create_batch_mask(counts: int, size: int) -> torch.tensor: return torch.cat([ torch.zeros(1, counts), torch.full((1, size - counts), mask_value, dtype=torch.float) ], dim=1) true_batched_context = torch.cat([ create_true_batch(i, i, max_context_len) for i in contexts_per_label ]) true_attention_mask = torch.cat([ create_batch_mask(i, max_context_len) for i in contexts_per_label ]) batched_context, attention_mask = cut_encoded_contexts( encoded_contexts, contexts_per_label, mask_value) torch.testing.assert_allclose(batched_context, true_batched_context) torch.testing.assert_allclose(attention_mask, true_attention_mask)