def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.fsl_bval, args.fsl_bvec])
    assert_outputs_exist(parser, args, [args.mrtrix_enc])

    fsl2mrtrix(args.fsl_bval, args.fsl_bvec, args.mrtrix_enc)
Example #2
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_gradients_filenames_valid(parser, [args.fsl_bval, args.fsl_bvec],
                                     'fsl')
    assert_gradients_filenames_valid(parser, args.mrtrix_enc, 'mrtrix')
    assert_inputs_exist(parser, [args.fsl_bval, args.fsl_bvec])
    assert_outputs_exist(parser, args, args.mrtrix_enc)

    fsl2mrtrix(args.fsl_bval, args.fsl_bvec, args.mrtrix_enc)
Example #3
0
def test_bvec_bval_tools():
    file_path = ""
    f_original_fsl_bval = file_path + "data/bval"
    f_original_fsl_bvec = file_path + "data/bvec"
    f_original_mrtrix_encoding = file_path + "data/encoding.b"
    f_original_dmri_bval = file_path + "data/b.txt"
    f_original_dmri_bvec = file_path + "data/grad.txt"

    f_generated_fsl_bval = file_path + "data/gen-bval"
    f_generated_fsl_bvec = file_path + "data/gen-bvec"
    f_generated_mrtrix_encoding = file_path + "data/gen-encoding.b"
    f_generated_temp_file1 = file_path + "data/temp_file1"
    f_generated_temp_file2 = file_path + "data/temp_file2"

    original_fsl_bval = np.loadtxt(f_original_fsl_bval)
    original_fsl_bvec = np.loadtxt(f_original_fsl_bvec)
    original_mrtrix_encoding = np.loadtxt(f_original_mrtrix_encoding)
    original_dmri_bval = np.loadtxt(f_original_dmri_bval)
    original_dmri_bvec = np.loadtxt(f_original_dmri_bvec)

    #dmri2fsl(f_original_dmri_bval, f_original_dmri_bvec, f_generated_fsl_bval, f_generated_fsl_bvec)
    #generated_fsl_bval = np.loadtxt(f_generated_fsl_bval)
    #generated_fsl_bvec = np.loadtxt(f_generated_fsl_bvec)
    #assert_array_equal(original_fsl_bval, generated_fsl_bval)
    #assert_array_equal(original_fsl_bvec, generated_fsl_bvec)

    mrtrix2fsl(f_original_mrtrix_encoding, f_generated_fsl_bval, f_generated_fsl_bvec)
    fsl2mrtrix(f_generated_fsl_bval, f_generated_fsl_bvec, f_generated_mrtrix_encoding)

    generated_mrtrix_encoding = np.loadtxt(f_generated_mrtrix_encoding)

    assert_array_equal(original_mrtrix_encoding, generated_mrtrix_encoding)

    dmri2fsl(f_original_dmri_bval, f_original_dmri_bvec, f_generated_fsl_bval, f_generated_fsl_bvec)
    dmri2mrtrix(f_original_dmri_bval, f_original_dmri_bvec, f_generated_mrtrix_encoding)
    fsl2mrtrix(f_generated_fsl_bval, f_generated_fsl_bvec, f_generated_temp_file1)

    dmri_fsl_mrtrix = np.loadtxt(f_generated_mrtrix_encoding)
    dmri_mrtrix = np.loadtxt(f_generated_temp_file1)

    assert_array_equal(dmri_fsl_mrtrix, dmri_mrtrix)

    #generated_fsl_bval = np.loadtxt(f_generated_fsl_bval)
    #generated_fsl_bvec = np.loadtxt(f_generated_fsl_bvec)
    #assert_array_equal(original_fsl_bval, generated_fsl_bval)
    #assert_array_equal(original_fsl_bvec, generated_fsl_bvec)


    return
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    if not os.path.exists(args.fsl_bval):
        parser.error('"{0}"'.format(args.fsl_bval) +
                     " doesn't exist. Please enter an existing file.")

    if not os.path.exists(args.fsl_bvec):
        parser.error('"{0}"'.format(args.fsl_bvec) +
                     " doesn't exist. Please enter an existing file.")

    if os.path.exists(args.mrtrix_enc):
        if args.isForce:
            print('Overwriting "{0}".'.format(args.mrtrix_enc))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(
                args.mrtrix_enc))

    fsl2mrtrix(args.fsl_bval, args.fsl_bvec, args.mrtrix_enc)
Example #5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
                        args.in_mask)
    assert_output_dirs_exist_and_empty(parser, args,
                                       os.path.join(args.out_dir, 'NODDI'),
                                       optional=args.save_kernels)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    # Generage a scheme file from the bvals and bvecs files
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Compute NODDI with AMICO on {} shells at found '
                  'at {}.'.format(len(shells_centroids), shells_centroids))

    with redirected_stdout:
        # Load the data
        amico.core.setup()
        ae = amico.Evaluation('.', '.')
        ae.load_data(args.in_dwi,
                     tmp_scheme_filename,
                     mask_filename=args.in_mask)
        # Compute the response functions
        ae.set_model("NODDI")

        intra_vol_frac = np.linspace(0.1, 0.99, 12)
        intra_orient_distr = np.hstack((np.array([0.03, 0.06]),
                                        np.linspace(0.09, 0.99, 10)))

        ae.model.set(args.para_diff, args.iso_diff,
                     intra_vol_frac, intra_orient_distr,
                     False)
        ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
            regenerate_kernels = True

        ae.set_config('ATOMS_path', kernels_dir)
        out_model_dir = os.path.join(args.out_dir, ae.model.id)
        ae.set_config('OUTPUT_path', out_model_dir)
        ae.generate_kernels(regenerate=regenerate_kernels)
        ae.load_kernels()

        # Set number of processes
        solver_params = ae.get_config('solver_params')
        solver_params['numThreads'] = args.nbr_processes
        ae.set_config('solver_params', solver_params)

        # Model fit
        ae.fit()
        # Save the results
        ae.save_results()

    tmp_dir.cleanup()
Example #6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
                        args.mask)

    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    # Generage a scheme file from the bvals and bvecs files
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug(
        'Compute FreeWater with AMICO on {} shells at found at {}.'.format(
            len(shells_centroids), shells_centroids))

    with redirected_stdout:
        amico.core.setup()
        # Load the data
        ae = amico.Evaluation('.', '.')
        # Load the data
        ae.load_data(args.in_dwi,
                     scheme_filename=tmp_scheme_filename,
                     mask_filename=args.mask)

        # Compute the response functions
        ae.set_model("FreeWater")
        model_type = 'Human'
        if args.mouse:
            model_type = 'Mouse'

        ae.model.set(args.para_diff,
                     np.linspace(args.perp_diff_min, args.perp_diff_max, 10),
                     [args.iso_diff], model_type)

        ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
            regenerate_kernels = True

        ae.set_config('ATOMS_path', kernels_dir)
        ae.set_config('OUTPUT_path', args.out_dir)
        ae.generate_kernels(regenerate=regenerate_kernels)
        if args.compute_only:
            return

        ae.load_kernels()

        # Set number of processes
        solver_params = ae.get_config('solver_params')
        solver_params['numThreads'] = args.nbr_processes
        ae.set_config('solver_params', solver_params)

        ae.set_config('doNormalizeSignal', True)
        ae.set_config('doKeepb0Intact', False)
        ae.set_config('doComputeNRMSE', True)
        ae.set_config('doSaveCorrectedDWI', True)

        # Model fit
        ae.fit()
        # Save the results
        ae.save_results()

    tmp_dir.cleanup()
Example #7
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec],
        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    if args.commit2:
        if os.path.splitext(args.in_tractogram)[1] != '.h5':
            parser.error('COMMIT2 requires .h5 file for connectomics.')
        args.ball_stick = True

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    hdf5_file = None
    offsets_list = None
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(
                hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        bundle_groups_len = []
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
            streamlines.extend(tmp_streamlines)
            bundle_groups_len.append(len(tmp_streamlines))

        offsets_list = np.cumsum([0] + bundle_groups_len)
        sft = StatefulTractogram(streamlines,
                                 args.in_dwi,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids), shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001 *
                     10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        use_mask = args.in_tracking_mask is not None
        mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/'))
        tol_fun = 1e-2 if args.commit2 else 1e-3
        mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False)
        mit.save_results()
        _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                              'commit_1/', False)

        if args.commit2:
            tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0)
            group_idx = np.array(
                [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)])
            group_w = np.empty_like(bundle_groups_len, dtype=np.float64)
            for k in range(len(bundle_groups_len)):
                group_w[k] = np.sqrt(bundle_groups_len[k]) / \
                    (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12)
            prior_on_bundles = commit.solvers.init_regularisation(
                mit,
                structureIC=group_idx,
                weightsIC=group_w,
                regnorms=[
                    commit.solvers.group_sparsity, commit.solvers.non_negative,
                    commit.solvers.non_negative
                ],
                lambdas=[args.lambda_commit_2, 0.0, 0.0])
            mit.fit(tol_fun=1e-3,
                    max_iter=args.nbr_iter,
                    regularisation=prior_on_bundles,
                    verbose=False)
            mit.save_results()
            _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                                  'commit_2/', True)

    tmp_dir.cleanup()
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi,
                                 args.in_bval, args.in_bvec],
                        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       optional=args.save_kernels)

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with '
                     'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram,
                                                  dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    if args.threshold_weights == 'None' or args.threshold_weights == 'none':
        args.threshold_weights = None
        if not args.keep_whole_tractogram and ext != '.h5':
            logging.warning('Not thresholding weigth with trk file without '
                            'the --keep_whole_tractogram will not save a '
                            'tractogram')
    else:
        args.threshold_weights = float(args.threshold_weights)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine,
                            atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        offsets_list = [0]
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file,
                                                                key)
            offsets_list.append(len(tmp_streamlines))
            streamlines.extend(tmp_streamlines)

        offsets_list = np.cumsum(offsets_list)

        sft = StatefulTractogram(streamlines, args.in_dwi,
                                 Space.VOX, origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals, args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids),
        shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           gen_trk=False,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        mit.load_dictionary(tmp_dir.name,
                            use_mask=args.in_tracking_mask is not None)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=tmp_dir.name)
        mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0)
        mit.save_results()

    # Simplifying output for streamlines and cleaning output directory
    commit_results_dir = os.path.join(tmp_dir.name,
                                      'Results_StickZeppelinBall')
    pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb')
    commit_output_dict = pickle.load(pk_file)
    nbr_streamlines = lazy_streamlines_count(args.in_tractogram)
    commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines])
    np.savetxt(os.path.join(commit_results_dir,
                            'commit_weights.txt'),
               commit_weights)

    if ext == '.h5':
        new_filename = os.path.join(commit_results_dir,
                                    'decompose_commit.h5')
        with h5py.File(new_filename, 'w') as new_hdf5_file:
            new_hdf5_file.attrs['affine'] = sft.affine
            new_hdf5_file.attrs['dimensions'] = sft.dimensions
            new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes
            new_hdf5_file.attrs['voxel_order'] = sft.voxel_order
            # Assign the weights into the hdf5, while respecting the ordering of
            # connections/streamlines
            logging.debug('Adding commit weights to {}.'.format(new_filename))
            for i, key in enumerate(hdf5_keys):
                new_group = new_hdf5_file.create_group(key)
                old_group = hdf5_file[key]
                tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]]
                if args.threshold_weights is not None:
                    essential_ind = np.where(
                        tmp_commit_weights > args.threshold_weights)[0]
                    tmp_streamlines = reconstruct_streamlines(old_group['data'],
                                                              old_group['offsets'],
                                                              old_group['lengths'],
                                                              indices=essential_ind)

                    # Replacing the data with the one above the threshold
                    # Safe since this hdf5 was a copy in the first place
                    new_group.create_dataset('data',
                                             data=tmp_streamlines.get_data(),
                                             dtype=np.float32)
                    new_group.create_dataset('offsets',
                                             data=tmp_streamlines._offsets,
                                             dtype=np.int64)
                    new_group.create_dataset('lengths',
                                             data=tmp_streamlines._lengths,
                                             dtype=np.int32)

                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        new_group.create_dataset(key,
                                                 data=hdf5_file[key][dps_key])
                new_group.create_dataset('commit_weights',
                                         data=tmp_commit_weights)

    files = os.listdir(commit_results_dir)
    for f in files:
        shutil.move(os.path.join(commit_results_dir, f), args.out_dir)

    # Save split tractogram (essential/nonessential) and/or saving the
    # tractogram with data_per_streamline updated
    if args.keep_whole_tractogram or args.threshold_weights is not None:
        # Reload is needed because of COMMIT handling its file by itself
        tractogram_file = nib.streamlines.load(args.in_tractogram)
        tractogram = tractogram_file.tractogram
        tractogram.data_per_streamline['commit_weights'] = commit_weights

        if args.threshold_weights is not None:
            essential_ind = np.where(
                commit_weights > args.threshold_weights)[0]
            nonessential_ind = np.where(
                commit_weights <= args.threshold_weights)[0]
            logging.debug('{} essential streamlines were kept at '
                          'threshold {}'.format(len(essential_ind),
                                                args.threshold_weights))
            logging.debug('{} nonessential streamlines were kept at '
                          'threshold {}'.format(len(nonessential_ind),
                                                args.threshold_weights))

            # TODO PR when Dipy 1.2 is out with sft slicing
            essential_streamlines = tractogram.streamlines[essential_ind]
            essential_dps = tractogram.data_per_streamline[essential_ind]
            essential_dpp = tractogram.data_per_point[essential_ind]
            essential_tractogram = Tractogram(essential_streamlines,
                                              data_per_point=essential_dpp,
                                              data_per_streamline=essential_dps,
                                              affine_to_rasmm=np.eye(4))

            nonessential_streamlines = tractogram.streamlines[nonessential_ind]
            nonessential_dps = tractogram.data_per_streamline[nonessential_ind]
            nonessential_dpp = tractogram.data_per_point[nonessential_ind]
            nonessential_tractogram = Tractogram(nonessential_streamlines,
                                                 data_per_point=nonessential_dpp,
                                                 data_per_streamline=nonessential_dps,
                                                 affine_to_rasmm=np.eye(4))

            nib.streamlines.save(essential_tractogram,
                                 os.path.join(args.out_dir,
                                              'essential_tractogram.trk'),
                                 header=tractogram_file.header)
            nib.streamlines.save(nonessential_tractogram,
                                 os.path.join(args.out_dir,
                                              'nonessential_tractogram.trk'),
                                 header=tractogram_file.header,)
        if args.keep_whole_tractogram:
            output_filename = os.path.join(args.out_dir, 'tractogram.trk')
            logging.debug('Saving tractogram with weights as {}'.format(
                output_filename))
            nib.streamlines.save(tractogram_file, output_filename)

    tmp_dir.cleanup()