def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: a_norm = matrix_1 / (matrix_1.norm(p=2, dim=-1, keepdim=True) + util.tiny_value_of_dtype(matrix_1.dtype)) b_norm = matrix_2 / (matrix_2.norm(p=2, dim=-1, keepdim=True) + util.tiny_value_of_dtype(matrix_2.dtype)) return torch.bmm(a_norm, b_norm.transpose(-1, -2))
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor: a_norm = vector / (vector.norm(p=2, dim=-1, keepdim=True) + util.tiny_value_of_dtype(vector.dtype)) b_norm = matrix / (matrix.norm(p=2, dim=-1, keepdim=True) + util.tiny_value_of_dtype(matrix.dtype)) return torch.bmm(a_norm.unsqueeze(dim=1), b_norm.transpose(-1, -2)).squeeze(1)
def forward(self, tensor: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: broadcast_mask = mask.unsqueeze(-1) num_elements = broadcast_mask.sum() * self.size mean = (tensor * broadcast_mask).sum() / num_elements masked_centered = (tensor - mean) * broadcast_mask std = torch.sqrt((masked_centered * masked_centered).sum() / num_elements + util.tiny_value_of_dtype(tensor.dtype)) return (self.gamma * (tensor - mean) / (std + util.tiny_value_of_dtype(tensor.dtype)) + self.beta)
def test_scalar_mix_layer_norm(self): mixture = ScalarMix(3, do_layer_norm="scalar_norm_reg") tensors = [torch.randn([3, 4, 5]) for _ in range(3)] numpy_mask = numpy.ones((3, 4), dtype="int32") numpy_mask[1, 2:] = 0 mask = torch.from_numpy(numpy_mask).bool() weights = [0.1, 0.2, 0.3] for k in range(3): mixture.scalar_parameters[k].data[0] = weights[k] mixture.gamma.data[0] = 0.5 result = mixture(tensors, mask) normed_weights = numpy.exp(weights) / numpy.sum(numpy.exp(weights)) expected_result = numpy.zeros((3, 4, 5)) for k in range(3): mean = numpy.mean(tensors[k].data.numpy()[numpy_mask == 1]) std = numpy.std(tensors[k].data.numpy()[numpy_mask == 1]) normed_tensor = (tensors[k].data.numpy() - mean) / ( std + util.tiny_value_of_dtype(torch.float)) expected_result += normed_tensor * normed_weights[k] expected_result *= 0.5 numpy.testing.assert_almost_equal(expected_result, result.data.numpy(), decimal=6)
def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): tensor_masked = tensor * broadcast_mask mean = torch.sum(tensor_masked) / num_elements_not_masked variance = (torch.sum( ((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked) return (tensor - mean) / torch.sqrt( variance + util.tiny_value_of_dtype(variance.dtype))
def log_gradient_updates(self, model: Model, param_updates: Dict[str, torch.Tensor]) -> None: for name, param in model.named_parameters(): update_norm = torch.norm(param_updates[name].view(-1)) param_norm = torch.norm(param.view(-1)).cpu() self.add_train_scalar( "gradient_update/" + name, update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)), )
def _log_gradient_updates(self, param_updates: Dict[str, torch.Tensor]) -> None: gradient_update_scalars: Dict[str, float] = {} for name, param in self.trainer.model.named_parameters(): # type: ignore[union-attr] update_norm = torch.norm(param_updates[name].view(-1)) param_norm = torch.norm(param.view(-1)).cpu() gradient_update_scalars[name] = ( update_norm / (param_norm + tiny_value_of_dtype(param_norm.dtype)) ).item() self.log_scalars(gradient_update_scalars, log_prefix="gradient_update")
def forward(self, tokens: torch.Tensor) -> torch.Tensor: # (batch_size, sentence_length, features_vocab_length) mask = (tokens > 0).float() # (batch_size, sentence_length, features_vocab_length, embedding_dim) x = super().forward(tokens) # (batch_size, sentence_length, embedding_dim) return x.sum(dim=-2) / ( (mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1))
def test_masked_layer_norm(self): x_n = np.random.rand(2, 3, 7) mask_n = np.array([[1, 1, 0], [1, 1, 1]]) x = torch.from_numpy(x_n).float() mask = torch.from_numpy(mask_n).bool() layer_norm = MaskedLayerNorm(7, gamma0=0.2) normed_x = layer_norm(x, mask) N = 7 * 5 mean = (x_n * np.expand_dims(mask_n, axis=-1)).sum() / N std = np.sqrt((( (x_n - mean) * np.expand_dims(mask_n, axis=-1))**2).sum() / N + util.tiny_value_of_dtype(torch.float)) expected = 0.2 * (x_n - mean) / (std + util.tiny_value_of_dtype(torch.float)) assert np.allclose(normed_x.data.numpy(), expected)
def forward( self, source: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: # Shape: (batch_size, embedding_dim) source_norm = source / ( source.norm(p=2, dim=-1, keepdim=True) + tiny_value_of_dtype(source.dtype) # type: ignore ) # Shape: (batch_size, embedding_dim) target_norm = target / ( target.norm(p=2, dim=-1, keepdim=True) + tiny_value_of_dtype(target.dtype) # type: ignore ) # Shape: (batch_size, ) similarity = (source_norm * target_norm).sum(-1) distances = 0.5 * (1 - similarity) return cast(torch.Tensor, distances)
def multi_perspective_match_pairwise(vector1: torch.Tensor, vector2: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """ Calculate multi-perspective cosine matching between each time step of one vector and each time step of another vector. # Parameters vector1 : `torch.Tensor` A tensor of shape `(batch, seq_len1, hidden_size)` vector2 : `torch.Tensor` A tensor of shape `(batch, seq_len2, hidden_size)` weight : `torch.Tensor` A tensor of shape `(num_perspectives, hidden_size)` # Returns `torch.Tensor` : A tensor of shape `(batch, seq_len1, seq_len2, num_perspectives)` consisting multi-perspective matching results """ num_perspectives = weight.size(0) # (1, num_perspectives, 1, hidden_size) weight = weight.unsqueeze(0).unsqueeze(2) # (batch, num_perspectives, seq_len*, hidden_size) vector1 = weight * vector1.unsqueeze(1).expand(-1, num_perspectives, -1, -1) vector2 = weight * vector2.unsqueeze(1).expand(-1, num_perspectives, -1, -1) # (batch, num_perspectives, seq_len*, 1) vector1_norm = vector1.norm(p=2, dim=3, keepdim=True) vector2_norm = vector2.norm(p=2, dim=3, keepdim=True) # (batch, num_perspectives, seq_len1, seq_len2) mul_result = torch.matmul(vector1, vector2.transpose(2, 3)) norm_value = vector1_norm * vector2_norm.transpose(2, 3) # (batch, seq_len1, seq_len2, num_perspectives) return ( mul_result / norm_value.clamp(min=tiny_value_of_dtype(norm_value.dtype))).permute( 0, 2, 3, 1)
def sparse_clip_norm(parameters, max_norm, norm_type=2) -> float: """Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Supports sparse gradients. # Parameters parameters : `(Iterable[torch.Tensor])` An iterable of Tensors that will have gradients normalized. max_norm : `float` The max norm of the gradients. norm_type : `float` The type of the used p-norm. Can be `'inf'` for infinity norm. # Returns Total norm of the parameters (viewed as a single vector). """ parameters = list(filter(lambda p: p.grad is not None, parameters)) max_norm = float(max_norm) norm_type = float(norm_type) if norm_type == float("inf"): total_norm = max(p.grad.data.abs().max() for p in parameters) else: total_norm = 0 for p in parameters: if p.grad.is_sparse: # need to coalesce the repeated indices before finding norm grad = p.grad.data.coalesce() param_norm = grad._values().norm(norm_type) else: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm**norm_type total_norm = total_norm**(1.0 / norm_type) clip_coef = max_norm / (total_norm + nn_util.tiny_value_of_dtype(total_norm.dtype)) if clip_coef < 1: for p in parameters: if p.grad.is_sparse: p.grad.data._values().mul_(clip_coef) else: p.grad.data.mul_(clip_coef) return total_norm
def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size, source_length = state["source_mask"].size() trimmed_source_length = source_length - 2 # Initialize the copy scores to zero. state["copy_log_probs"] = ( state["decoder_hidden"].new_zeros((batch_size, trimmed_source_length)) + util.tiny_value_of_dtype(state["decoder_hidden"].dtype) ).log() # shape: (batch_size,) start_predictions = state["source_mask"].new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long ) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_search_step ) return {"predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions}
def batch_end_logging(self, trainer): # Log parameter values to tensorboard if self.tensorboard.should_log_this_batch(): self.tensorboard.log_parameter_and_gradient_statistics( trainer.model, trainer.batch_grad_norm) self.tensorboard.log_learning_rates(trainer.model, trainer.optimizer) self.tensorboard.add_train_scalar("loss/loss_train", trainer.train_metrics["loss"]) self.tensorboard.log_metrics({ "epoch_metrics/" + k: v for k, v in trainer.train_metrics.items() }) if self.log_batch_size_period: cur_batch = training_util.get_batch_size(trainer.batch) self.cumulative_batch_size += cur_batch if (trainer.batches_this_epoch - 1) % self.log_batch_size_period == 0: average = self.cumulative_batch_size / trainer.batches_this_epoch logger.debug( f"current batch size: {cur_batch} mean batch size: {average}" ) self.tensorboard.add_train_scalar("current_batch_size", cur_batch) self.tensorboard.add_train_scalar("mean_batch_size", average) if self.tensorboard.should_log_histograms_this_batch(): for name, param in trainer.model.named_parameters(): self.param_updates[name].sub_(param.detach().cpu()) update_norm = torch.norm(self.param_updates[name].view(-1)) param_norm = torch.norm(param.view(-1)).cpu() self.tensorboard.add_train_scalar( "gradient_update/" + name, update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)), ) self.param_updates.clear() self.tensorboard.log_histograms(trainer.model, self.histogram_parameters)
def _get_ll_contrib( self, generation_scores: torch.Tensor, generation_scores_mask: torch.BoolTensor, copy_scores: torch.Tensor, target_tokens: torch.Tensor, target_to_source: torch.Tensor, copy_mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the log-likelihood contribution from a single timestep. # Parameters generation_scores : `torch.Tensor` Shape: `(batch_size, target_vocab_size)` generation_scores_mask : `torch.BoolTensor` Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's. copy_scores : `torch.Tensor` Shape: `(batch_size, trimmed_source_length)` target_tokens : `torch.Tensor` Shape: `(batch_size,)` target_to_source : `torch.Tensor` Shape: `(batch_size, trimmed_source_length)` copy_mask : `torch.BoolTensor` Shape: `(batch_size, trimmed_source_length)` # Returns Tuple[torch.Tensor, torch.Tensor] Shape: `(batch_size,), (batch_size, max_input_sequence_length)` """ _, target_size = generation_scores.size() # The point of this mask is to just mask out all source token scores # that just represent padding. We apply the mask to the concatenation # of the generation scores and the copy scores to normalize the scores # correctly during the softmax. # shape: (batch_size, target_vocab_size + trimmed_source_length) mask = torch.cat((generation_scores_mask, copy_mask), dim=-1) # shape: (batch_size, target_vocab_size + trimmed_source_length) all_scores = torch.cat((generation_scores, copy_scores), dim=-1) # Normalize generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) log_probs = util.masked_log_softmax(all_scores, mask) # Calculate the log probability (`copy_log_probs`) for each token in the source sentence # that matches the current target token. We use the sum of these copy probabilities # for matching tokens in the source sentence to get the total probability # for the target token. We also need to normalize the individual copy probabilities # to create `selective_weights`, which are used in the next timestep to create # a selective read state. # shape: (batch_size, trimmed_source_length) copy_log_probs = (log_probs[:, target_size:] + (target_to_source.to(log_probs.dtype) + util.tiny_value_of_dtype(log_probs.dtype)).log()) # Since `log_probs[:, target_size]` gives us the raw copy log probabilities, # we use a non-log softmax to get the normalized non-log copy probabilities. selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source) # This mask ensures that item in the batch has a non-zero generation probabilities # for this timestep only when the gold target token is not OOV or there are no # matching tokens in the source sentence. # shape: (batch_size, 1) gen_mask = (target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0) log_gen_mask = ( gen_mask + util.tiny_value_of_dtype(log_probs.dtype)).log().unsqueeze(-1) # Now we get the generation score for the gold target token. # shape: (batch_size, 1) generation_log_probs = log_probs.gather( 1, target_tokens.unsqueeze(1)) + log_gen_mask # ... and add the copy score to get the step log likelihood. # shape: (batch_size, 1 + trimmed_source_length) combined_gen_and_copy = torch.cat( (generation_log_probs, copy_log_probs), dim=-1) # shape: (batch_size,) step_log_likelihood = util.logsumexp(combined_gen_and_copy) return step_log_likelihood, selective_weights
def _train_epoch(self, epoch: int) -> Dict[str, float]: """ Trains one epoch and returns metrics. """ logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) peak_cpu_usage = common_util.peak_memory_mb() logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}") gpu_usage = [] for gpu, memory in common_util.gpu_memory_mb().items(): gpu_usage.append((gpu, memory)) logger.info(f"GPU {gpu} memory usage MB: {memory}") train_loss = 0.0 # Set the model to "train" mode. self._pytorch_model.train() # Get tqdm for the training batches batch_generator = iter(self.data_loader) batch_group_generator = common_util.lazy_groups_of( batch_generator, self._num_gradient_accumulation_steps ) logger.info("Training") num_training_batches = math.ceil( len(self.data_loader) / self._num_gradient_accumulation_steps ) # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's # progress is shown if self._master: batch_group_generator_tqdm = Tqdm.tqdm( batch_group_generator, total=num_training_batches ) else: batch_group_generator_tqdm = batch_group_generator self._last_log = time.time() last_save_time = time.time() batches_this_epoch = 0 if self._batch_num_total is None: self._batch_num_total = 0 histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging()) cumulative_batch_group_size = 0 done_early = False for batch_group in batch_group_generator_tqdm: if self._distributed: # Check whether the other workers have stopped already (due to differing amounts of # data in each). If so, we can't proceed because we would hang when we hit the # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor # here because NCCL process groups apparently don't support BoolTensor. done = torch.tensor(0, device=self.cuda_device) torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) if done.item() > 0: done_early = True logger.warning( f"Worker {torch.distributed.get_rank()} finishing training early! " "This implies that there is an imbalance in your training " "data across the workers and that some amount of it will be " "ignored. A small amount of this is fine, but a major imbalance " "should be avoided. Note: This warning will appear unless your " "data is perfectly balanced." ) break batches_this_epoch += 1 self._batch_num_total += 1 batch_num_total = self._batch_num_total self.optimizer.zero_grad() for batch in batch_group: loss = self.batch_loss(batch, for_training=True) if torch.isnan(loss): raise ValueError("nan loss encountered") loss = loss / len(batch_group) if self._opt_level is not None: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_loss += loss.item() batch_grad_norm = self.rescale_gradients() # This does nothing if batch_num_total is None or you are using a # scheduler which doesn't update per batch. if self._learning_rate_scheduler: self._learning_rate_scheduler.step_batch(batch_num_total) if self._momentum_scheduler: self._momentum_scheduler.step_batch(batch_num_total) if self._tensorboard.should_log_histograms_this_batch() and self._master: # get the magnitude of parameter updates for logging # We need a copy of current parameters to compute magnitude of updates, # and copy them to CPU so large models won't go OOM on the GPU. param_updates = { name: param.detach().cpu().clone() for name, param in self.model.named_parameters() } self.optimizer.step() for name, param in self.model.named_parameters(): param_updates[name].sub_(param.detach().cpu()) update_norm = torch.norm(param_updates[name].view(-1)) param_norm = torch.norm(param.view(-1)).cpu() self._tensorboard.add_train_scalar( "gradient_update/" + name, update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)), ) else: self.optimizer.step() # Update moving averages if self._moving_average is not None: self._moving_average.apply(batch_num_total) # Update the description with the latest metrics metrics = training_util.get_metrics( self.model, train_loss, batches_this_epoch, world_size=self._world_size, cuda_device=[self.cuda_device], ) # Updating tqdm only for the master as the trainers wouldn't have one if self._master: description = training_util.description_from_metrics(metrics) batch_group_generator_tqdm.set_description(description, refresh=False) # Log parameter values to Tensorboard (only from the master) if self._tensorboard.should_log_this_batch() and self._master: self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm) self._tensorboard.log_learning_rates(self.model, self.optimizer) self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"]) self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()}) if self._tensorboard.should_log_histograms_this_batch() and self._master: self._tensorboard.log_histograms(self.model, histogram_parameters) if self._log_batch_size_period: batch_group_size = sum(training_util.get_batch_size(batch) for batch in batch_group) cumulative_batch_group_size += batch_group_size if (batches_this_epoch - 1) % self._log_batch_size_period == 0: average = cumulative_batch_group_size / batches_this_epoch logger.info( f"current batch size: {batch_group_size} mean batch size: {average}" ) self._tensorboard.add_train_scalar("current_batch_size", batch_group_size) self._tensorboard.add_train_scalar("mean_batch_size", average) # Save model if needed. if ( self._model_save_interval is not None and (time.time() - last_save_time > self._model_save_interval) and self._master ): last_save_time = time.time() self._save_checkpoint( "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time))) ) if self._distributed and not done_early: logger.warning( f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)." ) # Indicate that we're done so that any workers that have remaining data stop the epoch early. done = torch.tensor(1, device=self.cuda_device) torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) assert done.item() # Let all workers finish their epoch before computing # the final statistics for the epoch. if self._distributed: dist.barrier() metrics = training_util.get_metrics( self.model, train_loss, batches_this_epoch, reset=True, world_size=self._world_size, cuda_device=[self.cuda_device], ) metrics["cpu_memory_MB"] = peak_cpu_usage for (gpu_num, memory) in gpu_usage: metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory return metrics
def forward(self, tensor: torch.Tensor): mean = tensor.mean(-1, keepdim=True) std = tensor.std(-1, unbiased=False, keepdim=True) return (self.gamma * (tensor - mean) / (std + util.tiny_value_of_dtype(std.dtype)) + self.beta)
def _forward_train(self, embeddings: torch.Tensor, targets: torch.Tensor, target_token_embedding: torch.Tensor) -> torch.Tensor: # (target_token_embedding is only used in the tie_embeddings case, # which is not implemented) # want to compute (n, n_samples + 1) array with the log # probabilities where the first index is the true target # and the remaining ones are the the negative samples. # then we can just select the first column # NOTE: targets input has padding removed (so 0 == the first id, NOT the padding id) ( sampled_ids, target_expected_count, sampled_expected_count, ) = self.log_uniform_candidate_sampler(targets, choice_func=self.choice_func) long_targets = targets.long() long_targets.requires_grad_(False) # Get the softmax weights (so we can compute logits) # shape (batch_size * max_sequence_length + num_samples) all_ids = torch.cat([long_targets, sampled_ids], dim=0) if self.sparse: all_ids_1 = all_ids.unsqueeze(1) all_w = self.softmax_w(all_ids_1).squeeze(1) all_b = self.softmax_b(all_ids_1).squeeze(2).squeeze(1) else: all_w = torch.nn.functional.embedding(all_ids, self.softmax_w) # the unsqueeze / squeeze works around an issue with 1 dim # embeddings all_b = torch.nn.functional.embedding( all_ids, self.softmax_b.unsqueeze(1)).squeeze(1) batch_size = long_targets.size(0) true_w = all_w[:batch_size, :] sampled_w = all_w[batch_size:, :] true_b = all_b[:batch_size] sampled_b = all_b[batch_size:] # compute the logits and remove log expected counts # [batch_size, ] true_logits = ( (true_w * embeddings).sum(dim=1) + true_b - torch.log(target_expected_count + util.tiny_value_of_dtype(target_expected_count.dtype))) # [batch_size, n_samples] sampled_logits = ( torch.matmul(embeddings, sampled_w.t()) + sampled_b - torch.log(sampled_expected_count + util.tiny_value_of_dtype(sampled_expected_count.dtype))) # remove true labels -- we will take # softmax, so set the sampled logits of true values to a large # negative number # [batch_size, n_samples] true_in_sample_mask = sampled_ids == long_targets.unsqueeze(1) masked_sampled_logits = sampled_logits.masked_fill( true_in_sample_mask, -10000.0) # now concat the true logits as index 0 # [batch_size, n_samples + 1] logits = torch.cat([true_logits.unsqueeze(1), masked_sampled_logits], dim=1) # finally take log_softmax log_softmax = torch.nn.functional.log_softmax(logits, dim=1) # true log likelihood is index 0, loss = -1.0 * sum over batch # the likelihood loss can become very large if the corresponding # true logit is very small, so we apply a per-target cap here # so that a single logit for a very rare word won't dominate the batch. nll_loss = -1.0 * log_softmax[:, 0].sum() return nll_loss
def _gather_final_log_probs( self, generation_log_probs: torch.Tensor, copy_log_probs: torch.Tensor, state: Dict[str, torch.Tensor], ) -> torch.Tensor: """ Combine copy probabilities with generation probabilities for matching tokens. # Parameters generation_log_probs : `torch.Tensor` Shape: `(group_size, target_vocab_size)` copy_log_probs : `torch.Tensor` Shape: `(group_size, trimmed_source_length)` state : `Dict[str, torch.Tensor]` # Returns torch.Tensor Shape: `(group_size, target_vocab_size + trimmed_source_length)`. """ _, trimmed_source_length = state["source_to_target"].size() source_token_ids = state["source_token_ids"] # shape: [(batch_size, *)] modified_log_probs_list: List[torch.Tensor] = [] for i in range(trimmed_source_length): # shape: (group_size,) copy_log_probs_slice = copy_log_probs[:, i] # `source_to_target` is a matrix of shape (group_size, trimmed_source_length) # where element (i, j) is the vocab index of the target token that matches the jth # source token in the ith group, if there is one, or the index of the OOV symbol otherwise. # We'll use this to add copy scores to corresponding generation scores. # shape: (group_size,) source_to_target_slice = state["source_to_target"][:, i] # The OOV index in the source_to_target_slice indicates that the source # token is not in the target vocab, so we don't want to add that copy score # to the OOV token. copy_log_probs_to_add_mask = source_to_target_slice != self._oov_index copy_log_probs_to_add = ( copy_log_probs_slice + (copy_log_probs_to_add_mask + util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log()) # shape: (batch_size, 1) copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1) # shape: (batch_size, 1) selected_generation_log_probs = generation_log_probs.gather( 1, source_to_target_slice.unsqueeze(-1)) combined_scores = util.logsumexp( torch.cat( (selected_generation_log_probs, copy_log_probs_to_add), dim=1)) generation_log_probs = generation_log_probs.scatter( -1, source_to_target_slice.unsqueeze(-1), combined_scores.unsqueeze(-1)) # We have to combine copy scores for duplicate source tokens so that # we can find the overall most likely source token. So, if this is the first # occurence of this particular source token, we add the log_probs from all other # occurences, otherwise we zero it out since it was already accounted for. if i < (trimmed_source_length - 1): # Sum copy scores from future occurences of source token. # shape: (group_size, trimmed_source_length - i) source_future_occurences = source_token_ids[:, ( i + 1):] == source_token_ids[:, i].unsqueeze(-1) # shape: (group_size, trimmed_source_length - i) future_copy_log_probs = ( copy_log_probs[:, (i + 1):] + (source_future_occurences + util.tiny_value_of_dtype(copy_log_probs.dtype)).log()) # shape: (group_size, 1 + trimmed_source_length - i) combined = torch.cat((copy_log_probs_slice.unsqueeze(-1), future_copy_log_probs), dim=-1) # shape: (group_size,) copy_log_probs_slice = util.logsumexp(combined) if i > 0: # Remove copy log_probs that we have already accounted for. # shape: (group_size, i) source_previous_occurences = source_token_ids[:, 0: i] == source_token_ids[:, i].unsqueeze( -1) # shape: (group_size,) duplicate_mask = source_previous_occurences.sum(dim=-1) == 0 copy_log_probs_slice = ( copy_log_probs_slice + (duplicate_mask + util.tiny_value_of_dtype( copy_log_probs_slice.dtype)).log()) # Finally, we zero-out copy scores that we added to the generation scores # above so that we don't double-count them. # shape: (group_size,) left_over_copy_log_probs = ( copy_log_probs_slice + (~copy_log_probs_to_add_mask + util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log()) modified_log_probs_list.append( left_over_copy_log_probs.unsqueeze(-1)) modified_log_probs_list.insert(0, generation_log_probs) # shape: (group_size, target_vocab_size + trimmed_source_length) modified_log_probs = torch.cat(modified_log_probs_list, dim=-1) return modified_log_probs