def _visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8):
     image_recons = complex_abs(c_img_recons)
     image_targets = complex_abs(c_img_targets)
     kspace_recons = make_k_grid(fft2(c_img_recons), smoothing_factor)
     kspace_targets = make_k_grid(fft2(c_img_targets), smoothing_factor)
     image_recons, image_targets, image_deltas = make_grid_triplet(
         image_recons, image_targets)
     return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
示例#2
0
def visualize_from_kspace(kspace_recons, kspace_targets, smoothing_factor=4):
    """
    Assumes that all values are on the same scale and have the same shape.
    """
    image_recons = complex_abs(ifft2(kspace_recons))
    image_targets = complex_abs(ifft2(kspace_targets))
    image_recons, image_targets, image_deltas = make_grid_triplet(
        image_recons, image_targets)
    kspace_targets = make_k_grid(kspace_targets, smoothing_factor)
    kspace_recons = make_k_grid(kspace_recons, smoothing_factor)
    return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
示例#3
0
    def forward(self, k_output, targets, scales):
        """

        Args:
            k_output (torch.Tensor):
            targets (list):
            scales (list):

        Returns:
            image_recons (list):
            kspace_recons (list):

        """
        assert k_output.size(0) == len(targets) == len(scales)
        image_recons = list()
        kspace_recons = list()

        for k_slice, target, scaling in zip(k_output, targets, scales):

            left = (k_slice.size(-1) - target.size(-1)) // 2
            right = left + target.size(-1)

            k_slice_recon = chw_to_k_slice(k_slice[..., left:right] * scaling)
            i_slice_recon = complex_abs(ifft2(k_slice_recon))

            assert i_slice_recon.shape == target.shape
            image_recons.append(i_slice_recon)
            kspace_recons.append(k_slice_recon)

        return image_recons, kspace_recons
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {'img_inputs': input_image, 'rss_inputs': input_rss}

        return input_image, targets, extra_params
def test_complex_abs(shape):
    shape = shape + [2]
    tensor = create_tensor(shape)
    out_torch = data_transforms.complex_abs(tensor).numpy()
    tensor_numpy = data_transforms.tensor_to_complex_np(tensor)
    out_numpy = np.abs(tensor_numpy)
    assert np.allclose(out_torch, out_numpy)
    def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])

        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        return recons  # Returning scaled reconstructions. Not rescaled.
    def forward(self, cmg_output, targets, extra_params):
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        cmg_target = targets['cmg_targets']
        cmg_recon = nchw_to_kspace(cmg_output)
        assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            raise NotImplementedError('Not ready yet.')
            # cmg_acs = targets['cmg_acss']
            # cmg_recon = cmg_recon + cmg_acs

        kspace_recon = fft2(cmg_recon)
        img_recon = complex_abs(cmg_recon)

        recons = {
            'kspace_recons': kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
示例#8
0
    def __call__(self, k_slice, target, attrs, file_name, slice_num):
        """
        Args:
            k_slice (numpy.array): Input k-space of shape (num_coils, height, width) for multi-coil
                data or (rows, cols) for single coil data.
            target (numpy.array): Target (320x320) image. May be None.
            attrs (dict): Acquisition related information stored in the HDF5 object.
            file_name (str): File name
            slice_num (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                data (torch.Tensor): kspace data converted to CHW format for CNNs, where C=(2*num_coils).
                    Also has padding in the width axis for auto-encoders, which have down-sampling regions.
                    This requires the data to be divisible by some number (usually 2**num_pooling_layers).
                    Otherwise, concatenation will not work in the decoder due to different sizes.
                    Only the width dimension is padded in this case due to the nature of the dataset.
                    The height is fixed at 640, while the width is variable.
                labels (torch.Tensor): Coil-wise ground truth images. Shape=(num_coils, H, W)
        """
        assert np.iscomplexobj(k_slice), 'kspace must be complex.'
        assert k_slice.shape[-1] % 2 == 0, 'k-space data width must be even.'

        if k_slice.ndim == 2:  # For singlecoil. Makes data processing later on much easier.
            k_slice = np.expand_dims(k_slice, axis=0)
        elif k_slice.ndim != 3:  # Prevents possible errors.
            raise TypeError('Invalid slice type')

        with torch.no_grad():  # Remove unnecessary gradient calculations.
            # Now a Tensor of (num_coils, height, width, 2), where 2 is (real, imag).
            # The data is in the GPU and has been amplified by the amplification factor.
            k_slice = to_tensor(k_slice).to(device=self.device) * self.amp_fac
            # k_slice = to_tensor(k_slice).cuda(self.device) * self.amp_fac
            target_slice = complex_abs(ifft2(k_slice))  # I need cuda here!
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask = apply_mask(k_slice, self.mask_func, seed)

            data_slice = k_slice_to_chw(masked_kspace)
            # assert data_slice.size(-1) % 2 == 0

            margin = (data_slice.shape[-1] % self.divisor)

            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a temporary fix.
                pad = [0, 0]
            # right_pad = self.divisor - left_pad
            # pad = [pad, pad]
            data_slice = F.pad(
                data_slice, pad=pad,
                value=0)  # This pads at the last dimension of a tensor.

            # Using the data acquisition method (fat suppression) may be useful later on.
        # print(1, data_slice.size())
        # print(2, target_slice.size())
        return data_slice, target_slice
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            else:  # Super-hackish temporary line.  # TODO: Fix this thing later!
                image = center_crop(image, shape=(352, image.size(-1)))

            margin = image.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            image = F.pad(image, pad=pad, value=0)

            img_scale = torch.std(
                center_crop(image, shape=(self.resolution,
                                          self.resolution)))  # Also a hack!
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

        return image, targets, extra_params
示例#10
0
    def _visualize_images(self, recons, targets, epoch, step, training=False):
        mode = 'Training' if training else 'Validation'

        # This numbering scheme seems to have issues for certain numbers.
        # Please check cases when there is no remainder.
        if self.display_interval and (step % self.display_interval == 0):
            img_recon_grid = make_img_grid(recons['img_recons'], self.shrink_scale)

            # The delta image is obtained by subtracting at the complex image, not the real valued image.
            delta_image = complex_abs(targets['cmg_targets'] - recons['cmg_recons'])
            delta_img_grid = make_img_grid(delta_image, self.shrink_scale)

            kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor, self.shrink_scale)

            self.writer.add_image(f'{mode} k-space Recons/{step}', kspace_recon_grid, epoch, dataformats='HW')
            self.writer.add_image(f'{mode} Image Recons/{step}', img_recon_grid, epoch, dataformats='HW')
            self.writer.add_image(f'{mode} Delta Image/{step}', delta_img_grid, epoch, dataformats='HW')

            # Adding RSS images of reconstructions and targets.
            if 'rss_recons' in recons:
                recon_rss = standardize_image(recons['rss_recons'])
                delta_rss = standardize_image(make_rss_slice(delta_image))
                self.writer.add_image(f'{mode} RSS Recons/{step}', recon_rss, epoch, dataformats='HW')
                self.writer.add_image(f'{mode} RSS Delta/{step}', delta_rss, epoch, dataformats='HW')

            if 'semi_kspace_recons' in recons:
                semi_kspace_recon_grid = make_k_grid(
                    recons['semi_kspace_recons'], self.smoothing_factor, self.shrink_scale)

                self.writer.add_image(
                    f'{mode} semi-k-space Recons/{step}', semi_kspace_recon_grid, epoch, dataformats='HW')

            if epoch == 1:  # Maybe add input images too later on.
                img_target_grid = make_img_grid(targets['img_targets'], self.shrink_scale)
                kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor, self.shrink_scale)

                # Not actually the input but what the input looks like as an image.
                img_grid = make_img_grid(targets['img_inputs'], self.shrink_scale)

                self.writer.add_image(f'{mode} k-space Targets/{step}', kspace_target_grid, epoch, dataformats='HW')
                self.writer.add_image(f'{mode} Image Targets/{step}', img_target_grid, epoch, dataformats='HW')
                self.writer.add_image(f'{mode} Inputs as Images/{step}', img_grid, epoch, dataformats='HW')

                if 'rss_targets' in targets:
                    target_rss = standardize_image(targets['rss_targets'])
                    self.writer.add_image(f'{mode} RSS Targets/{step}', target_rss, epoch, dataformats='HW')

                if 'semi_kspace_targets' in targets:
                    semi_kspace_target_grid = make_k_grid(targets['semi_kspace_targets'],
                                                          self.smoothing_factor, self.shrink_scale)

                    self.writer.add_image(f'{mode} semi-k-space Targets/{step}',
                                          semi_kspace_target_grid, epoch, dataformats='HW')
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(semi_kspace_recons, num_low_freqs)
            semi_kspace_recons = semi_kspace_recons + acs_mask * semi_kspace_targets

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = center_crop(img_recons,
                                     (self.resolution, self.resolution))
            rss_recons = root_sum_of_squares(rss_recons, dim=1).squeeze()
            rss_recons *= extra_params[
                'sk_scales']  # This value was divided in the inputs. It is thus multiplied here.
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
示例#12
0
    def forward(self, tensor,
                out_shape):  # Using out_shape only works for batch size of 1.
        """
        Args:
            tensor (torch.Tensor): Input tensor of shape [batch_size, in_chans, height, width]
            out_shape (tuple): shape [batch_size, num_coils, true_height, true_width].
            Note that in_chans = 2 * num_coils
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = list()
        output = tensor
        # Apply down-sampling layers
        for layer in self.down_sample_layers:
            output = layer(output)
            stack.append(output)
            output = F.max_pool2d(output, kernel_size=2)

        output = self.conv(output)

        # Apply up-sampling layers
        for layer in self.up_sample_layers:
            output = F.interpolate(output,
                                   scale_factor=2,
                                   mode='bilinear',
                                   align_corners=False)
            output = torch.cat((output, stack.pop()), dim=1)
            output = layer(output)

        output = self.conv2(output)  # End of learning.

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (output.size(-1) - out_shape[-1]
                ) // 2  # This depends on mini-batch size being 1 to work.
        right = left + out_shape[-1]

        # Previously, cropping was done by  [pad:-pad]. However, this fails catastrophically when pad=0 as
        # all values are wiped out in those cases where [0:0] creates an empty tensor.

        # Cropping width dimension by pad.
        output = output[..., left:right]

        # Processing to k-space form.
        output = nchw_to_kspace(output)

        # Convert to image.
        output = complex_abs(ifft2(output))

        assert output.size() == out_shape  # Checking just in case.
        return output
    def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])
        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(kspace_recons, num_low_freqs)
            kspace_recons = kspace_recons + acs_mask * kspace_targets

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if img_recons.size(1) == 15:
            top = (img_recons.size(-2) - self.resolution) // 2
            left = (img_recons.size(-1) - self.resolution) // 2
            rss_recon = img_recons[:, :, top:top + self.resolution,
                                   left:left + self.resolution]
            rss_recon = root_sum_of_squares(
                rss_recon, dim=1).squeeze()  # rss_recon is in 2D
            recons['rss_recons'] = rss_recon

        return recons  # Returning scaled reconstructions. Not rescaled.
示例#14
0
    def forward(self, semi_kspace_output: Tensor, targets: dict,
                extra_params: dict):
        assert semi_kspace_output.dim() == 5 and semi_kspace_output.size(
            1) == 2, 'Invalid shape!'
        if semi_kspace_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_target = targets['semi_kspace_targets']
        semi_kspace_recon = semi_kspace_output.permute(
            dims=(0, 2, 3, 4, 1))  # Convert back into NCHW2

        if semi_kspace_recon.shape != semi_kspace_target.shape:  # Cropping recon left-right.
            left = (semi_kspace_recon.size(-2) -
                    semi_kspace_target.size(-2)) // 2
            semi_kspace_recon = semi_kspace_recon[
                ..., left:left + semi_kspace_target.size(-2), :]

        assert semi_kspace_recon.shape == semi_kspace_target.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recon.size(-3) % 2 == 0) and (semi_kspace_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.weighted:
            semi_kspace_recon = semi_kspace_recon / extra_params['weightings']

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recon = semi_kspace_target * mask + (
                1 - mask) * semi_kspace_recon

        # kspace_recon = fft1(semi_kspace_recon, direction='height')
        cmg_recon = ifft1(semi_kspace_recon, direction='width')
        img_recon = complex_abs(cmg_recon)

        recons = {
            'semi_kspace_recons': semi_kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                (self.resolution, self.resolution)) * extra_params['sk_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
 def _cmg_output(cmg_output, targets, extra_params):
     cmg_target = targets['cmg_targets']
     cmg_recon = nchw_to_kspace(
         cmg_output)  # Assumes data was cropped already.
     assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
     assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
         'Not impossible but not expected to have sides with odd lengths.'
     cmg_recon = cmg_recon + targets[
         'cmg_inputs']  # Residual of complex input.
     kspace_recon = fft2(cmg_recon)
     img_recon = complex_abs(cmg_recon)
     recons = {
         'kspace_recons': kspace_recon,
         'cmg_recons': cmg_recon,
         'img_recons': img_recon
     }
     return recons
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            image = complex_abs(ifft2(masked_kspace))

            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))

            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

            # margin = image.size(-1) % self.divisor
            # if margin > 0:
            #     pad = [(self.divisor - margin) // 2, (1 + self.divisor - margin) // 2]
            # else:  # This is a fix to prevent padding by half the divisor when margin=0.
            #     pad = [0, 0]
            #
            # # This pads at the last dimension of a tensor with 0.
            # inputs = F.pad(image, pad=pad, value=0)

        return image, targets, extra_params
示例#17
0
    def __call__(self, k_slice, target, attrs, file_name, slice_num):
        assert np.iscomplexobj(k_slice), 'kspace must be complex.'
        # assert k_slice.shape[-1] % 2 == 0, 'k-space data width must be even.'

        if k_slice.ndim == 2:  # For singlecoil. Makes data processing later on much easier.
            k_slice = np.expand_dims(k_slice, axis=0)
        elif k_slice.ndim != 3:  # Prevents possible errors.
            raise RuntimeError(
                'Invalid slice shape. Please check input shape.')

        with torch.no_grad():  # Remove unnecessary gradient calculations.
            # Now a Tensor of (num_coils, height, width, 2), where 2 is (real, imag).
            k_slice = to_tensor(k_slice).to(device=self.device)
            scaling = torch.std(
                k_slice)  # Pseudo-standard deviation for normalization.
            target_slice = complex_abs(
                ifft2(k_slice))  # Labels are not standardized.
            k_slice *= (torch.ones(
                ()) / scaling)  # Standardization of CNN inputs.
            # Using weird multiplication because multiplication is much faster than division.
            # Multiplying the whole tensor by 1/scaling is faster than dividing the whole tensor by scaling.

            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask = apply_mask(k_slice, self.mask_func, seed)

            data_slice = k_slice_to_chw(masked_kspace)

            margin = data_slice.size(-1) % self.divisor

            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a temporary fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            data_slice = F.pad(
                data_slice, pad=pad,
                value=0)  # This pads at the last dimension of a tensor with 0.

        return data_slice, target_slice, scaling  # This has a different output API.
    def __call__(self, k_slice, target_slice):
        """
        'k_slice' should be in the (C, H, W) format.
        'target_slice' should be (Coil, Height, Width). It should be all real values. C = 2 * Coil
        k_slice should be padded while target_slice should not be padded.
        Singlecoil and Multicoil data both have the same dimensions. Singlecoil has Coil = 1
        """
        # assert k_slice.dim() == target_slice.dim() == 3
        # assert k_slice.size(0) / 2 == target_slice.size(0)

        # Remove padding, etc.
        k_slice = self.k_slice_fn(k_slice, target_slice)

        # Convert to image domain.
        recon_slice = complex_abs(ifft2(k_slice))

        # Image domain post-processing.
        recon_slice = self.img_slice_fn(recon_slice, target_slice)

        assert recon_slice.size() == target_slice.size(), 'Shape conversion went wrong somewhere.'
        return recon_slice
示例#19
0
def make_k_grid(kspace_recons, smoothing_factor=8, shrink_scale=1):
    """
    Function for making k-space visualizations for Tensorboard.
    """
    # Simple hack. Just use the first element if the input is a list --> batching implementation.
    if isinstance(kspace_recons, list):
        kspace_recons = kspace_recons[0].unsqueeze(dim=0)

    if kspace_recons.size(0) > 1:
        raise NotImplementedError(
            'Mini-batch size greater than 1 has not been implemented yet.')

    # Assumes that the smallest values will be close enough to 0 as to not matter much.
    kspace_grid = complex_abs(kspace_recons.detach()).squeeze(dim=0)
    # Scaling & smoothing.
    # smoothing_factor converted to float32 tensor. expm1 and log1p require float32 tensors.
    # They cannot accept python integers.
    sf = torch.tensor(smoothing_factor, dtype=torch.float32)
    kspace_grid *= torch.expm1(sf) / kspace_grid.max()
    kspace_grid = torch.log1p(kspace_grid)  # Adds 1 to input for natural log.
    kspace_grid /= kspace_grid.max()  # Standardization to 0~1 range.

    if kspace_grid.size(0) == 15:
        kspace_grid = torch.cat(torch.chunk(kspace_grid.view(
            -1, kspace_grid.size(-1)),
                                            chunks=5,
                                            dim=0),
                                dim=1)

    if shrink_scale < 1:
        kspace_grid = F.interpolate(kspace_grid.expand(1, 1, -1, -1),
                                    scale_factor=shrink_scale,
                                    mode='bicubic',
                                    align_corners=True)
    elif shrink_scale > 1:
        raise UserWarning(
            'shrink scale is expected to be below 1. Using image with the same size as input.'
        )

    return kspace_grid.squeeze().to(device='cpu', non_blocking=True)
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one batch at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction=self.direction)
        cmg_recons = ifft1(semi_kspace_recons, direction=self.recon_direction)
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        return recons  # Returning scaled reconstructions. Not rescaled.
示例#21
0
    def forward(self, k_output, targets, scales):
        """
            Output post-processing for output k-space tensor with batch size of 1.
            This is experimental and is subject to change.
            Planning on taking k-space outputs from CNNs, then transforming them into original k-space shapes.
            No batch size planned yet.

            Args:
                k_output (torch.Tensor): CNN output of k-space. Expected to have batch-size of 1.
                targets (torch.Tensor): Target image domain data.
                scales (torch.Tensor): scaling factor used to divide the input k-space slice data.
            Returns:
                kspace (torch.Tensor): kspace in original shape with batch dimension in-place.
        """

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (k_output.size(-1) - targets.size(-1)
                ) // 2  # This depends on mini-batch size being 1 to work.
        right = left + targets.size(-1)

        # Previously, cropping was done by  [pad:-pad]. However, this fails catastrophically when pad == 0 as
        # all values are wiped out in those cases where [0:0] creates an empty tensor.

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        k_output = k_output[..., left:right] * scales

        # Processing to k-space form. This is where the batch_size == 1 is important.
        kspace_recons = nchw_to_kspace(k_output)

        # Convert to image.
        image_recons = complex_abs(ifft2(kspace_recons))

        assert image_recons.size() == targets.size(
        ), 'Reconstruction and target sizes do not match.'

        return image_recons, kspace_recons
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs)
        semi_kspace_targets = targets['semi_kspace_targets']
        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            semi_kspace_acs = targets['semi_kspace_acss']
            semi_kspace_recons = semi_kspace_recons + semi_kspace_acs

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = root_sum_of_squares(img_recons, dim=1).squeeze()
            rss_recons *= extra_params['sk_scales']
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
示例#23
0
import torch
import h5py

from data.data_transforms import ifft2, to_tensor, root_sum_of_squares, center_crop, complex_center_crop, complex_abs

file = '/media/veritas/D/FastMRI/multicoil_val/file1001798.h5'
sdx = 10
with h5py.File(file, mode='r') as hf:
    kspace = hf['kspace'][sdx]
    target = hf['reconstruction_rss'][sdx]

cmg_scale = 2E-5
recon = complex_center_crop(ifft2(to_tensor(kspace) / cmg_scale),
                            shape=(320, 320)) * cmg_scale
recon = root_sum_of_squares(complex_abs(recon))
target = to_tensor(target)

print(torch.allclose(recon, target))
示例#24
0
import torch
from data.data_transforms import complex_abs

complex_image = torch.rand(10, 20, 30, 2, dtype=torch.float64)
absolute_image = complex_abs(complex_image)
print(absolute_image.shape)
angle_image = torch.atan2(complex_image[..., 1], complex_image[..., 0])
recon_real = absolute_image * torch.cos(angle_image)
recon_imag = absolute_image * torch.sin(angle_image)
recon_image = torch.stack([recon_real, recon_imag], dim=-1)
print(torch.allclose(complex_image, recon_image))
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(ifft2(masked_kspace))

            semi_kspace = ifft1(masked_kspace, direction='height')

            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            # The slope is meaningless as the results always become the same after standardization no matter the slope.
            # The ordering could be changed to allow a difference, but this would make the inputs non-standardized.
            sk_scale = torch.std(semi_kspace)

            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= sk_scale
            cmg_target = ifft2(kspace_target)
            img_target = complex_abs(cmg_target)
            semi_kspace_target = ifft1(kspace_target, direction='height')

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input
            }

            if kspace_target.size(1) == 15:  # If multi-coil.
                targets[
                    'rss_targets'] = target  # Scaling needed for metric comparison later.

            margin = inputs.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            inputs = F.pad(inputs, pad=pad, value=0)

        return inputs, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # Complex image made from down-sampled k-space.
            complex_image = ifft2(masked_kspace)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))

            cmg_scale = torch.std(complex_image)
            complex_image /= cmg_scale

            extra_params = {'cmg_scales': cmg_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= cmg_scale
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:  # No rotation implemented.
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.
                    kspace_target = fft2(cmg_target)

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))
                    kspace_target = fft2(cmg_target)

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))
                    kspace_target = fft2(cmg_target)

            # The image target is obtained after flipping the complex image.
            # This removes the need to flip the image target.
            img_target = complex_abs(cmg_target)
            img_inputs = complex_abs(complex_image)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'cmg_inputs': complex_image,
                'img_inputs': img_inputs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Creating concatenated image of real/imag/abs channels.
            concat_image = torch.cat(
                [complex_image, img_inputs.unsqueeze(dim=-1)], dim=-1)

            # Converting to NCHW format for CNN.
            inputs = kspace_to_nchw(concat_image)

        return inputs, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            img_target = complex_abs(ifft2(kspace_target))
            if self.crop_center:
                img_target = center_crop(img_target,
                                         shape=(self.resolution,
                                                self.resolution))
            img_target /= img_scale

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    image = torch.flip(image, dims=(-2, -1))
                    img_target = torch.flip(img_target, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    image = torch.flip(image, dims=(-2, ))
                    img_target = torch.flip(img_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    image = torch.flip(image, dims=(-1, ))
                    img_target = torch.flip(img_target, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            targets = {'img_targets': img_target, 'img_inputs': image}

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return image, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            num_low_freqs = info['num_low_frequency']
            acs_mask = self.find_acs_mask(kspace_target, num_low_freqs)
            acs_kspace = kspace_target * acs_mask
            semi_kspace_acs = fft1(complex_center_crop(
                ifft2(acs_kspace), shape=(self.resolution, self.resolution)),
                                   direction='width')

            complex_image = ifft2(masked_kspace)
            complex_image = complex_center_crop(complex_image,
                                                shape=(self.resolution,
                                                       self.resolution))
            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(complex_image)

            # Direction is fixed due to challenge conditions.
            semi_kspace = fft1(complex_image, direction='width')
            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            sk_scale = torch.std(semi_kspace)
            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            cmg_target = ifft2(kspace_target)
            cmg_target = complex_center_crop(cmg_target,
                                             shape=(self.resolution,
                                                    self.resolution))
            cmg_target /= sk_scale
            img_target = complex_abs(cmg_target)
            semi_kspace_target = fft1(cmg_target, direction='width')
            kspace_target = fft2(cmg_target)

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input,
                'semi_kspace_acss': semi_kspace_acs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return inputs, targets, extra_params
示例#29
0
import torch
from data.data_transforms import ifft2, fft2, complex_abs

image = torch.rand(10, 20, 30, 2)
lr_flip = torch.flip(image, dims=[-2])
ud_flip = torch.flip(image, dims=[-3])
all_flip = torch.flip(image, dims=[-3, -2])

kspace = fft2(image)
lr_kspace = fft2(lr_flip)
ud_kspace = fft2(ud_flip)
all_kspace = fft2(all_flip)

absolute = torch.sum(complex_abs(kspace))
lr_abs = torch.sum(complex_abs(lr_kspace))
ud_abs = torch.sum(complex_abs(ud_kspace))
all_abs = torch.sum(complex_abs(all_kspace))

a = torch.allclose(absolute, lr_abs)
b = torch.allclose(absolute, ud_abs)
c = torch.allclose(absolute, all_abs)

print(a, b, c)


示例#30
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            # No subtraction by the mean. I do not know if this is a good idea or not.

            if self.use_patch:
                assert self.resolution == 320
                left = torch.randint(low=0,
                                     high=self.resolution - self.patch_size,
                                     size=(1, )).squeeze()
                top = torch.randint(low=0,
                                    high=self.resolution - self.patch_size,
                                    size=(1, )).squeeze()

                input_image = input_image[:, :, top:top + self.patch_size,
                                          left:left + self.patch_size]
                target = target[top:top + self.patch_size,
                                left:left + self.patch_size]

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

        return input_image, targets, extra_params