Example #1
0
    def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):
        """
        Args:
            input_tensor: tensor containing initial class logits.
            reference_tensor: the reference tensor used to guide the message passing.

        Returns:
            output (torch.Tensor): output tensor.
        """

        # useful values
        spatial_dim = input_tensor.dim() - 2
        class_count = input_tensor.size(1)
        padding = self.compatibility_kernel_range

        # constructing spatial feature tensor
        spatial_features = _create_coordinate_tensor(reference_tensor)

        # constructing final feature tensors for bilateral and gaussian kernel
        bilateral_features = torch.cat(
            [spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma], dim=1
        )
        gaussian_features = spatial_features / self.gaussian_spatial_sigma

        # compatibility matrix (potts model (1 - diag) for now)
        compatibility_matrix = _potts_model_weights(class_count).to(device=input_tensor.device)

        # expanding matrix to kernel
        compatibility_kernel = _expand_matrix_to_kernel(
            compatibility_matrix, spatial_dim, self.compatibility_kernel_range
        )

        # choosing convolution function
        conv = [conv1d, conv2d, conv3d][spatial_dim - 1]

        # setting up output tensor
        output_tensor = softmax(input_tensor, dim=1)

        # mean field loop
        for _ in range(self.iterations):

            # message passing step for both kernels
            bliateral_output = PHLFilter.apply(output_tensor, bilateral_features)
            gaussian_output = PHLFilter.apply(output_tensor, gaussian_features)

            # combining filter outputs
            combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output

            # compatibility convolution
            combined_output = pad(combined_output, 2 * spatial_dim * [padding], mode="replicate")
            compatibility_update = conv(combined_output, compatibility_kernel)

            # update and normalize
            output_tensor = softmax(input_tensor - self.update_factor * compatibility_update, dim=1)

        return output_tensor
Example #2
0
File: crf.py Project: tuan-cs/MONAI
    def forward(self, input_tensor: torch.Tensor,
                reference_tensor: torch.Tensor):
        """
        Args:
            input_tensor: tensor containing initial class logits.
            reference_tensor: the reference tensor used to guide the message passing.

        Returns:
            output (torch.Tensor): output tensor.
        """

        # constructing spatial feature tensor
        spatial_features = _create_coordinate_tensor(reference_tensor)

        # constructing final feature tensors for bilateral and gaussian kernel
        bilateral_features = torch.cat([
            spatial_features / self.bilateral_spatial_sigma,
            reference_tensor / self.bilateral_color_sigma
        ],
                                       dim=1)
        gaussian_features = spatial_features / self.gaussian_spatial_sigma

        # setting up output tensor
        output_tensor = softmax(input_tensor, dim=1)

        # mean field loop
        for _ in range(self.iterations):

            # message passing step for both kernels
            bliateral_output = PHLFilter.apply(output_tensor,
                                               bilateral_features)
            gaussian_output = PHLFilter.apply(output_tensor, gaussian_features)

            # combining filter outputs
            combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output

            # optionally running a compatability transform
            if self.compatability_matrix is not None:
                flat = combined_output.flatten(start_dim=2).permute(0, 2, 1)
                flat = torch.matmul(flat, self.compatability_matrix)
                combined_output = flat.permute(0, 2, 1).reshape(
                    combined_output.shape)

            # update and normalize
            output_tensor = softmax(input_tensor +
                                    self.update_factor * combined_output,
                                    dim=1)

        return output_tensor
Example #3
0
    def test_cpu(self, test_case_description, sigmas, input, features, expected):

        # Create input tensors
        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu"))
        feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu"))

        # apply filter
        output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy()

        # Ensure result are as expected
        np.testing.assert_allclose(output, expected, atol=1e-4)