class Text_Embedding(TextFieldEmbedder): def __init__(self, vocab: Vocabulary, dense_dim=75, l2=1e-5, l1=1e-7, drop=0.1)-> None: super(Text_Embedding, self).__init__() self.dense_dim = dense_dim self.dropout_p = drop self.l1_lambda = l1 self.l2_lambda = l2 self.final_l2_norm = True self.embed_direction = Embedding(num_embeddings = vocab.get_vocab_size('tokens'), embedding_dim = self.dense_dim, norm_type = 2, max_norm = self.l2_lambda) self.embed_magnitude = Embedding(num_embeddings = vocab.get_vocab_size('tokens'), embedding_dim = 1, norm_type = 1, max_norm = self.l1_lambda) #pytorch hasn't implemented spatial dropout for 1d self.dropout = Dropout(p = self.dropout_p) def forward(self, text_field_input: tensor, num_wrapping_dims: int = 0)-> tensor: if numel(text_field_input) == 0: if text_field_input.is_cuda: return zeros(self.dense_dim).cuda() else: return zeros(self.dense_dim) direction = self.embed_direction.forward(text_field_input) direction_normalized = normalize(direction,p=2,dim=-1) magnitude = self.embed_magnitude.forward(text_field_input) embedding = direction_normalized*magnitude if self.final_l2_norm: summed = sum(embedding,dim=-2) normalized_sum = normalize(summed,p=2,dim=-1) return self.dropout.forward(normalized_sum) else: summed = sum(embedding,dim=-2) return self.dropout.forward(summed) def get_output_dim(self) -> int: return self.dense_dim
class PointerGeneratorNetwork(Model): def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, attention: Attention, max_decoding_steps: int, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., projection_dim: int = None, use_coverage: bool = False, coverage_loss_weight: float = None) -> None: super(PointerGeneratorNetwork, self).__init__(vocab) self._target_namespace = target_namespace self._start_index = self.vocab.get_token_index(START_SYMBOL, target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, target_namespace) self._unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN, target_namespace) self._vocab_size = self.vocab.get_vocab_size(target_namespace) # Encoder self._source_embedder = source_embedder self._encoder = encoder self._encoder_output_dim = self._encoder.get_output_dim() # Decoder self._target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim() self._num_classes = self.vocab.get_vocab_size(target_namespace) self._target_embedder = Embedding(self._num_classes, self._target_embedding_dim) self._decoder_input_dim = self._encoder_output_dim + self._target_embedding_dim self._decoder_output_dim = self._encoder_output_dim self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._projection_dim = projection_dim or self._source_embedder.get_output_dim() self._hidden_projection_layer = Linear(self._decoder_output_dim, self._projection_dim) self._output_projection_layer = Linear(self._projection_dim, self._num_classes) self._p_gen_layer = Linear(self._decoder_output_dim * 3 + self._decoder_input_dim, 1) self._attention = attention self._use_coverage = use_coverage self._coverage_loss_weight = coverage_loss_weight self._eps = 1e-31 # Decoding self._scheduled_sampling_ratio = scheduled_sampling_ratio self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size or 1) def forward(self, source_tokens: Dict[str, torch.LongTensor], source_token_ids: torch.Tensor, source_to_target: torch.Tensor, target_tokens: Dict[str, torch.LongTensor] = None, target_token_ids: torch.Tensor = None, metadata=None) -> Dict[str, torch.Tensor]: state = self._encode(source_tokens) target_tokens_tensor = target_tokens["tokens"].long() if target_tokens else None extra_zeros, modified_source_tokens, modified_target_tokens = self._prepare( source_to_target, source_token_ids, target_tokens_tensor, target_token_ids) state["tokens"] = modified_source_tokens state["extra_zeros"] = extra_zeros output_dict = {} if target_tokens: state["target_tokens"] = modified_target_tokens state = self._init_decoder_state(state) output_dict = self._forward_loop(state, target_tokens) output_dict["metadata"] = metadata output_dict["source_to_target"] = source_to_target if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) return output_dict def _prepare(self, source_tokens: torch.LongTensor, source_token_ids: torch.Tensor, target_tokens: torch.LongTensor = None, target_token_ids: torch.Tensor = None): batch_size = source_tokens.size(0) source_max_length = source_tokens.size(1) tokens = source_tokens token_ids = source_token_ids.long() # Concat target tokens if exist if target_tokens is not None: tokens = torch.cat((tokens, target_tokens), 1) token_ids = torch.cat((token_ids, target_token_ids.long()), 1) is_unk = torch.eq(tokens, self._unk_index).long() # Create tensor with ids of unknown tokens only. # Those ids are batch-local. unk_only = token_ids * is_unk # Recalculate batch-local ids to range [1, count_of_unique_unk_tokens]. # All known tokens have zero id. unk_token_nums = token_ids.new_zeros((batch_size, token_ids.size(1))) for i in range(batch_size): unique = torch.unique(unk_only[i, :], return_inverse=True, sorted=True)[1] unk_token_nums[i, :] = unique # Replace DEFAULT_OOV_TOKEN id with new batch-local ids starting from vocab_size # For example, if vocabulary size is 50000, the first unique unknown token will have 50000 index, # the second will have 50001 index and so on. tokens = tokens - tokens * is_unk + (self._vocab_size - 1) * is_unk + unk_token_nums modified_target_tokens = None modified_source_tokens = tokens if target_tokens is not None: # Remove target unknown tokens that do not exist in source tokens max_source_num = torch.max(tokens[:, :source_max_length], dim=1)[0] vocab_size = max_source_num.new_full((1,), self._vocab_size-1) max_source_num = torch.max(max_source_num, other=vocab_size).unsqueeze(1).expand((-1, tokens.size(1))) unk_target_tokens_mask = torch.gt(tokens, max_source_num).long() tokens = tokens - tokens * unk_target_tokens_mask + self._unk_index * unk_target_tokens_mask modified_target_tokens = tokens[:, source_max_length:] modified_source_tokens = tokens[:, :source_max_length] # Count unique unknown source tokens to create enough zeros for final distribution source_unk_count = torch.max(unk_token_nums[:, :source_max_length]) extra_zeros = tokens.new_zeros((batch_size, source_unk_count), dtype=torch.float32) return extra_zeros, modified_source_tokens, modified_target_tokens def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder.forward(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder.forward(embedded_input, source_mask) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output encoder_outputs = state["encoder_outputs"] state["decoder_context"] = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim) if self._use_coverage: state["coverage"] = encoder_outputs.new_zeros(batch_size, encoder_outputs.size(1)) return state def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] is_unk = (last_predictions >= self._vocab_size).long() last_predictions_fixed = last_predictions - last_predictions * is_unk + self._unk_index * is_unk embedded_input = self._target_embedder.forward(last_predictions_fixed) if not self._use_coverage: attn_scores = self._attention.forward(decoder_hidden, encoder_outputs, source_mask) else: coverage = state["coverage"] attn_scores = self._attention.forward(decoder_hidden, encoder_outputs, source_mask, coverage) coverage = coverage + attn_scores state["coverage"] = coverage attn_context = util.weighted_sum(encoder_outputs, attn_scores) decoder_input = torch.cat((attn_context, embedded_input), -1) decoder_hidden, decoder_context = self._decoder_cell(decoder_input, (decoder_hidden, decoder_context)) output_projections = self._output_projection_layer(self._hidden_projection_layer(decoder_hidden)) state["decoder_input"] = decoder_input state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context state["attn_scores"] = attn_scores state["attn_context"] = attn_context return output_projections, state def _get_final_dist(self, state: Dict[str, torch.Tensor], output_projections): attn_dist = state["attn_scores"] tokens = state["tokens"] extra_zeros = state["extra_zeros"] attn_context = state["attn_context"] decoder_input = state["decoder_input"] decoder_hidden = state["decoder_hidden"] decoder_context = state["decoder_context"] decoder_state = torch.cat((decoder_hidden, decoder_context), 1) p_gen = self._p_gen_layer(torch.cat((attn_context, decoder_state, decoder_input), 1)) p_gen = torch.sigmoid(p_gen) vocab_dist = F.softmax(output_projections, dim=-1) vocab_dist = vocab_dist * p_gen attn_dist = attn_dist * (1.0 - p_gen) if extra_zeros.size(1) != 0: vocab_dist = torch.cat((vocab_dist, extra_zeros), 1) final_dist = vocab_dist.scatter_add(1, tokens, attn_dist) normalization_factor = final_dist.sum(1, keepdim=True) final_dist = final_dist / normalization_factor return final_dist def _forward_loop(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size(0) num_decoding_steps = self._max_decoding_steps if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() num_decoding_steps = target_sequence_length - 1 if self._use_coverage: coverage_loss = source_mask.new_zeros(1, dtype=torch.float32) last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) step_proba: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: input_choices = last_predictions elif not target_tokens: input_choices = last_predictions else: input_choices = targets[:, timestep] if self._use_coverage: old_coverage = state["coverage"] output_projections, state = self._prepare_output_projections(input_choices, state) final_dist = self._get_final_dist(state, output_projections) step_proba.append(final_dist) last_predictions = torch.max(final_dist, 1)[1] step_predictions.append(last_predictions.unsqueeze(1)) if self._use_coverage: step_coverage_loss = torch.sum(torch.min(state["attn_scores"], old_coverage), 1) coverage_loss = coverage_loss + step_coverage_loss # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) num_classes = step_proba[0].size(1) proba = step_proba[0].new_zeros((batch_size, num_classes, len(step_proba))) for i, p in enumerate(step_proba): proba[:, :, i] = p loss = self._get_loss(proba, state["target_tokens"], self._eps) if self._use_coverage: coverage_loss = torch.mean(coverage_loss / num_decoding_steps) loss = loss + self._coverage_loss_weight * coverage_loss output_dict["loss"] = loss return output_dict @staticmethod def _get_loss(proba: torch.LongTensor, targets: torch.LongTensor, eps: float) -> torch.Tensor: targets = targets[:, 1:] proba = torch.log(proba + eps) loss = torch.nn.NLLLoss(ignore_index=0)(proba, targets) return loss def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections(last_predictions, state) final_dist = self._get_final_dist(state, output_projections) log_probabilities = torch.log(final_dist + self._eps) return log_probabilities, state def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, np.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] all_meta = output_dict["metadata"] all_source_to_target = output_dict["source_to_target"] for (indices, metadata), source_to_target in zip(zip(predicted_indices, all_meta), all_source_to_target): all_predicted_tokens.append(self._decode_sample(indices, metadata, source_to_target)) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _decode_sample(self, indices, metadata, source_to_target): predicted_tokens = [] # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] # Get all unknown tokens from source original_source_tokens = metadata["source_tokens"] unk_tokens = list() for i, token_vocab_index in enumerate(source_to_target): if token_vocab_index != self._unk_index: continue token = original_source_tokens[i] if token in unk_tokens: continue unk_tokens.append(token) for token_vocab_index in indices: if token_vocab_index < self._vocab_size: token = self.vocab.get_token_from_index(token_vocab_index, namespace=self._target_namespace) else: unk_number = token_vocab_index - self._vocab_size token = unk_tokens[unk_number] predicted_tokens.append(token) return predicted_tokens