def __getitem__(self, item): if self.loadASTrain: if self.torchiosub: return Subject({ 'img': H5DSImage(self.samples[item][0], lazypatch=self.lazypatch) }) else: return torch.from_numpy(self.samples[item][0][()]).unsqueeze(0) else: if self.torchiosub: if len(self.samples[item]) == 2: return Subject({ 'img': H5DSImage(self.samples[item][0], lazypatch=self.lazypatch), 'gt': H5DSImage(self.samples[item][1], lazypatch=self.lazypatch) }) else: return Subject({ 'img': H5DSImage(self.samples[item], lazypatch=self.lazypatch), 'gt': H5DSImage(self.samples[item], lazypatch=self.lazypatch) }) #this is dirty. TODO else: return (torch.from_numpy( self.samples[item][0][()]).unsqueeze(0), torch.from_numpy( self.samples[item][1][()]).unsqueeze(0))
def __getitem__(self, item): if self.torchiosub: return Subject({ 'img': H5DSImage(self.samples[item], lazypatch=self.lazypatch) }) else: return torch.from_numpy(self.samples[item][()]).unsqueeze(0)
def plot_subject(subject: Subject, save_plot_path: str): if save_plot_path: os.makedirs(save_plot_path, exist_ok=True) data_dict = {} sx, sy, sz = subject.spatial_shape sx, sy, sz = min(sx, sy, sz) / sx, min(sx, sy, sz) / sy, min(sx, sy, sz) / sz for name, image in subject.get_images_dict(intensity_only=False).items(): if isinstance(image, LabelMap): data_dict[name] = LabelMap( tensor=squeeze_segmentation(image), affine=np.eye(4) * np.array([sx, sy, sz, 1]), ) else: data_dict[name] = ScalarImage(tensor=image.data, affine=np.eye(4) * np.array([sx, sy, sz, 1])) out_subject = Subject(data_dict) out_subject.plot(reorient=False, show=True, figsize=(10, 10)) mpl, plt = import_mpl_plt() backend_ = mpl.get_backend() plt.ioff() mpl.use("agg") for x in range(max(out_subject.spatial_shape)): out_subject.plot( reorient=False, indices=( min(x, out_subject.spatial_shape[0] - 1), min(x, out_subject.spatial_shape[1] - 1), min(x, out_subject.spatial_shape[2] - 1), ), output_path=f"{save_plot_path}/{x:03d}.png", show=False, figsize=(10, 10), ) plt.close("all") plt.ion() mpl.use(backend_) create_gifs(save_plot_path, f"{save_plot_path}/{os.path.basename(save_plot_path)}.gif")
def plot_aggregated_image( writer: SummaryWriter, epoch: int, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, # type: ignore device: torch.device, save_path: str, ): log = logging.getLogger(__name__) sampler, subject_id = random_subject_from_loader(data_loader) aggregator_x = GridAggregator(sampler) aggregator_y = GridAggregator(sampler) aggregator_y_pred = GridAggregator(sampler) for batch, locations in batches_from_sampler(sampler, data_loader.batch_size): x: torch.Tensor = batch["image"]["data"] aggregator_x.add_batch(x, locations) y: torch.Tensor = batch["seg"]["data"] aggregator_y.add_batch(y, locations) logits = model(x.to(device)) y_pred = (torch.sigmoid(logits) > 0.5).float() aggregator_y_pred.add_batch(y_pred, locations) whole_x = aggregator_x.get_output_tensor() whole_y = aggregator_y.get_output_tensor() whole_y_pred = aggregator_y_pred.get_output_tensor() plot_subject( Subject( image=ScalarImage(tensor=whole_x), true_seg=LabelMap(tensor=whole_y), pred_seg=LabelMap(tensor=whole_y_pred), ), f"{save_path}/{epoch}-{subject_id}", )