def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> torch.Tensor: BATCH_SIZE, _, CLASSES = pred.size() valid_positions = mask.sum() pred = pred.reshape(-1, CLASSES) true = true.reshape(-1) mask = mask.reshape(-1) loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> torch.Tensor: assert pred.size() == true.size() BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size() valid_positions = mask.sum() pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES) true = true.reshape(-1, MORPHOLOGICAL_FEATURES) mask = mask.reshape(-1) loss = None loss_func = utils.masked_cross_entropy for cat, cat_indices in self.slices.items(): if cat not in ["__PAD__", "_"]: if loss is None: loss = loss_func(pred[:, cat_indices], true[:, cat_indices].argmax(dim=1), mask) else: loss += loss_func(pred[:, cat_indices], true[:, cat_indices].argmax(dim=1), mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions
def _unfold_long_sequences( self, embeddings: torch.FloatTensor, mask: torch.BoolTensor, batch_size: int, num_segment_concat_wordpieces: int, ) -> torch.FloatTensor: """ We take 2D segments of a long sequence and flatten them out to get the whole sequence representation while remove unnecessary special tokens. [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ] -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ] We truncate the start and end tokens for all segments, recombine the segments, and manually add back the start and end tokens. # Parameters embeddings: `torch.FloatTensor` Shape: [batch_size * num_segments, self._max_length, embedding_size]. mask: `torch.BoolTensor` Shape: [batch_size * num_segments, self._max_length]. The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask` in `forward()`. batch_size: `int` num_segment_concat_wordpieces: `int` The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e. the original `token_ids.size(1)`. # Returns: embeddings: `torch.FloatTensor` Shape: [batch_size, self._num_wordpieces, embedding_size]. """ def lengths_to_mask(lengths, max_len, device): return torch.arange(max_len, device=device).expand( lengths.size(0), max_len) < lengths.unsqueeze(1) device = embeddings.device num_segments = int(embeddings.size(0) / batch_size) embedding_size = embeddings.size(2) # We want to remove all segment-level special tokens but maintain sequence-level ones num_wordpieces = num_segment_concat_wordpieces - ( num_segments - 1) * self._num_added_tokens embeddings = embeddings.reshape(batch_size, num_segments * self._max_length, embedding_size) mask = mask.reshape(batch_size, num_segments * self._max_length) # We assume that all 1s in the mask precede all 0s, and add an assert for that. # Open an issue on GitHub if this breaks for you. # Shape: (batch_size,) seq_lengths = mask.sum(-1) if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all(): raise ValueError( "Long sequence splitting only supports masks with all 1s preceding all 0s." ) # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op end_token_indices = ( seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1) # Shape: (batch_size, self._num_added_start_tokens, embedding_size) start_token_embeddings = embeddings[:, :self. _num_added_start_tokens, :] # Shape: (batch_size, self._num_added_end_tokens, embedding_size) end_token_embeddings = batched_index_select(embeddings, end_token_indices) embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size) embeddings = embeddings[:, :, self._num_added_start_tokens:-self. _num_added_end_tokens, :] # truncate segment-level start/end tokens embeddings = embeddings.reshape(batch_size, -1, embedding_size) # flatten # Now try to put end token embeddings back which is a little tricky. # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation. # Shape: (batch_size,) num_effective_segments = (seq_lengths + self._max_length - 1) / self._max_length # The number of indices that end tokens should shift back. num_removed_non_end_tokens = ( num_effective_segments * self._num_added_tokens - self._num_added_end_tokens) # Shape: (batch_size, self._num_added_end_tokens) end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1) assert (end_token_indices >= self._num_added_start_tokens).all() # Add space for end embeddings embeddings = torch.cat( [embeddings, torch.zeros_like(end_token_embeddings)], 1) # Add end token embeddings back embeddings.scatter_( 1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings) # Now put back start tokens. We can do this before putting back end tokens, but then # we need to change `num_removed_non_end_tokens` a little. embeddings = torch.cat([start_token_embeddings, embeddings], 1) # Truncate to original length embeddings = embeddings[:, :num_wordpieces, :] return embeddings
def forward( self, sdf: Callable[[torch.Tensor], torch.Tensor], cam_loc: torch.Tensor, object_mask: torch.BoolTensor, ray_directions: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: sdf: A callable that takes a (N, 3) tensor of points and returns a tensor of (N,) SDF values. cam_loc: A tensor of (B, N, 3) ray origins. object_mask: A (N, 3) tensor of indicators whether a sampled pixel corresponds to the rendered object or background. ray_directions: A tensor of (B, N, 3) ray directions. Returns: curr_start_points: A tensor of (B*N, 3) found intersection points with the implicit surface. network_object_mask: A tensor of (B*N,) indicators denoting whether intersections were found. acc_start_dis: A tensor of (B*N,) distances from the ray origins to intersrection points. """ batch_size, num_pixels, _ = ray_directions.shape device = cam_loc.device sphere_intersections, mask_intersect = _get_sphere_intersection( cam_loc, ray_directions, r=self.object_bounding_sphere) ( curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis, ) = self.sphere_tracing( batch_size, num_pixels, sdf, cam_loc, ray_directions, mask_intersect, sphere_intersections, ) network_object_mask = acc_start_dis < acc_end_dis # The non convergent rays should be handled by the sampler sampler_mask = unfinished_mask_start sampler_net_obj_mask = torch.zeros_like(sampler_mask, dtype=torch.bool, device=device) if sampler_mask.sum() > 0: sampler_min_max = torch.zeros((batch_size, num_pixels, 2), device=device) sampler_min_max.reshape(-1, 2)[sampler_mask, 0] = acc_start_dis[sampler_mask] sampler_min_max.reshape(-1, 2)[sampler_mask, 1] = acc_end_dis[sampler_mask] sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler( sdf, cam_loc, object_mask, ray_directions, sampler_min_max, sampler_mask) curr_start_points[sampler_mask] = sampler_pts[sampler_mask] acc_start_dis[sampler_mask] = sampler_dists[sampler_mask] network_object_mask[sampler_mask] = sampler_net_obj_mask[ sampler_mask] if not self.training: return curr_start_points, network_object_mask, acc_start_dis # in case we are training, we are updating curr_start_points and acc_start_dis for ray_directions = ray_directions.reshape(-1, 3) mask_intersect = mask_intersect.reshape(-1) object_mask = object_mask.reshape(-1) in_mask = ~network_object_mask & object_mask & ~sampler_mask out_mask = ~object_mask & ~sampler_mask # pyre-fixme[16]: `Tensor` has no attribute `__invert__`. mask_left_out = (in_mask | out_mask) & ~mask_intersect if (mask_left_out.sum() > 0 ): # project the origin to the not intersect points on the sphere cam_left_out = cam_loc.reshape(-1, 3)[mask_left_out] rays_left_out = ray_directions[mask_left_out] acc_start_dis[mask_left_out] = -torch.bmm( rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3, 1)).squeeze() curr_start_points[mask_left_out] = ( cam_left_out + acc_start_dis[mask_left_out].unsqueeze(1) * rays_left_out) mask = (in_mask | out_mask) & mask_intersect if mask.sum() > 0: min_dis[network_object_mask & out_mask] = acc_start_dis[network_object_mask & out_mask] min_mask_points, min_mask_dist = self.minimal_sdf_points( sdf, cam_loc, ray_directions, mask, min_dis, max_dis) curr_start_points[mask] = min_mask_points acc_start_dis[mask] = min_mask_dist return curr_start_points, network_object_mask, acc_start_dis