コード例 #1
0
 def forward(self, x: Any) -> Any:  # type: ignore
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     target_device = get_device_from_parameters(self)
     [x] = move_to_device(input_tensors=[x], target_device=target_device)
     x = self.block1(x)
     return self.block2(x) + x if self.use_residual else self.block2(x)
コード例 #2
0
 def forward(self, x: Any) -> Any:  # type: ignore
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     [x
      ] = move_to_device([x],
                         target_device=get_device_from_parameters(self))
     return self.upsample_block(x)
コード例 #3
0
 def forward(self, x: Any, skip_connection: Any) -> Any:  # type: ignore
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     [x, skip_connection] = move_to_device(input_tensors=[x, skip_connection],
                                           target_device=get_device_from_parameters(self))
     x = self.conv1(x)
     x += self.conv2(skip_connection)
     x = self.activation_block(x)
     return self.block2(x) + x
コード例 #4
0
 def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
     skip_connections: List[torch.Tensor] = list()
     # Unet Encoder and Decoder paths
     for layer_id, layer in enumerate(self._layers):  # type: ignore
         x = layer(x, skip_connections.pop()) if layer.concat else layer(x)
         if layer_id < self.num_downsampling_paths:  # type: ignore
             skip_connections.append(x)
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     [x] = move_to_device(input_tensors=[x], target_device=get_device_from_parameters(self.output_layer))
     return self.output_layer(x)
コード例 #5
0
    def forward(self, patches: torch.Tensor) -> torch.Tensor:
        """
        Ignore the actual patches and return a fixed segmentation, explained in make_nesting_rectangles.

        :param patches: Set of patches, of shape (#patches, #image_channels, Z, Y, X). Only the shape
        is used.
        :return: Fixed tensor of shape (#patches, number_of_classes, Z, Y, Z).
        """
        output_size: TupleInt3 = (patches.shape[2], patches.shape[3],
                                  patches.shape[4])
        if self.cached_patch_size == output_size:
            patch = self.cached_patch
        else:
            patch = self.make_nest(output_size)
        if patches.shape[0] == 1:
            np_predictions = patch
        else:
            np_predictions = np.broadcast_to(
                patch, (patches.shape[0], *patch.shape[1:]))
        x = torch.tensor(np_predictions, requires_grad=True)
        [x] = move_to_device(input_tensors=[x],
                             target_device=get_device_from_parameters(self))
        return x