def construct_training_validation_dataset( data_config: DataConfig, feature_size: int, n_slices: int, output_dir: Path, is_3d: bool = False, augmentation_config: AugmentationConfig = None, seed: int = None, template_path: Path = None, ) -> Tuple[List["Torch2DSegmentationDataset"], List["Torch2DSegmentationDataset"], List["Torch2DSegmentationDataset"]]: training_set_configs = [ DatasetConfig.from_conf(name, mode=data_config.data_mode, mount_prefix=data_config.mount_prefix) for name in data_config.training_datasets ] extra_val_set_configs = [ DatasetConfig.from_conf(name, mode=data_config.data_mode, mount_prefix=data_config.mount_prefix) for name in data_config.extra_validation_datasets ] training_sets = [] validation_sets = [] for config in training_set_configs: train, val, template_path = train_val_dataset_from_config( dataset_config=config, augmentation_config=augmentation_config, augmentation_prob=data_config.augmentation_prob, validation_split=data_config.validation_split, feature_size=feature_size, n_slices=n_slices, is_3d=is_3d, renew_dataframe=data_config.renew_dataframe, seed=seed, output_dir=output_dir, template_path=template_path) training_sets.append(train) validation_sets.append(val) extra_val_sets = [] for config in extra_val_set_configs: __, val, __ = train_val_dataset_from_config( dataset_config=config, validation_split=data_config.validation_split, feature_size=feature_size, n_slices=n_slices, is_3d=is_3d, only_val=True, renew_dataframe=data_config.renew_dataframe, seed=seed, output_dir=output_dir, template_path=template_path, ) extra_val_sets.append(val) return training_sets, validation_sets, extra_val_sets
def main(): args = parse_args() model_path = Path(args.model_path) if args.network_conf_path is not None: train_conf = ConfigFactory.parse_file(str(Path( args.network_conf_path))) conf_path = Path(args.network_conf_path) else: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH data_config = DataConfig.from_conf(conf_path) dataset_config = DatasetConfig.from_conf( name=data_config.training_datasets[0], mount_prefix=data_config.mount_prefix, mode=data_config.data_mode) input_path = Path(args.input_dir).joinpath( dataset_config.image_label_format.image.format(phase=args.phase)) output_dir = Path(args.output_dir) checkpoint = torch.load(str(model_path), map_location=torch.device(args.device)) get_conf(train_conf, group="network", key="experiment_dir") network = UNet( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), ) network.load_state_dict(checkpoint) network.cuda(device=args.device) # image = nib.load(str(input_path)).get_data() # if image.ndim == 4: # image = np.squeeze(image, axis=-1).astype(np.int16) # image = image.astype(np.int16) # image = np.transpose(image, (2, 0, 1)) dataset = Torch2DSegmentationDataset( name=dataset_config.name, image_paths=[input_path], label_paths=[ input_path.parent.joinpath( dataset_config.image_label_format.label.format( phase=args.phase)) ], feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, ) image = dataset.get_image_tensor_from_index(0) image = torch.unsqueeze(image, 0) image = prepare_tensors(image, True, args.device) label = dataset.get_label_tensor_from_index(0) inference( image=image, label=label, image_path=input_path, network=network, output_dir=output_dir, )
def main(): args = parse_args() model_path = Path(args.model_path) if args.network_conf_path is not None: train_conf = ConfigFactory.parse_file(str(Path( args.network_conf_path))) conf_path = Path(args.network_conf_path) else: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH data_config = DataConfig.from_conf(conf_path) dataset_config = DatasetConfig.from_conf( name=data_config.dataset_names[0], mount_prefix=data_config.mount_prefix, mode=data_config.data_mode) input_path = Path(args.input_dir).joinpath( dataset_config.image_label_format.image.format(phase=args.phase)) output_dir = Path(args.output_dir) checkpoint = torch.load(str(model_path), map_location=torch.device(args.device)) get_conf(train_conf, group="network", key="experiment_dir") network = FCN2DSegmentationModel( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), up_conv_filter=get_conf(train_conf, group="network", key="up_conv_filter"), final_conv_filter=get_conf(train_conf, group="network", key="final_conv_filter"), feature_size=get_conf(train_conf, group="network", key="feature_size")) network.load_state_dict(checkpoint) network.cuda(device=args.device) # image = nib.load(str(input_path)).get_data() # if image.ndim == 4: # image = np.squeeze(image, axis=-1).astype(np.int16) # image = image.astype(np.int16) # image = np.transpose(image, (2, 0, 1)) image = Torch2DSegmentationDataset.read_image( input_path, get_conf(train_conf, group="network", key="feature_size"), get_conf(train_conf, group="network", key="in_channels"), crop=args.crop_image, ) image = np.expand_dims(image, 0) image = torch.from_numpy(image).float() image = prepare_tensors(image, gpu=True, device=args.device) predicted = network(image) predicted = torch.sigmoid(predicted) print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = (predicted > 0.5).float() print("0.5", torch.mean(predicted).item(), torch.max(predicted).item()) predicted = predicted.cpu().detach().numpy() nim = nib.load(str(input_path)) # Transpose and crop the segmentation to recover the original size predicted = np.squeeze(predicted, axis=0) print(predicted.shape) # map back to original size final_predicted = np.zeros( (image.shape[1], image.shape[2], image.shape[3])) print(predicted.shape, final_predicted.shape) for i in range(predicted.shape[0]): a = predicted[i, :, :, :] > 0.5 print(a.shape) final_predicted[predicted[i, :, :, :] > 0.5] = i + 1 # image = nim.get_data() final_predicted = np.transpose(final_predicted, [1, 2, 0]) print(predicted.shape, final_predicted.shape) # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2])) print(predicted.shape, final_predicted.shape, np.max(final_predicted), np.mean(final_predicted), np.min(final_predicted)) # if Z < 64: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z] # else: # pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :] # pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant') nim2 = nib.Nifti1Image(final_predicted, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir))) final_image = image.cpu().detach().numpy() final_image = np.squeeze(final_image, 0) final_image = np.transpose(final_image, [1, 2, 0]) print(final_image.shape) nim2 = nib.Nifti1Image(final_image, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir))) # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz"))) label = Torch2DSegmentationDataset.read_label( input_path.parent.joinpath( dataset_config.image_label_format.label.format(phase=args.phase)), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="in_channels"), crop=args.crop_image, ) final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3])) for i in range(label.shape[0]): final_label[label[i, :, :, :] == 1.0] = i + 1 final_label = np.transpose(final_label, [1, 2, 0]) print(final_label.shape) nim2 = nib.Nifti1Image(final_label, nim.affine) nim2.header['pixdim'] = nim.header['pixdim'] output_dir.mkdir(parents=True, exist_ok=True) nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
import cv2 import numpy as np from pathlib import Path from CMRSegment.common.config import DatasetConfig from CMRSegment.common.data_table import DataTable from CMRSegment.common.nn.torch.data import Torch2DSegmentationDataset from tqdm import tqdm config = DatasetConfig.from_conf( name="RBH_3D_atlases", mode="3D", mount_prefix=Path("/mnt/storage/home/suruli/")) # config = DatasetConfig.from_conf(name="RBH_3D_atlases", mode="3D", mount_prefix=Path("D:/surui/rbh/")) def read_dataframe(dataframe_path: Path): data_table = DataTable.from_csv(dataframe_path) image_paths = data_table.select_column("image_path") label_paths = data_table.select_column("label_path") image_paths = [Path(path) for path in image_paths] label_paths = [Path(path) for path in label_paths] return image_paths, label_paths image_paths, label_paths = read_dataframe(config.dataframe_path) images = [] labels = [] output_dir = Path(__file__).parent.joinpath("output") output_dir.joinpath("image").mkdir(parents=True, exist_ok=True) output_dir.joinpath("label").mkdir(parents=True, exist_ok=True) pbar = list(zip(image_paths, label_paths))