Пример #1
0
    def validation_step(self, batch, batch_idx):
        masked_kspace, mask, target, fname, slice_num, max_value, _ = batch

        output = self.forward(masked_kspace, mask)
        target, output = T.center_crop_to_smallest(target, output)

        # hash strings to int so pytorch can concat them
        fnumber = torch.zeros(len(fname),
                              dtype=torch.long,
                              device=output.device)
        for i, fn in enumerate(fname):
            fnumber[i] = (
                int(hashlib.sha256(fn.encode("utf-8")).hexdigest(), 16) %
                10**12)

        return {
            "fname":
            fnumber,
            "slice":
            slice_num,
            "output":
            output,
            "target":
            target,
            "val_loss":
            self.loss(output.unsqueeze(1),
                      target.unsqueeze(1),
                      data_range=max_value),
        }
Пример #2
0
    def training_step(self, batch, batch_idx):
        masked_kspace, mask, target, _, _, max_value, _ = batch

        output = self(masked_kspace, mask)

        target, output = T.center_crop_to_smallest(target, output)
        loss = self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=max_value)

        return {"loss": loss, "log": {"train_loss": loss.item()}}
Пример #3
0
    def training_step(self, batch, batch_idx):
        masked_kspace, mask, target, _, _, max_value, _ = batch

        output = self(masked_kspace, mask)

        target, output = transforms.center_crop_to_smallest(target, output)
        loss = self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=max_value)

        self.log("train_loss", loss)

        return loss
Пример #4
0
    def training_step(self, batch, batch_idx):
        output = self(batch.masked_kspace, batch.mask,
                      batch.num_low_frequencies)

        target, output = transforms.center_crop_to_smallest(
            batch.target, output)
        loss = self.loss(output.unsqueeze(1),
                         target.unsqueeze(1),
                         data_range=batch.max_value)

        self.log("train_loss", loss)

        return loss
Пример #5
0
    def training_loss(self, batch):
        output, target = self.predict(batch)
        output, target = transforms.center_crop_to_smallest(output, target)

        if self.args.nan_detection:
            if torch.any(torch.isnan(output)):
                print(output)
                raise Exception("nan encountered")
            if torch.any(torch.isinf(output)):
                print(output)
                raise Exception("inf encountered")

        loss = F.l1_loss(output, target)
        return loss, output, target
Пример #6
0
    def validation_step(self, batch, batch_idx):
        masked_kspace, mask, target, fname, slice_num, max_value, _ = batch

        kspace_pred = self(masked_kspace, mask)
        output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred))
        target, output = transforms.center_crop_to_smallest(target, output)

        return {
            "batch_idx": batch_idx,
            "fname": fname,
            "slice_num": slice_num,
            "max_value": max_value,
            "output": output,
            "target": target,
            "val_loss": self.loss(kspace_pred, masked_kspace),
        }
Пример #7
0
    def validation_step(self, batch, batch_idx):
        masked_kspace, mask, target, fname, slice_num, max_value, _ = batch

        output = self.forward(masked_kspace, mask)
        target, output = transforms.center_crop_to_smallest(target, output)

        return {
            "batch_idx": batch_idx,
            "fname": fname,
            "slice_num": slice_num,
            "max_value": max_value,
            "output": output,
            "target": target,
            "val_loss": self.loss(
                output.unsqueeze(1), target.unsqueeze(1), data_range=max_value
            ),
        }
Пример #8
0
    def validation_step(self, batch, batch_idx):
        output = self.forward(batch.masked_kspace, batch.mask,
                              batch.num_low_frequencies)
        target, output = transforms.center_crop_to_smallest(
            batch.target, output)

        return {
            "batch_idx":
            batch_idx,
            "fname":
            batch.fname,
            "slice_num":
            batch.slice_num,
            "max_value":
            batch.max_value,
            "output":
            output,
            "target":
            target,
            "val_loss":
            self.loss(output.unsqueeze(1),
                      target.unsqueeze(1),
                      data_range=batch.max_value),
        }
Пример #9
0
    def compute_stats(self, epoch, loader, setname):
        """ This is separate from stats mainly for distributed support"""
        args = self.args
        self.model.eval()
        ndevbatches = len(self.dev_loader)
        logging.info(f"Evaluating {ndevbatches} batches ...")

        recons, gts = defaultdict(list), defaultdict(list)
        acquisition_machine_by_fname = dict()
        with torch.no_grad():
            for batch_idx, batch in enumerate(self.dev_loader):
                progress = epoch + batch_idx / ndevbatches
                logging_epoch = batch_idx % args.log_interval == 0
                logging_epoch_info = batch_idx % (2 * args.log_interval) == 0
                log = logging.info if logging_epoch_info else logging.debug

                self.start_of_test_batch_hook(progress, logging_epoch)

                batch = self.preprocess_data(batch)
                output, target = self.predict(batch)
                output = self.unnorm(output, batch)
                target = self.unnorm(target, batch)
                fname, slice = batch.fname, batch.slice

                for i in range(output.shape[0]):
                    slice_cpu = slice[i].item()
                    recons[fname[i]].append(
                        (slice_cpu, output[i].float().cpu().numpy()))
                    gts[fname[i]].append(
                        (slice_cpu, target[i].float().cpu().numpy()))

                    acquisition_type = batch.attrs_dict['acquisition'][i]
                    machine_type = batch.attrs_dict['system'][i]
                    acquisition_machine_by_fname[
                        fname[i]] = machine_type + '_' + acquisition_type

                if logging_epoch or batch_idx == ndevbatches - 1:
                    gpu_memory_gb = torch.cuda.memory_allocated() / 1000000000
                    host_memory_gb = utils.host_memory_usage_in_gb()
                    log(f"Evaluated {batch_idx+1} of {ndevbatches} (GPU Mem: {gpu_memory_gb:2.3f}gb Host Mem: {gpu_memory_gb:2.3f}gb)"
                        )
                    sys.stdout.flush()

                if self.args.debug_epoch_stats:
                    break
                del output, target, batch

            logging.debug(f"Finished evaluating")
            self.end_of_test_epoch_hook()

            recons = {
                fname: np.stack([pred for _, pred in sorted(slice_preds)])
                for fname, slice_preds in recons.items()
            }
            gts = {
                fname: np.stack([pred for _, pred in sorted(slice_preds)])
                for fname, slice_preds in gts.items()
            }

            nmse, psnr, ssims = [], [], []
            ssim_for_acquisition_machine = defaultdict(list)
            recon_keys = list(recons.keys()).copy()
            for fname in recon_keys:
                pred_or, gt_or = recons[fname].squeeze(1), gts[fname].squeeze(
                    1)
                pred, gt = transforms.center_crop_to_smallest(pred_or, gt_or)
                del pred_or, gt_or

                ssim = evaluate.ssim(gt, pred)
                acquisition_machine = acquisition_machine_by_fname[fname]
                ssim_for_acquisition_machine[acquisition_machine].append(ssim)
                ssims.append(ssim)
                nmse.append(evaluate.nmse(gt, pred))
                psnr.append(evaluate.psnr(gt, pred))
                del gt, pred
                del recons[fname], gts[fname]

            if len(nmse) == 0:
                nmse.append(0)
                ssims.append(0)
                psnr.append(0)

            min_vol_ssim = np.argmin(ssims)
            min_vol = str(recon_keys[min_vol_ssim])
            logging.info(f"Min vol ssims: {min_vol}")
            sys.stdout.flush()

            del recons, gts

            acquisition_machine_losses = dict.fromkeys(
                self.dev_data.system_acquisitions, 0)
            for key, value in ssim_for_acquisition_machine.items():
                acquisition_machine_losses[key] = np.mean(value)

            losses = {
                'NMSE': np.mean(nmse),
                'PSNR': np.mean(psnr),
                'SSIM': np.mean(ssims),
                'SSIM_var': np.var(ssims),
                'SSIM_min': np.min(ssims),
                **acquisition_machine_losses
            }

        return losses
Пример #10
0
def test_center_crop_to_smallest(x_shape, y_shape, target_shape):
    input_x = create_input(x_shape)
    input_y = create_input(y_shape)
    x_out, y_out = transforms.center_crop_to_smallest(input_x, input_y)
    assert x_out.shape == y_out.shape
    assert list(x_out.shape) == target_shape