def forward(self, matrix_1: torch.Tensor,
             matrix_2: torch.Tensor) -> torch.Tensor:
     a_norm = matrix_1 / (matrix_1.norm(p=2, dim=-1, keepdim=True) +
                          util.tiny_value_of_dtype(matrix_1.dtype))
     b_norm = matrix_2 / (matrix_2.norm(p=2, dim=-1, keepdim=True) +
                          util.tiny_value_of_dtype(matrix_2.dtype))
     return torch.bmm(a_norm, b_norm.transpose(-1, -2))
示例#2
0
 def _forward_internal(self, vector: torch.Tensor,
                       matrix: torch.Tensor) -> torch.Tensor:
     a_norm = vector / (vector.norm(p=2, dim=-1, keepdim=True) +
                        util.tiny_value_of_dtype(vector.dtype))
     b_norm = matrix / (matrix.norm(p=2, dim=-1, keepdim=True) +
                        util.tiny_value_of_dtype(matrix.dtype))
     return torch.bmm(a_norm.unsqueeze(dim=1),
                      b_norm.transpose(-1, -2)).squeeze(1)
示例#3
0
    def forward(self, tensor: torch.Tensor,
                mask: torch.BoolTensor) -> torch.Tensor:

        broadcast_mask = mask.unsqueeze(-1)
        num_elements = broadcast_mask.sum() * self.size
        mean = (tensor * broadcast_mask).sum() / num_elements
        masked_centered = (tensor - mean) * broadcast_mask
        std = torch.sqrt((masked_centered * masked_centered).sum() /
                         num_elements + util.tiny_value_of_dtype(tensor.dtype))
        return (self.gamma * (tensor - mean) /
                (std + util.tiny_value_of_dtype(tensor.dtype)) + self.beta)
示例#4
0
    def test_scalar_mix_layer_norm(self):
        mixture = ScalarMix(3, do_layer_norm="scalar_norm_reg")

        tensors = [torch.randn([3, 4, 5]) for _ in range(3)]
        numpy_mask = numpy.ones((3, 4), dtype="int32")
        numpy_mask[1, 2:] = 0
        mask = torch.from_numpy(numpy_mask).bool()

        weights = [0.1, 0.2, 0.3]
        for k in range(3):
            mixture.scalar_parameters[k].data[0] = weights[k]
        mixture.gamma.data[0] = 0.5
        result = mixture(tensors, mask)

        normed_weights = numpy.exp(weights) / numpy.sum(numpy.exp(weights))
        expected_result = numpy.zeros((3, 4, 5))
        for k in range(3):
            mean = numpy.mean(tensors[k].data.numpy()[numpy_mask == 1])
            std = numpy.std(tensors[k].data.numpy()[numpy_mask == 1])
            normed_tensor = (tensors[k].data.numpy() - mean) / (
                std + util.tiny_value_of_dtype(torch.float))
            expected_result += normed_tensor * normed_weights[k]
        expected_result *= 0.5

        numpy.testing.assert_almost_equal(expected_result,
                                          result.data.numpy(),
                                          decimal=6)
示例#5
0
 def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
     tensor_masked = tensor * broadcast_mask
     mean = torch.sum(tensor_masked) / num_elements_not_masked
     variance = (torch.sum(
         ((tensor_masked - mean) * broadcast_mask)**2) /
                 num_elements_not_masked)
     return (tensor - mean) / torch.sqrt(
         variance + util.tiny_value_of_dtype(variance.dtype))
示例#6
0
 def log_gradient_updates(self, model: Model, param_updates: Dict[str, torch.Tensor]) -> None:
     for name, param in model.named_parameters():
         update_norm = torch.norm(param_updates[name].view(-1))
         param_norm = torch.norm(param.view(-1)).cpu()
         self.add_train_scalar(
             "gradient_update/" + name,
             update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)),
         )
示例#7
0
 def _log_gradient_updates(self, param_updates: Dict[str, torch.Tensor]) -> None:
     gradient_update_scalars: Dict[str, float] = {}
     for name, param in self.trainer.model.named_parameters():  # type: ignore[union-attr]
         update_norm = torch.norm(param_updates[name].view(-1))
         param_norm = torch.norm(param.view(-1)).cpu()
         gradient_update_scalars[name] = (
             update_norm / (param_norm + tiny_value_of_dtype(param_norm.dtype))
         ).item()
     self.log_scalars(gradient_update_scalars, log_prefix="gradient_update")
示例#8
0
 def forward(self, tokens: torch.Tensor) -> torch.Tensor:
     # (batch_size, sentence_length, features_vocab_length)
     mask = (tokens > 0).float()
     # (batch_size, sentence_length, features_vocab_length, embedding_dim)
     x = super().forward(tokens)
     # (batch_size, sentence_length, embedding_dim)
     return x.sum(dim=-2) / (
         (mask.sum(dim=-1) +
          util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1))
    def test_masked_layer_norm(self):
        x_n = np.random.rand(2, 3, 7)
        mask_n = np.array([[1, 1, 0], [1, 1, 1]])

        x = torch.from_numpy(x_n).float()
        mask = torch.from_numpy(mask_n).bool()

        layer_norm = MaskedLayerNorm(7, gamma0=0.2)
        normed_x = layer_norm(x, mask)

        N = 7 * 5
        mean = (x_n * np.expand_dims(mask_n, axis=-1)).sum() / N
        std = np.sqrt(((
            (x_n - mean) * np.expand_dims(mask_n, axis=-1))**2).sum() / N +
                      util.tiny_value_of_dtype(torch.float))
        expected = 0.2 * (x_n - mean) / (std +
                                         util.tiny_value_of_dtype(torch.float))

        assert np.allclose(normed_x.data.numpy(), expected)
示例#10
0
    def forward(
        self,
        source: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        # Shape: (batch_size, embedding_dim)
        source_norm = source / (
            source.norm(p=2, dim=-1, keepdim=True) +
            tiny_value_of_dtype(source.dtype)  # type: ignore
        )
        # Shape: (batch_size, embedding_dim)
        target_norm = target / (
            target.norm(p=2, dim=-1, keepdim=True) +
            tiny_value_of_dtype(target.dtype)  # type: ignore
        )
        # Shape: (batch_size, )
        similarity = (source_norm * target_norm).sum(-1)
        distances = 0.5 * (1 - similarity)

        return cast(torch.Tensor, distances)
示例#11
0
def multi_perspective_match_pairwise(vector1: torch.Tensor,
                                     vector2: torch.Tensor,
                                     weight: torch.Tensor) -> torch.Tensor:
    """
    Calculate multi-perspective cosine matching between each time step of
    one vector and each time step of another vector.

    # Parameters

    vector1 : `torch.Tensor`
        A tensor of shape `(batch, seq_len1, hidden_size)`
    vector2 : `torch.Tensor`
        A tensor of shape `(batch, seq_len2, hidden_size)`
    weight : `torch.Tensor`
        A tensor of shape `(num_perspectives, hidden_size)`

    # Returns

    `torch.Tensor` :
        A tensor of shape `(batch, seq_len1, seq_len2, num_perspectives)` consisting
        multi-perspective matching results
    """
    num_perspectives = weight.size(0)

    # (1, num_perspectives, 1, hidden_size)
    weight = weight.unsqueeze(0).unsqueeze(2)

    # (batch, num_perspectives, seq_len*, hidden_size)
    vector1 = weight * vector1.unsqueeze(1).expand(-1, num_perspectives, -1,
                                                   -1)
    vector2 = weight * vector2.unsqueeze(1).expand(-1, num_perspectives, -1,
                                                   -1)

    # (batch, num_perspectives, seq_len*, 1)
    vector1_norm = vector1.norm(p=2, dim=3, keepdim=True)
    vector2_norm = vector2.norm(p=2, dim=3, keepdim=True)

    # (batch, num_perspectives, seq_len1, seq_len2)
    mul_result = torch.matmul(vector1, vector2.transpose(2, 3))
    norm_value = vector1_norm * vector2_norm.transpose(2, 3)

    # (batch, seq_len1, seq_len2, num_perspectives)
    return (
        mul_result /
        norm_value.clamp(min=tiny_value_of_dtype(norm_value.dtype))).permute(
            0, 2, 3, 1)
示例#12
0
def sparse_clip_norm(parameters, max_norm, norm_type=2) -> float:
    """Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.
    Supports sparse gradients.

    # Parameters

    parameters : `(Iterable[torch.Tensor])`
        An iterable of Tensors that will have gradients normalized.
    max_norm : `float`
        The max norm of the gradients.
    norm_type : `float`
        The type of the used p-norm. Can be `'inf'` for infinity norm.

    # Returns

    Total norm of the parameters (viewed as a single vector).
    """
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if norm_type == float("inf"):
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            if p.grad.is_sparse:
                # need to coalesce the repeated indices before finding norm
                grad = p.grad.data.coalesce()
                param_norm = grad._values().norm(norm_type)
            else:
                param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm**norm_type
        total_norm = total_norm**(1.0 / norm_type)
    clip_coef = max_norm / (total_norm +
                            nn_util.tiny_value_of_dtype(total_norm.dtype))
    if clip_coef < 1:
        for p in parameters:
            if p.grad.is_sparse:
                p.grad.data._values().mul_(clip_coef)
            else:
                p.grad.data.mul_(clip_coef)
    return total_norm
示例#13
0
 def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
     batch_size, source_length = state["source_mask"].size()
     trimmed_source_length = source_length - 2
     # Initialize the copy scores to zero.
     state["copy_log_probs"] = (
         state["decoder_hidden"].new_zeros((batch_size, trimmed_source_length))
         + util.tiny_value_of_dtype(state["decoder_hidden"].dtype)
     ).log()
     # shape: (batch_size,)
     start_predictions = state["source_mask"].new_full(
         (batch_size,), fill_value=self._start_index, dtype=torch.long
     )
     # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
     # shape (log_probabilities): (batch_size, beam_size)
     all_top_k_predictions, log_probabilities = self._beam_search.search(
         start_predictions, state, self.take_search_step
     )
     return {"predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions}
示例#14
0
    def batch_end_logging(self, trainer):
        # Log parameter values to tensorboard
        if self.tensorboard.should_log_this_batch():
            self.tensorboard.log_parameter_and_gradient_statistics(
                trainer.model, trainer.batch_grad_norm)
            self.tensorboard.log_learning_rates(trainer.model,
                                                trainer.optimizer)

            self.tensorboard.add_train_scalar("loss/loss_train",
                                              trainer.train_metrics["loss"])
            self.tensorboard.log_metrics({
                "epoch_metrics/" + k: v
                for k, v in trainer.train_metrics.items()
            })

        if self.log_batch_size_period:
            cur_batch = training_util.get_batch_size(trainer.batch)
            self.cumulative_batch_size += cur_batch
            if (trainer.batches_this_epoch -
                    1) % self.log_batch_size_period == 0:
                average = self.cumulative_batch_size / trainer.batches_this_epoch
                logger.debug(
                    f"current batch size: {cur_batch} mean batch size: {average}"
                )
                self.tensorboard.add_train_scalar("current_batch_size",
                                                  cur_batch)
                self.tensorboard.add_train_scalar("mean_batch_size", average)

        if self.tensorboard.should_log_histograms_this_batch():
            for name, param in trainer.model.named_parameters():
                self.param_updates[name].sub_(param.detach().cpu())
                update_norm = torch.norm(self.param_updates[name].view(-1))
                param_norm = torch.norm(param.view(-1)).cpu()
                self.tensorboard.add_train_scalar(
                    "gradient_update/" + name,
                    update_norm /
                    (param_norm +
                     nn_util.tiny_value_of_dtype(param_norm.dtype)),
                )
            self.param_updates.clear()
            self.tensorboard.log_histograms(trainer.model,
                                            self.histogram_parameters)
    def _get_ll_contrib(
        self,
        generation_scores: torch.Tensor,
        generation_scores_mask: torch.BoolTensor,
        copy_scores: torch.Tensor,
        target_tokens: torch.Tensor,
        target_to_source: torch.Tensor,
        copy_mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get the log-likelihood contribution from a single timestep.

        # Parameters

        generation_scores : `torch.Tensor`
            Shape: `(batch_size, target_vocab_size)`
        generation_scores_mask : `torch.BoolTensor`
            Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's.
        copy_scores : `torch.Tensor`
            Shape: `(batch_size, trimmed_source_length)`
        target_tokens : `torch.Tensor`
            Shape: `(batch_size,)`
        target_to_source : `torch.Tensor`
            Shape: `(batch_size, trimmed_source_length)`
        copy_mask : `torch.BoolTensor`
            Shape: `(batch_size, trimmed_source_length)`

        # Returns

        Tuple[torch.Tensor, torch.Tensor]
            Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
        """
        _, target_size = generation_scores.size()

        # The point of this mask is to just mask out all source token scores
        # that just represent padding. We apply the mask to the concatenation
        # of the generation scores and the copy scores to normalize the scores
        # correctly during the softmax.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores_mask, copy_mask), dim=-1)
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        log_probs = util.masked_log_softmax(all_scores, mask)
        # Calculate the log probability (`copy_log_probs`) for each token in the source sentence
        # that matches the current target token. We use the sum of these copy probabilities
        # for matching tokens in the source sentence to get the total probability
        # for the target token. We also need to normalize the individual copy probabilities
        # to create `selective_weights`, which are used in the next timestep to create
        # a selective read state.
        # shape: (batch_size, trimmed_source_length)
        copy_log_probs = (log_probs[:, target_size:] +
                          (target_to_source.to(log_probs.dtype) +
                           util.tiny_value_of_dtype(log_probs.dtype)).log())
        # Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
        # we use a non-log softmax to get the normalized non-log copy probabilities.
        selective_weights = util.masked_softmax(log_probs[:, target_size:],
                                                target_to_source)
        # This mask ensures that item in the batch has a non-zero generation probabilities
        # for this timestep only when the gold target token is not OOV or there are no
        # matching tokens in the source sentence.
        # shape: (batch_size, 1)
        gen_mask = (target_tokens !=
                    self._oov_index) | (target_to_source.sum(-1) == 0)
        log_gen_mask = (
            gen_mask +
            util.tiny_value_of_dtype(log_probs.dtype)).log().unsqueeze(-1)
        # Now we get the generation score for the gold target token.
        # shape: (batch_size, 1)
        generation_log_probs = log_probs.gather(
            1, target_tokens.unsqueeze(1)) + log_gen_mask
        # ... and add the copy score to get the step log likelihood.
        # shape: (batch_size, 1 + trimmed_source_length)
        combined_gen_and_copy = torch.cat(
            (generation_log_probs, copy_log_probs), dim=-1)
        # shape: (batch_size,)
        step_log_likelihood = util.logsumexp(combined_gen_and_copy)

        return step_log_likelihood, selective_weights
示例#16
0
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        peak_cpu_usage = common_util.peak_memory_mb()
        logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
        gpu_usage = []
        for gpu, memory in common_util.gpu_memory_mb().items():
            gpu_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self._pytorch_model.train()

        # Get tqdm for the training batches
        batch_generator = iter(self.data_loader)
        batch_group_generator = common_util.lazy_groups_of(
            batch_generator, self._num_gradient_accumulation_steps
        )

        logger.info("Training")

        num_training_batches = math.ceil(
            len(self.data_loader) / self._num_gradient_accumulation_steps
        )
        # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's
        # progress is shown
        if self._master:
            batch_group_generator_tqdm = Tqdm.tqdm(
                batch_group_generator, total=num_training_batches
            )
        else:
            batch_group_generator_tqdm = batch_group_generator

        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())

        cumulative_batch_group_size = 0
        done_early = False
        for batch_group in batch_group_generator_tqdm:
            if self._distributed:
                # Check whether the other workers have stopped already (due to differing amounts of
                # data in each). If so, we can't proceed because we would hang when we hit the
                # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                # here because NCCL process groups apparently don't support BoolTensor.
                done = torch.tensor(0, device=self.cuda_device)
                torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
                if done.item() > 0:
                    done_early = True
                    logger.warning(
                        f"Worker {torch.distributed.get_rank()} finishing training early! "
                        "This implies that there is an imbalance in your training "
                        "data across the workers and that some amount of it will be "
                        "ignored. A small amount of this is fine, but a major imbalance "
                        "should be avoided. Note: This warning will appear unless your "
                        "data is perfectly balanced."
                    )
                    break

            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self.optimizer.zero_grad()

            for batch in batch_group:
                loss = self.batch_loss(batch, for_training=True)
                if torch.isnan(loss):
                    raise ValueError("nan loss encountered")
                loss = loss / len(batch_group)
                if self._opt_level is not None:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_loss += loss.item()

            batch_grad_norm = self.rescale_gradients()

            # This does nothing if batch_num_total is None or you are using a
            # scheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)
            if self._momentum_scheduler:
                self._momentum_scheduler.step_batch(batch_num_total)

            if self._tensorboard.should_log_histograms_this_batch() and self._master:
                # get the magnitude of parameter updates for logging
                # We need a copy of current parameters to compute magnitude of updates,
                # and copy them to CPU so large models won't go OOM on the GPU.
                param_updates = {
                    name: param.detach().cpu().clone()
                    for name, param in self.model.named_parameters()
                }
                self.optimizer.step()
                for name, param in self.model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
                    update_norm = torch.norm(param_updates[name].view(-1))
                    param_norm = torch.norm(param.view(-1)).cpu()
                    self._tensorboard.add_train_scalar(
                        "gradient_update/" + name,
                        update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)),
                    )
            else:
                self.optimizer.step()

            # Update moving averages
            if self._moving_average is not None:
                self._moving_average.apply(batch_num_total)

            # Update the description with the latest metrics
            metrics = training_util.get_metrics(
                self.model,
                train_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=[self.cuda_device],
            )

            # Updating tqdm only for the master as the trainers wouldn't have one
            if self._master:
                description = training_util.description_from_metrics(metrics)
                batch_group_generator_tqdm.set_description(description, refresh=False)

            # Log parameter values to Tensorboard (only from the master)
            if self._tensorboard.should_log_this_batch() and self._master:
                self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm)
                self._tensorboard.log_learning_rates(self.model, self.optimizer)

                self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"])
                self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})

            if self._tensorboard.should_log_histograms_this_batch() and self._master:
                self._tensorboard.log_histograms(self.model, histogram_parameters)

            if self._log_batch_size_period:
                batch_group_size = sum(training_util.get_batch_size(batch) for batch in batch_group)
                cumulative_batch_group_size += batch_group_size
                if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
                    average = cumulative_batch_group_size / batches_this_epoch
                    logger.info(
                        f"current batch size: {batch_group_size} mean batch size: {average}"
                    )
                    self._tensorboard.add_train_scalar("current_batch_size", batch_group_size)
                    self._tensorboard.add_train_scalar("mean_batch_size", average)

            # Save model if needed.
            if (
                self._model_save_interval is not None
                and (time.time() - last_save_time > self._model_save_interval)
                and self._master
            ):
                last_save_time = time.time()
                self._save_checkpoint(
                    "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time)))
                )
        if self._distributed and not done_early:
            logger.warning(
                f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)."
            )
            # Indicate that we're done so that any workers that have remaining data stop the epoch early.
            done = torch.tensor(1, device=self.cuda_device)
            torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
            assert done.item()

        # Let all workers finish their epoch before computing
        # the final statistics for the epoch.
        if self._distributed:
            dist.barrier()

        metrics = training_util.get_metrics(
            self.model,
            train_loss,
            batches_this_epoch,
            reset=True,
            world_size=self._world_size,
            cuda_device=[self.cuda_device],
        )
        metrics["cpu_memory_MB"] = peak_cpu_usage
        for (gpu_num, memory) in gpu_usage:
            metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
        return metrics
示例#17
0
 def forward(self, tensor: torch.Tensor):
     mean = tensor.mean(-1, keepdim=True)
     std = tensor.std(-1, unbiased=False, keepdim=True)
     return (self.gamma * (tensor - mean) /
             (std + util.tiny_value_of_dtype(std.dtype)) + self.beta)
示例#18
0
    def _forward_train(self, embeddings: torch.Tensor, targets: torch.Tensor,
                       target_token_embedding: torch.Tensor) -> torch.Tensor:

        # (target_token_embedding is only used in the tie_embeddings case,
        #  which is not implemented)

        # want to compute (n, n_samples + 1) array with the log
        # probabilities where the first index is the true target
        # and the remaining ones are the the negative samples.
        # then we can just select the first column

        # NOTE: targets input has padding removed (so 0 == the first id, NOT the padding id)

        (
            sampled_ids,
            target_expected_count,
            sampled_expected_count,
        ) = self.log_uniform_candidate_sampler(targets,
                                               choice_func=self.choice_func)

        long_targets = targets.long()
        long_targets.requires_grad_(False)

        # Get the softmax weights (so we can compute logits)
        # shape (batch_size * max_sequence_length + num_samples)
        all_ids = torch.cat([long_targets, sampled_ids], dim=0)

        if self.sparse:
            all_ids_1 = all_ids.unsqueeze(1)
            all_w = self.softmax_w(all_ids_1).squeeze(1)
            all_b = self.softmax_b(all_ids_1).squeeze(2).squeeze(1)
        else:
            all_w = torch.nn.functional.embedding(all_ids, self.softmax_w)
            # the unsqueeze / squeeze works around an issue with 1 dim
            # embeddings
            all_b = torch.nn.functional.embedding(
                all_ids, self.softmax_b.unsqueeze(1)).squeeze(1)

        batch_size = long_targets.size(0)
        true_w = all_w[:batch_size, :]
        sampled_w = all_w[batch_size:, :]
        true_b = all_b[:batch_size]
        sampled_b = all_b[batch_size:]

        # compute the logits and remove log expected counts
        # [batch_size, ]
        true_logits = (
            (true_w * embeddings).sum(dim=1) + true_b -
            torch.log(target_expected_count +
                      util.tiny_value_of_dtype(target_expected_count.dtype)))
        # [batch_size, n_samples]
        sampled_logits = (
            torch.matmul(embeddings, sampled_w.t()) + sampled_b -
            torch.log(sampled_expected_count +
                      util.tiny_value_of_dtype(sampled_expected_count.dtype)))

        # remove true labels -- we will take
        # softmax, so set the sampled logits of true values to a large
        # negative number
        # [batch_size, n_samples]
        true_in_sample_mask = sampled_ids == long_targets.unsqueeze(1)
        masked_sampled_logits = sampled_logits.masked_fill(
            true_in_sample_mask, -10000.0)
        # now concat the true logits as index 0
        # [batch_size, n_samples + 1]
        logits = torch.cat([true_logits.unsqueeze(1), masked_sampled_logits],
                           dim=1)

        # finally take log_softmax
        log_softmax = torch.nn.functional.log_softmax(logits, dim=1)
        # true log likelihood is index 0, loss = -1.0 * sum over batch
        # the likelihood loss can become very large if the corresponding
        # true logit is very small, so we apply a per-target cap here
        # so that a single logit for a very rare word won't dominate the batch.
        nll_loss = -1.0 * log_softmax[:, 0].sum()
        return nll_loss
    def _gather_final_log_probs(
        self,
        generation_log_probs: torch.Tensor,
        copy_log_probs: torch.Tensor,
        state: Dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """
        Combine copy probabilities with generation probabilities for matching tokens.

        # Parameters

        generation_log_probs : `torch.Tensor`
            Shape: `(group_size, target_vocab_size)`
        copy_log_probs : `torch.Tensor`
            Shape: `(group_size, trimmed_source_length)`
        state : `Dict[str, torch.Tensor]`

        # Returns

        torch.Tensor
            Shape: `(group_size, target_vocab_size + trimmed_source_length)`.
        """
        _, trimmed_source_length = state["source_to_target"].size()
        source_token_ids = state["source_token_ids"]

        # shape: [(batch_size, *)]
        modified_log_probs_list: List[torch.Tensor] = []
        for i in range(trimmed_source_length):
            # shape: (group_size,)
            copy_log_probs_slice = copy_log_probs[:, i]
            # `source_to_target` is a matrix of shape (group_size, trimmed_source_length)
            # where element (i, j) is the vocab index of the target token that matches the jth
            # source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
            # We'll use this to add copy scores to corresponding generation scores.
            # shape: (group_size,)
            source_to_target_slice = state["source_to_target"][:, i]
            # The OOV index in the source_to_target_slice indicates that the source
            # token is not in the target vocab, so we don't want to add that copy score
            # to the OOV token.
            copy_log_probs_to_add_mask = source_to_target_slice != self._oov_index
            copy_log_probs_to_add = (
                copy_log_probs_slice +
                (copy_log_probs_to_add_mask +
                 util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log())
            # shape: (batch_size, 1)
            copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1)
            # shape: (batch_size, 1)
            selected_generation_log_probs = generation_log_probs.gather(
                1, source_to_target_slice.unsqueeze(-1))
            combined_scores = util.logsumexp(
                torch.cat(
                    (selected_generation_log_probs, copy_log_probs_to_add),
                    dim=1))
            generation_log_probs = generation_log_probs.scatter(
                -1, source_to_target_slice.unsqueeze(-1),
                combined_scores.unsqueeze(-1))
            # We have to combine copy scores for duplicate source tokens so that
            # we can find the overall most likely source token. So, if this is the first
            # occurence of this particular source token, we add the log_probs from all other
            # occurences, otherwise we zero it out since it was already accounted for.
            if i < (trimmed_source_length - 1):
                # Sum copy scores from future occurences of source token.
                # shape: (group_size, trimmed_source_length - i)
                source_future_occurences = source_token_ids[:, (
                    i + 1):] == source_token_ids[:, i].unsqueeze(-1)
                # shape: (group_size, trimmed_source_length - i)
                future_copy_log_probs = (
                    copy_log_probs[:, (i + 1):] +
                    (source_future_occurences +
                     util.tiny_value_of_dtype(copy_log_probs.dtype)).log())
                # shape: (group_size, 1 + trimmed_source_length - i)
                combined = torch.cat((copy_log_probs_slice.unsqueeze(-1),
                                      future_copy_log_probs),
                                     dim=-1)
                # shape: (group_size,)
                copy_log_probs_slice = util.logsumexp(combined)
            if i > 0:
                # Remove copy log_probs that we have already accounted for.
                # shape: (group_size, i)
                source_previous_occurences = source_token_ids[:, 0:
                                                              i] == source_token_ids[:, i].unsqueeze(
                                                                  -1)
                # shape: (group_size,)
                duplicate_mask = source_previous_occurences.sum(dim=-1) == 0
                copy_log_probs_slice = (
                    copy_log_probs_slice +
                    (duplicate_mask + util.tiny_value_of_dtype(
                        copy_log_probs_slice.dtype)).log())

            # Finally, we zero-out copy scores that we added to the generation scores
            # above so that we don't double-count them.
            # shape: (group_size,)
            left_over_copy_log_probs = (
                copy_log_probs_slice +
                (~copy_log_probs_to_add_mask +
                 util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log())
            modified_log_probs_list.append(
                left_over_copy_log_probs.unsqueeze(-1))
        modified_log_probs_list.insert(0, generation_log_probs)

        # shape: (group_size, target_vocab_size + trimmed_source_length)
        modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)

        return modified_log_probs