def eval(self, *metrics: TorchLoss, datasets: List[TorchDataset], atlas) -> Tuple[TorchLoss]: """Evaluate on validation set with training loss function if none provided""" val_data_loader = MultiDataLoader(*datasets, batch_size=self.config.batch_size, sampler_cls="sequential", num_workers=self.config.num_workers, pin_memory=self.config.pin_memory) self.network.eval() if not isinstance(metrics, Iterable) and isinstance( metrics, TorchLoss): metrics = [metrics] for metric in metrics: metric.reset() for idx, (inputs, outputs) in enumerate(val_data_loader): inputs = prepare_tensors(inputs, self.config.gpu, self.config.device) preds = self.network(inputs, atlas) for metric in metrics: preds = prepare_tensors(preds, self.config.gpu, self.config.device) outputs = prepare_tensors(outputs, self.config.gpu, self.config.device) metric.cumulate(preds, outputs) return metrics
def inference(self, epoch: int): output_dir = self.config.experiment_dir.joinpath("inference").joinpath("CP_{}".format(epoch)) for val in self.validation_sets: indices = np.random.choice(len(val.image_paths), self.config.n_inference) for idx in indices: image_path = val.image_paths[idx] label_path = val.label_paths[idx] output_dir.joinpath(val.name, image_path.parent.stem).mkdir(exist_ok=True, parents=True) image = val.get_image_tensor_from_index(idx) label = val.get_label_tensor_from_index(idx) # template = val.get_label_tensor_from_index(np.random.randint(0, len(val))) template = val.template template = torch.from_numpy(template).float() template = torch.unsqueeze(template, 0) template = prepare_tensors(template, self.config.gpu, self.config.device) # template_image = val.template_image # template_image = np.expand_dims(template_image, 0) # template_image = torch.from_numpy(template_image).float() # template_image = torch.unsqueeze(template_image, 0) # template_image = prepare_tensors(template_image, self.config.gpu, self.config.device) image = torch.unsqueeze(image, 0) image = prepare_tensors(image, self.config.gpu, self.config.device) self.logger.info("Inferencing for {} dataset, image {}.".format(val.name, idx)) self.inference_func( (image, template), label, image_path, self.network, output_dir.joinpath(val.name, image_path.parent.stem) ) for val in self.extra_validation_sets: indices = np.random.choice(len(val.image_paths), self.config.n_inference) for idx in indices: image_path = val.image_paths[idx] label_path = val.label_paths[idx] output_dir.joinpath(val.name, image_path.parent.stem).mkdir(exist_ok=True, parents=True) image = val.get_image_tensor_from_index(idx) label = val.get_label_tensor_from_index(idx) # template = val.get_label_tensor_from_index(np.random.randint(0, len(val))) template = val.template template = torch.from_numpy(template).float() template = torch.unsqueeze(template, 0) template = prepare_tensors(template, self.config.gpu, self.config.device) # template_image = val.template_image # template_image = np.expand_dims(template_image, 0) # template_image = torch.from_numpy(template_image).float() # template_image = torch.unsqueeze(template_image, 0) # template_image = prepare_tensors(template_image, self.config.gpu, self.config.device) image = torch.unsqueeze(image, 0) image = prepare_tensors(image, self.config.gpu, self.config.device) self.logger.info("Inferencing for {} dataset, image {}.".format(val.name, idx)) self.inference_func( (image, template), label, image_path, self.network, output_dir.joinpath(val.name, image_path.parent.stem), )
def update_atlas(self, train_data_loader, atlas_label, atlas, eta: float = 0.01): pbar = tqdm(enumerate(train_data_loader)) warped_labels = None warped_images = None n = 0 for idx, (inputs, outputs) in pbar: inputs = prepare_tensors(inputs, self.config.gpu, self.config.device) predicted = self.network(inputs, atlas_label) # pass updated atlas in warped_label = torch.sum(predicted[-4], dim=0).cpu().detach().numpy() image = np.squeeze(torch.sum(predicted[-5], dim=0).cpu().detach().numpy(), axis=0) n += predicted[-5].shape[0] if warped_images is None: warped_images = np.zeros( (image.shape[0], image.shape[1], image.shape[2])) if warped_labels is None: warped_labels = np.zeros( (warped_label.shape[0], warped_label.shape[1], warped_label.shape[2], warped_label.shape[3])) warped_images += image warped_labels += warped_label print(n) mean_image = warped_images / n mean_label = warped_labels / n atlas.update(mean_image, mean_label, eta=eta) return atlas
def main(): args = parse_args() model_path = Path(args.model_path) if args.network_conf_path is not None: train_conf = ConfigFactory.parse_file(str(Path( args.network_conf_path))) conf_path = Path(args.network_conf_path) else: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH data_config = DataConfig.from_conf(conf_path) dataset_config = DatasetConfig.from_conf( name=data_config.training_datasets[0], mount_prefix=data_config.mount_prefix, mode=data_config.data_mode) input_path = Path(args.input_dir).joinpath( dataset_config.image_label_format.image.format(phase=args.phase)) output_dir = Path(args.output_dir) checkpoint = torch.load(str(model_path), map_location=torch.device(args.device)) get_conf(train_conf, group="network", key="experiment_dir") network = UNet( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), ) network.load_state_dict(checkpoint) network.cuda(device=args.device) # image = nib.load(str(input_path)).get_data() # if image.ndim == 4: # image = np.squeeze(image, axis=-1).astype(np.int16) # image = image.astype(np.int16) # image = np.transpose(image, (2, 0, 1)) dataset = Torch2DSegmentationDataset( name=dataset_config.name, image_paths=[input_path], label_paths=[ input_path.parent.joinpath( dataset_config.image_label_format.label.format( phase=args.phase)) ], feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, ) image = dataset.get_image_tensor_from_index(0) image = torch.unsqueeze(image, 0) image = prepare_tensors(image, True, args.device) label = dataset.get_label_tensor_from_index(0) inference( image=image, label=label, image_path=input_path, network=network, output_dir=output_dir, )
def forward(self, inputs, atlas): image, label = inputs img_down1, img_down2, img_down3, img_down4, img_bridge = self.image_encoder( image) pred_maps = self.seg_decoder(img_down1, img_down2, img_down3, img_down4, img_bridge) pred_maps = torch.sigmoid(pred_maps) if image.shape[0] == self.batch_size: batch_atlas = self.batch_atlas batch_affine_added = self.batch_affine_added else: batch_atlas = torch.stack([atlas for _ in range(image.shape[0])], dim=0) batch_affine_added = prepare_tensors( torch.stack([ torch.from_numpy(np.array([[0, 0, 0, 1]])) for _ in range(image.shape[0]) ], dim=0), self.gpu, self.device, ) temp_down1, temp_down2, temp_down3, temp_down4, temp_bridge = self.template_encoder( batch_atlas) affine_params = self.affine_regressor(img_bridge, temp_bridge) affine_warped_template = self.affine_transformer( batch_atlas, affine_params) temp_down1, temp_down2, temp_down3, temp_down4, temp_bridge = self.template_encoder( affine_warped_template) warped_template, preint_flow, pos_flow, neg_flow = self.decoder_vxm( affine_warped_template, None, img_down1, img_down2, img_down3, img_down4, img_bridge, temp_down1, temp_down2, temp_down3, temp_down4, temp_bridge) # inverse transform of label to template space inverse_affine_par = inverse_affine_params(affine_params, batch_affine_added) affine_warped_label = self.affine_transformer(label, inverse_affine_par) warped_label = self.decoder_vxm.transformer(affine_warped_label, neg_flow) # inverse transform of image to template space affine_warped_image = self.affine_transformer(image, inverse_affine_par) warped_image = self.decoder_vxm.transformer(affine_warped_image, neg_flow) return affine_warped_template, warped_template, pred_maps, preint_flow, warped_image, warped_label, \ affine_warped_label, batch_atlas, affine_warped_image
def update_atlas(self, train_data_loader, atlas_label, atlas): pbar = tqdm(enumerate(train_data_loader)) warped_labels = [] warped_images = [] for idx, (inputs, outputs) in pbar: inputs = prepare_tensors(inputs, self.config.gpu, self.config.device) predicted = self.network(inputs, atlas_label) # pass updated atlas in warped_label = predicted[-3].cpu().detach().numpy() image = predicted[-4].cpu().detach().numpy() warped_labels.append(warped_label) warped_images.append(np.squeeze(image, axis=1)) atlas.update(warped_images, warped_labels)
def forward(self, inputs, atlas): image, label = inputs if image.shape[0] == self.batch_size: batch_atlas = self.batch_atlas batch_affine_added = self.batch_affine_added else: batch_atlas = torch.stack([atlas for _ in range(image.shape[0])], dim=0) batch_affine_added = prepare_tensors( torch.stack([ torch.from_numpy(np.array([[0, 0, 0, 1]])) for _ in range(image.shape[0]) ], dim=0), self.gpu, self.device, ) img_down1, img_down2, img_down3, img_down4, img_bridge = self.seg_encoder( image) pred_maps_logits = self.seg_decoder(img_down1, img_down2, img_down3, img_down4, img_bridge) # pred_maps = torch.sigmoid(pred_maps) flow, affine_params = self.stn(pred_maps_logits, batch_atlas) # inverse transform of label to template space inverse_affine_par = inverse_affine_params(affine_params, batch_affine_added) affine_warped_atlas = self.affine_transformer(batch_atlas, affine_params) affine_warped_label = self.affine_transformer(label, inverse_affine_par) warped_atlas, warped_label, preint_flow, pos_flow, neg_flow = self.ffd_transformer( affine_warped_atlas, affine_warped_label, flow, bidir=True, resize=False) # inverse transform of image to template space affine_warped_image = self.affine_transformer(image, inverse_affine_par) warped_image = self.ffd_transformer.transformer( affine_warped_image, neg_flow) return affine_warped_atlas, warped_atlas, pred_maps_logits, preint_flow, warped_image, warped_label, \ affine_warped_label, batch_atlas, affine_warped_image
def run(self, image: np.ndarray) -> np.ndarray: image = torch.from_numpy(image).float() image = torch.unsqueeze(image, 0) image = prepare_tensors(image, True, self.device) predicted = self.model(image) predicted = torch.sigmoid(predicted) # print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = (predicted > 0.5).float() # print("0.5", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = predicted.cpu().detach().numpy() predicted = np.squeeze(predicted, axis=0) # map back to original size final_predicted = np.zeros((image.shape[2], image.shape[3], image.shape[4])) # print(predicted.shape, final_predicted.shape) for i in range(predicted.shape[0]): final_predicted[predicted[i, :, :, :] > 0.5] = i + 1 # image = nim.get_data() final_predicted = np.transpose(final_predicted, [1, 2, 0]) return final_predicted
def __init__(self, in_channels, n_classes, n_slices, feature_size, n_filters, batch_norm: bool, group_norm, bidir, int_downsize, batch_size, gpu, device): super().__init__() self.image_encoder = Encoder(in_channels, n_filters, batch_norm=batch_norm, group_norm=group_norm) self.template_encoder = Encoder(n_classes, n_filters, batch_norm=batch_norm, group_norm=group_norm) self.affine_regressor = AffineRegressor(group_norm=group_norm) self.affine_transformer = AffineSpatialTransformer(size=(n_slices, feature_size, feature_size), mode="bilinear") self.decoder_vxm = DecoderVxmDense(inshape=(n_slices, feature_size, feature_size), n_filters=n_filters, batch_norm=batch_norm, group_norm=group_norm, int_downsize=int_downsize, bidir=False) self.seg_decoder = SegDecoder(n_filters, batch_norm, group_norm, n_classes) self.batch_size = batch_size self.batch_atlas = None self.gpu = gpu self.device = device self.batch_affine_added = prepare_tensors( torch.stack([ torch.from_numpy(np.array([[0, 0, 0, 1]])) for _ in range(self.batch_size) ], dim=0), self.gpu, self.device, )
def __init__(self, in_channels, n_classes, n_slices, feature_size, n_filters, batch_norm: bool, group_norm, bidir, int_downsize, batch_size, gpu, device): super().__init__() # ITN self.seg_encoder = Encoder(in_channels, n_filters, batch_norm=batch_norm, group_norm=group_norm) self.seg_decoder = SegDecoder(n_filters, batch_norm, group_norm, n_classes) # STN self.stn = STN(num_filters=n_filters, group_norm=group_norm) self.batch_size = batch_size self.batch_atlas = None self.gpu = gpu self.device = device # configure transformer self.ffd_transformer = FFDTransformer(inshape=(n_slices, feature_size, feature_size), int_downsize=2, bidir=bidir, mode="bilinear") self.affine_transformer = AffineSpatialTransformer(size=(n_slices, feature_size, feature_size), mode="bilinear") self.batch_affine_added = prepare_tensors( torch.stack([ torch.from_numpy(np.array([[0, 0, 0, 1]])) for _ in range(self.batch_size) ], dim=0), self.gpu, self.device, )
def inference(image: np.ndarray, label: torch.Tensor, image_path: Path, network: torch.nn.Module, output_dir: Path, gpu, device): import math original_image = image Z, X, Y = image.shape n_slices = 96 X2, Y2 = int(math.ceil(X / 32.0)) * 32, int(math.ceil(Y / 32.0)) * 32 x_pre, y_pre, z_pre = int((X2 - X) / 2), int((Y2 - Y) / 2), int( (Z - n_slices) / 2) x_post, y_post, z_post = (X2 - X) - x_pre, (Y2 - Y) - y_pre, ( Z - n_slices) - z_pre z1, z2 = int(Z / 2) - int(n_slices / 2), int(Z / 2) + int(n_slices / 2) z1_, z2_ = max(z1, 0), min(z2, Z) image = image[z1_:z2_, :, :] image = np.pad(image, ((z1_ - z1, z2 - z2_), (x_pre, x_post), (y_pre, y_post)), 'constant') image = np.expand_dims(image, 0) image = torch.from_numpy(image).float() image = torch.unsqueeze(image, 0) image = prepare_tensors(image, gpu, device) predicted = network(image) predicted = torch.sigmoid(predicted) # print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = (predicted > 0.5).float() # print("0.5", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = predicted.cpu().detach().numpy() nim = nib.load(str(image_path)) # Transpose and crop the segmentation to recover the original size predicted = np.squeeze(predicted, axis=0) # print(predicted.shape) predicted = np.pad(predicted, ((0, 0), (z1_, Z - z2_), (0, 0), (0, 0)), "constant") predicted = predicted[:, z1_ - z1:z1_ - z1 + Z, x_pre:x_pre + X, y_pre:y_pre + Y] # map back to original size final_predicted = np.zeros( (original_image.shape[0], original_image.shape[1], original_image.shape[2])) # print(predicted.shape, final_predicted.shape) for i in range(predicted.shape[0]): # print(a.shape) final_predicted[predicted[i, :, :, :] > 0.5] = i + 1 # image = nim.get_data() final_predicted = np.transpose(final_predicted, [1, 2, 0]) # print(predicted.shape, final_predicted.shape) # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2])) # print(predicted.shape, final_predicted.shape, np.max(final_predicted), np.mean(final_predicted), # np.min(final_predicted)) # if Z < 64: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z] # else: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :] # pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant') nim2 = nib.Nifti1Image(final_predicted, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir))) final_image = np.transpose(original_image, [1, 2, 0]) # print(final_image.shape) nim2 = nib.Nifti1Image(final_image, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir))) # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz"))) final_label = np.zeros((label.shape[1], label.shape[2], label.shape[3])) label = label.cpu().detach().numpy() for i in range(label.shape[0]): final_label[label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) # print(final_label.shape) nim2 = nib.Nifti1Image(final_label, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
def main(): args = parse_args() model_path = Path(args.model_path) if args.network_conf_path is not None: train_conf = ConfigFactory.parse_file(str(Path( args.network_conf_path))) conf_path = Path(args.network_conf_path) else: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH data_config = DataConfig.from_conf(conf_path) dataset_config = DatasetConfig.from_conf( name=data_config.dataset_names[0], mount_prefix=data_config.mount_prefix, mode=data_config.data_mode) input_path = Path(args.input_dir).joinpath( dataset_config.image_label_format.image.format(phase=args.phase)) output_dir = Path(args.output_dir) checkpoint = torch.load(str(model_path), map_location=torch.device(args.device)) get_conf(train_conf, group="network", key="experiment_dir") network = FCN2DSegmentationModel( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), up_conv_filter=get_conf(train_conf, group="network", key="up_conv_filter"), final_conv_filter=get_conf(train_conf, group="network", key="final_conv_filter"), feature_size=get_conf(train_conf, group="network", key="feature_size")) network.load_state_dict(checkpoint) network.cuda(device=args.device) # image = nib.load(str(input_path)).get_data() # if image.ndim == 4: # image = np.squeeze(image, axis=-1).astype(np.int16) # image = image.astype(np.int16) # image = np.transpose(image, (2, 0, 1)) image = Torch2DSegmentationDataset.read_image( input_path, get_conf(train_conf, group="network", key="feature_size"), get_conf(train_conf, group="network", key="in_channels"), crop=args.crop_image, ) image = np.expand_dims(image, 0) image = torch.from_numpy(image).float() image = prepare_tensors(image, gpu=True, device=args.device) predicted = network(image) predicted = torch.sigmoid(predicted) print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = (predicted > 0.5).float() print("0.5", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = predicted.cpu().detach().numpy() nim = nib.load(str(input_path)) # Transpose and crop the segmentation to recover the original size predicted = np.squeeze(predicted, axis=0) print(predicted.shape) # map back to original size final_predicted = np.zeros( (image.shape[1], image.shape[2], image.shape[3])) print(predicted.shape, final_predicted.shape) for i in range(predicted.shape[0]): a = predicted[i, :, :, :] > 0.5 print(a.shape) final_predicted[predicted[i, :, :, :] > 0.5] = i + 1 # image = nim.get_data() final_predicted = np.transpose(final_predicted, [1, 2, 0]) print(predicted.shape, final_predicted.shape) # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2])) print(predicted.shape, final_predicted.shape, np.max(final_predicted), np.mean(final_predicted), np.min(final_predicted)) # if Z < 64: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z] # else: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :] # pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant') nim2 = nib.Nifti1Image(final_predicted, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir))) final_image = image.cpu().detach().numpy() final_image = np.squeeze(final_image, 0) final_image = np.transpose(final_image, [1, 2, 0]) print(final_image.shape) nim2 = nib.Nifti1Image(final_image, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir))) # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz"))) label = Torch2DSegmentationDataset.read_label( input_path.parent.joinpath( dataset_config.image_label_format.label.format(phase=args.phase)), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="in_channels"), crop=args.crop_image, ) final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3])) for i in range(label.shape[0]): final_label[label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) print(final_label.shape) nim2 = nib.Nifti1Image(final_label, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
def train(self): self.network.train() train_data_loader = MultiDataLoader( *self.training_sets, batch_size=self.config.batch_size, sampler_cls="random", num_workers=self.config.num_workers, pin_memory=self.config.pin_memory) set = False atlas = Atlas.from_data_loader(train_data_loader) atlas.save(output_dir=self.config.experiment_dir.joinpath( "atlas").joinpath("init")) for epoch in range(self.config.num_epochs): self.network.train() self.logger.info("{}: starting epoch {}/{}".format( datetime.now(), epoch, self.config.num_epochs)) self.loss.reset() # if epoch > 10 and not set: # self.optimizer.param_groups[0]['lr'] /= 10 # print("-------------Learning rate: {}-------------".format(self.optimizer.param_groups[0]['lr'])) # set = True pbar = tqdm(enumerate(train_data_loader)) atlas_label = prepare_tensors(torch.from_numpy(atlas.label()), self.config.gpu, self.config.device) self.network.update_batch_atlas(atlas_label) for idx, (inputs, outputs) in pbar: inputs = prepare_tensors(inputs, self.config.gpu, self.config.device) outputs = prepare_tensors(outputs, self.config.gpu, self.config.device) predicted = self.network(inputs, atlas_label) # pass updated atlas in loss = self.loss.cumulate(predicted, outputs) self.optimizer.zero_grad() loss.backward() self.optimizer.step() pbar.set_description("{:.2f} --- {}".format( (idx + 1) / len(train_data_loader), self.loss.description())) # update atlas self.network.eval() self.update_atlas(train_data_loader, atlas_label, atlas) atlas.save(output_dir=self.config.experiment_dir.joinpath( "atlas").joinpath("epoch_{}".format(epoch))) self.logger.info("Epoch finished !") val_metrics = self.eval(self.loss.new(), *self.other_validation_metrics, datasets=self.validation_sets, atlas=atlas_label) self.logger.info("Validation loss: {}".format( val_metrics[0].description())) if val_metrics[1:]: self.logger.info("Other metrics on validation set.") for metric in val_metrics[1:]: self.logger.info("{}".format(metric.description())) self.tensor_board.add_scalar( "loss/training/{}".format(self.loss.document()), self.loss.log(), epoch) self.tensor_board.add_scalar( "loss/validation/loss_{}".format(val_metrics[0].document()), val_metrics[0].log(), epoch) for metric in val_metrics[1:]: self.tensor_board.add_scalar( "other_metrics/validation/{}".format(metric.document()), metric.avg(), epoch) # eval extra validation sets if self.extra_validation_sets: for val in self.extra_validation_sets: val_metrics = self.eval(self.loss.new(), *self.other_validation_metrics, datasets=[val], atlas=atlas_label) self.logger.info( "Extra Validation loss on dataset {}: {}".format( val.name, val_metrics[0].description())) if val_metrics[1:]: self.logger.info( "Other metrics on extra validation set.") for metric in val_metrics[1:]: self.logger.info("{}".format(metric.description())) self.tensor_board.add_scalar( "loss/extra_validation_{}/loss_{}".format( val.name, val_metrics[0].document()), val_metrics[0].log(), epoch) for metric in val_metrics[1:]: self.tensor_board.add_scalar( "other_metrics/extra_validation_{}/{}".format( val.name, metric.document()), metric.avg(), epoch) checkpoint_dir = self.config.experiment_dir.joinpath("checkpoints") checkpoint_dir.mkdir(parents=True, exist_ok=True) output_path = checkpoint_dir.joinpath("CP_{}.pth".format(epoch)) torch.save(self.network.state_dict(), str(output_path)) self.logger.info("Checkpoint {} saved at {}!".format( epoch, str(output_path))) if self.inference_func is not None: self.inference(epoch, atlas_label)
def inference(image: np.ndarray, label: np.ndarray, network: torch.nn.Module, output_dir: Path, atlas, output_size, gpu, device): np_cropped_image, np_cropped_label = central_crop_with_padding(image, label, output_size=output_size) cropped_label = torch.from_numpy(np_cropped_label).float() cropped_image = np.expand_dims(np_cropped_image, 0) cropped_image = torch.from_numpy(cropped_image).float() cropped_image = torch.unsqueeze(cropped_image, 0) cropped_label = torch.unsqueeze(cropped_label, 0) cropped_image = prepare_tensors(cropped_image, gpu, device) cropped_label = prepare_tensors(cropped_label, gpu, device) affine_warped_template, warped_template, pred_maps, flow, warped_image, warped_label, affine_warped_label, \ batch_atlas, affine_warped_image = network((cropped_image, cropped_label), atlas) for prefix, predicted in zip( ["warped_template", "affine_warped_template", "pred_maps"], [warped_template, affine_warped_template, pred_maps] ): # predicted = torch.sigmoid(predicted) # print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = (predicted > 0.5).float() # print("0.5", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = predicted.cpu().detach().numpy() # nim = nib.load(str(image_path)) # Transpose and crop the segmentation to recover the original size predicted = np.squeeze(predicted, axis=0) # print(predicted.shape) # predicted = np.pad(predicted, ((0, 0), (z1_, Z - z2_), (0, 0), (0, 0)), "constant") # predicted = predicted[:, z1_ - z1:z1_ - z1 + Z, x_pre:x_pre + X, y_pre:y_pre + Y] # map back to original size # final_predicted = np.zeros((image.shape[2], image.shape[3], image.shape[4])) __, predicted = central_crop_with_padding(None, predicted, image.shape) final_predicted = np.zeros((image.shape[0], image.shape[1], image.shape[2])) # print(predicted.shape, final_predicted.shape) for i in range(predicted.shape[0]): # print(a.shape) final_predicted[predicted[i, :, :, :] > 0.5] = i + 1 # image = nim.get_data() final_predicted = np.transpose(final_predicted, [1, 2, 0]) # print(predicted.shape, final_predicted.shape) # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2])) # print(predicted.shape, final_predicted.shape, np.max(final_predicted), np.mean(final_predicted), # np.min(final_predicted)) # if Z < 64: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z] # else: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :] # pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant') # nim2 = nib.Nifti1Image(final_predicted, nim.affine) # nim2.header['pixdim'] = nim.header['pixdim'] nim2 = nib.Nifti1Image(final_predicted, None) output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{}/{}_seg.nii.gz'.format(str(output_dir), prefix)) # final_image = image.cpu().detach().numpy() # final_image = np.squeeze(final_image) # final_image, __ = central_crop_with_padding(final_image, None, image.shape) final_image = np.transpose(image, [1, 2, 0]) # print(final_image.shape) # nim2 = nib.Nifti1Image(final_image, nim.affine) # nim2.header['pixdim'] = nim.header['pixdim'] nim2 = nib.Nifti1Image(final_image, None) nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir))) # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz"))) final_cropped_image = np.transpose(np_cropped_image, [1, 2, 0]) nim2 = nib.Nifti1Image(final_cropped_image, None) nib.save(nim2, '{0}/cropped_image.nii.gz'.format(str(output_dir))) final_cropped_image, final_cropped_label = central_crop_with_padding(np_cropped_image, np_cropped_label, image.shape) final_cropped_image = np.transpose(final_cropped_image, [1, 2, 0]) nim2 = nib.Nifti1Image(final_cropped_image, None) nib.save(nim2, '{0}/padded_cropped_image.nii.gz'.format(str(output_dir))) warped_image = warped_image.cpu().detach().numpy() warped_image = np.squeeze(warped_image, axis=0) warped_image = np.squeeze(warped_image, axis=0) warped_image, __ = central_crop_with_padding(warped_image, None, image.shape) final_warped_image = np.transpose(warped_image, [1, 2, 0]) nim2 = nib.Nifti1Image(final_warped_image, None) nib.save(nim2, '{0}/padded_warped_image.nii.gz'.format(str(output_dir))) affine_warped_image = affine_warped_image.cpu().detach().numpy() affine_warped_image = np.squeeze(affine_warped_image, axis=0) affine_warped_image = np.squeeze(affine_warped_image, axis=0) affine_warped_image, __ = central_crop_with_padding(affine_warped_image, None, image.shape) final_warped_image = np.transpose(affine_warped_image, [1, 2, 0]) nim2 = nib.Nifti1Image(final_warped_image, None) nib.save(nim2, '{0}/padded_affine_warped_image.nii.gz'.format(str(output_dir))) affine_warped_label = affine_warped_label.cpu().detach().numpy() affine_warped_label = np.squeeze(affine_warped_label, axis=0) affine_warped_label = affine_warped_label > 0.5 __, affine_warped_label = central_crop_with_padding(None, affine_warped_label, image.shape) final_label = np.zeros((affine_warped_label.shape[1], affine_warped_label.shape[2], affine_warped_label.shape[3])) # label = label.cpu().detach().numpy() # label = np.squeeze(label, 0) for i in range(affine_warped_label.shape[0]): final_label[affine_warped_label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) # print(final_label.shape) # nim2 = nib.Nifti1Image(final_label, nim.affine) # nim2.header['pixdim'] = nim.header['pixdim'] nim2 = nib.Nifti1Image(final_label, None) output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/padded_affine_warped_label.nii.gz'.format(str(output_dir))) warped_label = warped_label.cpu().detach().numpy() warped_label = np.squeeze(warped_label, axis=0) warped_label = warped_label > 0.5 __, warped_label = central_crop_with_padding(None, warped_label, image.shape) final_label = np.zeros((warped_label.shape[1], warped_label.shape[2], warped_label.shape[3])) # label = label.cpu().detach().numpy() # label = np.squeeze(label, 0) for i in range(warped_label.shape[0]): final_label[warped_label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) # print(final_label.shape) # nim2 = nib.Nifti1Image(final_label, nim.affine) # nim2.header['pixdim'] = nim.header['pixdim'] nim2 = nib.Nifti1Image(final_label, None) output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/padded_warped_label.nii.gz'.format(str(output_dir))) final_label = np.zeros((np_cropped_label.shape[1], np_cropped_label.shape[2], np_cropped_label.shape[3])) # label = label.cpu().detach().numpy() # label = np.squeeze(label, 0) for i in range(np_cropped_label.shape[0]): final_label[np_cropped_label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) # print(final_label.shape) # nim2 = nib.Nifti1Image(final_label, nim.affine) # nim2.header['pixdim'] = nim.header['pixdim'] nim2 = nib.Nifti1Image(final_label, None) output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/cropped_label.nii.gz'.format(str(output_dir))) final_label = np.zeros((label.shape[1], label.shape[2], label.shape[3])) # label = label.cpu().detach().numpy() # label = np.squeeze(label, 0) for i in range(label.shape[0]): final_label[label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) # print(final_label.shape) # nim2 = nib.Nifti1Image(final_label, nim.affine) # nim2.header['pixdim'] = nim.header['pixdim'] nim2 = nib.Nifti1Image(final_label, None) output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))