Beispiel #1
0
def test_xpdnet():
    n_primal = 2
    n_scales = 3
    submodel_kwargs = dict(
        n_scales=n_scales,
        kernel_size=3,
        bn=False,
        n_filters_per_scale=[4, 8, 8],
        n_convs_per_scale=[2, 2, 2],
        n_first_convs=2,
        first_conv_n_filters=4,
        res=False,
        n_outputs=2 * n_primal,
    )
    model = XPDNet(
        model_fun=MWCNN,
        model_kwargs=submodel_kwargs,
        n_primal=n_primal,
        n_iter=2,
        multicoil=True,
        n_scales=n_scales,
    )
    model([
        tf.zeros([1, 5, 640, 320, 1], dtype=tf.complex64),  # kspace
        tf.zeros([1, 5, 640, 320], dtype=tf.complex64),  # mask
        tf.zeros([1, 5, 640, 320], dtype=tf.complex64),  # smaps
    ])
def test_works_in_xpdnet_train(model_fun,
                               model_kwargs,
                               n_scales,
                               res,
                               n_iter=10,
                               multicoil=False,
                               use_mixed_precision=False,
                               data_consistency_learning=False):
    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    mixed_precision.set_global_policy(policy_type)
    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': multicoil,
        'res': res,
        'primal_only': not data_consistency_learning,
    }
    model = XPDNet(model_fun, model_kwargs, **run_params)
    default_model_compile(model, lr=1e-3, loss='mae')
    n_coils = 15
    k_shape = (640, 400)
    if multicoil:
        k_shape = (n_coils, *k_shape)
    inputs = [
        tf.ones([1, *k_shape, 1], dtype=tf.complex64),
        tf.ones([1, *k_shape], dtype=tf.complex64),
    ]
    if multicoil:
        inputs += [
            tf.ones([1, *k_shape], dtype=tf.complex64),
        ]
    try:
        model.fit(
            x=inputs,
            y=tf.ones([1, 320, 320, 1]),
            epochs=1,
        )
    except (tf.errors.ResourceExhaustedError, tf.errors.InternalError):
        return False
    else:
        return True
Beispiel #3
0
def evaluate_xpdnet(
    model_fun,
    model_kwargs,
    run_id,
    multicoil=True,
    brain=False,
    n_epochs=200,
    contrast=None,
    af=4,
    n_iter=10,
    res=True,
    n_scales=0,
    n_primal=5,
    refine_smaps=False,
    refine_big=False,
    n_samples=None,
    cuda_visible_devices='0123',
    equidistant_fake=False,
    mask_type=None,
    primal_only=True,
    n_dual=1,
    n_dual_filters=16,
    multiscale_kspace_learning=False,
):
    if multicoil:
        if brain:
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'refine_big': refine_big,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
        'multiscale_kspace_learning': multiscale_kspace_learning,
    }

    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=False,
        scale_factor=1e6,
        **kwargs,
    )
    if brain:
        n_volumes = brain_n_volumes_validation
        if contrast is not None:
            n_volumes = brain_volumes_per_contrast['validation'][contrast]
    else:
        n_volumes = n_volumes_val
        if contrast is not None:
            n_volumes //= 2
            n_volumes += 1
    if n_samples is not None:
        val_set = val_set.take(n_samples)
    else:
        val_set = val_set.take(n_volumes)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        if multicoil:
            kspace_size = [1, 15, 640, 372]
        else:
            kspace_size = [1, 640, 372]

        model = XPDNet(model_fun, model_kwargs, **run_params)
        inputs = [
            tf.zeros(kspace_size + [1], dtype=tf.complex64),
            tf.zeros(kspace_size, dtype=tf.complex64),
        ]
        if multicoil:
            inputs.append(tf.zeros(kspace_size, dtype=tf.complex64))
        if brain:
            inputs.append(tf.constant([[320, 320]]))
        model(inputs)
    model.load_weights(
        f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5')
    eval_res = Metrics(METRIC_FUNCS)
    for x, y_true in tqdm(val_set.as_numpy_iterator(),
                          total=n_volumes if n_samples is None else n_samples):
        y_pred = model.predict(x, batch_size=4)
        eval_res.push(y_true[..., 0], y_pred[..., 0])
    return METRIC_FUNCS, (list(eval_res.means().values()),
                          list(eval_res.stddevs().values()))
def train_xpdnet(
        model_fun,
        model_kwargs,
        model_size=None,
        multicoil=True,
        brain=False,
        af=4,
        contrast=None,
        cuda_visible_devices='0123',
        n_samples=None,
        n_epochs=200,
        checkpoint_epoch=0,
        save_state=False,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        use_mixed_precision=False,
        refine_smaps=False,
        refine_big=False,
        loss='mae',
        original_run_id=None,
        fixed_masks=False,
        n_epochs_original=250,
        equidistant_fake=False,
        multi_gpu=False,
        mask_type=None,
    ):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        brain (bool): whether to consider brain data instead of knee. Defaults
            to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        cuda_visible_devices (str): the GPUs to consider visible. Defaults to
            '0123'.
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        checkpoint_epoch (int): the number of epochs used to train the model
            during the first step of the full training. This is typically used
            when on a cluster the training duration exceeds the maximum job
            duration. Defaults to 0, which means that the training is done
            without checkpoints.
        save_state (bool): whether you should save the entire model state for
            this training, for example to retrain where left off. Defaults to
            False.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False
        n_epochs_original (int): the number of epochs used to pre-train the
            model, only applicable if original_run_id is not None. Defaults to
            250.
        equidistant_fake (bool): whether to use fake equidistant masks from
            fastMRI. Defaults to False.
        multi_gpu (bool): whether to use multiple GPUs for the XPDNet training.
            Defaults to False.

    Returns:
        - str: the run id of the trained network.
    """
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if multicoil:
        if brain:
            train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'


    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(
        train_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        n_samples=n_samples,
        fixed_masks=fixed_masks,
        **kwargs
    )
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        **kwargs
    )

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'multi_gpu': multi_gpu,
        'refine_big': refine_big,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
        if brain:
            xpdnet_type += 'brain_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
        if refine_big:
            additional_info += 'b'
    if fixed_masks:
        additional_info += '_fixed_masks'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    if checkpoint_epoch == 0:
        run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}'
    else:
        run_id = original_run_id
    final_epoch = checkpoint_epoch + n_epochs
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    if not save_state:
        chkpt_path += '.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    if checkpoint_epoch == 0:
        model = XPDNet(model_fun, model_kwargs, **run_params)
        if original_run_id is not None:
            lr = 1e-7
            n_steps = brain_volumes_per_contrast['train'].get(contrast, n_volumes)//2
        else:
            lr = 1e-4
            n_steps = n_volumes
        default_model_compile(model, lr=lr, loss=loss)
    else:
        model = load_model(
            f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}',
            custom_objects=CUSTOM_TF_OBJECTS,
        )
        n_steps = n_volumes

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=n_epochs*n_steps,
        save_weights_only=not save_state,
    )
    print(run_id)
    if original_run_id is not None and not checkpoint_epoch:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        if multicoil:
            kspace_size = [1, 15, 640, 372]
        else:
            kspace_size = [1, 640, 372]
        inputs = [
            tf.zeros(kspace_size + [1], dtype=tf.complex64),
            tf.zeros(kspace_size, dtype=tf.complex64),
        ]
        if multicoil:
            inputs.append(tf.zeros(kspace_size, dtype=tf.complex64))
        if brain:
            inputs.append(tf.constant([[320, 320]]))
        model(inputs)
        model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        initial_epoch=checkpoint_epoch,
        epochs=final_epoch,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def generate_postproc_tf_records(
    model_fun,
    model_kwargs,
    run_id,
    brain=False,
    n_epochs=200,
    af=4,
    n_iter=10,
    res=True,
    n_scales=0,
    n_primal=5,
    refine_smaps=False,
    refine_big=False,
    primal_only=True,
    n_dual=1,
    n_dual_filters=16,
    mode='train',
):
    main_path = Path(FASTMRI_DATA_DIR)
    if brain:
        path = main_path / f'brain_multicoil_{mode}'
    else:
        path = main_path / f'multicoil_{mode}'
    filenames = sorted(list(path.glob('*.h5')))
    kspace_transform = generic_from_kspace_to_masked_kspace_and_mask(
        AF=af,
        scale_factor=1e6,
        parallel=False,
        fixed_masks=False,
        output_shape_spec=brain,
        mask_type='equidistant_fake' if brain else 'random',
        batch_size=None,
        target_image_size=(640, 400),
    )

    class PreProcModel(tf.keras.models.Model):
        def call(self, inputs):
            image, kspace = inputs
            return kspace_transform(image, kspace)

    selection = [
        {
            'inner_slices': None,
            'rand': False
        },  # slice selection
        {
            'rand': False,
            'keep_dim': False
        },  # coil selection
    ]
    extension = f'_{run_id}.tfrecords'
    # Model init
    af = int(af)

    run_params = {
        'n_primal': n_primal,
        'multicoil': True,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'refine_big': refine_big,
        'res': res,
        'output_shape_spec': brain,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
    }
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        preproc_model = PreProcModel()
        model = XPDNet(model_fun, model_kwargs, **run_params)
        fake_inputs = [
            tf.zeros([1, 15, 640, 372, 1], dtype=tf.complex64),
            tf.zeros([1, 15, 640, 372], dtype=tf.complex64),
            tf.zeros([1, 15, 640, 372], dtype=tf.complex64),
        ]
        if brain:
            fake_inputs.append(tf.constant([[320, 320]]))
        model(fake_inputs)
    model.load_weights(
        f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5')
    for filename in tqdm(filenames):
        directory = filename.parent
        filename_tfrecord = directory / (filename.stem + extension)
        if filename_tfrecord.exists():
            continue
        image, kspace, _ = from_multicoil_train_file_to_image_and_kspace_and_contrast(
            filename,
            selection=selection,
        )
        model_inputs, model_outputs = preproc_model.predict([image, kspace])
        res = model.predict(model_inputs, batch_size=4)
        with tf.io.TFRecordWriter(str(filename_tfrecord)) as writer:
            example = encode_postproc_example([res], [model_outputs])
            writer.write(example)
def xpdnet_inference(
    model_fun,
    model_kwargs,
    run_id,
    multicoil=True,
    exp_id='xpdnet',
    brain=False,
    challenge=False,
    n_epochs=200,
    contrast=None,
    af=4,
    n_iter=10,
    res=True,
    n_scales=0,
    n_primal=5,
    refine_smaps=False,
    refine_big=False,
    n_samples=None,
    cuda_visible_devices='0123',
    primal_only=True,
    n_dual=1,
    n_dual_filters=16,
    distributed=False,
    manual_saving=False,
):
    if brain:
        if challenge:
            test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_challenge/'
        else:
            test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_test/'
    else:
        if multicoil:
            test_path = f'{FASTMRI_DATA_DIR}multicoil_test_v2/'
        else:
            test_path = f'{FASTMRI_DATA_DIR}singlecoil_test/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'refine_big': refine_big,
        'res': res,
        'output_shape_spec': brain,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
    }
    if multicoil:
        ds_fun = multicoil_dataset
        extra_kwargs = dict(output_shape_spec=brain)
    else:
        ds_fun = singecoil_dataset
        extra_kwargs = {}
    test_set = ds_fun(test_path,
                      AF=af,
                      contrast=contrast,
                      scale_factor=1e6,
                      n_samples=n_samples,
                      **extra_kwargs)
    test_set_filenames = test_filenames(
        test_path,
        AF=af,
        contrast=contrast,
        n_samples=n_samples,
    )
    if multicoil:
        fake_kspace_size = [15, 640, 372]
    else:
        fake_kspace_size = [640, 372]
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        if distributed and not manual_saving:
            model = load_model(
                f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}',
                custom_objects=CUSTOM_TF_OBJECTS,
            )
        else:
            model = XPDNet(model_fun, model_kwargs, **run_params)
            fake_inputs = [
                tf.zeros([1, *fake_kspace_size, 1], dtype=tf.complex64),
                tf.zeros([1, *fake_kspace_size], dtype=tf.complex64),
            ]
            if multicoil:
                fake_inputs.append(
                    tf.zeros([1, *fake_kspace_size], dtype=tf.complex64))
            if brain:
                fake_inputs.append(tf.constant([[320, 320]]))
            model(fake_inputs)
            model.load_weights(
                f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5')
    if n_samples is None:
        if not brain:
            if contrast:
                tqdm_total = n_volumes_test[af] // 2
            else:
                tqdm_total = n_volumes_test[af]
        else:
            if contrast:
                tqdm_total = brain_volumes_per_contrast['test'][af][contrast]
            else:
                tqdm_total = brain_n_volumes_test[af]
    else:
        tqdm_total = n_samples
    tqdm_desc = f'{exp_id}_{contrast}_{af}'

    for data_example, filename in tqdm(zip(test_set, test_set_filenames),
                                       total=tqdm_total,
                                       desc=tqdm_desc):
        res = model.predict(data_example, batch_size=16)
        write_result(
            exp_id,
            res,
            filename.numpy().decode('utf-8'),
            scale_factor=1e6,
            brain=brain,
            challenge=challenge,
            coiltype='multicoil' if multicoil else 'singlecoil',
        )
Beispiel #7
0
def train_xpdnet(
        model_fun,
        model_kwargs,
        model_size=None,
        multicoil=True,
        af=4,
        contrast=None,
        cuda_visible_devices='0123',
        n_samples=None,
        n_epochs=200,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        use_mixed_precision=False,
        refine_smaps=False,
        loss='mae',
        original_run_id=None,
        fixed_masks=False,
    ):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        cuda_visible_devices (str): the GPUs to consider visible. Defaults to
            '0123'.
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False

    Returns:
        - str: the run id of the trained network.
    """
    # paths
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'


    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        kwargs = {'parallel': False}
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(
        train_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        n_samples=n_samples,
        fixed_masks=fixed_masks,
        **kwargs
    )
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        **kwargs
    )

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
    if fixed_masks:
        additional_info += '_fixed_masks'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    model = XPDNet(model_fun, model_kwargs, **run_params)
    if original_run_id is not None:
        lr = 1e-7
        n_steps = n_volumes_train//2
    else:
        lr = 1e-4
        n_steps = n_volumes_train
    default_model_compile(model, lr=lr, loss=loss)
    print(run_id)
    if original_run_id is not None:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        else:
            n_epochs_original = 250
        model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def train_xpdnet(
        model_fun,
        model_kwargs,
        model_size=None,
        multicoil=True,
        brain=False,
        af=4,
        contrast=None,
        n_samples=None,
        batch_size=None,
        n_epochs=200,
        checkpoint_epoch=0,
        save_state=False,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        use_mixed_precision=False,
        refine_smaps=False,
        refine_big=False,
        loss='mae',
        lr=1e-4,
        original_run_id=None,
        fixed_masks=False,
        n_epochs_original=250,
        equidistant_fake=False,
        multi_gpu=False,
        mask_type=None,
        primal_only=True,
        n_dual=1,
        n_dual_filters=16,
        multiscale_kspace_learning=False,
        distributed=False,
        manual_saving=False,
    ):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        brain (bool): whether to consider brain data instead of knee. Defaults
            to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        checkpoint_epoch (int): the number of epochs used to train the model
            during the first step of the full training. This is typically used
            when on a cluster the training duration exceeds the maximum job
            duration. Defaults to 0, which means that the training is done
            without checkpoints.
        save_state (bool): whether you should save the entire model state for
            this training, for example to retrain where left off. Defaults to
            False.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False
        n_epochs_original (int): the number of epochs used to pre-train the
            model, only applicable if original_run_id is not None. Defaults to
            250.
        equidistant_fake (bool): whether to use fake equidistant masks from
            fastMRI. Defaults to False.
        multi_gpu (bool): whether to use multiple GPUs for the XPDNet training.
            Defaults to False.

    Returns:
        - str: the run id of the trained network.
    """
    if distributed:
        com_options = tf.distribute.experimental.CommunicationOptions(
            implementation=tf.distribute.experimental.CommunicationImplementation.NCCL,
        )
        slurm_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(port_base=15000)
        mirrored_strategy = tf.distribute.MultiWorkerMirroredStrategy(
            cluster_resolver=slurm_resolver,
            communication_options=com_options,
        )
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if multicoil:
        if brain:
            train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    if distributed:
        def _dataset_fn(input_context, mode='train'):
            ds = dataset(
                train_path if mode == 'train' else val_path,
                input_context=input_context,
                AF=af,
                contrast=contrast,
                inner_slices=None,
                rand=True,
                scale_factor=1e6,
                batch_size=batch_size // input_context.num_replicas_in_sync,
                target_image_size=IM_SIZE,
                **kwargs
            )
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
            ds = ds.with_options(options)
            return ds
        train_set = mirrored_strategy.distribute_datasets_from_function(partial(
            _dataset_fn,
            mode='train',
        ))
        val_set = mirrored_strategy.distribute_datasets_from_function(partial(
            _dataset_fn,
            mode='val',
        ))
    else:
        train_set = dataset(
            train_path,
            AF=af,
            contrast=contrast,
            inner_slices=None,
            rand=True,
            scale_factor=1e6,
            n_samples=n_samples,
            fixed_masks=fixed_masks,
            batch_size=batch_size,
            target_image_size=IM_SIZE,
            **kwargs
        )
        val_set = dataset(
            val_path,
            AF=af,
            contrast=contrast,
            inner_slices=None,
            rand=True,
            scale_factor=1e6,
            **kwargs
        )

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'multi_gpu': multi_gpu,
        'refine_big': refine_big,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
        'multiscale_kspace_learning': multiscale_kspace_learning,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
        if brain:
            xpdnet_type += 'brain_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
        if refine_big:
            additional_info += 'b'
    if fixed_masks:
        additional_info += '_fixed_masks'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    if checkpoint_epoch == 0:
        run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}'
    else:
        run_id = original_run_id
    final_epoch = checkpoint_epoch + n_epochs
    if not distributed or slurm_resolver.task_id == 0:
        chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    else:
        chkpt_path = f'{TMP_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    if not save_state or manual_saving:
        chkpt_path += '.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    with ExitStack() as stack:
        # can't be always used because of https://github.com/tensorflow/tensorflow/issues/46146
        if distributed:
            stack.enter_context(mirrored_strategy.scope())
        if checkpoint_epoch == 0:
            model = XPDNet(model_fun, model_kwargs, **run_params)
            if original_run_id is not None:
                lr = 1e-7
                n_steps = brain_volumes_per_contrast['train'].get(contrast, n_volumes)//2
            else:
                n_steps = n_volumes
            default_model_compile(model, lr=lr, loss=loss)
        elif not manual_saving:
            model = load_model(
                f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}',
                custom_objects=CUSTOM_TF_OBJECTS,
            )
            n_steps = n_volumes
        else:
            model = XPDNet(model_fun, model_kwargs, **run_params)
            n_steps = n_volumes
            default_model_compile(model, lr=lr, loss=loss)

    if batch_size is not None:
        n_steps //= batch_size

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=int(n_epochs*n_steps),
        save_weights_only=(not save_state and not distributed) or manual_saving,
    )
    print(run_id)
    if original_run_id is not None and (not checkpoint_epoch or manual_saving):
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        if manual_saving:
            n_epochs_original = checkpoint_epoch
        if multicoil:
            kspace_size = [1, 15, 640, 372]
        else:
            kspace_size = [1, 640, 372]
        inputs = [
            tf.zeros(kspace_size + [1], dtype=tf.complex64),
            tf.zeros(kspace_size, dtype=tf.complex64),
        ]
        if multicoil:
            inputs.append(tf.zeros(kspace_size, dtype=tf.complex64))
        if brain:
            inputs.append(tf.constant([[320, 320]]))
        with ExitStack() as stack:
            if distributed:
                # see https://github.com/tensorflow/tensorflow/issues/32561#issuecomment-544319907
                stack.enter_context(mirrored_strategy.scope())
            model(inputs)
            model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

        if manual_saving:
            def _model_weight_setting():
                grad_vars = model.trainable_weights
                zero_grads = [tf.zeros_like(w) for w in grad_vars]
                model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
                with open(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-optimizer.pkl', 'rb') as f:
                    weight_values = pickle.load(f)
                model.optimizer.set_weights(weight_values)
            if distributed:
                mirrored_strategy.run(_model_weight_setting)
            else:
                _model_weight_setting()

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        initial_epoch=checkpoint_epoch,
        epochs=final_epoch,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )

    if manual_saving:
        symbolic_weights = getattr(model.optimizer, 'weights')
        weight_values = K.batch_get_value(symbolic_weights)
        with open(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-optimizer.pkl', 'wb') as f:
            pickle.dump(weight_values, f)
    return run_id
def evaluate_xpdnet(
        model_fun,
        model_kwargs,
        run_id,
        multicoil=True,
        n_epochs=200,
        contrast=None,
        af=4,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        refine_smaps=False,
        n_samples=None,
        cuda_visible_devices='0123',
    ):
    if multicoil:
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
    }

    if multicoil:
        dataset = multicoil_dataset
        kwargs = {'parallel': False}
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=False,
        scale_factor=1e6,
        **kwargs,
    )
    if n_samples is not None:
        val_set = val_set.take(n_samples)

    if multicoil:
        kspace_size = [1, 15, 640, 372]
    else:
        kspace_size = [1, 640, 372]

    model = XPDNet(model_fun, model_kwargs, **run_params)
    inputs = [
        tf.zeros(kspace_size + [1], dtype=tf.complex64),
        tf.zeros(kspace_size, dtype=tf.complex64),
    ]
    if multicoil:
        inputs.append(tf.zeros(kspace_size, dtype=tf.complex64))
    model(inputs)
    def tf_psnr(y_true, y_pred):
        perm_psnr = [3, 1, 2, 0]
        psnr = tf.image.psnr(
            tf.transpose(y_true, perm_psnr),
            tf.transpose(y_pred, perm_psnr),
            tf.reduce_max(y_true),
        )
        return psnr
    def tf_ssim(y_true, y_pred):
        perm_ssim = [0, 1, 2, 3]
        ssim = tf.image.ssim(
            tf.transpose(y_true, perm_ssim),
            tf.transpose(y_pred, perm_ssim),
            tf.reduce_max(y_true),
        )
        return ssim

    model.compile(loss=tf_psnr, metrics=[tf_ssim])
    model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5')
    n_volumes = 199
    if contrast is not None:
        n_volumes //= 2
        n_volumes += 1
    try:
        eval_res = model.evaluate(val_set, verbose=1, steps=n_volumes if n_samples is None else None)
    except tf.errors.ResourceExhaustedError:
        eval_res = Metrics(METRIC_FUNCS)
        if n_samples is None:
            val_set = val_set.take(n_volumes)
        for data in val_set:
            y_true = data[1].numpy()
            y_pred = model.predict(data[0], batch_size=1)
            eval_res.push(y_true[..., 0], y_pred[..., 0])
        eval_res = [eval_res.means()['PSNR'], eval_res.means()['SSIM']]
    return model.metrics_names, eval_res
def train_xpdnet_block(
    model_fun,
    model_kwargs,
    model_size=None,
    multicoil=True,
    brain=False,
    af=4,
    contrast=None,
    n_samples=None,
    batch_size=None,
    n_epochs=200,
    n_iter=10,
    res=True,
    n_scales=0,
    n_primal=5,
    use_mixed_precision=False,
    refine_smaps=False,
    refine_big=False,
    loss='mae',
    lr=1e-4,
    fixed_masks=False,
    equidistant_fake=False,
    multi_gpu=False,
    mask_type=None,
    primal_only=True,
    n_dual=1,
    n_dual_filters=16,
    multiscale_kspace_learning=False,
    block_size=10,
    block_overlap=0,
    epochs_per_block_step=None,
):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        brain (bool): whether to consider brain data instead of knee. Defaults
            to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        checkpoint_epoch (int): the number of epochs used to train the model
            during the first step of the full training. This is typically used
            when on a cluster the training duration exceeds the maximum job
            duration. Defaults to 0, which means that the training is done
            without checkpoints.
        save_state (bool): whether you should save the entire model state for
            this training, for example to retrain where left off. Defaults to
            False.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False
        n_epochs_original (int): the number of epochs used to pre-train the
            model, only applicable if original_run_id is not None. Defaults to
            250.
        equidistant_fake (bool): whether to use fake equidistant masks from
            fastMRI. Defaults to False.
        multi_gpu (bool): whether to use multiple GPUs for the XPDNet training.
            Defaults to False.

    Returns:
        - str: the run id of the trained network.
    """
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if multicoil:
        if brain:
            train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(train_path,
                        AF=af,
                        contrast=contrast,
                        inner_slices=None,
                        rand=True,
                        scale_factor=1e6,
                        n_samples=n_samples,
                        fixed_masks=fixed_masks,
                        batch_size=batch_size,
                        target_image_size=IM_SIZE,
                        **kwargs)
    val_set = dataset(val_path,
                      AF=af,
                      contrast=contrast,
                      inner_slices=None,
                      rand=True,
                      scale_factor=1e6,
                      **kwargs)

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'multi_gpu': multi_gpu,
        'refine_big': refine_big,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
        'multiscale_kspace_learning': multiscale_kspace_learning,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
        if brain:
            xpdnet_type += 'brain_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
        if refine_big:
            additional_info += 'b'
    if fixed_masks:
        additional_info += '_fixed_masks'
    if block_overlap != 0:
        additional_info += f'_blkov{block_overlap}'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    run_id = f'{xpdnet_type}_{additional_info}_bbb_{submodel_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    chkpt_path += '.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    model = XPDNet(model_fun, model_kwargs, **run_params)
    n_steps = n_volumes

    if batch_size is not None:
        n_steps //= batch_size

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=int(n_epochs * n_steps),
        save_weights_only=True,
    )
    print(run_id)
    stride = block_size - block_overlap
    assert stride > 0
    n_block_steps = int(math.ceil((n_iter - block_size) / stride) + 1)
    ## epochs handling
    start_epoch = 0
    final_epoch = min(epochs_per_block_step, n_epochs)

    for i_step in range(n_block_steps):
        first_block_to_train = i_step * stride
        blocks = list(
            range(first_block_to_train, first_block_to_train + block_size))
        model.blocks_to_train = blocks
        default_model_compile(model, lr=lr, loss=loss)

        model.fit(
            train_set,
            steps_per_epoch=n_steps,
            initial_epoch=start_epoch,
            epochs=final_epoch,
            validation_data=val_set,
            validation_steps=5,
            validation_freq=5,
            verbose=0,
            callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
        )
        n_epochs = n_epochs - (final_epoch - start_epoch)
        if n_epochs <= 0:
            break
        start_epoch = final_epoch
        final_epoch += min(epochs_per_block_step, n_epochs)
    return run_id
Beispiel #11
0
def xpdnet_inference(
        model_fun,
        model_kwargs,
        run_id,
        exp_id='xpdnet',
        brain=False,
        challenge=False,
        n_epochs=200,
        contrast=None,
        af=4,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        refine_smaps=False,
        refine_big=False,
        n_samples=None,
        cuda_visible_devices='0123',
    ):
    if brain:
        if challenge:
            test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_challenge/'
        else:
            test_path = f'{FASTMRI_DATA_DIR}brain_multicoil_test/'
    else:
        test_path = f'{FASTMRI_DATA_DIR}multicoil_test_v2/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    run_params = {
        'n_primal': n_primal,
        'multicoil': True,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'refine_big': refine_big,
        'res': res,
        'output_shape_spec': brain,
    }

    test_set = test_masked_kspace_dataset_from_indexable(
        test_path,
        AF=af,
        contrast=contrast,
        scale_factor=1e6,
        n_samples=n_samples,
        output_shape_spec=brain,
    )
    test_set_filenames = test_filenames(
        test_path,
        AF=af,
        contrast=contrast,
        n_samples=n_samples,
    )

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = XPDNet(model_fun, model_kwargs, **run_params)
        fake_inputs = [
            tf.zeros([1, 15, 640, 372, 1], dtype=tf.complex64),
            tf.zeros([1, 15, 640, 372], dtype=tf.complex64),
            tf.zeros([1, 15, 640, 372], dtype=tf.complex64),
        ]
        if brain:
            fake_inputs.append(tf.constant([[320, 320]]))
        model(fake_inputs)
    model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-{n_epochs:02d}.hdf5')
    if n_samples is None:
        if not brain:
            if contrast:
                tqdm_total = n_volumes_test[af] // 2
            else:
                tqdm_total = n_volumes_test[af]
        else:
            if contrast:
                tqdm_total = brain_volumes_per_contrast['test'][af][contrast]
            else:
                tqdm_total = brain_n_volumes_test[af]
    else:
        tqdm_total = n_samples
    tqdm_desc = f'{exp_id}_{contrast}_{af}'

    # TODO: change when the following issue has been dealt with
    # https://github.com/tensorflow/tensorflow/issues/38561
    @tf.function(experimental_relax_shapes=True)
    def predict(t):
        return model(t)

    for data_example, filename in tqdm(zip(test_set, test_set_filenames), total=tqdm_total, desc=tqdm_desc):
        res = predict(data_example)
        write_result(
            exp_id,
            res.numpy(),
            filename.numpy().decode('utf-8'),
            scale_factor=1e6,
            brain=brain,
            challenge=challenge,
        )