def test_xpdnet(): n_primal = 2 n_scales = 3 submodel_kwargs = dict( n_scales=n_scales, kernel_size=3, bn=False, n_filters_per_scale=[4, 8, 8], n_convs_per_scale=[2, 2, 2], n_first_convs=2, first_conv_n_filters=4, res=False, n_outputs=2 * n_primal, ) model = XPDNet( model_fun=MWCNN, model_kwargs=submodel_kwargs, n_primal=n_primal, n_iter=2, multicoil=True, n_scales=n_scales, ) model([ tf.zeros([1, 5, 640, 320, 1], dtype=tf.complex64), # kspace tf.zeros([1, 5, 640, 320], dtype=tf.complex64), # mask tf.zeros([1, 5, 640, 320], dtype=tf.complex64), # smaps ])
def test_works_in_xpdnet_train(model_fun, model_kwargs, n_scales, res, n_iter=10, multicoil=False, use_mixed_precision=False, data_consistency_learning=False): # trying mixed precision if use_mixed_precision: policy_type = 'mixed_float16' else: policy_type = 'float32' mixed_precision.set_global_policy(policy_type) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': multicoil, 'res': res, 'primal_only': not data_consistency_learning, } model = XPDNet(model_fun, model_kwargs, **run_params) default_model_compile(model, lr=1e-3, loss='mae') n_coils = 15 k_shape = (640, 400) if multicoil: k_shape = (n_coils, *k_shape) inputs = [ tf.ones([1, *k_shape, 1], dtype=tf.complex64), tf.ones([1, *k_shape], dtype=tf.complex64), ] if multicoil: inputs += [ tf.ones([1, *k_shape], dtype=tf.complex64), ] try: x=inputs, y=tf.ones([1, 320, 320, 1]), epochs=1, ) except (tf.errors.ResourceExhaustedError, tf.errors.InternalError): return False else: return True
def evaluate_xpdnet( model_fun, model_kwargs, run_id, multicoil=True, brain=False, n_epochs=200, contrast=None, af=4, n_iter=10, res=True, n_scales=0, n_primal=5, refine_smaps=False, refine_big=False, n_samples=None, cuda_visible_devices='0123', equidistant_fake=False, mask_type=None, primal_only=True, n_dual=1, n_dual_filters=16, multiscale_kspace_learning=False, ): if multicoil: if brain: val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/' else: val_path = f'{FASTMRI_DATA_DIR}multicoil_val/' else: val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) af = int(af) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'res': res, 'output_shape_spec': brain, 'refine_big': refine_big, 'primal_only': primal_only, 'n_dual': n_dual, 'n_dual_filters': n_dual_filters, 'multiscale_kspace_learning': multiscale_kspace_learning, } if multicoil: dataset = multicoil_dataset if mask_type is None: if brain: if equidistant_fake: mask_type = 'equidistant_fake' else: mask_type = 'equidistant' else: mask_type = 'random' kwargs = { 'parallel': False, 'output_shape_spec': brain, 'mask_type': mask_type, } else: dataset = singlecoil_dataset kwargs = {} val_set = dataset( val_path, AF=af, contrast=contrast, inner_slices=None, rand=False, scale_factor=1e6, **kwargs, ) if brain: n_volumes = brain_n_volumes_validation if contrast is not None: n_volumes = brain_volumes_per_contrast['validation'][contrast] else: n_volumes = n_volumes_val if contrast is not None: n_volumes //= 2 n_volumes += 1 if n_samples is not None: val_set = val_set.take(n_samples) else: val_set = val_set.take(n_volumes) mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): if multicoil: kspace_size = [1, 15, 640, 372] else: kspace_size = [1, 640, 372] model = XPDNet(model_fun, model_kwargs, **run_params) inputs = [ tf.zeros(kspace_size + [1], dtype=tf.complex64), tf.zeros(kspace_size, dtype=tf.complex64), ] if multicoil: inputs.append(tf.zeros(kspace_size, dtype=tf.complex64)) if brain: inputs.append(tf.constant([[320, 320]])) model(inputs) model.load_weights( f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5') eval_res = Metrics(METRIC_FUNCS) for x, y_true in tqdm(val_set.as_numpy_iterator(), total=n_volumes if n_samples is None else n_samples): y_pred = model.predict(x, batch_size=4) eval_res.push(y_true[..., 0], y_pred[..., 0]) return METRIC_FUNCS, (list(eval_res.means().values()), list(eval_res.stddevs().values()))
def train_xpdnet( model_fun, model_kwargs, model_size=None, multicoil=True, brain=False, af=4, contrast=None, cuda_visible_devices='0123', n_samples=None, n_epochs=200, checkpoint_epoch=0, save_state=False, n_iter=10, res=True, n_scales=0, n_primal=5, use_mixed_precision=False, refine_smaps=False, refine_big=False, loss='mae', original_run_id=None, fixed_masks=False, n_epochs_original=250, equidistant_fake=False, multi_gpu=False, mask_type=None, ): r"""Train an XPDNet network on the fastMRI dataset. The training is done with a learning rate of 1e-4, using the RAdam optimizer. The validation is performed every 5 epochs on 5 volumes. A scale factor of 1e6 is applied to the data. Arguments: model_fun (function): the function initializing the image correction network of the XPDNet. model_kwargs (dict): the set of arguments used to initialize the image correction network. model_size (str or None): a string describing the size of the network used. This is used in the run id. Defaults to None. multicoil (bool): whether the input data is multicoil. Defaults to False. brain (bool): whether to consider brain data instead of knee. Defaults to False. af (int): the acceleration factor for the retrospective undersampling of the data. Defaults to 4. contrast (str or None): the contrast used for this specific training. If None, all contrasts are considered. Defaults to None cuda_visible_devices (str): the GPUs to consider visible. Defaults to '0123'. n_samples (int or None): the number of samples to consider for this training. If None, all samples are considered. Defaults to None. n_epochs (int): the number of epochs (i.e. one pass though all the volumes/samples) for this training. Defaults to 200. checkpoint_epoch (int): the number of epochs used to train the model during the first step of the full training. This is typically used when on a cluster the training duration exceeds the maximum job duration. Defaults to 0, which means that the training is done without checkpoints. save_state (bool): whether you should save the entire model state for this training, for example to retrain where left off. Defaults to False. n_iter (int): the number of iterations for the XPDNet. res (bool): whether the XPDNet image correction networks should be residual. n_scales (int): the number of scales used in the image correction network. Defaults to 0. n_primal (int): the size of the buffer in the image space. Defaults to 5. use_mixed_precision (bool): whether to use the mixed precision API for training. Currently not working. Defaults to False. refine_smaps (bool): whether you want to refine the sensitivity maps with a neural network. loss (tf.keras.losses.Loss or str): the loss function used for the training. It should be understandable by the tf.keras loss API, or be 'compound_mssim', in which case the compound L1 MSSIM loss inspired by [P2020]. Defaults to 'mae'. original_run_id (str or None): run id of the same network trained before fine-tuning. If this is present, the training is considered fine-tuning for a network trained for 250 epochs. It will therefore apply a learning rate of 1e-7 and the epoch size will be divided in half. If None, the training is done normally, without loading weights. Defaults to None. fixed_masks (bool): whether fixed masks should be used for the retrospective undersampling. Defaults to False n_epochs_original (int): the number of epochs used to pre-train the model, only applicable if original_run_id is not None. Defaults to 250. equidistant_fake (bool): whether to use fake equidistant masks from fastMRI. Defaults to False. multi_gpu (bool): whether to use multiple GPUs for the XPDNet training. Defaults to False. Returns: - str: the run id of the trained network. """ if brain: n_volumes = brain_n_volumes_train else: n_volumes = n_volumes_train # paths if multicoil: if brain: train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/' val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) af = int(af) # trying mixed precision if use_mixed_precision: policy_type = 'mixed_float16' else: policy_type = 'float32' policy = mixed_precision.Policy(policy_type) mixed_precision.set_policy(policy) # generators if multicoil: dataset = multicoil_dataset if mask_type is None: if brain: if equidistant_fake: mask_type = 'equidistant_fake' else: mask_type = 'equidistant' else: mask_type = 'random' kwargs = { 'parallel': False, 'output_shape_spec': brain, 'mask_type': mask_type, } else: dataset = singlecoil_dataset kwargs = {} train_set = dataset( train_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, n_samples=n_samples, fixed_masks=fixed_masks, **kwargs ) val_set = dataset( val_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, **kwargs ) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'res': res, 'output_shape_spec': brain, 'multi_gpu': multi_gpu, 'refine_big': refine_big, } if multicoil: xpdnet_type = 'xpdnet_sense_' if brain: xpdnet_type += 'brain_' else: xpdnet_type = 'xpdnet_singlecoil_' additional_info = f'af{af}' if contrast is not None: additional_info += f'_{contrast}' if n_samples is not None: additional_info += f'_{n_samples}' if n_iter != 10: additional_info += f'_i{n_iter}' if loss != 'mae': additional_info += f'_{loss}' if refine_smaps: additional_info += '_rf_sm' if refine_big: additional_info += 'b' if fixed_masks: additional_info += '_fixed_masks' submodel_info = model_fun.__name__ if model_size is not None: submodel_info += model_size if checkpoint_epoch == 0: run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}' else: run_id = original_run_id final_epoch = checkpoint_epoch + n_epochs chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}' if not save_state: chkpt_path += '.hdf5' log_dir = op.join(f'{LOGS_DIR}logs', run_id) tboard_cback = TensorBoard( profile_batch=0, log_dir=log_dir, histogram_freq=0, write_graph=False, write_images=False, ) tqdm_cback = TQDMProgressBar() if checkpoint_epoch == 0: model = XPDNet(model_fun, model_kwargs, **run_params) if original_run_id is not None: lr = 1e-7 n_steps = brain_volumes_per_contrast['train'].get(contrast, n_volumes)//2 else: lr = 1e-4 n_steps = n_volumes default_model_compile(model, lr=lr, loss=loss) else: model = load_model( f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}', custom_objects=CUSTOM_TF_OBJECTS, ) n_steps = n_volumes chkpt_cback = ModelCheckpointWorkAround( chkpt_path, save_freq=n_epochs*n_steps, save_weights_only=not save_state, ) print(run_id) if original_run_id is not None and not checkpoint_epoch: if os.environ.get('FASTMRI_DEBUG'): n_epochs_original = 1 if multicoil: kspace_size = [1, 15, 640, 372] else: kspace_size = [1, 640, 372] inputs = [ tf.zeros(kspace_size + [1], dtype=tf.complex64), tf.zeros(kspace_size, dtype=tf.complex64), ] if multicoil: inputs.append(tf.zeros(kspace_size, dtype=tf.complex64)) if brain: inputs.append(tf.constant([[320, 320]])) model(inputs) model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5') train_set, steps_per_epoch=n_steps, initial_epoch=checkpoint_epoch, epochs=final_epoch, validation_data=val_set, validation_steps=5, validation_freq=5, verbose=0, callbacks=[tboard_cback, chkpt_cback, tqdm_cback], ) return run_id
def generate_postproc_tf_records( model_fun, model_kwargs, run_id, brain=False, n_epochs=200, af=4, n_iter=10, res=True, n_scales=0, n_primal=5, refine_smaps=False, refine_big=False, primal_only=True, n_dual=1, n_dual_filters=16, mode='train', ): main_path = Path(FASTMRI_DATA_DIR) if brain: path = main_path / f'brain_multicoil_{mode}' else: path = main_path / f'multicoil_{mode}' filenames = sorted(list(path.glob('*.h5'))) kspace_transform = generic_from_kspace_to_masked_kspace_and_mask( AF=af, scale_factor=1e6, parallel=False, fixed_masks=False, output_shape_spec=brain, mask_type='equidistant_fake' if brain else 'random', batch_size=None, target_image_size=(640, 400), ) class PreProcModel(tf.keras.models.Model): def call(self, inputs): image, kspace = inputs return kspace_transform(image, kspace) selection = [ { 'inner_slices': None, 'rand': False }, # slice selection { 'rand': False, 'keep_dim': False }, # coil selection ] extension = f'_{run_id}.tfrecords' # Model init af = int(af) run_params = { 'n_primal': n_primal, 'multicoil': True, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'refine_big': refine_big, 'res': res, 'output_shape_spec': brain, 'primal_only': primal_only, 'n_dual': n_dual, 'n_dual_filters': n_dual_filters, } mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): preproc_model = PreProcModel() model = XPDNet(model_fun, model_kwargs, **run_params) fake_inputs = [ tf.zeros([1, 15, 640, 372, 1], dtype=tf.complex64), tf.zeros([1, 15, 640, 372], dtype=tf.complex64), tf.zeros([1, 15, 640, 372], dtype=tf.complex64), ] if brain: fake_inputs.append(tf.constant([[320, 320]])) model(fake_inputs) model.load_weights( f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5') for filename in tqdm(filenames): directory = filename.parent filename_tfrecord = directory / (filename.stem + extension) if filename_tfrecord.exists(): continue image, kspace, _ = from_multicoil_train_file_to_image_and_kspace_and_contrast( filename, selection=selection, ) model_inputs, model_outputs = preproc_model.predict([image, kspace]) res = model.predict(model_inputs, batch_size=4) with as writer: example = encode_postproc_example([res], [model_outputs]) writer.write(example)
def xpdnet_inference( model_fun, model_kwargs, run_id, multicoil=True, exp_id='xpdnet', brain=False, challenge=False, n_epochs=200, contrast=None, af=4, n_iter=10, res=True, n_scales=0, n_primal=5, refine_smaps=False, refine_big=False, n_samples=None, cuda_visible_devices='0123', primal_only=True, n_dual=1, n_dual_filters=16, distributed=False, manual_saving=False, ): if brain: if challenge: test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_challenge/' else: test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_test/' else: if multicoil: test_path = f'{FASTMRI_DATA_DIR}multicoil_test_v2/' else: test_path = f'{FASTMRI_DATA_DIR}singlecoil_test/' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) af = int(af) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'refine_big': refine_big, 'res': res, 'output_shape_spec': brain, 'primal_only': primal_only, 'n_dual': n_dual, 'n_dual_filters': n_dual_filters, } if multicoil: ds_fun = multicoil_dataset extra_kwargs = dict(output_shape_spec=brain) else: ds_fun = singecoil_dataset extra_kwargs = {} test_set = ds_fun(test_path, AF=af, contrast=contrast, scale_factor=1e6, n_samples=n_samples, **extra_kwargs) test_set_filenames = test_filenames( test_path, AF=af, contrast=contrast, n_samples=n_samples, ) if multicoil: fake_kspace_size = [15, 640, 372] else: fake_kspace_size = [640, 372] mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): if distributed and not manual_saving: model = load_model( f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}', custom_objects=CUSTOM_TF_OBJECTS, ) else: model = XPDNet(model_fun, model_kwargs, **run_params) fake_inputs = [ tf.zeros([1, *fake_kspace_size, 1], dtype=tf.complex64), tf.zeros([1, *fake_kspace_size], dtype=tf.complex64), ] if multicoil: fake_inputs.append( tf.zeros([1, *fake_kspace_size], dtype=tf.complex64)) if brain: fake_inputs.append(tf.constant([[320, 320]])) model(fake_inputs) model.load_weights( f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5') if n_samples is None: if not brain: if contrast: tqdm_total = n_volumes_test[af] // 2 else: tqdm_total = n_volumes_test[af] else: if contrast: tqdm_total = brain_volumes_per_contrast['test'][af][contrast] else: tqdm_total = brain_n_volumes_test[af] else: tqdm_total = n_samples tqdm_desc = f'{exp_id}_{contrast}_{af}' for data_example, filename in tqdm(zip(test_set, test_set_filenames), total=tqdm_total, desc=tqdm_desc): res = model.predict(data_example, batch_size=16) write_result( exp_id, res, filename.numpy().decode('utf-8'), scale_factor=1e6, brain=brain, challenge=challenge, coiltype='multicoil' if multicoil else 'singlecoil', )
def train_xpdnet( model_fun, model_kwargs, model_size=None, multicoil=True, af=4, contrast=None, cuda_visible_devices='0123', n_samples=None, n_epochs=200, n_iter=10, res=True, n_scales=0, n_primal=5, use_mixed_precision=False, refine_smaps=False, loss='mae', original_run_id=None, fixed_masks=False, ): r"""Train an XPDNet network on the fastMRI dataset. The training is done with a learning rate of 1e-4, using the RAdam optimizer. The validation is performed every 5 epochs on 5 volumes. A scale factor of 1e6 is applied to the data. Arguments: model_fun (function): the function initializing the image correction network of the XPDNet. model_kwargs (dict): the set of arguments used to initialize the image correction network. model_size (str or None): a string describing the size of the network used. This is used in the run id. Defaults to None. multicoil (bool): whether the input data is multicoil. Defaults to False. af (int): the acceleration factor for the retrospective undersampling of the data. Defaults to 4. contrast (str or None): the contrast used for this specific training. If None, all contrasts are considered. Defaults to None cuda_visible_devices (str): the GPUs to consider visible. Defaults to '0123'. n_samples (int or None): the number of samples to consider for this training. If None, all samples are considered. Defaults to None. n_epochs (int): the number of epochs (i.e. one pass though all the volumes/samples) for this training. Defaults to 200. n_iter (int): the number of iterations for the XPDNet. res (bool): whether the XPDNet image correction networks should be residual. n_scales (int): the number of scales used in the image correction network. Defaults to 0. n_primal (int): the size of the buffer in the image space. Defaults to 5. use_mixed_precision (bool): whether to use the mixed precision API for training. Currently not working. Defaults to False. refine_smaps (bool): whether you want to refine the sensitivity maps with a neural network. loss (tf.keras.losses.Loss or str): the loss function used for the training. It should be understandable by the tf.keras loss API, or be 'compound_mssim', in which case the compound L1 MSSIM loss inspired by [P2020]. Defaults to 'mae'. original_run_id (str or None): run id of the same network trained before fine-tuning. If this is present, the training is considered fine-tuning for a network trained for 250 epochs. It will therefore apply a learning rate of 1e-7 and the epoch size will be divided in half. If None, the training is done normally, without loading weights. Defaults to None. fixed_masks (bool): whether fixed masks should be used for the retrospective undersampling. Defaults to False Returns: - str: the run id of the trained network. """ # paths if multicoil: train_path = f'{FASTMRI_DATA_DIR}multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/' val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) af = int(af) # trying mixed precision if use_mixed_precision: policy_type = 'mixed_float16' else: policy_type = 'float32' policy = mixed_precision.Policy(policy_type) mixed_precision.set_policy(policy) # generators if multicoil: dataset = multicoil_dataset kwargs = {'parallel': False} else: dataset = singlecoil_dataset kwargs = {} train_set = dataset( train_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, n_samples=n_samples, fixed_masks=fixed_masks, **kwargs ) val_set = dataset( val_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, **kwargs ) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'res': res, } if multicoil: xpdnet_type = 'xpdnet_sense_' else: xpdnet_type = 'xpdnet_singlecoil_' additional_info = f'af{af}' if contrast is not None: additional_info += f'_{contrast}' if n_samples is not None: additional_info += f'_{n_samples}' if n_iter != 10: additional_info += f'_i{n_iter}' if loss != 'mae': additional_info += f'_{loss}' if refine_smaps: additional_info += '_rf_sm' if fixed_masks: additional_info += '_fixed_masks' submodel_info = model_fun.__name__ if model_size is not None: submodel_info += model_size run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}' chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5' chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True) log_dir = op.join(f'{LOGS_DIR}logs', run_id) tboard_cback = TensorBoard( profile_batch=0, log_dir=log_dir, histogram_freq=0, write_graph=False, write_images=False, ) tqdm_cback = TQDMProgressBar() model = XPDNet(model_fun, model_kwargs, **run_params) if original_run_id is not None: lr = 1e-7 n_steps = n_volumes_train//2 else: lr = 1e-4 n_steps = n_volumes_train default_model_compile(model, lr=lr, loss=loss) print(run_id) if original_run_id is not None: if os.environ.get('FASTMRI_DEBUG'): n_epochs_original = 1 else: n_epochs_original = 250 model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5') train_set, steps_per_epoch=n_steps, epochs=n_epochs, validation_data=val_set, validation_steps=5, validation_freq=5, verbose=0, callbacks=[tboard_cback, chkpt_cback, tqdm_cback], ) return run_id
def train_xpdnet( model_fun, model_kwargs, model_size=None, multicoil=True, brain=False, af=4, contrast=None, n_samples=None, batch_size=None, n_epochs=200, checkpoint_epoch=0, save_state=False, n_iter=10, res=True, n_scales=0, n_primal=5, use_mixed_precision=False, refine_smaps=False, refine_big=False, loss='mae', lr=1e-4, original_run_id=None, fixed_masks=False, n_epochs_original=250, equidistant_fake=False, multi_gpu=False, mask_type=None, primal_only=True, n_dual=1, n_dual_filters=16, multiscale_kspace_learning=False, distributed=False, manual_saving=False, ): r"""Train an XPDNet network on the fastMRI dataset. The training is done with a learning rate of 1e-4, using the RAdam optimizer. The validation is performed every 5 epochs on 5 volumes. A scale factor of 1e6 is applied to the data. Arguments: model_fun (function): the function initializing the image correction network of the XPDNet. model_kwargs (dict): the set of arguments used to initialize the image correction network. model_size (str or None): a string describing the size of the network used. This is used in the run id. Defaults to None. multicoil (bool): whether the input data is multicoil. Defaults to False. brain (bool): whether to consider brain data instead of knee. Defaults to False. af (int): the acceleration factor for the retrospective undersampling of the data. Defaults to 4. contrast (str or None): the contrast used for this specific training. If None, all contrasts are considered. Defaults to None n_samples (int or None): the number of samples to consider for this training. If None, all samples are considered. Defaults to None. n_epochs (int): the number of epochs (i.e. one pass though all the volumes/samples) for this training. Defaults to 200. checkpoint_epoch (int): the number of epochs used to train the model during the first step of the full training. This is typically used when on a cluster the training duration exceeds the maximum job duration. Defaults to 0, which means that the training is done without checkpoints. save_state (bool): whether you should save the entire model state for this training, for example to retrain where left off. Defaults to False. n_iter (int): the number of iterations for the XPDNet. res (bool): whether the XPDNet image correction networks should be residual. n_scales (int): the number of scales used in the image correction network. Defaults to 0. n_primal (int): the size of the buffer in the image space. Defaults to 5. use_mixed_precision (bool): whether to use the mixed precision API for training. Currently not working. Defaults to False. refine_smaps (bool): whether you want to refine the sensitivity maps with a neural network. loss (tf.keras.losses.Loss or str): the loss function used for the training. It should be understandable by the tf.keras loss API, or be 'compound_mssim', in which case the compound L1 MSSIM loss inspired by [P2020]. Defaults to 'mae'. original_run_id (str or None): run id of the same network trained before fine-tuning. If this is present, the training is considered fine-tuning for a network trained for 250 epochs. It will therefore apply a learning rate of 1e-7 and the epoch size will be divided in half. If None, the training is done normally, without loading weights. Defaults to None. fixed_masks (bool): whether fixed masks should be used for the retrospective undersampling. Defaults to False n_epochs_original (int): the number of epochs used to pre-train the model, only applicable if original_run_id is not None. Defaults to 250. equidistant_fake (bool): whether to use fake equidistant masks from fastMRI. Defaults to False. multi_gpu (bool): whether to use multiple GPUs for the XPDNet training. Defaults to False. Returns: - str: the run id of the trained network. """ if distributed: com_options = tf.distribute.experimental.CommunicationOptions( implementation=tf.distribute.experimental.CommunicationImplementation.NCCL, ) slurm_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(port_base=15000) mirrored_strategy = tf.distribute.MultiWorkerMirroredStrategy( cluster_resolver=slurm_resolver, communication_options=com_options, ) if brain: n_volumes = brain_n_volumes_train else: n_volumes = n_volumes_train # paths if multicoil: if brain: train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/' val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/' af = int(af) # trying mixed precision if use_mixed_precision: policy_type = 'mixed_float16' else: policy_type = 'float32' policy = mixed_precision.Policy(policy_type) mixed_precision.set_policy(policy) # generators if multicoil: dataset = multicoil_dataset if mask_type is None: if brain: if equidistant_fake: mask_type = 'equidistant_fake' else: mask_type = 'equidistant' else: mask_type = 'random' kwargs = { 'parallel': False, 'output_shape_spec': brain, 'mask_type': mask_type, } else: dataset = singlecoil_dataset kwargs = {} if distributed: def _dataset_fn(input_context, mode='train'): ds = dataset( train_path if mode == 'train' else val_path, input_context=input_context, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, batch_size=batch_size // input_context.num_replicas_in_sync, target_image_size=IM_SIZE, **kwargs ) options = options.experimental_distribute.auto_shard_policy = ds = ds.with_options(options) return ds train_set = mirrored_strategy.distribute_datasets_from_function(partial( _dataset_fn, mode='train', )) val_set = mirrored_strategy.distribute_datasets_from_function(partial( _dataset_fn, mode='val', )) else: train_set = dataset( train_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, n_samples=n_samples, fixed_masks=fixed_masks, batch_size=batch_size, target_image_size=IM_SIZE, **kwargs ) val_set = dataset( val_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, **kwargs ) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'res': res, 'output_shape_spec': brain, 'multi_gpu': multi_gpu, 'refine_big': refine_big, 'primal_only': primal_only, 'n_dual': n_dual, 'n_dual_filters': n_dual_filters, 'multiscale_kspace_learning': multiscale_kspace_learning, } if multicoil: xpdnet_type = 'xpdnet_sense_' if brain: xpdnet_type += 'brain_' else: xpdnet_type = 'xpdnet_singlecoil_' additional_info = f'af{af}' if contrast is not None: additional_info += f'_{contrast}' if n_samples is not None: additional_info += f'_{n_samples}' if n_iter != 10: additional_info += f'_i{n_iter}' if loss != 'mae': additional_info += f'_{loss}' if refine_smaps: additional_info += '_rf_sm' if refine_big: additional_info += 'b' if fixed_masks: additional_info += '_fixed_masks' submodel_info = model_fun.__name__ if model_size is not None: submodel_info += model_size if checkpoint_epoch == 0: run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}' else: run_id = original_run_id final_epoch = checkpoint_epoch + n_epochs if not distributed or slurm_resolver.task_id == 0: chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}' else: chkpt_path = f'{TMP_DIR}checkpoints/{run_id}' + '-{epoch:02d}' if not save_state or manual_saving: chkpt_path += '.hdf5' log_dir = op.join(f'{LOGS_DIR}logs', run_id) tboard_cback = TensorBoard( profile_batch=0, log_dir=log_dir, histogram_freq=0, write_graph=False, write_images=False, ) tqdm_cback = TQDMProgressBar() with ExitStack() as stack: # can't be always used because of if distributed: stack.enter_context(mirrored_strategy.scope()) if checkpoint_epoch == 0: model = XPDNet(model_fun, model_kwargs, **run_params) if original_run_id is not None: lr = 1e-7 n_steps = brain_volumes_per_contrast['train'].get(contrast, n_volumes)//2 else: n_steps = n_volumes default_model_compile(model, lr=lr, loss=loss) elif not manual_saving: model = load_model( f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}', custom_objects=CUSTOM_TF_OBJECTS, ) n_steps = n_volumes else: model = XPDNet(model_fun, model_kwargs, **run_params) n_steps = n_volumes default_model_compile(model, lr=lr, loss=loss) if batch_size is not None: n_steps //= batch_size chkpt_cback = ModelCheckpointWorkAround( chkpt_path, save_freq=int(n_epochs*n_steps), save_weights_only=(not save_state and not distributed) or manual_saving, ) print(run_id) if original_run_id is not None and (not checkpoint_epoch or manual_saving): if os.environ.get('FASTMRI_DEBUG'): n_epochs_original = 1 if manual_saving: n_epochs_original = checkpoint_epoch if multicoil: kspace_size = [1, 15, 640, 372] else: kspace_size = [1, 640, 372] inputs = [ tf.zeros(kspace_size + [1], dtype=tf.complex64), tf.zeros(kspace_size, dtype=tf.complex64), ] if multicoil: inputs.append(tf.zeros(kspace_size, dtype=tf.complex64)) if brain: inputs.append(tf.constant([[320, 320]])) with ExitStack() as stack: if distributed: # see stack.enter_context(mirrored_strategy.scope()) model(inputs) model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5') if manual_saving: def _model_weight_setting(): grad_vars = model.trainable_weights zero_grads = [tf.zeros_like(w) for w in grad_vars] model.optimizer.apply_gradients(zip(zero_grads, grad_vars)) with open(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-optimizer.pkl', 'rb') as f: weight_values = pickle.load(f) model.optimizer.set_weights(weight_values) if distributed: else: _model_weight_setting() train_set, steps_per_epoch=n_steps, initial_epoch=checkpoint_epoch, epochs=final_epoch, validation_data=val_set, validation_steps=5, validation_freq=5, verbose=0, callbacks=[tboard_cback, chkpt_cback, tqdm_cback], ) if manual_saving: symbolic_weights = getattr(model.optimizer, 'weights') weight_values = K.batch_get_value(symbolic_weights) with open(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-optimizer.pkl', 'wb') as f: pickle.dump(weight_values, f) return run_id
def evaluate_xpdnet( model_fun, model_kwargs, run_id, multicoil=True, n_epochs=200, contrast=None, af=4, n_iter=10, res=True, n_scales=0, n_primal=5, refine_smaps=False, n_samples=None, cuda_visible_devices='0123', ): if multicoil: val_path = f'{FASTMRI_DATA_DIR}multicoil_val/' else: val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) af = int(af) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'res': res, } if multicoil: dataset = multicoil_dataset kwargs = {'parallel': False} else: dataset = singlecoil_dataset kwargs = {} val_set = dataset( val_path, AF=af, contrast=contrast, inner_slices=None, rand=False, scale_factor=1e6, **kwargs, ) if n_samples is not None: val_set = val_set.take(n_samples) if multicoil: kspace_size = [1, 15, 640, 372] else: kspace_size = [1, 640, 372] model = XPDNet(model_fun, model_kwargs, **run_params) inputs = [ tf.zeros(kspace_size + [1], dtype=tf.complex64), tf.zeros(kspace_size, dtype=tf.complex64), ] if multicoil: inputs.append(tf.zeros(kspace_size, dtype=tf.complex64)) model(inputs) def tf_psnr(y_true, y_pred): perm_psnr = [3, 1, 2, 0] psnr = tf.image.psnr( tf.transpose(y_true, perm_psnr), tf.transpose(y_pred, perm_psnr), tf.reduce_max(y_true), ) return psnr def tf_ssim(y_true, y_pred): perm_ssim = [0, 1, 2, 3] ssim = tf.image.ssim( tf.transpose(y_true, perm_ssim), tf.transpose(y_pred, perm_ssim), tf.reduce_max(y_true), ) return ssim model.compile(loss=tf_psnr, metrics=[tf_ssim]) model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5') n_volumes = 199 if contrast is not None: n_volumes //= 2 n_volumes += 1 try: eval_res = model.evaluate(val_set, verbose=1, steps=n_volumes if n_samples is None else None) except tf.errors.ResourceExhaustedError: eval_res = Metrics(METRIC_FUNCS) if n_samples is None: val_set = val_set.take(n_volumes) for data in val_set: y_true = data[1].numpy() y_pred = model.predict(data[0], batch_size=1) eval_res.push(y_true[..., 0], y_pred[..., 0]) eval_res = [eval_res.means()['PSNR'], eval_res.means()['SSIM']] return model.metrics_names, eval_res
def train_xpdnet_block( model_fun, model_kwargs, model_size=None, multicoil=True, brain=False, af=4, contrast=None, n_samples=None, batch_size=None, n_epochs=200, n_iter=10, res=True, n_scales=0, n_primal=5, use_mixed_precision=False, refine_smaps=False, refine_big=False, loss='mae', lr=1e-4, fixed_masks=False, equidistant_fake=False, multi_gpu=False, mask_type=None, primal_only=True, n_dual=1, n_dual_filters=16, multiscale_kspace_learning=False, block_size=10, block_overlap=0, epochs_per_block_step=None, ): r"""Train an XPDNet network on the fastMRI dataset. The training is done with a learning rate of 1e-4, using the RAdam optimizer. The validation is performed every 5 epochs on 5 volumes. A scale factor of 1e6 is applied to the data. Arguments: model_fun (function): the function initializing the image correction network of the XPDNet. model_kwargs (dict): the set of arguments used to initialize the image correction network. model_size (str or None): a string describing the size of the network used. This is used in the run id. Defaults to None. multicoil (bool): whether the input data is multicoil. Defaults to False. brain (bool): whether to consider brain data instead of knee. Defaults to False. af (int): the acceleration factor for the retrospective undersampling of the data. Defaults to 4. contrast (str or None): the contrast used for this specific training. If None, all contrasts are considered. Defaults to None n_samples (int or None): the number of samples to consider for this training. If None, all samples are considered. Defaults to None. n_epochs (int): the number of epochs (i.e. one pass though all the volumes/samples) for this training. Defaults to 200. checkpoint_epoch (int): the number of epochs used to train the model during the first step of the full training. This is typically used when on a cluster the training duration exceeds the maximum job duration. Defaults to 0, which means that the training is done without checkpoints. save_state (bool): whether you should save the entire model state for this training, for example to retrain where left off. Defaults to False. n_iter (int): the number of iterations for the XPDNet. res (bool): whether the XPDNet image correction networks should be residual. n_scales (int): the number of scales used in the image correction network. Defaults to 0. n_primal (int): the size of the buffer in the image space. Defaults to 5. use_mixed_precision (bool): whether to use the mixed precision API for training. Currently not working. Defaults to False. refine_smaps (bool): whether you want to refine the sensitivity maps with a neural network. loss (tf.keras.losses.Loss or str): the loss function used for the training. It should be understandable by the tf.keras loss API, or be 'compound_mssim', in which case the compound L1 MSSIM loss inspired by [P2020]. Defaults to 'mae'. original_run_id (str or None): run id of the same network trained before fine-tuning. If this is present, the training is considered fine-tuning for a network trained for 250 epochs. It will therefore apply a learning rate of 1e-7 and the epoch size will be divided in half. If None, the training is done normally, without loading weights. Defaults to None. fixed_masks (bool): whether fixed masks should be used for the retrospective undersampling. Defaults to False n_epochs_original (int): the number of epochs used to pre-train the model, only applicable if original_run_id is not None. Defaults to 250. equidistant_fake (bool): whether to use fake equidistant masks from fastMRI. Defaults to False. multi_gpu (bool): whether to use multiple GPUs for the XPDNet training. Defaults to False. Returns: - str: the run id of the trained network. """ if brain: n_volumes = brain_n_volumes_train else: n_volumes = n_volumes_train # paths if multicoil: if brain: train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}multicoil_train/' val_path = f'{FASTMRI_DATA_DIR}multicoil_val/' else: train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/' val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/' af = int(af) # trying mixed precision if use_mixed_precision: policy_type = 'mixed_float16' else: policy_type = 'float32' policy = mixed_precision.Policy(policy_type) mixed_precision.set_policy(policy) # generators if multicoil: dataset = multicoil_dataset if mask_type is None: if brain: if equidistant_fake: mask_type = 'equidistant_fake' else: mask_type = 'equidistant' else: mask_type = 'random' kwargs = { 'parallel': False, 'output_shape_spec': brain, 'mask_type': mask_type, } else: dataset = singlecoil_dataset kwargs = {} train_set = dataset(train_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, n_samples=n_samples, fixed_masks=fixed_masks, batch_size=batch_size, target_image_size=IM_SIZE, **kwargs) val_set = dataset(val_path, AF=af, contrast=contrast, inner_slices=None, rand=True, scale_factor=1e6, **kwargs) run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'res': res, 'output_shape_spec': brain, 'multi_gpu': multi_gpu, 'refine_big': refine_big, 'primal_only': primal_only, 'n_dual': n_dual, 'n_dual_filters': n_dual_filters, 'multiscale_kspace_learning': multiscale_kspace_learning, } if multicoil: xpdnet_type = 'xpdnet_sense_' if brain: xpdnet_type += 'brain_' else: xpdnet_type = 'xpdnet_singlecoil_' additional_info = f'af{af}' if contrast is not None: additional_info += f'_{contrast}' if n_samples is not None: additional_info += f'_{n_samples}' if n_iter != 10: additional_info += f'_i{n_iter}' if loss != 'mae': additional_info += f'_{loss}' if refine_smaps: additional_info += '_rf_sm' if refine_big: additional_info += 'b' if fixed_masks: additional_info += '_fixed_masks' if block_overlap != 0: additional_info += f'_blkov{block_overlap}' submodel_info = model_fun.__name__ if model_size is not None: submodel_info += model_size run_id = f'{xpdnet_type}_{additional_info}_bbb_{submodel_info}_{int(time.time())}' chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}' chkpt_path += '.hdf5' log_dir = op.join(f'{LOGS_DIR}logs', run_id) tboard_cback = TensorBoard( profile_batch=0, log_dir=log_dir, histogram_freq=0, write_graph=False, write_images=False, ) tqdm_cback = TQDMProgressBar() model = XPDNet(model_fun, model_kwargs, **run_params) n_steps = n_volumes if batch_size is not None: n_steps //= batch_size chkpt_cback = ModelCheckpointWorkAround( chkpt_path, save_freq=int(n_epochs * n_steps), save_weights_only=True, ) print(run_id) stride = block_size - block_overlap assert stride > 0 n_block_steps = int(math.ceil((n_iter - block_size) / stride) + 1) ## epochs handling start_epoch = 0 final_epoch = min(epochs_per_block_step, n_epochs) for i_step in range(n_block_steps): first_block_to_train = i_step * stride blocks = list( range(first_block_to_train, first_block_to_train + block_size)) model.blocks_to_train = blocks default_model_compile(model, lr=lr, loss=loss) train_set, steps_per_epoch=n_steps, initial_epoch=start_epoch, epochs=final_epoch, validation_data=val_set, validation_steps=5, validation_freq=5, verbose=0, callbacks=[tboard_cback, chkpt_cback, tqdm_cback], ) n_epochs = n_epochs - (final_epoch - start_epoch) if n_epochs <= 0: break start_epoch = final_epoch final_epoch += min(epochs_per_block_step, n_epochs) return run_id
def xpdnet_inference( model_fun, model_kwargs, run_id, exp_id='xpdnet', brain=False, challenge=False, n_epochs=200, contrast=None, af=4, n_iter=10, res=True, n_scales=0, n_primal=5, refine_smaps=False, refine_big=False, n_samples=None, cuda_visible_devices='0123', ): if brain: if challenge: test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_challenge/' else: test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_test/' else: test_path = f'{FASTMRI_DATA_DIR}multicoil_test_v2/' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) af = int(af) run_params = { 'n_primal': n_primal, 'multicoil': True, 'n_scales': n_scales, 'n_iter': n_iter, 'refine_smaps': refine_smaps, 'refine_big': refine_big, 'res': res, 'output_shape_spec': brain, } test_set = test_masked_kspace_dataset_from_indexable( test_path, AF=af, contrast=contrast, scale_factor=1e6, n_samples=n_samples, output_shape_spec=brain, ) test_set_filenames = test_filenames( test_path, AF=af, contrast=contrast, n_samples=n_samples, ) mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = XPDNet(model_fun, model_kwargs, **run_params) fake_inputs = [ tf.zeros([1, 15, 640, 372, 1], dtype=tf.complex64), tf.zeros([1, 15, 640, 372], dtype=tf.complex64), tf.zeros([1, 15, 640, 372], dtype=tf.complex64), ] if brain: fake_inputs.append(tf.constant([[320, 320]])) model(fake_inputs) model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5') if n_samples is None: if not brain: if contrast: tqdm_total = n_volumes_test[af] // 2 else: tqdm_total = n_volumes_test[af] else: if contrast: tqdm_total = brain_volumes_per_contrast['test'][af][contrast] else: tqdm_total = brain_n_volumes_test[af] else: tqdm_total = n_samples tqdm_desc = f'{exp_id}_{contrast}_{af}' # TODO: change when the following issue has been dealt with # @tf.function(experimental_relax_shapes=True) def predict(t): return model(t) for data_example, filename in tqdm(zip(test_set, test_set_filenames), total=tqdm_total, desc=tqdm_desc): res = predict(data_example) write_result( exp_id, res.numpy(), filename.numpy().decode('utf-8'), scale_factor=1e6, brain=brain, challenge=challenge, )