Exemplo n.º 1
0
 def forward(self, x: Any) -> Any:  # type: ignore
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     target_device = get_device_from_parameters(self)
     [x] = move_to_device(input_tensors=[x], target_device=target_device)
     x = self.block1(x)
     return self.block2(x) + x if self.use_residual else self.block2(x)
Exemplo n.º 2
0
 def forward(self, x: Any) -> Any:  # type: ignore
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     [x
      ] = move_to_device([x],
                         target_device=get_device_from_parameters(self))
     return self.upsample_block(x)
Exemplo n.º 3
0
 def forward(self, x: Any, skip_connection: Any) -> Any:  # type: ignore
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     [x, skip_connection] = move_to_device(input_tensors=[x, skip_connection],
                                           target_device=get_device_from_parameters(self))
     x = self.conv1(x)
     x += self.conv2(skip_connection)
     x = self.activation_block(x)
     return self.block2(x) + x
Exemplo n.º 4
0
 def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
     skip_connections: List[torch.Tensor] = list()
     # Unet Encoder and Decoder paths
     for layer_id, layer in enumerate(self._layers):  # type: ignore
         x = layer(x, skip_connections.pop()) if layer.concat else layer(x)
         if layer_id < self.num_downsampling_paths:  # type: ignore
             skip_connections.append(x)
     # When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
     # the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
     # construct.
     [x] = move_to_device(input_tensors=[x], target_device=get_device_from_parameters(self.output_layer))
     return self.output_layer(x)
Exemplo n.º 5
0
def test_move_to_device() -> None:
    def assert_device_matches(tensors: List[Tensor], target_device: torch.device) -> None:
        for tensor in tensors:
            assert tensor.device == target_device

    target_device = torch.device('cuda:0')
    input_tensor_1 = torch.tensor(3, device=torch.device('cpu'))
    input_tensor_2 = torch.tensor(3, device=torch.device('cuda:0'))
    tensors = [input_tensor_1, input_tensor_2]
    moved = list(move_to_device(tensors, target_device=target_device))
    assert_device_matches(moved, target_device)

    if torch.cuda.device_count() > 1:
        target_device = torch.device('cuda:1')
        moved = list(move_to_device(tensors, target_device=target_device))
        assert_device_matches(moved, target_device)

    # Not supplying a target device should leave the tensor untouched
    moved = list(move_to_device(tensors, target_device=None))
    assert moved[0].device == tensors[0].device
    assert moved[1].device == tensors[1].device
Exemplo n.º 6
0
    def forward(self, patches: torch.Tensor) -> torch.Tensor:
        """
        Ignore the actual patches and return a fixed segmentation, explained in make_nesting_rectangles.

        :param patches: Set of patches, of shape (#patches, #image_channels, Z, Y, X). Only the shape
        is used.
        :return: Fixed tensor of shape (#patches, number_of_classes, Z, Y, Z).
        """
        output_size: TupleInt3 = (patches.shape[2], patches.shape[3],
                                  patches.shape[4])
        if self.cached_patch_size == output_size:
            patch = self.cached_patch
        else:
            patch = self.make_nest(output_size)
        if patches.shape[0] == 1:
            np_predictions = patch
        else:
            np_predictions = np.broadcast_to(
                patch, (patches.shape[0], *patch.shape[1:]))
        x = torch.tensor(np_predictions, requires_grad=True)
        [x] = move_to_device(input_tensors=[x],
                             target_device=get_device_from_parameters(self))
        return x