Ejemplo n.º 1
0
def train_denoiser(
        model,
        run_id,
        noise_std=30,
        contrast=None,
        n_samples=None,
        n_epochs=200,
        loss='mae',
        lr=1e-4,
        n_steps_per_epoch=973,  # number of volumes in the fastMRI dataset
    ):
    train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
    val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'
    train_set = train_noisy_dataset_from_indexable(
        train_path,
        noise_std=noise_std,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        n_samples=n_samples,
    )
    val_set = train_noisy_dataset_from_indexable(
        val_path,
        noise_std=noise_std,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
    )
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()
    if isinstance(model, tuple):
        model = build_model_from_specs(*model)
    default_model_compile(model, lr=lr, loss=loss)
    model.fit(
        train_set,
        steps_per_epoch=n_steps_per_epoch,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=10,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def train_denoiser(
        model,
        run_id,
        noise_std=30,
        contrast=None,
        n_samples=None,
        n_epochs=200,
        loss='mae',
        lr=1e-4,
        n_steps_per_epoch=973,  # number of volumes in the fastMRI dataset
    ):
    ds_kwargs = dict(
        contrast=contrast,
        slice_random=True,
        scale_factor=1e4,
        noise_input=False,
        noise_power_spec=noise_std,
        noise_mode='gaussian',
    )
    train_set = NoisyFastMRIDatasetBuilder(
        dataset='train',
        n_samples=n_samples,
        **ds_kwargs,
    ).preprocessed_ds
    val_set = NoisyFastMRIDatasetBuilder(
        dataset='val',
        **ds_kwargs,
    ).preprocessed_ds
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()
    if isinstance(model, tuple):
        model = build_model_from_specs(*model)
    default_model_compile(model, lr=lr, loss=loss)
    model.fit(
        train_set,
        steps_per_epoch=n_steps_per_epoch,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=100,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def train_model(model, space='K', n=1):
    print(model.summary(line_length=150))
    run_id = f'kikinet_sep_{space}{n}_af{AF}_{int(time.time())}'
    chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
    print(run_id)

    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs // 2)
    log_dir = op.join('logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=True,
        write_images=False,
    )
    lrate_cback = LearningRateScheduler(learning_rate_from_epoch)
    tqdm_cb = TQDMProgressBar()
    if space == 'K':
        train_gen = train_gen_k
        val_gen = val_gen_k
    elif space == 'I':
        if n == 2:
            train_gen = train_gen_last
            val_gen = val_gen_last
        elif n == 1:
            train_gen = train_gen_i
            val_gen = val_gen_i
    model.fit_generator(
        train_gen,
        steps_per_epoch=n_volumes_train,
        epochs=n_epochs,
        validation_data=val_gen,
        validation_steps=1,
        verbose=0,
        callbacks=[
            tqdm_cb,
            tboard_cback,
            chkpt_cback,
            lrate_cback,
        ],
        # max_queue_size=35,
        use_multiprocessing=True,
        workers=35,
        shuffle=True,
    )
    return model
Ejemplo n.º 4
0
def train_dealiaser(
        model_fun,
        model_kwargs,
        run_id,
        n_scales=0,
        multicoil=False,
        af=4,
        contrast=None,
        cuda_visible_devices='0123',
        n_samples=None,
        n_epochs=200,
        use_mixed_precision=False,
        loss='mae',
        original_run_id=None,
        fixed_masks=False,
        n_steps_per_epoch=973,
    ):
    # paths
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'


    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        kwargs = {'parallel': False}
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(
        train_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        n_samples=n_samples,
        fixed_masks=fixed_masks,
        **kwargs
    )
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        **kwargs
    )

    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if fixed_masks:
        additional_info += '_fixed_masks'

    run_id = f'{run_id}_{additional_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    sub_model = build_model_from_specs(model_fun, model_kwargs, 2)
    model = MultiscaleComplex(
        sub_model,
        res=False,
        n_scales=n_scales,
        fastmri_format=True,
    )
    if original_run_id is not None:
        lr = 1e-7
        n_steps = n_steps_per_epoch//2
    else:
        lr = 1e-4
        n_steps = n_steps_per_epoch
    default_model_compile(model, lr=lr, loss=loss)
    print(run_id)
    if original_run_id is not None:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        else:
            n_epochs_original = 250
        model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def train_xpdnet(
        model_fun,
        model_kwargs,
        model_size=None,
        multicoil=True,
        brain=False,
        af=4,
        contrast=None,
        cuda_visible_devices='0123',
        n_samples=None,
        n_epochs=200,
        checkpoint_epoch=0,
        save_state=False,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        use_mixed_precision=False,
        refine_smaps=False,
        refine_big=False,
        loss='mae',
        original_run_id=None,
        fixed_masks=False,
        n_epochs_original=250,
        equidistant_fake=False,
        multi_gpu=False,
        mask_type=None,
    ):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        brain (bool): whether to consider brain data instead of knee. Defaults
            to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        cuda_visible_devices (str): the GPUs to consider visible. Defaults to
            '0123'.
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        checkpoint_epoch (int): the number of epochs used to train the model
            during the first step of the full training. This is typically used
            when on a cluster the training duration exceeds the maximum job
            duration. Defaults to 0, which means that the training is done
            without checkpoints.
        save_state (bool): whether you should save the entire model state for
            this training, for example to retrain where left off. Defaults to
            False.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False
        n_epochs_original (int): the number of epochs used to pre-train the
            model, only applicable if original_run_id is not None. Defaults to
            250.
        equidistant_fake (bool): whether to use fake equidistant masks from
            fastMRI. Defaults to False.
        multi_gpu (bool): whether to use multiple GPUs for the XPDNet training.
            Defaults to False.

    Returns:
        - str: the run id of the trained network.
    """
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if multicoil:
        if brain:
            train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'


    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(
        train_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        n_samples=n_samples,
        fixed_masks=fixed_masks,
        **kwargs
    )
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        **kwargs
    )

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'multi_gpu': multi_gpu,
        'refine_big': refine_big,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
        if brain:
            xpdnet_type += 'brain_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
        if refine_big:
            additional_info += 'b'
    if fixed_masks:
        additional_info += '_fixed_masks'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    if checkpoint_epoch == 0:
        run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}'
    else:
        run_id = original_run_id
    final_epoch = checkpoint_epoch + n_epochs
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    if not save_state:
        chkpt_path += '.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    if checkpoint_epoch == 0:
        model = XPDNet(model_fun, model_kwargs, **run_params)
        if original_run_id is not None:
            lr = 1e-7
            n_steps = brain_volumes_per_contrast['train'].get(contrast, n_volumes)//2
        else:
            lr = 1e-4
            n_steps = n_volumes
        default_model_compile(model, lr=lr, loss=loss)
    else:
        model = load_model(
            f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}',
            custom_objects=CUSTOM_TF_OBJECTS,
        )
        n_steps = n_volumes

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=n_epochs*n_steps,
        save_weights_only=not save_state,
    )
    print(run_id)
    if original_run_id is not None and not checkpoint_epoch:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        if multicoil:
            kspace_size = [1, 15, 640, 372]
        else:
            kspace_size = [1, 640, 372]
        inputs = [
            tf.zeros(kspace_size + [1], dtype=tf.complex64),
            tf.zeros(kspace_size, dtype=tf.complex64),
        ]
        if multicoil:
            inputs.append(tf.zeros(kspace_size, dtype=tf.complex64))
        if brain:
            inputs.append(tf.constant([[320, 320]]))
        model(inputs)
        model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        initial_epoch=checkpoint_epoch,
        epochs=final_epoch,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
    'layers_n_non_lins': 2,
}
n_epochs = 300
run_id = f'unet_af{AF}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

chkpt_cback = ModelCheckpoint(chkpt_path, period=100)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir,
    profile_batch=0,
    histogram_freq=0,
    write_graph=True,
    write_images=False,
)
tqdm_cb = TQDMProgressBar()

model = unet(input_size=(320, 320, 1), lr=1e-3, **run_params)
print(model.summary())

model.fit_generator(
    train_gen,
    steps_per_epoch=n_volumes_train,
    epochs=n_epochs,
    validation_data=val_gen,
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback],
    # max_queue_size=100,
    use_multiprocessing=True,
    workers=35,
Ejemplo n.º 7
0
def train_updnet(
    multicoil=True,
    af=4,
    contrast=None,
    cuda_visible_devices='0123',
    n_samples=None,
    n_epochs=200,
    n_iter=10,
    use_mixed_precision=False,
    n_layers=3,
    base_n_filter=16,
    non_linearity='relu',
    channel_attention_kwargs=None,
    refine_smaps=False,
    loss='mae',
    original_run_id=None,
    fixed_masks=False,
):
    # paths
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        kwargs = {'parallel': False}
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(train_path,
                        AF=af,
                        contrast=contrast,
                        inner_slices=None,
                        rand=True,
                        scale_factor=1e6,
                        n_samples=n_samples,
                        fixed_masks=fixed_masks,
                        **kwargs)
    val_set = dataset(val_path,
                      AF=af,
                      contrast=contrast,
                      inner_slices=None,
                      rand=True,
                      scale_factor=1e6,
                      **kwargs)

    run_params = {
        'n_primal': 5,
        'n_dual': 1,
        'primal_only': True,
        'multicoil': multicoil,
        'n_layers': n_layers,
        'layers_n_channels': [base_n_filter * 2**i for i in range(n_layers)],
        'non_linearity': non_linearity,
        'n_iter': n_iter,
        'channel_attention_kwargs': channel_attention_kwargs,
        'refine_smaps': refine_smaps,
    }

    if multicoil:
        updnet_type = 'updnet_sense_'
    else:
        updnet_type = 'updnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if non_linearity != 'relu':
        additional_info += f'_{non_linearity}'
    if n_layers != 3:
        additional_info += f'_l{n_layers}'
    if base_n_filter != 16:
        additional_info += f'_bf{base_n_filter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if channel_attention_kwargs:
        additional_info += '_ca'
    if refine_smaps:
        additional_info += '_rf_sm'
    if fixed_masks:
        additional_info += '_fixed_masks'

    run_id = f'{updnet_type}_{additional_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    chkpt_cback = ModelCheckpoint(chkpt_path,
                                  period=n_epochs,
                                  save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    model = UPDNet(**run_params)
    if original_run_id is not None:
        lr = 1e-7
        n_steps = n_volumes_train // 2
    else:
        lr = 1e-4
        n_steps = n_volumes_train
    default_model_compile(model, lr=lr, loss=loss)
    print(run_id)
    if original_run_id is not None:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        else:
            n_epochs_original = 250
        model.load_weights(
            f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5'
        )

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=2,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def train_ncnet_block(
    model,
    n_iter=10,
    run_id=None,
    multicoil=False,
    three_d=False,
    acq_type='radial',
    scale_factor=1e6,
    dcomp=False,
    contrast=None,
    cuda_visible_devices='0123',
    n_samples=None,
    n_epochs=200,
    use_mixed_precision=False,
    loss='mae',
    original_run_id=None,
    checkpoint_epoch=0,
    save_state=False,
    lr=1e-4,
    block_size=10,
    block_overlap=0,
    epochs_per_block_step=None,
    **acq_kwargs,
):
    # paths
    n_volumes_train = n_volumes_train_fastmri
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    elif three_d:
        train_path = f'{OASIS_DATA_DIR}/train/'
        val_path = f'{OASIS_DATA_DIR}/val/'
        n_volumes_train = n_volumes_train_oasis
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        image_size = IM_SIZE
    elif three_d:
        dataset = three_d_dataset
        image_size = VOLUME_SIZE
    else:
        dataset = singlecoil_dataset
        image_size = IM_SIZE
    if not three_d:
        add_kwargs = {
            'contrast': contrast,
            'rand': True,
            'inner_slices': None,
        }
    else:
        add_kwargs = {}
    add_kwargs.update(**acq_kwargs)
    train_set = dataset(train_path,
                        image_size,
                        acq_type=acq_type,
                        compute_dcomp=dcomp,
                        scale_factor=scale_factor,
                        n_samples=n_samples,
                        **add_kwargs)
    val_set = dataset(val_path,
                      image_size,
                      acq_type=acq_type,
                      compute_dcomp=dcomp,
                      scale_factor=scale_factor,
                      **add_kwargs)

    additional_info = f'{acq_type}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if dcomp:
        additional_info += '_dcomp'
    if block_overlap != 0:
        additional_info += f'_blkov{block_overlap}'
    if checkpoint_epoch == 0:
        run_id = f'{run_id}_bbb_{additional_info}_{int(time.time())}'
    else:
        run_id = original_run_id
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    n_steps = n_volumes_train

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=int(epochs_per_block_step * n_steps),
        save_optimizer=False,
        save_weights_only=True,
    )
    print(run_id)
    # if there are 4 blocks, with a block size of 2 and a block overlap of 1
    # we do the following block combinations:
    # 01, 12, 23 -> n block steps = 3
    # if there are 6 blocks with a block size 3 and a block overlap of 2:
    # 012, 123, 234, 345 -> n = 4
    # if there are 6 blocks with a block size 3 and a block overlap of 1:
    # 012, 234, 456 -> n = 3
    stride = block_size - block_overlap
    assert stride > 0
    n_block_steps = int(math.ceil((n_iter - block_size) / stride) + 1)
    ## epochs handling
    restart_at_block_step = checkpoint_epoch // epochs_per_block_step
    start_epoch = checkpoint_epoch
    final_epoch = checkpoint_epoch + min(epochs_per_block_step, n_epochs)
    for i_step in range(n_block_steps):
        if i_step < restart_at_block_step:
            continue
        first_block_to_train = i_step * stride
        blocks = list(
            range(first_block_to_train, first_block_to_train + block_size))
        model.blocks_to_train = blocks
        default_model_compile(model, lr=lr, loss=loss)
        # first run of the model to avoid the saving error
        # ValueError: as_list() is not defined on an unknown TensorShape.
        # it can also allow loading of weights
        model(next(iter(train_set))[0])
        if not checkpoint_epoch == 0 and i_step == restart_at_block_step:
            model.load_weights(
                f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}.hdf5'
            )
            if not checkpoint_epoch % epochs_per_block_step == 0:
                grad_vars = model.trainable_weights
                zero_grads = [tf.zeros_like(w) for w in grad_vars]
                model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
                with open(
                        f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-optimizer.pkl',
                        'rb') as f:
                    weight_values = pickle.load(f)
                model.optimizer.set_weights(weight_values)
        model.fit(
            train_set,
            steps_per_epoch=n_steps,
            initial_epoch=start_epoch,
            epochs=final_epoch,
            validation_data=val_set,
            validation_steps=5,
            verbose=0,
            callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
        )
        n_epochs = n_epochs - (final_epoch - start_epoch)
        if n_epochs <= 0:
            break
        start_epoch = final_epoch
        final_epoch += min(epochs_per_block_step, n_epochs)
    if save_state:
        symbolic_weights = getattr(model.optimizer, 'weights')
        weight_values = K.batch_get_value(symbolic_weights)
        with open(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-optimizer.pkl',
                  'wb') as f:
            pickle.dump(weight_values, f)
    return run_id
Ejemplo n.º 9
0
def train_xpdnet(
        model_fun,
        model_kwargs,
        model_size=None,
        multicoil=True,
        af=4,
        contrast=None,
        cuda_visible_devices='0123',
        n_samples=None,
        n_epochs=200,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        use_mixed_precision=False,
        refine_smaps=False,
        loss='mae',
        original_run_id=None,
        fixed_masks=False,
    ):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        cuda_visible_devices (str): the GPUs to consider visible. Defaults to
            '0123'.
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False

    Returns:
        - str: the run id of the trained network.
    """
    # paths
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'


    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        kwargs = {'parallel': False}
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(
        train_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        n_samples=n_samples,
        fixed_masks=fixed_masks,
        **kwargs
    )
    val_set = dataset(
        val_path,
        AF=af,
        contrast=contrast,
        inner_slices=None,
        rand=True,
        scale_factor=1e6,
        **kwargs
    )

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
    if fixed_masks:
        additional_info += '_fixed_masks'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    model = XPDNet(model_fun, model_kwargs, **run_params)
    if original_run_id is not None:
        lr = 1e-7
        n_steps = n_volumes_train//2
    else:
        lr = 1e-4
        n_steps = n_volumes_train
    default_model_compile(model, lr=lr, loss=loss)
    print(run_id)
    if original_run_id is not None:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        else:
            n_epochs_original = 250
        model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
def train_xpdnet(
        model_fun,
        model_kwargs,
        model_size=None,
        multicoil=True,
        brain=False,
        af=4,
        contrast=None,
        n_samples=None,
        batch_size=None,
        n_epochs=200,
        checkpoint_epoch=0,
        save_state=False,
        n_iter=10,
        res=True,
        n_scales=0,
        n_primal=5,
        use_mixed_precision=False,
        refine_smaps=False,
        refine_big=False,
        loss='mae',
        lr=1e-4,
        original_run_id=None,
        fixed_masks=False,
        n_epochs_original=250,
        equidistant_fake=False,
        multi_gpu=False,
        mask_type=None,
        primal_only=True,
        n_dual=1,
        n_dual_filters=16,
        multiscale_kspace_learning=False,
        distributed=False,
        manual_saving=False,
    ):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        brain (bool): whether to consider brain data instead of knee. Defaults
            to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        checkpoint_epoch (int): the number of epochs used to train the model
            during the first step of the full training. This is typically used
            when on a cluster the training duration exceeds the maximum job
            duration. Defaults to 0, which means that the training is done
            without checkpoints.
        save_state (bool): whether you should save the entire model state for
            this training, for example to retrain where left off. Defaults to
            False.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False
        n_epochs_original (int): the number of epochs used to pre-train the
            model, only applicable if original_run_id is not None. Defaults to
            250.
        equidistant_fake (bool): whether to use fake equidistant masks from
            fastMRI. Defaults to False.
        multi_gpu (bool): whether to use multiple GPUs for the XPDNet training.
            Defaults to False.

    Returns:
        - str: the run id of the trained network.
    """
    if distributed:
        com_options = tf.distribute.experimental.CommunicationOptions(
            implementation=tf.distribute.experimental.CommunicationImplementation.NCCL,
        )
        slurm_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(port_base=15000)
        mirrored_strategy = tf.distribute.MultiWorkerMirroredStrategy(
            cluster_resolver=slurm_resolver,
            communication_options=com_options,
        )
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if multicoil:
        if brain:
            train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    if distributed:
        def _dataset_fn(input_context, mode='train'):
            ds = dataset(
                train_path if mode == 'train' else val_path,
                input_context=input_context,
                AF=af,
                contrast=contrast,
                inner_slices=None,
                rand=True,
                scale_factor=1e6,
                batch_size=batch_size // input_context.num_replicas_in_sync,
                target_image_size=IM_SIZE,
                **kwargs
            )
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
            ds = ds.with_options(options)
            return ds
        train_set = mirrored_strategy.distribute_datasets_from_function(partial(
            _dataset_fn,
            mode='train',
        ))
        val_set = mirrored_strategy.distribute_datasets_from_function(partial(
            _dataset_fn,
            mode='val',
        ))
    else:
        train_set = dataset(
            train_path,
            AF=af,
            contrast=contrast,
            inner_slices=None,
            rand=True,
            scale_factor=1e6,
            n_samples=n_samples,
            fixed_masks=fixed_masks,
            batch_size=batch_size,
            target_image_size=IM_SIZE,
            **kwargs
        )
        val_set = dataset(
            val_path,
            AF=af,
            contrast=contrast,
            inner_slices=None,
            rand=True,
            scale_factor=1e6,
            **kwargs
        )

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'multi_gpu': multi_gpu,
        'refine_big': refine_big,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
        'multiscale_kspace_learning': multiscale_kspace_learning,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
        if brain:
            xpdnet_type += 'brain_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
        if refine_big:
            additional_info += 'b'
    if fixed_masks:
        additional_info += '_fixed_masks'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    if checkpoint_epoch == 0:
        run_id = f'{xpdnet_type}_{additional_info}_{submodel_info}_{int(time.time())}'
    else:
        run_id = original_run_id
    final_epoch = checkpoint_epoch + n_epochs
    if not distributed or slurm_resolver.task_id == 0:
        chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    else:
        chkpt_path = f'{TMP_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    if not save_state or manual_saving:
        chkpt_path += '.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    with ExitStack() as stack:
        # can't be always used because of https://github.com/tensorflow/tensorflow/issues/46146
        if distributed:
            stack.enter_context(mirrored_strategy.scope())
        if checkpoint_epoch == 0:
            model = XPDNet(model_fun, model_kwargs, **run_params)
            if original_run_id is not None:
                lr = 1e-7
                n_steps = brain_volumes_per_contrast['train'].get(contrast, n_volumes)//2
            else:
                n_steps = n_volumes
            default_model_compile(model, lr=lr, loss=loss)
        elif not manual_saving:
            model = load_model(
                f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}',
                custom_objects=CUSTOM_TF_OBJECTS,
            )
            n_steps = n_volumes
        else:
            model = XPDNet(model_fun, model_kwargs, **run_params)
            n_steps = n_volumes
            default_model_compile(model, lr=lr, loss=loss)

    if batch_size is not None:
        n_steps //= batch_size

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=int(n_epochs*n_steps),
        save_weights_only=(not save_state and not distributed) or manual_saving,
    )
    print(run_id)
    if original_run_id is not None and (not checkpoint_epoch or manual_saving):
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        if manual_saving:
            n_epochs_original = checkpoint_epoch
        if multicoil:
            kspace_size = [1, 15, 640, 372]
        else:
            kspace_size = [1, 640, 372]
        inputs = [
            tf.zeros(kspace_size + [1], dtype=tf.complex64),
            tf.zeros(kspace_size, dtype=tf.complex64),
        ]
        if multicoil:
            inputs.append(tf.zeros(kspace_size, dtype=tf.complex64))
        if brain:
            inputs.append(tf.constant([[320, 320]]))
        with ExitStack() as stack:
            if distributed:
                # see https://github.com/tensorflow/tensorflow/issues/32561#issuecomment-544319907
                stack.enter_context(mirrored_strategy.scope())
            model(inputs)
            model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

        if manual_saving:
            def _model_weight_setting():
                grad_vars = model.trainable_weights
                zero_grads = [tf.zeros_like(w) for w in grad_vars]
                model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
                with open(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-optimizer.pkl', 'rb') as f:
                    weight_values = pickle.load(f)
                model.optimizer.set_weights(weight_values)
            if distributed:
                mirrored_strategy.run(_model_weight_setting)
            else:
                _model_weight_setting()

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        initial_epoch=checkpoint_epoch,
        epochs=final_epoch,
        validation_data=val_set,
        validation_steps=5,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )

    if manual_saving:
        symbolic_weights = getattr(model.optimizer, 'weights')
        weight_values = K.batch_get_value(symbolic_weights)
        with open(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-optimizer.pkl', 'wb') as f:
            pickle.dump(weight_values, f)
    return run_id
Ejemplo n.º 11
0
    def TrainEMG(self):
        labels = []

        print("Loading data from disk!")
        prepare_array = np.loadtxt(RESULT_PATH + DATA_PATH + self.subject + '-' + str(self.age) + '.txt')

        # This division is to make the iterator for making labels run 20 times in inner loop and 3 times in outer loop
        # running total 60 times for 3 foot gestures
        samples = 20
        self.number_of_gestures = int(prepare_array.shape[0] / samples)
        print("Preprocess EMG data of ", self.subject, "with ", samples, " samples per", self.number_of_gestures,
              "exercise, training data with a nr. of ",
              self.training_batch_size, "batch size, for a total of ", self.epochs, "epochs.")

        # Now we append all data in training label
        # We iterate to make 3 finger movement labels.
        for i in range(0, self.number_of_gestures):
            for j in range(0, int(samples)):
                labels.append(i)
        labels = np.asarray(labels)
        print("Labels: ", labels, len(labels), type(labels))
        # print(conc_array.shape[0])

        permutation_function = np.random.permutation(prepare_array.shape[0])
        total_samples = prepare_array.shape[0]
        all_shuffled_data, all_shuffled_labels = np.zeros((total_samples, 8)), np.zeros((total_samples, 8))

        all_shuffled_data, all_shuffled_labels = prepare_array[permutation_function], labels[permutation_function]
        # print(all_shuffled_data.shape)
        # print(all_shuffled_labels.shape)

        number_of_training_samples = int(np.floor(0.8 * total_samples))
        number_of_validation_samples = int(total_samples - number_of_training_samples)

        # train_data = np.zeros((number_of_training_samples, 8))
        # train_labels = np.zeros((number_of_training_samples, 8))

        train_data = all_shuffled_data[0:number_of_training_samples, :]
        train_labels = all_shuffled_labels[0:number_of_training_samples, ]
        print("Length of train data is ", train_data.shape)

        validation_data = all_shuffled_data[number_of_training_samples:total_samples, :]
        validation_labels = all_shuffled_labels[number_of_training_samples:total_samples, ]
        # print("Length of validation data is ", validation_data.shape, " validation labels is ",
        # validation_labels.shape)
        # print(train_data, train_labels)

        print("Building model...")
        instructions = "Building model..."
        model = keras.Sequential([
            # Input dimensions means input columns. Here we have 8 columns, one for each sensor
            keras.layers.Dense(8, activation=relu, input_dim=8, kernel_regularizer=regularizers.l2(0.1)),
            keras.layers.BatchNormalization(),
            keras.layers.Dense(self.number_of_gestures, activation=softmax)])

        adam_optimizer = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0,
                                               amsgrad=False)
        model.compile(
            optimizer=adam_optimizer,
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])

        print("Fitting training data to the model...")
        tqdm_callback = TQDMProgressBar(
            show_epoch_progress=False,
            leave_overall_progress=False,
            leave_epoch_progress=False
        )
        history = model.fit(train_data, train_labels, epochs=self.epochs,
                            validation_data=(validation_data, validation_labels),
                            batch_size=self.training_batch_size, verbose=0, callbacks=[tqdm_callback])

        instructions = "Training model successful!"
        print(instructions)

        save_path = RESULT_PATH + MODEL_PATH + self.subject + '-' + str(self.age) + '_model.h5'
        model.save(save_path)
        print("Saving model for later...")

        self.SaveModelHistory(history)
def train_xpdnet_block(
    model_fun,
    model_kwargs,
    model_size=None,
    multicoil=True,
    brain=False,
    af=4,
    contrast=None,
    n_samples=None,
    batch_size=None,
    n_epochs=200,
    n_iter=10,
    res=True,
    n_scales=0,
    n_primal=5,
    use_mixed_precision=False,
    refine_smaps=False,
    refine_big=False,
    loss='mae',
    lr=1e-4,
    fixed_masks=False,
    equidistant_fake=False,
    multi_gpu=False,
    mask_type=None,
    primal_only=True,
    n_dual=1,
    n_dual_filters=16,
    multiscale_kspace_learning=False,
    block_size=10,
    block_overlap=0,
    epochs_per_block_step=None,
):
    r"""Train an XPDNet network on the fastMRI dataset.

    The training is done with a learning rate of 1e-4, using the RAdam optimizer.
    The validation is performed every 5 epochs on 5 volumes.
    A scale factor of 1e6 is applied to the data.

    Arguments:
        model_fun (function): the function initializing the image correction
            network of the XPDNet.
        model_kwargs (dict): the set of arguments used to initialize the image
            correction network.
        model_size (str or None): a string describing the size of the network
            used. This is used in the run id. Defaults to None.
        multicoil (bool): whether the input data is multicoil. Defaults to False.
        brain (bool): whether to consider brain data instead of knee. Defaults
            to False.
        af (int): the acceleration factor for the retrospective undersampling
            of the data. Defaults to 4.
        contrast (str or None): the contrast used for this specific training.
            If None, all contrasts are considered. Defaults to None
        n_samples (int or None): the number of samples to consider for this
            training. If None, all samples are considered. Defaults to None.
        n_epochs (int): the number of epochs (i.e. one pass though all the
            volumes/samples) for this training. Defaults to 200.
        checkpoint_epoch (int): the number of epochs used to train the model
            during the first step of the full training. This is typically used
            when on a cluster the training duration exceeds the maximum job
            duration. Defaults to 0, which means that the training is done
            without checkpoints.
        save_state (bool): whether you should save the entire model state for
            this training, for example to retrain where left off. Defaults to
            False.
        n_iter (int): the number of iterations for the XPDNet.
        res (bool): whether the XPDNet image correction networks should be
            residual.
        n_scales (int): the number of scales used in the image correction
            network. Defaults to 0.
        n_primal (int): the size of the buffer in the image space. Defaults to
            5.
        use_mixed_precision (bool): whether to use the mixed precision API for
            training. Currently not working. Defaults to False.
        refine_smaps (bool): whether you want to refine the sensitivity maps
            with a neural network.
        loss (tf.keras.losses.Loss or str): the loss function used for the
            training. It should be understandable by the tf.keras loss API,
            or be 'compound_mssim', in which case the compound L1 MSSIM loss
            inspired by [P2020]. Defaults to 'mae'.
        original_run_id (str or None): run id of the same network trained before
            fine-tuning. If this is present, the training is considered
            fine-tuning for a network trained for 250 epochs. It will therefore
            apply a learning rate of 1e-7 and the epoch size will be divided in
            half. If None, the training is done normally, without loading
            weights. Defaults to None.
        fixed_masks (bool): whether fixed masks should be used for the
            retrospective undersampling. Defaults to False
        n_epochs_original (int): the number of epochs used to pre-train the
            model, only applicable if original_run_id is not None. Defaults to
            250.
        equidistant_fake (bool): whether to use fake equidistant masks from
            fastMRI. Defaults to False.
        multi_gpu (bool): whether to use multiple GPUs for the XPDNet training.
            Defaults to False.

    Returns:
        - str: the run id of the trained network.
    """
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if multicoil:
        if brain:
            train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
        else:
            train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
            val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    af = int(af)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        if mask_type is None:
            if brain:
                if equidistant_fake:
                    mask_type = 'equidistant_fake'
                else:
                    mask_type = 'equidistant'
            else:
                mask_type = 'random'
        kwargs = {
            'parallel': False,
            'output_shape_spec': brain,
            'mask_type': mask_type,
        }
    else:
        dataset = singlecoil_dataset
        kwargs = {}
    train_set = dataset(train_path,
                        AF=af,
                        contrast=contrast,
                        inner_slices=None,
                        rand=True,
                        scale_factor=1e6,
                        n_samples=n_samples,
                        fixed_masks=fixed_masks,
                        batch_size=batch_size,
                        target_image_size=IM_SIZE,
                        **kwargs)
    val_set = dataset(val_path,
                      AF=af,
                      contrast=contrast,
                      inner_slices=None,
                      rand=True,
                      scale_factor=1e6,
                      **kwargs)

    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': refine_smaps,
        'res': res,
        'output_shape_spec': brain,
        'multi_gpu': multi_gpu,
        'refine_big': refine_big,
        'primal_only': primal_only,
        'n_dual': n_dual,
        'n_dual_filters': n_dual_filters,
        'multiscale_kspace_learning': multiscale_kspace_learning,
    }

    if multicoil:
        xpdnet_type = 'xpdnet_sense_'
        if brain:
            xpdnet_type += 'brain_'
    else:
        xpdnet_type = 'xpdnet_singlecoil_'
    additional_info = f'af{af}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if refine_smaps:
        additional_info += '_rf_sm'
        if refine_big:
            additional_info += 'b'
    if fixed_masks:
        additional_info += '_fixed_masks'
    if block_overlap != 0:
        additional_info += f'_blkov{block_overlap}'

    submodel_info = model_fun.__name__
    if model_size is not None:
        submodel_info += model_size
    run_id = f'{xpdnet_type}_{additional_info}_bbb_{submodel_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}'
    chkpt_path += '.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    model = XPDNet(model_fun, model_kwargs, **run_params)
    n_steps = n_volumes

    if batch_size is not None:
        n_steps //= batch_size

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=int(n_epochs * n_steps),
        save_weights_only=True,
    )
    print(run_id)
    stride = block_size - block_overlap
    assert stride > 0
    n_block_steps = int(math.ceil((n_iter - block_size) / stride) + 1)
    ## epochs handling
    start_epoch = 0
    final_epoch = min(epochs_per_block_step, n_epochs)

    for i_step in range(n_block_steps):
        first_block_to_train = i_step * stride
        blocks = list(
            range(first_block_to_train, first_block_to_train + block_size))
        model.blocks_to_train = blocks
        default_model_compile(model, lr=lr, loss=loss)

        model.fit(
            train_set,
            steps_per_epoch=n_steps,
            initial_epoch=start_epoch,
            epochs=final_epoch,
            validation_data=val_set,
            validation_steps=5,
            validation_freq=5,
            verbose=0,
            callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
        )
        n_epochs = n_epochs - (final_epoch - start_epoch)
        if n_epochs <= 0:
            break
        start_epoch = final_epoch
        final_epoch += min(epochs_per_block_step, n_epochs)
    return run_id
Ejemplo n.º 13
0
def train_ncnet(
        model,
        run_id=None,
        multicoil=False,
        three_d=False,
        acq_type='radial',
        scale_factor=1e6,
        dcomp=False,
        contrast=None,
        cuda_visible_devices='0123',
        n_samples=None,
        n_epochs=200,
        use_mixed_precision=False,
        loss='mae',
        original_run_id=None,
        **acq_kwargs,
    ):
    # paths
    n_volumes_train = n_volumes_train_fastmri
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    elif three_d:
        train_path = f'{OASIS_DATA_DIR}/train/'
        val_path = f'{OASIS_DATA_DIR}/val/'
        n_volumes_train = n_volumes_train_oasis
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'


    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        image_size = IM_SIZE
    elif three_d:
        dataset = three_d_dataset
        image_size = VOLUME_SIZE
    else:
        dataset = singlecoil_dataset
        image_size = IM_SIZE
    if not three_d:
        add_kwargs = {
            'contrast': contrast,
            'rand': True,
            'inner_slices': None,
        }
    else:
        add_kwargs = {}
    add_kwargs.update(**acq_kwargs)
    train_set = dataset(
        train_path,
        image_size,
        acq_type=acq_type,
        compute_dcomp=dcomp,
        scale_factor=scale_factor,
        n_samples=n_samples,
        **add_kwargs
    )
    val_set = dataset(
        val_path,
        image_size,
        acq_type=acq_type,
        compute_dcomp=dcomp,
        scale_factor=scale_factor,
        **add_kwargs
    )

    additional_info = f'{acq_type}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if dcomp:
        additional_info += '_dcomp'
    run_id = f'{run_id}_{additional_info}_{int(time.time())}'
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=True)
    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    if original_run_id is not None:
        lr = 1e-7
        n_steps = n_volumes_train//2
    else:
        lr = 1e-4
        n_steps = n_volumes_train
    default_model_compile(model, lr=lr, loss=loss)
    print(run_id)
    if original_run_id is not None:
        if os.environ.get('FASTMRI_DEBUG'):
            n_epochs_original = 1
        else:
            n_epochs_original = 250
        model.load_weights(f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{n_epochs_original:02d}.hdf5')

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=2,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id
Ejemplo n.º 14
0
def train_ncnet(
    model,
    run_id=None,
    multicoil=False,
    three_d=False,
    acq_type='radial',
    scale_factor=1e6,
    dcomp=False,
    contrast=None,
    cuda_visible_devices='0123',
    n_samples=None,
    n_epochs=200,
    use_mixed_precision=False,
    loss='mae',
    original_run_id=None,
    checkpoint_epoch=0,
    save_state=False,
    lr=1e-4,
    **acq_kwargs,
):
    # paths
    n_volumes_train = n_volumes_train_fastmri
    if multicoil:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'
    elif three_d:
        train_path = f'{OASIS_DATA_DIR}/train/'
        val_path = f'{OASIS_DATA_DIR}/val/'
        n_volumes_train = n_volumes_train_oasis
    else:
        train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    if multicoil:
        dataset = multicoil_dataset
        image_size = IM_SIZE
    elif three_d:
        dataset = three_d_dataset
        image_size = VOLUME_SIZE
    else:
        dataset = singlecoil_dataset
        image_size = IM_SIZE
    if not three_d:
        add_kwargs = {
            'contrast': contrast,
            'rand': True,
            'inner_slices': None,
        }
    else:
        add_kwargs = {}
    add_kwargs.update(**acq_kwargs)
    train_set = dataset(train_path,
                        image_size,
                        acq_type=acq_type,
                        compute_dcomp=dcomp,
                        scale_factor=scale_factor,
                        n_samples=n_samples,
                        **add_kwargs)
    val_set = dataset(val_path,
                      image_size,
                      acq_type=acq_type,
                      compute_dcomp=dcomp,
                      scale_factor=scale_factor,
                      **add_kwargs)

    additional_info = f'{acq_type}'
    if contrast is not None:
        additional_info += f'_{contrast}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if dcomp:
        additional_info += '_dcomp'
    if checkpoint_epoch == 0:
        run_id = f'{run_id}_{additional_info}_{int(time.time())}'
    else:
        run_id = original_run_id
    final_epoch = checkpoint_epoch + n_epochs
    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    n_steps = n_volumes_train

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=int(n_epochs * n_steps),
        save_weights_only=True,
    )
    default_model_compile(model, lr=lr, loss=loss)
    # first run of the model to avoid the saving error
    # ValueError: as_list() is not defined on an unknown TensorShape.
    # it can also allow loading of weights
    model(next(iter(train_set))[0])
    if not checkpoint_epoch == 0:
        model.load_weights(
            f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-{checkpoint_epoch:02d}.hdf5'
        )
        grad_vars = model.trainable_weights
        zero_grads = [tf.zeros_like(w) for w in grad_vars]
        model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
        with open(
                f'{CHECKPOINTS_DIR}checkpoints/{original_run_id}-optimizer.pkl',
                'rb') as f:
            weight_values = pickle.load(f)
        model.optimizer.set_weights(weight_values)
    print(run_id)

    model.fit(
        train_set,
        steps_per_epoch=n_steps,
        initial_epoch=checkpoint_epoch,
        epochs=final_epoch,
        validation_data=val_set,
        validation_steps=2,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    if save_state:
        symbolic_weights = getattr(model.optimizer, 'weights')
        weight_values = K.batch_get_value(symbolic_weights)
        with open(f'{CHECKPOINTS_DIR}checkpoints/{run_id}-optimizer.pkl',
                  'wb') as f:
            pickle.dump(weight_values, f)
    return run_id
def train_vnet_postproc(
        original_run_id,
        af=4,
        brain=False,
        n_samples=None,
        n_epochs=200,
        use_mixed_precision=False,
        loss='mae',
        lr=1e-4,
        base_n_filters=16,
        n_scales=4,
        non_linearity='prelu',
    ):
    if brain:
        n_volumes = brain_n_volumes_train
    else:
        n_volumes = n_volumes_train
    # paths
    if brain:
        train_path = f'{FASTMRI_DATA_DIR}brain_multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}brain_multicoil_val/'
    else:
        train_path = f'{FASTMRI_DATA_DIR}multicoil_train/'
        val_path = f'{FASTMRI_DATA_DIR}multicoil_val/'

    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    policy = mixed_precision.Policy(policy_type)
    mixed_precision.set_policy(policy)
    # generators
    train_set = train_postproc_dataset_from_tfrecords(
        train_path,
        original_run_id,
        n_samples=n_samples,
    )
    val_set = train_postproc_dataset_from_tfrecords(
        val_path,
        original_run_id,
        n_samples=n_samples,
    )
    run_params = dict(
        layers_n_channels=[base_n_filters*2**i for i in range(n_scales)],
        layers_n_non_lins=2,
        non_linearity=non_linearity,
        res=True,
    )
    model = PostProcessVnet(None, run_params)
    default_model_compile(model, lr=lr, loss=loss)

    vnet_type = 'vnet_postproc_'
    if brain:
        vnet_type += 'brain_'
    additional_info = f'af{af}'
    if n_samples is not None:
        additional_info += f'_{n_samples}'
    if loss != 'mae':
        additional_info += f'_{loss}'
    if base_n_filters != 16:
        additional_info += f'_bf{base_n_filters}'
    if n_scales != 4:
        additional_info += f'_sc{n_scales}'
    if non_linearity != 'prelu':
        additional_info += f'_{non_linearity}'
    run_id = f'{vnet_type}_{additional_info}_{int(time.time())}'

    chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

    log_dir = op.join(f'{LOGS_DIR}logs', run_id)
    tboard_cback = TensorBoard(
        profile_batch=0,
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=False,
        write_images=False,
    )
    tqdm_cback = TQDMProgressBar()

    chkpt_cback = ModelCheckpointWorkAround(
        chkpt_path,
        save_freq=n_epochs*n_volumes,
        save_weights_only=True,
    )
    print(run_id)


    model.fit(
        train_set,
        steps_per_epoch=n_volumes,
        epochs=n_epochs,
        validation_data=val_set,
        validation_steps=10,
        validation_freq=5,
        verbose=0,
        callbacks=[tboard_cback, chkpt_cback, tqdm_cback],
    )
    return run_id