コード例 #1
0
    def cropper(self, source, target, resolution=(320, 320)):
        slice_index = target.names.index("slice")
        source = source.refine_names(*target.names)

        use_center_slice = True
        if use_center_slice:
            # Source and target have a different number of slices when trimming in depth.
            source = source.select(slice_index,
                                   source.size("slice") // 2).rename(None)
            target = target.select("slice",
                                   target.size("slice") // 2).rename(None)
        else:
            source = source.flatten(["batch", "slice"], "batch").rename(None)
            target = target.flatten(["batch", "slice"], "batch").rename(None)

        complex_names = self.complex_names().copy()
        complex_names.pop(slice_index)

        source_abs = T.modulus(source.refine_names(*complex_names))
        if not resolution or all([_ == 0 for _ in resolution]):
            return source_abs.rename(None).unsqueeze(1), target

        source_abs = T.center_crop(source_abs,
                                   resolution).rename(None).unsqueeze(1)
        target_abs = T.center_crop(target, resolution)
        return source_abs, target_abs
コード例 #2
0
ファイル: rim_engine.py プロジェクト: jonasteuwen/direct
    def cropper(self, source, target, resolution=(320, 320)):
        # Can also do reshaping and compute over the full volume
        slice_index = target.names.index("slice")

        use_center_slice = True
        if use_center_slice:
            center_slice = target.size("slice") // 2
            source = source.select(slice_index, center_slice)
            target = target.select("slice", center_slice).rename(None)
        else:
            source = source.refine_names(*target.names)
            source = source.flatten(["batch", "slice"], "batch").rename(None)
            target = target.flatten(["batch", "slice"], "batch").rename(None)

        complex_names = self.complex_names.copy()
        complex_names.pop(slice_index)

        source_abs = modulus(source.refine_names(*complex_names))
        if not resolution or all([_ == 0 for _ in resolution]):
            return source_abs.rename(None).unsqueeze(1), target

        source_abs = center_crop(source_abs,
                                 resolution).rename(None).unsqueeze(1)
        target_abs = center_crop(target, resolution)
        return source_abs, target_abs
コード例 #3
0
    def cropper(self, source, target, resolution):
        source = source.rename(None)
        target = target.align_to(*self.complex_names()).rename(None)
        source_abs = T.modulus(source.refine_names(*self.complex_names()))
        if not resolution or all([_ == 0 for _ in resolution]):
            return source_abs.rename(None).unsqueeze(1), target

        source_abs = T.center_crop(source_abs,
                                   resolution).rename(None).unsqueeze(1)
        target_abs = T.center_crop(target, resolution)
        return source_abs, target_abs
コード例 #4
0
    def process_output(self, data, scaling_factors=None, resolution=None):
        if scaling_factors is not None:
            data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
        data = modulus_if_complex(data).rename(None)
        if len(data.shape) == 3:  # (batch, height, width)
            data = data.unsqueeze(1)  # Added channel dimension.

        if resolution is not None:
            data = center_crop(data, (resolution, resolution)).contiguous()

        return data
コード例 #5
0
def test_center_crop(shape, target_shape, named):
    input = create_input(shape, named=named)
    out_torch = transforms.center_crop(input, target_shape).numpy()
    assert list(out_torch.shape) == target_shape
コード例 #6
0
 def cropper(source, target, resolution=(320, 320)):
     source_abs = modulus(source.refine_names('batch', 'complex', 'height', 'width'))
     if resolution is not None or all([_ is not 0 for _ in resolution]):
         source_abs = center_crop(source_abs, resolution).rename(None).unsqueeze(1)
         target_abs = center_crop(target, resolution)
     return source_abs, target_abs