def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): """Roll tensor with different shifts for each row. Note: We assume the src is a 3 dimensions tensor and roll the last dimension. Example: >>> src = torch.arange(15).reshape((1,3,5)) >>> src tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]]) >>> shift = torch.tensor([[1, 2, 3]]) >>> shift tensor([[1, 2, 3]]) >>> _roll_by_shifts(src, shift) tensor([[[ 4, 0, 1, 2, 3], [ 8, 9, 5, 6, 7], [12, 13, 14, 10, 11]]]) """ assert src.dim() == 3 (B, T, S) = src.shape assert shifts.shape == (B, T) index = ( torch.arange(S, device=src.device) .view((1, S)) .repeat((T, 1)) .repeat((B, 1, 1)) ) index = (index - shifts.reshape(B, T, 1)) % S return torch.gather(src, 2, index)
def fit(self,s,a_index,Q,critic_loss_coef,entropy_coef): self.net.train() s=tensor(s,dtype=float) a_index=LongTensor(a_index.reshape((-1,1))) Q=tensor(Q,dtype=float) output_V,output_pi=self.net(s.float())#V,π取得 log_prob=(output_pi.gather(1,a_index).log()).view(-1)#log方策計算 adv=Q-output_V.view(-1)#アドバンテージ関数取得 actor_loss=-(adv.detach()*log_prob).mean()#方策勾配定理よりactorのloss計算 critic_loss=critic_loss_coef*adv.pow(2).mean()#二乗誤差からcriticのloss計算 entropy=entropy_coef*(output_pi*output_pi.log()).sum(axis=1).mean()#方策のエントロピー計算 total_loss=actor_loss+critic_loss-entropy self.optim.zero_grad() total_loss.backward() utils.clip_grad_norm(self.net.parameters(),0.5)#更新を抑える self.optim.step()
def forward(self, x: torch.LongTensor): mask = x != 1 mask = mask.reshape(-1, mask.shape[-1]) mask[torch.sum(mask, dim=1) == 0, 0] = 1 x = self.embedding[x].to(self.device) batch_size, seq_len, max_char_num, vector_size = x.shape x = x.reshape(-1, max_char_num, vector_size) x = self.dropout_layer(x) x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int(), batch_first=True, enforce_sorted=False) h, _ = self.char_encoder(x, None) h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first=True) h = h[:, 0, :self.hidden_size] + h[:, -1, self.hidden_size:] embed = h.reshape(batch_size, seq_len, -1) return embed
def forward(self, scores: _torch.FloatTensor, relevance: _torch.LongTensor, n: _torch.LongTensor) -> _torch.FloatTensor: """Computes the loss for given batch of samples. Args: scores: A batch of per-query-document scores. relevance: A batch of per-query-document relevance labels. n: A batch of per-query number of documents (for padding purposes). """ # Reshape relevance if necessary. if relevance.ndimension() == 2: relevance = relevance.reshape( (relevance.shape[0], relevance.shape[1], 1)) if scores.ndimension() == 2: scores = scores.reshape((scores.shape[0], scores.shape[1], 1)) # Compute ranking and sort scores and relevance ranking = _rank_by_score(scores, n) ranking = ranking.view((ranking.shape[0], ranking.shape[1], 1)) scores = _torch.gather(scores, 1, ranking) relevance = _torch.gather(relevance, 1, ranking) # Compute pairwise differences for scores and relevances. score_pairs = _batch_pairs(scores) rel_pairs = _batch_pairs(relevance) # Compute loss per doc pair. loss_pairs = self._loss_per_doc_pair(score_pairs, rel_pairs, n) # Mask out padded documents per query in the batch n_grid = n[:, None, None].repeat(1, score_pairs.shape[1], score_pairs.shape[2]) arange = _torch.arange(score_pairs.shape[1], device=score_pairs.device) range_grid = _torch.max(*_torch.meshgrid([arange, arange])) range_grid = range_grid[None, :, :].repeat(n.shape[0], 1, 1) loss_pairs[n_grid <= range_grid] = 0.0 # Reduce final list loss from per doc pair loss to a per query loss. loss = self._loss_reduction(loss_pairs) # Return loss return loss
def mask_padded_values(xs: _torch.FloatTensor, n: _torch.LongTensor, mask_value: float = -float('inf'), mutate: bool = False): """Turns padded values into given mask value. Args: xs: A tensor of size (batch_size, list_size, 1) containing padded values. n: A tensor of size (batch_size) containing list size of each query. mask_value: The value to mask with (default: -inf). mutate: Whether to mutate the values of xs or return a copy. """ mask = _torch.repeat_interleave( _torch.arange(xs.shape[1], device=xs.device).reshape((1, xs.shape[1])), xs.shape[0], dim=0) n_mask = _torch.repeat_interleave( n.reshape((n.shape[0], 1)), xs.shape[1], dim=1) if not mutate: xs = xs.clone() xs[mask >= n_mask] = mask_value return xs
def get_loss( self, rule_probs: torch.FloatTensor, target_rules: torch.LongTensor, target_mask: torch.FloatTensor, ): """ :param rule_probs (batch_size, target_length, num_rules) :param target_mask (batch_size, target_length) """ batch_size, target_length = target_rules.size() rule_probs = torch.gather( rule_probs.reshape(-1, self._num_rules), dim=1, index=target_rules.reshape(-1).unsqueeze(-1).long()) rule_probs = rule_probs.reshape(batch_size, target_length) rule_log_probs = (rule_probs + 1e-10).log() rule_log_probs *= target_mask.float() rule_normalize_factor = target_mask.sum(-1) rule_normalize_factor[rule_normalize_factor == 0] = 1 rule_loss = rule_log_probs.sum(-1) / rule_normalize_factor.float() rule_loss = -1 * (rule_loss.sum() / batch_size) return rule_loss
def _unfold_long_sequences( self, embeddings: torch.FloatTensor, mask: torch.LongTensor, batch_size: int, num_segment_concat_wordpieces: int, ) -> torch.FloatTensor: """ We take 2D segments of a long sequence and flatten them out to get the whole sequence representation while remove unnecessary special tokens. [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ] -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ] We truncate the start and end tokens for all segments, recombine the segments, and manually add back the start and end tokens. # Parameters embeddings: `torch.FloatTensor` Shape: [batch_size * num_segments, self._max_length, embedding_size]. mask: `torch.LongTensor` Shape: [batch_size * num_segments, self._max_length]. The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask` in `forward()`. batch_size: `int` num_segment_concat_wordpieces: `int` The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e. the original `token_ids.size(1)`. # Returns: embeddings: `torch.FloatTensor` Shape: [batch_size, self._num_wordpieces, embedding_size]. """ def lengths_to_mask(lengths, max_len, device): return torch.arange(max_len, device=device).expand( lengths.size(0), max_len ) < lengths.unsqueeze(1) device = embeddings.device num_segments = int(embeddings.size(0) / batch_size) embedding_size = embeddings.size(2) # We want to remove all segment-level special tokens but maintain sequence-level ones num_wordpieces = num_segment_concat_wordpieces - (num_segments - 1) * self._num_added_tokens embeddings = embeddings.reshape(batch_size, num_segments * self._max_length, embedding_size) mask = mask.reshape(batch_size, num_segments * self._max_length) # We assume that all 1s in the mask preceed all 0s, and add an assert for that. # Open an issue on GitHub if this breaks for you. # Shape: (batch_size,) seq_lengths = mask.sum(-1) if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all(): raise ValueError( "Long sequence splitting only supports masks with all 1s preceding all 0s." ) # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op end_token_indices = ( seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1 ) # Shape: (batch_size, self._num_added_start_tokens, embedding_size) start_token_embeddings = embeddings[:, : self._num_added_start_tokens, :] # Shape: (batch_size, self._num_added_end_tokens, embedding_size) end_token_embeddings = batched_index_select(embeddings, end_token_indices) embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size) embeddings = embeddings[ :, :, self._num_added_start_tokens : -self._num_added_end_tokens, : ] # truncate segment-level start/end tokens embeddings = embeddings.reshape(batch_size, -1, embedding_size) # flatten # Now try to put end token embeddings back which is a little tricky. # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation. # Shape: (batch_size,) num_effective_segments = (seq_lengths + self._max_length - 1) / self._max_length # The number of indices that end tokens should shift back. num_removed_non_end_tokens = ( num_effective_segments * self._num_added_tokens - self._num_added_end_tokens ) # Shape: (batch_size, self._num_added_end_tokens) end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1) assert (end_token_indices >= self._num_added_start_tokens).all() # Add space for end embeddings embeddings = torch.cat([embeddings, torch.zeros_like(end_token_embeddings)], 1) # Add end token embeddings back embeddings.scatter_( 1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings ) # Now put back start tokens. We can do this before putting back end tokens, but then # we need to change `num_removed_non_end_tokens` a little. embeddings = torch.cat([start_token_embeddings, embeddings], 1) # Truncate to original length embeddings = embeddings[:, :num_wordpieces, :] return embeddings
def forward(self, indices: torch.LongTensor, offsets: Optional[torch.LongTensor] = None, per_index_weights: Optional[torch.Tensor] = None): """ Forward process to the embedding bag layer. :param indices: Tensor containing bags of indices into the embedding matrix. :param offsets: Only used when indices is 1D. offsets determines the starting index position of each bag (sequence)in input. :param per_index_weights: a tensor of float / double weights, or None to indicate all weights should be taken to be 1. If specified, per_sample_weights must have exactly the same shape as input and is treated as having the same offsets, if those are not None. :return: an #bag x embedding_dim Tensor. """ # always move indices to cpu, as we need to get its corresponding minhash values from table in memory indices = indices.cpu() # Check input validation. if per_index_weights is not None and indices.size() != per_index_weights.size(): raise ValueError("embedding_bag: If per_index_weights ({}) is not None, " "then it must have the same shape as the indices ({})" .format(per_index_weights.shape, indices.shape)) if indices.dim() == 2: if offsets is not None: raise ValueError("if input is 2D, then offsets has to be None" ", as input is treated is a mini-batch of" " fixed length sequences. However, found " "offsets of type {}".format(type(offsets))) offsets = torch.arange(0, indices.numel(), indices.size(1), dtype=torch.long, device=indices.device) indices = indices.reshape(-1) if per_index_weights is not None: per_sample_weights = per_index_weights.reshape(-1) elif indices.dim() == 1: if offsets is None: raise ValueError("offsets has to be a 1D Tensor but got None") if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") else: ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input.dim())) num_bags = offsets.size(0) # get the min-hash for each category value, note that lsh_weight_index is in cpu memory lsh_weight_index = self._minhash_table[indices] # print("In forward: ", lsh_weight_index, indices, self._minhash_table[indices], self.lsh_weight_size) # move the min-hash values to target device lsh_weight_index = lsh_weight_index.to(self.hashed_weight.device) lsh_weight_index %= self.lsh_weight_size # indices_embedding_vector is a |indices| x |embedding_dim| tensor. indices_embedding_vectors = self.hashed_weight[lsh_weight_index] # print('indices_embedding_vectors: ', lsh_weight_index, indices_embedding_vectors) # multiply embedding vectors by weights if per_index_weights is not None: per_index_weights = per_index_weights.to(indices_embedding_vectors.device) indices_embedding_vectors *= per_index_weights[:, None] # print("per_index_weights",per_index_weights) offsets2bag = make_offset2bag(offsets, indices) # print("offsets2bag: ", offsets2bag) if self._mode == "sum" or self._mode == "mean": result = \ torch.zeros(num_bags, self.embedding_dim, dtype=indices_embedding_vectors.dtype, device=self.hashed_weight.device) result.index_add_(0, offsets2bag, indices_embedding_vectors) if self._mode == "sum": return result # self._mode == "mean": bag_size = make_bag_size(offsets, indices).to(result.device) result /= bag_size[:, None] return result
def forward(self, question: Dict[str, torch.LongTensor], segment_ids: torch.LongTensor = None, label: torch.LongTensor = None, binary_labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> torch.Tensor: self._debug -= 1 input_ids = question['tokens']['token_ids'] batch_size = input_ids.size(0) num_choices = input_ids.size(1) num_binary_choices = 1 # question_mask = (input_ids != self._padding_value).long() question_mask = question['tokens']['mask'] if self._debug > 0: logger.info(f"batch_size = {batch_size}") logger.info(f"num_choices = {num_choices}") logger.info(f"question_mask = {question_mask}") logger.info(f"input_ids.size() = {input_ids.size()}") logger.info(f"input_ids = {input_ids}") logger.info(f"segment_ids = {segment_ids}") logger.info(f"label = {label}") logger.info(f"binary_labels = {binary_labels}") # Segment ids are not used by RoBERTa transformer_outputs = self._transformer_model( input_ids=util.combine_initial_dims(input_ids), # token_type_ids=util.combine_initial_dims(segment_ids), attention_mask=util.combine_initial_dims(question_mask)) cls_output = transformer_outputs[0] if self._debug > 0: logger.info(f"cls_output = {cls_output}") label_logits = self._classifier(cls_output) label_logits_binary = label_logits.view(-1, num_binary_choices) label_logits = label_logits.view(-1, num_choices) output_dict = {} output_dict['label_logits'] = label_logits if self._binary_loss: output_dict['label_probs'] = self._sigmoid(label_logits) else: output_dict['label_probs'] = torch.nn.functional.softmax( label_logits, dim=1) output_dict['answer_index'] = label_logits.argmax(1) if self._binary_loss and binary_labels is not None: labels_float_reshaped = binary_labels.reshape( -1, num_binary_choices).to(label_logits.dtype) loss = self._loss(label_logits_binary, labels_float_reshaped) self._accuracy(label_logits, label) output_dict["loss"] = loss elif label is not None: loss = self._loss(label_logits, label) self._accuracy(label_logits, label) output_dict["loss"] = loss if self._debug > 0: logger.info(output_dict) return output_dict
def sequence_cross_entropy_with_logits( logits: torch.FloatTensor, targets: torch.LongTensor, weights: torch.FloatTensor, average: str = "batch", label_smoothing: float = None, gamma: float = None, alpha: Union[float, List[float], torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Computes the cross entropy loss of a sequence, weighted with respect to some user provided weights. Note that the weighting here is not the same as in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting classes; here we are weighting the loss contribution from particular elements in the sequence. This allows loss computations for models which use padding. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step weights : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch, sequence_length) average: str, optional (default = "batch") If "batch", average the loss across the batches. If "token", average the loss across each item in the input. If ``None``, return a vector of losses per batch element. label_smoothing : ``float``, optional (default = None) Whether or not to apply label smoothing to the cross-entropy loss. For example, with a label smoothing value of 0.2, a 4 class classification target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was the correct label. gamma : ``float``, optional (default = None) Focal loss[*] focusing parameter ``gamma`` to reduces the relative loss for well-classified examples and put more focus on hard. The greater value ``gamma`` is, the more focus on hard examples. alpha : ``float`` or ``List[float]``, optional (default = None) Focal loss[*] weighting factor ``alpha`` to balance between classes. Can be used independently with ``gamma``. If a single ``float`` is provided, it is assumed binary case using ``alpha`` and ``1 - alpha`` for positive and negative respectively. If a list of ``float`` is provided, with the same length as the number of classes, the weights will match the classes. [*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for Dense Object Detection," 2017 IEEE International Conference on Computer Vision (ICCV), Venice, 2017, pp. 2999-3007. Returns ------- A torch.FloatTensor representing the cross entropy loss. If ``average=="batch"`` or ``average=="token"``, the returned loss is a scalar. If ``average is None``, the returned loss is a vector of shape (batch_size,). """ if average not in {None, "token", "batch"}: raise ValueError("Got average f{average}, expected one of " "None, 'token', or 'batch'") # make sure weights are float weights = weights.float() # sum all dim except batch non_batch_dims = tuple(range(1, len(weights.shape))) # shape : (batch_size,) weights_batch_sum = weights.sum(dim=non_batch_dims) # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1) # shape : (batch * max_len, 1) targets_flat = targets.reshape(-1, 1).long() # focal loss coefficient if gamma: # shape : (batch * sequence_length, num_classes) probs_flat = log_probs_flat.exp() # shape : (batch * sequence_length,) probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat) # shape : (batch * sequence_length,) focal_factor = (1.0 - probs_flat)**gamma # shape : (batch, sequence_length) focal_factor = focal_factor.view(*targets.size()) weights = weights * focal_factor if alpha is not None: # shape : () / (num_classes,) if isinstance(alpha, (float, int)): # shape : (2,) alpha_factor = torch.tensor( [1.0 - float(alpha), float(alpha)], dtype=weights.dtype, device=weights.device) elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)): # shape : (c,) alpha_factor = torch.tensor(alpha, dtype=weights.dtype, device=weights.device) if not alpha_factor.size(): # shape : (1,) alpha_factor = alpha_factor.view(1) # shape : (2,) alpha_factor = torch.cat([1 - alpha_factor, alpha_factor]) else: raise TypeError( ("alpha must be float, list of float, or torch.FloatTensor, " "{} provided.").format(type(alpha))) # shape : (batch, max_len) alpha_factor = torch.gather( alpha_factor, dim=0, index=targets_flat.view(-1)).view(*targets.size()) weights = weights * alpha_factor if label_smoothing is not None and label_smoothing > 0.0: num_classes = logits.size(-1) smoothing_value = label_smoothing / num_classes # Fill all the correct indices with 1 - smoothing value. one_hot_targets = torch.zeros_like(log_probs_flat).scatter_( -1, targets_flat, 1.0 - label_smoothing) smoothed_targets = one_hot_targets + smoothing_value negative_log_likelihood_flat = -log_probs_flat * smoothed_targets negative_log_likelihood_flat = negative_log_likelihood_flat.sum( -1, keepdim=True) else: # Contribution to the negative log likelihood only comes from the exact indices # of the targets, as the target distributions are one-hot. Here we use torch.gather # to extract the indices of the num_classes dimension which contribute to the loss. # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = -torch.gather( log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view( *targets.size()) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood * weights if average == "batch": # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / ( weights_batch_sum + 1e-13) num_non_empty_sequences = (weights_batch_sum > 0).float().sum() + 1e-13 return per_batch_loss.sum() / num_non_empty_sequences elif average == "token": return negative_log_likelihood.sum() / (weights_batch_sum.sum() + 1e-13) else: # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / ( weights_batch_sum + 1e-13) return per_batch_loss