def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): """ Args: input_tensor: tensor containing initial class logits. reference_tensor: the reference tensor used to guide the message passing. Returns: output (torch.Tensor): output tensor. """ # useful values spatial_dim = input_tensor.dim() - 2 class_count = input_tensor.size(1) padding = self.compatibility_kernel_range # constructing spatial feature tensor spatial_features = _create_coordinate_tensor(reference_tensor) # constructing final feature tensors for bilateral and gaussian kernel bilateral_features = torch.cat( [spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma], dim=1 ) gaussian_features = spatial_features / self.gaussian_spatial_sigma # compatibility matrix (potts model (1 - diag) for now) compatibility_matrix = _potts_model_weights(class_count).to(device=input_tensor.device) # expanding matrix to kernel compatibility_kernel = _expand_matrix_to_kernel( compatibility_matrix, spatial_dim, self.compatibility_kernel_range ) # choosing convolution function conv = [conv1d, conv2d, conv3d][spatial_dim - 1] # setting up output tensor output_tensor = softmax(input_tensor, dim=1) # mean field loop for _ in range(self.iterations): # message passing step for both kernels bliateral_output = PHLFilter.apply(output_tensor, bilateral_features) gaussian_output = PHLFilter.apply(output_tensor, gaussian_features) # combining filter outputs combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output # compatibility convolution combined_output = pad(combined_output, 2 * spatial_dim * [padding], mode="replicate") compatibility_update = conv(combined_output, compatibility_kernel) # update and normalize output_tensor = softmax(input_tensor - self.update_factor * compatibility_update, dim=1) return output_tensor
def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): """ Args: input_tensor: tensor containing initial class logits. reference_tensor: the reference tensor used to guide the message passing. Returns: output (torch.Tensor): output tensor. """ # constructing spatial feature tensor spatial_features = _create_coordinate_tensor(reference_tensor) # constructing final feature tensors for bilateral and gaussian kernel bilateral_features = torch.cat([ spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma ], dim=1) gaussian_features = spatial_features / self.gaussian_spatial_sigma # setting up output tensor output_tensor = softmax(input_tensor, dim=1) # mean field loop for _ in range(self.iterations): # message passing step for both kernels bliateral_output = PHLFilter.apply(output_tensor, bilateral_features) gaussian_output = PHLFilter.apply(output_tensor, gaussian_features) # combining filter outputs combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output # optionally running a compatability transform if self.compatability_matrix is not None: flat = combined_output.flatten(start_dim=2).permute(0, 2, 1) flat = torch.matmul(flat, self.compatability_matrix) combined_output = flat.permute(0, 2, 1).reshape( combined_output.shape) # update and normalize output_tensor = softmax(input_tensor + self.update_factor * combined_output, dim=1) return output_tensor
def test_cpu(self, test_case_description, sigmas, input, features, expected): # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) # apply filter output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy() # Ensure result are as expected np.testing.assert_allclose(output, expected, atol=1e-4)