Beispiel #1
0
def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
    """Roll tensor with different shifts for each row.

    Note:
      We assume the src is a 3 dimensions tensor and roll the last dimension.

    Example:

      >>> src = torch.arange(15).reshape((1,3,5))
      >>> src
      tensor([[[ 0,  1,  2,  3,  4],
               [ 5,  6,  7,  8,  9],
               [10, 11, 12, 13, 14]]])
      >>> shift = torch.tensor([[1, 2, 3]])
      >>> shift
      tensor([[1, 2, 3]])
      >>> _roll_by_shifts(src, shift)
      tensor([[[ 4,  0,  1,  2,  3],
               [ 8,  9,  5,  6,  7],
               [12, 13, 14, 10, 11]]])
    """
    assert src.dim() == 3
    (B, T, S) = src.shape
    assert shifts.shape == (B, T)

    index = (
        torch.arange(S, device=src.device)
        .view((1, S))
        .repeat((T, 1))
        .repeat((B, 1, 1))
    )
    index = (index - shifts.reshape(B, T, 1)) % S
    return torch.gather(src, 2, index)
 def fit(self,s,a_index,Q,critic_loss_coef,entropy_coef):
     
     self.net.train()
     
     s=tensor(s,dtype=float)
     a_index=LongTensor(a_index.reshape((-1,1)))
     Q=tensor(Q,dtype=float)
     
     output_V,output_pi=self.net(s.float())#V,π取得
     
     log_prob=(output_pi.gather(1,a_index).log()).view(-1)#log方策計算
     
     adv=Q-output_V.view(-1)#アドバンテージ関数取得
     
     actor_loss=-(adv.detach()*log_prob).mean()#方策勾配定理よりactorのloss計算
      
     critic_loss=critic_loss_coef*adv.pow(2).mean()#二乗誤差からcriticのloss計算
     
     entropy=entropy_coef*(output_pi*output_pi.log()).sum(axis=1).mean()#方策のエントロピー計算
     
     total_loss=actor_loss+critic_loss-entropy
     
     self.optim.zero_grad()
     total_loss.backward()
     utils.clip_grad_norm(self.net.parameters(),0.5)#更新を抑える
     self.optim.step()
Beispiel #3
0
 def forward(self, x: torch.LongTensor):
     mask = x != 1
     mask = mask.reshape(-1, mask.shape[-1])
     mask[torch.sum(mask, dim=1) == 0, 0] = 1
     x = self.embedding[x].to(self.device)
     batch_size, seq_len, max_char_num, vector_size = x.shape
     x = x.reshape(-1, max_char_num, vector_size)
     x = self.dropout_layer(x)
     x = nn.utils.rnn.pack_padded_sequence(x,
                                           mask.sum(1).int(),
                                           batch_first=True,
                                           enforce_sorted=False)
     h, _ = self.char_encoder(x, None)
     h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first=True)
     h = h[:, 0, :self.hidden_size] + h[:, -1, self.hidden_size:]
     embed = h.reshape(batch_size, seq_len, -1)
     return embed
Beispiel #4
0
    def forward(self, scores: _torch.FloatTensor, relevance: _torch.LongTensor,
                n: _torch.LongTensor) -> _torch.FloatTensor:
        """Computes the loss for given batch of samples.

        Args:
            scores: A batch of per-query-document scores.
            relevance: A batch of per-query-document relevance labels.
            n: A batch of per-query number of documents (for padding purposes).
        """
        # Reshape relevance if necessary.
        if relevance.ndimension() == 2:
            relevance = relevance.reshape(
                (relevance.shape[0], relevance.shape[1], 1))
        if scores.ndimension() == 2:
            scores = scores.reshape((scores.shape[0], scores.shape[1], 1))

        # Compute ranking and sort scores and relevance
        ranking = _rank_by_score(scores, n)
        ranking = ranking.view((ranking.shape[0], ranking.shape[1], 1))
        scores = _torch.gather(scores, 1, ranking)
        relevance = _torch.gather(relevance, 1, ranking)

        # Compute pairwise differences for scores and relevances.
        score_pairs = _batch_pairs(scores)
        rel_pairs = _batch_pairs(relevance)

        # Compute loss per doc pair.
        loss_pairs = self._loss_per_doc_pair(score_pairs, rel_pairs, n)

        # Mask out padded documents per query in the batch
        n_grid = n[:, None, None].repeat(1, score_pairs.shape[1],
                                         score_pairs.shape[2])
        arange = _torch.arange(score_pairs.shape[1],
                               device=score_pairs.device)
        range_grid = _torch.max(*_torch.meshgrid([arange, arange]))
        range_grid = range_grid[None, :, :].repeat(n.shape[0], 1, 1)
        loss_pairs[n_grid <= range_grid] = 0.0

        # Reduce final list loss from per doc pair loss to a per query loss.
        loss = self._loss_reduction(loss_pairs)

        # Return loss
        return loss
def mask_padded_values(xs: _torch.FloatTensor, n: _torch.LongTensor,
                       mask_value: float = -float('inf'),
                       mutate: bool = False):
    """Turns padded values into given mask value.

    Args:
        xs: A tensor of size (batch_size, list_size, 1) containing padded
            values.
        n: A tensor of size (batch_size) containing list size of each query.
        mask_value: The value to mask with (default: -inf).
        mutate: Whether to mutate the values of xs or return a copy.
    """
    mask = _torch.repeat_interleave(
        _torch.arange(xs.shape[1], device=xs.device).reshape((1, xs.shape[1])),
        xs.shape[0], dim=0)
    n_mask = _torch.repeat_interleave(
        n.reshape((n.shape[0], 1)), xs.shape[1], dim=1)
    if not mutate:
        xs = xs.clone()
    xs[mask >= n_mask] = mask_value
    return xs
 def get_loss(
     self,
     rule_probs: torch.FloatTensor,
     target_rules: torch.LongTensor,
     target_mask: torch.FloatTensor,
 ):
     """
     :param rule_probs   (batch_size, target_length, num_rules)
     :param target_mask  (batch_size, target_length)
     """
     batch_size, target_length = target_rules.size()
     rule_probs = torch.gather(
         rule_probs.reshape(-1, self._num_rules),
         dim=1,
         index=target_rules.reshape(-1).unsqueeze(-1).long())
     rule_probs = rule_probs.reshape(batch_size, target_length)
     rule_log_probs = (rule_probs + 1e-10).log()
     rule_log_probs *= target_mask.float()
     rule_normalize_factor = target_mask.sum(-1)
     rule_normalize_factor[rule_normalize_factor == 0] = 1
     rule_loss = rule_log_probs.sum(-1) / rule_normalize_factor.float()
     rule_loss = -1 * (rule_loss.sum() / batch_size)
     return rule_loss
    def _unfold_long_sequences(
        self,
        embeddings: torch.FloatTensor,
        mask: torch.LongTensor,
        batch_size: int,
        num_segment_concat_wordpieces: int,
    ) -> torch.FloatTensor:
        """
        We take 2D segments of a long sequence and flatten them out to get the whole sequence
        representation while remove unnecessary special tokens.

        [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ]
        -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ]

        We truncate the start and end tokens for all segments, recombine the segments,
        and manually add back the start and end tokens.

        # Parameters

        embeddings: `torch.FloatTensor`
            Shape: [batch_size * num_segments, self._max_length, embedding_size].
        mask: `torch.LongTensor`
            Shape: [batch_size * num_segments, self._max_length].
            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
            in `forward()`.
        batch_size: `int`
        num_segment_concat_wordpieces: `int`
            The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e.
            the original `token_ids.size(1)`.

        # Returns:

        embeddings: `torch.FloatTensor`
            Shape: [batch_size, self._num_wordpieces, embedding_size].
        """

        def lengths_to_mask(lengths, max_len, device):
            return torch.arange(max_len, device=device).expand(
                lengths.size(0), max_len
            ) < lengths.unsqueeze(1)

        device = embeddings.device
        num_segments = int(embeddings.size(0) / batch_size)
        embedding_size = embeddings.size(2)

        # We want to remove all segment-level special tokens but maintain sequence-level ones
        num_wordpieces = num_segment_concat_wordpieces - (num_segments - 1) * self._num_added_tokens

        embeddings = embeddings.reshape(batch_size, num_segments * self._max_length, embedding_size)
        mask = mask.reshape(batch_size, num_segments * self._max_length)
        # We assume that all 1s in the mask preceed all 0s, and add an assert for that.
        # Open an issue on GitHub if this breaks for you.
        # Shape: (batch_size,)
        seq_lengths = mask.sum(-1)
        if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all():
            raise ValueError(
                "Long sequence splitting only supports masks with all 1s preceding all 0s."
            )
        # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op
        end_token_indices = (
            seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1
        )

        # Shape: (batch_size, self._num_added_start_tokens, embedding_size)
        start_token_embeddings = embeddings[:, : self._num_added_start_tokens, :]
        # Shape: (batch_size, self._num_added_end_tokens, embedding_size)
        end_token_embeddings = batched_index_select(embeddings, end_token_indices)

        embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size)
        embeddings = embeddings[
            :, :, self._num_added_start_tokens : -self._num_added_end_tokens, :
        ]  # truncate segment-level start/end tokens
        embeddings = embeddings.reshape(batch_size, -1, embedding_size)  # flatten

        # Now try to put end token embeddings back which is a little tricky.

        # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation.
        # Shape: (batch_size,)
        num_effective_segments = (seq_lengths + self._max_length - 1) / self._max_length
        # The number of indices that end tokens should shift back.
        num_removed_non_end_tokens = (
            num_effective_segments * self._num_added_tokens - self._num_added_end_tokens
        )
        # Shape: (batch_size, self._num_added_end_tokens)
        end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1)
        assert (end_token_indices >= self._num_added_start_tokens).all()
        # Add space for end embeddings
        embeddings = torch.cat([embeddings, torch.zeros_like(end_token_embeddings)], 1)
        # Add end token embeddings back
        embeddings.scatter_(
            1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings
        )

        # Now put back start tokens. We can do this before putting back end tokens, but then
        # we need to change `num_removed_non_end_tokens` a little.
        embeddings = torch.cat([start_token_embeddings, embeddings], 1)

        # Truncate to original length
        embeddings = embeddings[:, :num_wordpieces, :]
        return embeddings
    def forward(self,
                indices: torch.LongTensor,
                offsets: Optional[torch.LongTensor] = None,
                per_index_weights: Optional[torch.Tensor] = None):
        """
        Forward process to the embedding bag layer.
        :param indices: Tensor containing bags of indices into the embedding matrix.
        :param offsets: Only used when indices is 1D. offsets determines the starting index position of each bag
        (sequence)in input.
        :param per_index_weights: a tensor of float / double weights, or None to indicate all weights should be taken to
        be 1. If specified, per_sample_weights must have exactly the same shape as input and is treated as having the
        same offsets, if those are not None.
        :return: an #bag x embedding_dim Tensor.
        """

        # always move indices to cpu, as we need to get its corresponding minhash values from table in memory
        indices = indices.cpu()

        # Check input validation.
        if per_index_weights is not None and indices.size() != per_index_weights.size():
            raise ValueError("embedding_bag: If per_index_weights ({}) is not None, "
                             "then it must have the same shape as the indices ({})"
                             .format(per_index_weights.shape, indices.shape))
        if indices.dim() == 2:
            if offsets is not None:
                raise ValueError("if input is 2D, then offsets has to be None"
                                 ", as input is treated is a mini-batch of"
                                 " fixed length sequences. However, found "
                                 "offsets of type {}".format(type(offsets)))
            offsets = torch.arange(0, indices.numel(), indices.size(1), dtype=torch.long, device=indices.device)
            indices = indices.reshape(-1)
            if per_index_weights is not None:
                per_sample_weights = per_index_weights.reshape(-1)
        elif indices.dim() == 1:
            if offsets is None:
                raise ValueError("offsets has to be a 1D Tensor but got None")
            if offsets.dim() != 1:
                raise ValueError("offsets has to be a 1D Tensor")
        else:
            ValueError("input has to be 1D or 2D Tensor,"
                       " but got Tensor of dimension {}".format(input.dim()))

        num_bags = offsets.size(0)

        # get the min-hash for each category value, note that lsh_weight_index is in cpu memory
        lsh_weight_index = self._minhash_table[indices]
        # print("In forward: ", lsh_weight_index, indices, self._minhash_table[indices], self.lsh_weight_size)

        # move the min-hash values to target device
        lsh_weight_index = lsh_weight_index.to(self.hashed_weight.device)
        lsh_weight_index %= self.lsh_weight_size

        # indices_embedding_vector is a |indices| x |embedding_dim| tensor.
        indices_embedding_vectors = self.hashed_weight[lsh_weight_index]
        # print('indices_embedding_vectors: ', lsh_weight_index, indices_embedding_vectors)

        # multiply embedding vectors by weights
        if per_index_weights is not None:
            per_index_weights = per_index_weights.to(indices_embedding_vectors.device)
            indices_embedding_vectors *= per_index_weights[:, None]
        # print("per_index_weights",per_index_weights)
        offsets2bag = make_offset2bag(offsets, indices)
        # print("offsets2bag: ", offsets2bag)
        if self._mode == "sum" or self._mode == "mean":
            result = \
                torch.zeros(num_bags, self.embedding_dim, dtype=indices_embedding_vectors.dtype,
                            device=self.hashed_weight.device)
            result.index_add_(0, offsets2bag, indices_embedding_vectors)
            if self._mode == "sum":
                return result

            # self._mode == "mean":
            bag_size = make_bag_size(offsets, indices).to(result.device)
            result /= bag_size[:, None]
            return result
    def forward(self,
                question: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                label: torch.LongTensor = None,
                binary_labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['tokens']['token_ids']
        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)
        num_binary_choices = 1

        # question_mask = (input_ids != self._padding_value).long()
        question_mask = question['tokens']['mask']

        if self._debug > 0:
            logger.info(f"batch_size = {batch_size}")
            logger.info(f"num_choices = {num_choices}")
            logger.info(f"question_mask = {question_mask}")
            logger.info(f"input_ids.size() = {input_ids.size()}")
            logger.info(f"input_ids = {input_ids}")
            logger.info(f"segment_ids = {segment_ids}")
            logger.info(f"label = {label}")
            logger.info(f"binary_labels = {binary_labels}")

        # Segment ids are not used by RoBERTa

        transformer_outputs = self._transformer_model(
            input_ids=util.combine_initial_dims(input_ids),
            # token_type_ids=util.combine_initial_dims(segment_ids),
            attention_mask=util.combine_initial_dims(question_mask))

        cls_output = transformer_outputs[0]

        if self._debug > 0:
            logger.info(f"cls_output = {cls_output}")

        label_logits = self._classifier(cls_output)
        label_logits_binary = label_logits.view(-1, num_binary_choices)
        label_logits = label_logits.view(-1, num_choices)

        output_dict = {}
        output_dict['label_logits'] = label_logits

        if self._binary_loss:
            output_dict['label_probs'] = self._sigmoid(label_logits)
        else:
            output_dict['label_probs'] = torch.nn.functional.softmax(
                label_logits, dim=1)
        output_dict['answer_index'] = label_logits.argmax(1)

        if self._binary_loss and binary_labels is not None:
            labels_float_reshaped = binary_labels.reshape(
                -1, num_binary_choices).to(label_logits.dtype)
            loss = self._loss(label_logits_binary, labels_float_reshaped)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss
        elif label is not None:
            loss = self._loss(label_logits, label)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if self._debug > 0:
            logger.info(output_dict)
        return output_dict
Beispiel #10
0
def sequence_cross_entropy_with_logits(
    logits: torch.FloatTensor,
    targets: torch.LongTensor,
    weights: torch.FloatTensor,
    average: str = "batch",
    label_smoothing: float = None,
    gamma: float = None,
    alpha: Union[float, List[float], torch.FloatTensor] = None,
) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.
    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    average: str, optional (default = "batch")
        If "batch", average the loss across the batches. If "token", average
        the loss across each item in the input. If ``None``, return a vector
        of losses per batch element.
    label_smoothing : ``float``, optional (default = None)
        Whether or not to apply label smoothing to the cross-entropy loss.
        For example, with a label smoothing value of 0.2, a 4 class classification
        target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
        the correct label.
    gamma : ``float``, optional (default = None)
        Focal loss[*] focusing parameter ``gamma`` to reduces the relative loss for
        well-classified examples and put more focus on hard. The greater value
        ``gamma`` is, the more focus on hard examples.
    alpha : ``float`` or ``List[float]``, optional (default = None)
        Focal loss[*] weighting factor ``alpha`` to balance between classes. Can be
        used independently with ``gamma``. If a single ``float`` is provided, it
        is assumed binary case using ``alpha`` and ``1 - alpha`` for positive and
        negative respectively. If a list of ``float`` is provided, with the same
        length as the number of classes, the weights will match the classes.
        [*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for
        Dense Object Detection," 2017 IEEE International Conference on Computer
        Vision (ICCV), Venice, 2017, pp. 2999-3007.
    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``average=="batch"`` or ``average=="token"``, the returned loss is a scalar.
    If ``average is None``, the returned loss is a vector of shape (batch_size,).
    """
    if average not in {None, "token", "batch"}:
        raise ValueError("Got average f{average}, expected one of "
                         "None, 'token', or 'batch'")

    # make sure weights are float
    weights = weights.float()
    # sum all dim except batch
    non_batch_dims = tuple(range(1, len(weights.shape)))
    # shape : (batch_size,)
    weights_batch_sum = weights.sum(dim=non_batch_dims)
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.reshape(-1, 1).long()
    # focal loss coefficient
    if gamma:
        # shape : (batch * sequence_length, num_classes)
        probs_flat = log_probs_flat.exp()
        # shape : (batch * sequence_length,)
        probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
        # shape : (batch * sequence_length,)
        focal_factor = (1.0 - probs_flat)**gamma
        # shape : (batch, sequence_length)
        focal_factor = focal_factor.view(*targets.size())
        weights = weights * focal_factor

    if alpha is not None:
        # shape : () / (num_classes,)
        if isinstance(alpha, (float, int)):

            # shape : (2,)
            alpha_factor = torch.tensor(
                [1.0 - float(alpha), float(alpha)],
                dtype=weights.dtype,
                device=weights.device)

        elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)):

            # shape : (c,)
            alpha_factor = torch.tensor(alpha,
                                        dtype=weights.dtype,
                                        device=weights.device)

            if not alpha_factor.size():
                # shape : (1,)
                alpha_factor = alpha_factor.view(1)
                # shape : (2,)
                alpha_factor = torch.cat([1 - alpha_factor, alpha_factor])
        else:
            raise TypeError(
                ("alpha must be float, list of float, or torch.FloatTensor, "
                 "{} provided.").format(type(alpha)))
        # shape : (batch, max_len)
        alpha_factor = torch.gather(
            alpha_factor, dim=0,
            index=targets_flat.view(-1)).view(*targets.size())
        weights = weights * alpha_factor

    if label_smoothing is not None and label_smoothing > 0.0:
        num_classes = logits.size(-1)
        smoothing_value = label_smoothing / num_classes
        # Fill all the correct indices with 1 - smoothing value.
        one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(
            -1, targets_flat, 1.0 - label_smoothing)
        smoothed_targets = one_hot_targets + smoothing_value
        negative_log_likelihood_flat = -log_probs_flat * smoothed_targets
        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(
            -1, keepdim=True)
    else:
        # Contribution to the negative log likelihood only comes from the exact indices
        # of the targets, as the target distributions are one-hot. Here we use torch.gather
        # to extract the indices of the num_classes dimension which contribute to the loss.
        # shape : (batch * sequence_length, 1)
        negative_log_likelihood_flat = -torch.gather(
            log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(
        *targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights

    if average == "batch":
        # shape : (batch_size,)
        per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (
            weights_batch_sum + 1e-13)
        num_non_empty_sequences = (weights_batch_sum > 0).float().sum() + 1e-13
        return per_batch_loss.sum() / num_non_empty_sequences
    elif average == "token":
        return negative_log_likelihood.sum() / (weights_batch_sum.sum() +
                                                1e-13)
    else:
        # shape : (batch_size,)
        per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (
            weights_batch_sum + 1e-13)
        return per_batch_loss