コード例 #1
0
ファイル: base.py プロジェクト: ipipan/combo
 def _loss(self, pred: torch.Tensor, true: torch.Tensor,
           mask: torch.BoolTensor,
           sample_weights: torch.Tensor) -> torch.Tensor:
     BATCH_SIZE, _, CLASSES = pred.size()
     valid_positions = mask.sum()
     pred = pred.reshape(-1, CLASSES)
     true = true.reshape(-1)
     mask = mask.reshape(-1)
     loss = utils.masked_cross_entropy(pred, true, mask)
     loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
     return loss.sum() / valid_positions
コード例 #2
0
ファイル: morpho.py プロジェクト: ipipan/combo
    def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
              sample_weights: torch.Tensor) -> torch.Tensor:
        assert pred.size() == true.size()
        BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size()

        valid_positions = mask.sum()

        pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES)
        true = true.reshape(-1, MORPHOLOGICAL_FEATURES)
        mask = mask.reshape(-1)
        loss = None
        loss_func = utils.masked_cross_entropy
        for cat, cat_indices in self.slices.items():
            if cat not in ["__PAD__", "_"]:
                if loss is None:
                    loss = loss_func(pred[:, cat_indices],
                                     true[:, cat_indices].argmax(dim=1),
                                     mask)
                else:
                    loss += loss_func(pred[:, cat_indices],
                                      true[:, cat_indices].argmax(dim=1),
                                      mask)
        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
        return loss.sum() / valid_positions
    def _unfold_long_sequences(
        self,
        embeddings: torch.FloatTensor,
        mask: torch.BoolTensor,
        batch_size: int,
        num_segment_concat_wordpieces: int,
    ) -> torch.FloatTensor:
        """
        We take 2D segments of a long sequence and flatten them out to get the whole sequence
        representation while remove unnecessary special tokens.

        [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ]
        -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ]

        We truncate the start and end tokens for all segments, recombine the segments,
        and manually add back the start and end tokens.

        # Parameters

        embeddings: `torch.FloatTensor`
            Shape: [batch_size * num_segments, self._max_length, embedding_size].
        mask: `torch.BoolTensor`
            Shape: [batch_size * num_segments, self._max_length].
            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
            in `forward()`.
        batch_size: `int`
        num_segment_concat_wordpieces: `int`
            The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e.
            the original `token_ids.size(1)`.

        # Returns:

        embeddings: `torch.FloatTensor`
            Shape: [batch_size, self._num_wordpieces, embedding_size].
        """
        def lengths_to_mask(lengths, max_len, device):
            return torch.arange(max_len, device=device).expand(
                lengths.size(0), max_len) < lengths.unsqueeze(1)

        device = embeddings.device
        num_segments = int(embeddings.size(0) / batch_size)
        embedding_size = embeddings.size(2)

        # We want to remove all segment-level special tokens but maintain sequence-level ones
        num_wordpieces = num_segment_concat_wordpieces - (
            num_segments - 1) * self._num_added_tokens

        embeddings = embeddings.reshape(batch_size,
                                        num_segments * self._max_length,
                                        embedding_size)
        mask = mask.reshape(batch_size, num_segments * self._max_length)
        # We assume that all 1s in the mask precede all 0s, and add an assert for that.
        # Open an issue on GitHub if this breaks for you.
        # Shape: (batch_size,)
        seq_lengths = mask.sum(-1)
        if not (lengths_to_mask(seq_lengths, mask.size(1), device)
                == mask).all():
            raise ValueError(
                "Long sequence splitting only supports masks with all 1s preceding all 0s."
            )
        # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op
        end_token_indices = (
            seq_lengths.unsqueeze(-1) -
            torch.arange(self._num_added_end_tokens, device=device) - 1)

        # Shape: (batch_size, self._num_added_start_tokens, embedding_size)
        start_token_embeddings = embeddings[:, :self.
                                            _num_added_start_tokens, :]
        # Shape: (batch_size, self._num_added_end_tokens, embedding_size)
        end_token_embeddings = batched_index_select(embeddings,
                                                    end_token_indices)

        embeddings = embeddings.reshape(batch_size, num_segments,
                                        self._max_length, embedding_size)
        embeddings = embeddings[:, :, self._num_added_start_tokens:-self.
                                _num_added_end_tokens, :]  # truncate segment-level start/end tokens
        embeddings = embeddings.reshape(batch_size, -1,
                                        embedding_size)  # flatten

        # Now try to put end token embeddings back which is a little tricky.

        # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation.
        # Shape: (batch_size,)
        num_effective_segments = (seq_lengths + self._max_length -
                                  1) / self._max_length
        # The number of indices that end tokens should shift back.
        num_removed_non_end_tokens = (
            num_effective_segments * self._num_added_tokens -
            self._num_added_end_tokens)
        # Shape: (batch_size, self._num_added_end_tokens)
        end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1)
        assert (end_token_indices >= self._num_added_start_tokens).all()
        # Add space for end embeddings
        embeddings = torch.cat(
            [embeddings, torch.zeros_like(end_token_embeddings)], 1)
        # Add end token embeddings back
        embeddings.scatter_(
            1,
            end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings),
            end_token_embeddings)

        # Now put back start tokens. We can do this before putting back end tokens, but then
        # we need to change `num_removed_non_end_tokens` a little.
        embeddings = torch.cat([start_token_embeddings, embeddings], 1)

        # Truncate to original length
        embeddings = embeddings[:, :num_wordpieces, :]
        return embeddings
コード例 #4
0
    def forward(
        self,
        sdf: Callable[[torch.Tensor], torch.Tensor],
        cam_loc: torch.Tensor,
        object_mask: torch.BoolTensor,
        ray_directions: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            sdf: A callable that takes a (N, 3) tensor of points and returns
                a tensor of (N,) SDF values.
            cam_loc: A tensor of (B, N, 3) ray origins.
            object_mask: A (N, 3) tensor of indicators whether a sampled pixel
                corresponds to the rendered object or background.
            ray_directions: A tensor of (B, N, 3) ray directions.

        Returns:
            curr_start_points: A tensor of (B*N, 3) found intersection points
                with the implicit surface.
            network_object_mask: A tensor of (B*N,) indicators denoting whether
                intersections were found.
            acc_start_dis: A tensor of (B*N,) distances from the ray origins
                to intersrection points.
        """
        batch_size, num_pixels, _ = ray_directions.shape
        device = cam_loc.device

        sphere_intersections, mask_intersect = _get_sphere_intersection(
            cam_loc, ray_directions, r=self.object_bounding_sphere)

        (
            curr_start_points,
            unfinished_mask_start,
            acc_start_dis,
            acc_end_dis,
            min_dis,
            max_dis,
        ) = self.sphere_tracing(
            batch_size,
            num_pixels,
            sdf,
            cam_loc,
            ray_directions,
            mask_intersect,
            sphere_intersections,
        )

        network_object_mask = acc_start_dis < acc_end_dis

        # The non convergent rays should be handled by the sampler
        sampler_mask = unfinished_mask_start
        sampler_net_obj_mask = torch.zeros_like(sampler_mask,
                                                dtype=torch.bool,
                                                device=device)
        if sampler_mask.sum() > 0:
            sampler_min_max = torch.zeros((batch_size, num_pixels, 2),
                                          device=device)
            sampler_min_max.reshape(-1, 2)[sampler_mask,
                                           0] = acc_start_dis[sampler_mask]
            sampler_min_max.reshape(-1, 2)[sampler_mask,
                                           1] = acc_end_dis[sampler_mask]

            sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(
                sdf, cam_loc, object_mask, ray_directions, sampler_min_max,
                sampler_mask)

            curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
            acc_start_dis[sampler_mask] = sampler_dists[sampler_mask]
            network_object_mask[sampler_mask] = sampler_net_obj_mask[
                sampler_mask]

        if not self.training:
            return curr_start_points, network_object_mask, acc_start_dis

        # in case we are training, we are updating curr_start_points and acc_start_dis for

        ray_directions = ray_directions.reshape(-1, 3)
        mask_intersect = mask_intersect.reshape(-1)
        object_mask = object_mask.reshape(-1)

        in_mask = ~network_object_mask & object_mask & ~sampler_mask
        out_mask = ~object_mask & ~sampler_mask

        # pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
        mask_left_out = (in_mask | out_mask) & ~mask_intersect
        if (mask_left_out.sum() > 0
            ):  # project the origin to the not intersect points on the sphere
            cam_left_out = cam_loc.reshape(-1, 3)[mask_left_out]
            rays_left_out = ray_directions[mask_left_out]
            acc_start_dis[mask_left_out] = -torch.bmm(
                rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3,
                                                                1)).squeeze()
            curr_start_points[mask_left_out] = (
                cam_left_out +
                acc_start_dis[mask_left_out].unsqueeze(1) * rays_left_out)

        mask = (in_mask | out_mask) & mask_intersect

        if mask.sum() > 0:
            min_dis[network_object_mask
                    & out_mask] = acc_start_dis[network_object_mask & out_mask]

            min_mask_points, min_mask_dist = self.minimal_sdf_points(
                sdf, cam_loc, ray_directions, mask, min_dis, max_dis)

            curr_start_points[mask] = min_mask_points
            acc_start_dis[mask] = min_mask_dist

        return curr_start_points, network_object_mask, acc_start_dis