示例#1
0
    def forward(self, volume, image_features, projection_indices_3d,
                projection_indices_2d, volume_dims):
        assert len(volume.shape) == 5 and len(image_features.shape) == 4
        batch_size = volume.shape[0]
        num_images = projection_indices_3d.shape[0] // batch_size

        # project 2d to 3d
        image_features = [
            Projection.apply(ft, ind3d, ind2d, volume_dims)
            for ft, ind3d, ind2d in zip(image_features, projection_indices_3d,
                                        projection_indices_2d)
        ]
        image_features = torch.stack(image_features, dim=4)

        # reshape to max pool over features
        sz = image_features.shape
        image_features = image_features.view(sz[0], -1,
                                             batch_size * num_images)
        if num_images == self.num_images:
            image_features = self.pooling(image_features)
        else:
            image_features = nn.MaxPool1d(
                kernel_size=num_images)(image_features)
        image_features = image_features.view(sz[0], sz[1], sz[2], sz[3],
                                             batch_size)
        image_features = image_features.permute(4, 0, 1, 2, 3)

        volume = self.features3d(volume)
        image_features = self.features2d(image_features)
        x = torch.cat([volume, image_features], 1)
        x = self.features(x)
        x = x.view(batch_size, self.nf2 * 54)
        semantic_output = self.semanticClassifier(x)
        semantic_output = semantic_output.view(batch_size, self.grid_dims[2],
                                               self.num_classes)
        scan_output = None
        if self.train_scan_completion:
            scan_output = self.scanClassifier(x)
            # scan_output - [batch_size, 62, 2]
            scan_output = scan_output.view(
                batch_size, self.grid_dims[2],
                3)  # 3 represents voxel grid occupancy values
        return semantic_output, scan_output