Exemplo n.º 1
0
def main(args=None):

    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
    else:
        arguments = parser.parse_args(
            args=None if sys.argv[1:] else ['--help'])

    fname_in = arguments.bvec
    fname_out = arguments.o
    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level

    # get bvecs in proper orientation
    from dipy.io import read_bvals_bvecs
    bvals, bvecs = read_bvals_bvecs(None, fname_in)

    # # Transpose bvecs
    # printv('Transpose bvecs...', verbose)
    # # from numpy import transpose
    # bvecs = bvecs.transpose()

    # Write new file
    if fname_out == '':
        path_in, file_in, ext_in = extract_fname(fname_in)
        fname_out = path_in + file_in + ext_in
    fid = open(fname_out, 'w')
    for iLine in range(bvecs.shape[0]):
        fid.write(' '.join(str(i) for i in bvecs[iLine, :]) + '\n')
    fid.close()

    # display message
    printv('Created file:\n--> ' + fname_out + '\n', verbose, 'info')
Exemplo n.º 2
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initialization
    file_mask = ''

    # Get parser info
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])
    fname_in = arguments.i
    fname_bvals = arguments.bval
    fname_bvecs = arguments.bvec
    prefix = arguments.o
    method = arguments.method
    evecs = arguments.evecs
    if arguments.m is not None:
        file_mask = arguments.m

    # compute DTI
    if not compute_dti(fname_in, fname_bvals, fname_bvecs, prefix, method,
                       evecs, file_mask, verbose):
        printv('ERROR in compute_dti()', 1, 'error')
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_in = arguments.bvec
    fname_out = arguments.o

    # get bvecs in proper orientation
    from dipy.io import read_bvals_bvecs
    bvals, bvecs = read_bvals_bvecs(None, fname_in)

    # # Transpose bvecs
    # printv('Transpose bvecs...', verbose)
    # # from numpy import transpose
    # bvecs = bvecs.transpose()

    # Write new file
    if fname_out == '':
        path_in, file_in, ext_in = extract_fname(fname_in)
        fname_out = path_in + file_in + ext_in
    fid = open(fname_out, 'w')
    for iLine in range(bvecs.shape[0]):
        fid.write(' '.join(str(i) for i in bvecs[iLine, :]) + '\n')
    fid.close()

    # display message
    printv('Created file:\n--> ' + fname_out + '\n', verbose, 'info')
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv if argv else ['--help'])
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    printv('Load data...', verbose)
    nii_mt = Image(arguments.mt)
    nii_pd = Image(arguments.pd)
    nii_t1 = Image(arguments.t1)
    if arguments.b1map is None:
        nii_b1map = None
    else:
        nii_b1map = Image(arguments.b1map)

    if arguments.trmt is None:
        arguments.trmt = fetch_metadata(
            get_json_file_name(arguments.mt, check_exist=True),
            'RepetitionTime')
    if arguments.trpd is None:
        arguments.trpd = fetch_metadata(
            get_json_file_name(arguments.pd, check_exist=True),
            'RepetitionTime')
    if arguments.trt1 is None:
        arguments.trt1 = fetch_metadata(
            get_json_file_name(arguments.t1, check_exist=True),
            'RepetitionTime')
    if arguments.famt is None:
        arguments.famt = fetch_metadata(
            get_json_file_name(arguments.mt, check_exist=True), 'FlipAngle')
    if arguments.fapd is None:
        arguments.fapd = fetch_metadata(
            get_json_file_name(arguments.pd, check_exist=True), 'FlipAngle')
    if arguments.fat1 is None:
        arguments.fat1 = fetch_metadata(
            get_json_file_name(arguments.t1, check_exist=True), 'FlipAngle')

    # compute MTsat
    nii_mtsat, nii_t1map = compute_mtsat(nii_mt,
                                         nii_pd,
                                         nii_t1,
                                         arguments.trmt,
                                         arguments.trpd,
                                         arguments.trt1,
                                         arguments.famt,
                                         arguments.fapd,
                                         arguments.fat1,
                                         nii_b1map=nii_b1map)

    # Output MTsat and T1 maps
    printv('Generate output files...', verbose)
    nii_mtsat.save(arguments.omtsat)
    nii_t1map.save(arguments.ot1map)

    display_viewer_syntax([arguments.omtsat, arguments.ot1map],
                          colormaps=['gray', 'gray'],
                          minmax=['-10,10', '0, 3'],
                          opacities=['1', '1'],
                          verbose=verbose)
Exemplo n.º 5
0
def convert(fname_in, fname_out, squeeze_data=True, dtype=None, verbose=1):
    """
    Convert data
    :return True/False
    """
    printv('sct_convert -i ' + fname_in + ' -o ' + fname_out, verbose, 'code')

    img = image.Image(fname_in)
    img = image.convert(img, squeeze_data=squeeze_data, dtype=dtype)
    img.save(fname_out, mutable=True, verbose=verbose)
Exemplo n.º 6
0
def fs_ok(sig_a, sig_b, exclude=()):
    errors = list()
    for path, data in sig_b.items():
        if path not in sig_a:
            errors.append((path, "added: {}".format(path)))
            continue
        if sig_a[path] != data:
            errors.append((path, "modified: {}".format(path)))
    errors = [(x, y) for (x, y) in errors if not x.startswith(exclude)]
    if errors:
        for error in errors:
            printv("Error: %s", 1, type='error')
        raise RuntimeError()
Exemplo n.º 7
0
def downloaddata(param):
    """
    Download testing data from internet.
    Parameters
    ----------
    param

    Returns
    -------
    None
    """
    printv('\nDownloading testing data...', param.verbose)
    sct_download_data.main(['-d', 'sct_testing_data'])
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv if argv else ['--help'])
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    param.fname_data = arguments.i
    arg = 0
    if arguments.f is not None:
        param.new_size = arguments.f
        param.new_size_type = 'factor'
        arg += 1
    elif arguments.mm is not None:
        param.new_size = arguments.mm
        param.new_size_type = 'mm'
        arg += 1
    elif arguments.vox is not None:
        param.new_size = arguments.vox
        param.new_size_type = 'vox'
        arg += 1
    elif arguments.ref is not None:
        param.ref = arguments.ref
        arg += 1
    else:
        printv(
            parser.error(
                'ERROR: you need to specify one of those three arguments : -f, -mm or -vox'
            ))

    if arg > 1:
        printv(
            parser.error(
                'ERROR: you need to specify ONLY one of those three arguments : -f, -mm or -vox'
            ))

    if arguments.o is not None:
        param.fname_out = arguments.o
    if arguments.x is not None:
        if len(arguments.x) == 1:
            param.interpolation = int(arguments.x)
        else:
            param.interpolation = arguments.x

    spinalcordtoolbox.resampling.resample_file(param.fname_data,
                                               param.fname_out,
                                               param.new_size,
                                               param.new_size_type,
                                               param.interpolation,
                                               param.verbose,
                                               fname_ref=param.ref)
Exemplo n.º 9
0
def run_main():
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])
    param.fname_data = arguments.i
    arg = 0
    if arguments.f is not None:
        param.new_size = arguments.f
        param.new_size_type = 'factor'
        arg += 1
    elif arguments.mm is not None:
        param.new_size = arguments.mm
        param.new_size_type = 'mm'
        arg += 1
    elif arguments.vox is not None:
        param.new_size = arguments.vox
        param.new_size_type = 'vox'
        arg += 1
    elif arguments.ref is not None:
        param.ref = arguments.ref
        arg += 1
    else:
        printv(
            parser.error(
                'ERROR: you need to specify one of those three arguments : -f, -mm or -vox'
            ))

    if arg > 1:
        printv(
            parser.error(
                'ERROR: you need to specify ONLY one of those three arguments : -f, -mm or -vox'
            ))

    if arguments.o is not None:
        param.fname_out = arguments.o
    if arguments.x is not None:
        if len(arguments.x) == 1:
            param.interpolation = int(arguments.x)
        else:
            param.interpolation = arguments.x
    param.verbose = int(arguments.v)
    init_sct(log_level=param.verbose, update=True)  # Update log level

    spinalcordtoolbox.resampling.resample_file(param.fname_data,
                                               param.fname_out,
                                               param.new_size,
                                               param.new_size_type,
                                               param.interpolation,
                                               param.verbose,
                                               fname_ref=param.ref)
def main():
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])
    # initialization
    verbose = 1
    # Get parser info
    di = arguments.di
    da = arguments.da
    db = arguments.db

    # Compute MSCC
    MSCC = mscc(di, da, db)

    # Display results
    printv('\nMSCC = ' + str(MSCC) + '\n', verbose, 'info')
Exemplo n.º 11
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    fname_mtr = arguments.o

    # compute MTR
    printv('\nCompute MTR...', verbose)
    nii_mtr = compute_mtr(nii_mt1=Image(arguments.mt1), nii_mt0=Image(arguments.mt0), threshold_mtr=arguments.thr)

    # save MTR file
    nii_mtr.save(fname_mtr, dtype='float32')

    display_viewer_syntax([arguments.mt0, arguments.mt1, fname_mtr])
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    # Get parser info
    di = arguments.di
    da = arguments.da
    db = arguments.db

    # Compute MSCC
    MSCC = mscc(di, da, db)

    # Display results
    printv('\nMSCC = ' + str(MSCC) + '\n', verbose, 'info')
def main():
    # Check input parameters
    parser = get_parser()
    args = parser.parse_args(args=None if sys.argv[1:] else ['--help'])
    fname_mtr = args.o
    verbose = args.v

    # compute MTR
    printv('\nCompute MTR...', verbose)
    nii_mtr = compute_mtr(nii_mt1=Image(args.mt1),
                          nii_mt0=Image(args.mt0),
                          threshold_mtr=args.thr)

    # save MTR file
    nii_mtr.save(fname_mtr, dtype='float32')

    display_viewer_syntax([args.mt0, args.mt1, fname_mtr])
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # check number of input args
    if not len(arguments.i) == len(arguments.order):
        raise Exception(
            "Number of items between flags '-i' and '-order' should be the same."
        )
    if not len(arguments.bval) == len(arguments.bvec):
        raise Exception(
            "Number of files for bval and bvec should be the same.")

    # Concatenate NIFTI files
    im_list = [Image(fname) for fname in arguments.i]
    im_concat = concat_data(im_list, dim=3, squeeze_data=False)
    im_concat.save(arguments.o)
    printv("Generated file: {}".format(arguments.o))

    # Concatenate bvals and bvecs
    bvals_concat = ''
    bvecs_concat = ['', '', '']
    i_dwi = 0  # counter for DWI files, to read in bvec/bval files
    for i_item in range(len(arguments.order)):
        if arguments.order[i_item] == 'b0':
            # count number of b=0
            n_b0 = Image(arguments.i[i_item]).dim[3]
            bval = np.array([0.0] * n_b0)
            bvec = np.array([[0.0, 0.0, 0.0]] * n_b0)
        elif arguments.order[i_item] == 'dwi':
            # read bval/bvec files
            bval, bvec = read_bvals_bvecs(arguments.bval[i_dwi],
                                          arguments.bvec[i_dwi])
            i_dwi += 1
        # Concatenate bvals
        bvals_concat += ' '.join(str(v) for v in bval)
        bvals_concat += ' '
        # Concatenate bvecs
        for i in (0, 1, 2):
            bvecs_concat[i] += ' '.join(
                str(v) for v in map(lambda n: '%.16f' % n, bvec[:, i]))
            bvecs_concat[i] += ' '
    bvecs_concat = '\n'.join(
        str(v) for v in bvecs_concat)  # transform list into lines of strings
    # Write files
    new_f = open(arguments.obval, 'w')
    new_f.write(bvals_concat)
    new_f.close()
    printv("Generated file: {}".format(arguments.obval))
    new_f = open(arguments.obvec, 'w')
    new_f.write(bvecs_concat)
    new_f.close()
    printv("Generated file: {}".format(arguments.obvec))
Exemplo n.º 15
0
def main(args=None):
    # initialization
    file_mask = ''

    # Get parser info
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])
    fname_in = arguments.i
    fname_bvals = arguments.bval
    fname_bvecs = arguments.bvec
    prefix = arguments.o
    method = arguments.method
    evecs = arguments.evecs
    if arguments.m is not None:
        file_mask = arguments.m
    param.verbose = arguments.v
    init_sct(log_level=param.verbose, update=True)  # Update log level

    # compute DTI
    if not compute_dti(fname_in, fname_bvals, fname_bvecs, prefix, method,
                       evecs, file_mask):
        printv('ERROR in compute_dti()', 1, 'error')
Exemplo n.º 16
0
def main(script_name=''):
    script_name = ''

    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hi:')
    except getopt.GetoptError:
        usage()
    for opt, arg in opts:
        if opt == '-h':
            usage()
        elif opt in '-i':
            script_name = arg

    try:
        script_tested = importlib.import_module(script_name)
        print(script_tested)
    except IOError:
        printv(
            "\nException caught: IOerror, can not import " + script_name +
            '\n', 1, 'warning')
        sys.exit(2)
    except ImportError as e:
        printv("\nException caught: ImportError in " + script_name + '\n', 1,
               'warning')
        print(e)
        sys.exit(2)
    else:
        try:
            sct = script_tested.Param()
        except AttributeError:
            # except IOError:
            print('\nno class param found in script ' + script_name + '\n')
            sys.exit(0)
        else:
            if hasattr(sct, 'debug'):
                if sct.debug == 1:
                    print('\nWarning debug mode on in script ' + script_name +
                          '\n')
                    sys.exit(2)
Exemplo n.º 17
0
def display_voxel(img: Image, verbose: int = 1) -> Sequence[Coordinate]:
    """
    Display all the labels that are contained in the input image.
    :param img: source image
    :param verbose: verbosity level
    """

    coordinates_input = img.getNonZeroCoordinates(sorting='value')
    useful_notation = ''

    for coord in coordinates_input:
        printv('Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value), verbose=verbose)
        if useful_notation:
            useful_notation = useful_notation + ':'
        useful_notation += str(coord)

    printv('All labels (useful syntax):', verbose=verbose)
    printv(useful_notation, verbose=verbose)
Exemplo n.º 18
0
def main(argv: Sequence[str]):
    """
    Main function. When this script is run via CLI, the main function is called using main(sys.argv[1:]).

    :param argv: A list of unparsed arguments, which is passed to ArgumentParser.parse_args()
    """
    for i, arg in enumerate(argv):
        if arg == '-create-seg' and len(argv) > i+1 and '-1,' in argv[i+1]:
            raise DeprecationWarning("The use of '-1' for '-create-seg' has been deprecated. Please use "
                                     "'-create-seg-mid' instead.")

    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    input_filename = arguments.i
    output_fname = arguments.o

    img = Image(input_filename)
    dtype = None

    if arguments.add is not None:
        value = arguments.add
        out = sct_labels.add(img, value)
    elif arguments.create is not None:
        labels = arguments.create
        out = sct_labels.create_labels_empty(img, labels)
    elif arguments.create_add is not None:
        labels = arguments.create_add
        out = sct_labels.create_labels(img, labels)
    elif arguments.create_seg is not None:
        labels = arguments.create_seg
        out = sct_labels.create_labels_along_segmentation(img, labels)
    elif arguments.create_seg_mid is not None:
        labels = [(-1, arguments.create_seg_mid)]
        out = sct_labels.create_labels_along_segmentation(img, labels)
    elif arguments.cubic_to_point:
        out = sct_labels.cubic_to_point(img)
    elif arguments.display:
        display_voxel(img, verbose)
        return
    elif arguments.increment:
        out = sct_labels.increment_z_inverse(img)
    elif arguments.disc is not None:
        ref = Image(arguments.disc)
        out = sct_labels.labelize_from_discs(img, ref)
    elif arguments.vert_body is not None:
        levels = arguments.vert_body
        if len(levels) == 1 and levels[0] == 0:
            levels = None  # all levels
        out = sct_labels.label_vertebrae(img, levels)
    elif arguments.vert_continuous:
        out = sct_labels.continuous_vertebral_levels(img)
        dtype = 'float32'
    elif arguments.MSE is not None:
        ref = Image(arguments.MSE)
        mse = sct_labels.compute_mean_squared_error(img, ref)
        printv(f"Computed MSE: {mse}")
        return
    elif arguments.remove_reference is not None:
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)
    elif arguments.remove_sym is not None:
        # first pass use img as source
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)

        # second pass use previous pass result as reference
        ref_out = sct_labels.remove_missing_labels(ref, out)
        ref_out.save(path=ref.absolutepath)
    elif arguments.remove is not None:
        labels = arguments.remove
        out = sct_labels.remove_labels_from_image(img, labels)
    elif arguments.keep is not None:
        labels = arguments.keep
        out = sct_labels.remove_other_labels_from_image(img, labels)
    elif arguments.create_viewer is not None:
        msg = "" if arguments.msg is None else f"{arguments.msg}\n"
        if arguments.ilabel is not None:
            input_labels_img = Image(arguments.ilabel)
            out = launch_manual_label_gui(img, input_labels_img, parse_num_list(arguments.create_viewer), msg)
        else:
            out = launch_sagittal_viewer(img, parse_num_list(arguments.create_viewer), msg)

    printv("Generating output files...")
    out.save(path=output_fname, dtype=dtype)
    display_viewer_syntax([input_filename, output_fname])

    if arguments.qc is not None:
        generate_qc(fname_in1=input_filename, fname_seg=output_fname, args=argv,
                    path_qc=os.path.abspath(arguments.qc), dataset=arguments.qc_dataset,
                    subject=arguments.qc_subject, process='sct_label_utils')
Exemplo n.º 19
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # Default params
    param = Param()

    # Get parser info
    fname_data = arguments.i
    if arguments.m is not None:
        fname_mask = arguments.m
    else:
        fname_mask = ''
    method = arguments.method
    if arguments.vol is not None:
        index_vol_user = arguments.vol
    else:
        index_vol_user = ''

    # Check parameters
    if method == 'diff':
        if not fname_mask:
            printv('You need to provide a mask with -method diff. Exit.',
                   1,
                   type='error')

    # Load data and orient to RPI
    im_data = Image(fname_data).change_orientation('RPI')
    data = im_data.data
    if fname_mask:
        mask = Image(fname_mask).change_orientation('RPI').data

    # Retrieve selected volumes
    if index_vol_user:
        index_vol = parse_num_list(index_vol_user)
    else:
        index_vol = range(data.shape[3])

    # Make sure user selected 2 volumes with diff method
    if method == 'diff':
        if not len(index_vol) == 2:
            printv(
                'Method "diff" should be used with exactly two volumes (specify with flag "-vol").',
                1, 'error')

    # Compute SNR
    # NB: "time" is assumed to be the 4th dimension of the variable "data"
    if method == 'mult':
        # Compute mean and STD across time
        data_mean = np.mean(data[:, :, :, index_vol], axis=3)
        data_std = np.std(data[:, :, :, index_vol], axis=3, ddof=1)
        # Generate mask where std is different from 0
        mask_std_nonzero = np.where(data_std > param.almost_zero)
        snr_map = np.zeros_like(data_mean)
        snr_map[mask_std_nonzero] = data_mean[mask_std_nonzero] / data_std[
            mask_std_nonzero]
        # Output SNR map
        fname_snr = add_suffix(fname_data, '_SNR-' + method)
        im_snr = empty_like(im_data)
        im_snr.data = snr_map
        im_snr.save(fname_snr, dtype=np.float32)
        # Output non-zero mask
        fname_stdnonzero = add_suffix(fname_data, '_mask-STD-nonzero' + method)
        im_stdnonzero = empty_like(im_data)
        data_stdnonzero = np.zeros_like(data_mean)
        data_stdnonzero[mask_std_nonzero] = 1
        im_stdnonzero.data = data_stdnonzero
        im_stdnonzero.save(fname_stdnonzero, dtype=np.float32)
        # Compute SNR in ROI
        if fname_mask:
            mean_in_roi = np.average(data_mean[mask_std_nonzero],
                                     weights=mask[mask_std_nonzero])
            std_in_roi = np.average(data_std[mask_std_nonzero],
                                    weights=mask[mask_std_nonzero])
            snr_roi = mean_in_roi / std_in_roi
            # snr_roi = np.average(snr_map[mask_std_nonzero], weights=mask[mask_std_nonzero])

    elif method == 'diff':
        data_2vol = np.take(data, index_vol, axis=3)
        # Compute mean in ROI
        data_mean = np.mean(data_2vol, axis=3)
        mean_in_roi = np.average(data_mean, weights=mask)
        data_sub = np.subtract(data_2vol[:, :, :, 1], data_2vol[:, :, :, 0])
        _, std_in_roi = weighted_avg_and_std(data_sub, mask)
        # Compute SNR, correcting for Rayleigh noise (see eq. 7 in Dietrich et al.)
        snr_roi = (2 / np.sqrt(2)) * mean_in_roi / std_in_roi

    # Display result
    if fname_mask:
        printv('\nSNR_' + method + ' = ' + str(snr_roi) + '\n', type='info')
Exemplo n.º 20
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    # Default params
    param = Param()

    # Get parser info
    fname_data = arguments.i
    fname_mask = arguments.m
    fname_mask_noise = arguments.m_noise
    method = arguments.method
    file_name = arguments.o
    rayleigh_correction = arguments.rayleigh

    # Check parameters
    if method in ['diff', 'single']:
        if not fname_mask:
            raise parser.error(
                f"Argument '-m' must be specified when using '-method {method}'."
            )

    # Load data
    im_data = Image(fname_data)
    data = im_data.data
    dim = len(data.shape)
    nz = data.shape[2]
    if fname_mask:
        mask = Image(fname_mask).data

    # Check dimensionality
    if method in ['diff', 'mult']:
        if dim != 4:
            raise ValueError(
                f"Input data dimension: {dim}. Input dimension for this method should be 4."
            )
    if method in ['single']:
        if dim not in [3, 4]:
            raise ValueError(
                f"Input data dimension: {dim}. Input dimension for this method should be 3 or 4."
            )

    # Check dimensionality of mask
    if fname_mask:
        if len(mask.shape) != 3:
            raise ValueError(
                f"Mask should be a 3D image, but the input mask has shape '{mask.shape}'."
            )

    # Retrieve selected volumes
    index_vol = parse_num_list(arguments.vol)
    if not index_vol:
        if method == 'mult':
            index_vol = range(data.shape[3])
        elif method == 'diff':
            index_vol = [0, 1]
        elif method == 'single':
            index_vol = [0]

    # Compute SNR
    # NB: "time" is assumed to be the 4th dimension of the variable "data"
    if method == 'mult':
        # Compute mean and STD across time
        data_mean = np.mean(data[:, :, :, index_vol], axis=3)
        data_std = np.std(data[:, :, :, index_vol], axis=3, ddof=1)
        # Generate mask where std is different from 0
        mask_std_nonzero = np.where(data_std > param.almost_zero)
        snr_map = np.zeros_like(data_mean)
        snr_map[mask_std_nonzero] = data_mean[mask_std_nonzero] / data_std[
            mask_std_nonzero]
        # Output SNR map
        fname_snr = add_suffix(fname_data, '_SNR-' + method)
        im_snr = empty_like(im_data)
        im_snr.data = snr_map
        im_snr.save(fname_snr, dtype=np.float32)
        # Output non-zero mask
        fname_stdnonzero = add_suffix(fname_data, '_mask-STD-nonzero' + method)
        im_stdnonzero = empty_like(im_data)
        data_stdnonzero = np.zeros_like(data_mean)
        data_stdnonzero[mask_std_nonzero] = 1
        im_stdnonzero.data = data_stdnonzero
        im_stdnonzero.save(fname_stdnonzero, dtype=np.float32)
        # Compute SNR in ROI
        if fname_mask:
            snr_roi = np.average(snr_map[mask_std_nonzero],
                                 weights=mask[mask_std_nonzero])

    elif method == 'diff':
        # Check user selected exactly 2 volumes for this method.
        if not len(index_vol) == 2:
            raise ValueError(
                f"Number of selected volumes: {len(index_vol)}. The method 'diff' should be used with "
                f"exactly 2 volumes. You can specify the number of volumes with the flag '-vol'."
            )
        data_2vol = np.take(data, index_vol, axis=3)
        # Compute mean across the two volumes
        data_mean = np.mean(data_2vol, axis=3)
        # Compute mean in ROI for each z-slice, if the slice in the mask is not null
        mean_in_roi = [
            np.average(data_mean[..., iz], weights=mask[..., iz])
            for iz in range(nz) if np.any(mask[..., iz])
        ]
        data_sub = np.subtract(data_2vol[:, :, :, 1], data_2vol[:, :, :, 0])
        # Compute STD in the ROI for each z-slice. The "np.sqrt(2)" results from the variance of the subtraction of two
        # distributions: var(A-B) = var(A) + var(B).
        # More context in: https://github.com/spinalcordtoolbox/spinalcordtoolbox/issues/3481
        std_in_roi = [
            weighted_std(data_sub[..., iz] / np.sqrt(2), weights=mask[..., iz])
            for iz in range(nz) if np.any(mask[..., iz])
        ]
        # Compute SNR
        snr_roi_slicewise = [m / s for m, s in zip(mean_in_roi, std_in_roi)]
        snr_roi = sum(snr_roi_slicewise) / len(snr_roi_slicewise)

    elif method == 'single':
        # Check that the input volume is 3D, or if it is 4D, that the user selected exactly 1 volume for this method.
        if dim == 3:
            data3d = data
        elif dim == 4:
            if not len(index_vol) == 1:
                raise ValueError(
                    f"Selected volumes: {index_vol}. The method 'single' should be used with "
                    f"exactly 1 volume. You can specify the index of the volume with the flag '-vol'."
                )
            data3d = np.squeeze(data[..., index_vol])
        # Check that input noise mask is provided
        if fname_mask_noise:
            mask_noise = Image(fname_mask_noise).data
        else:
            raise parser.error(
                "A noise mask is mandatory with '-method single'.")
        # Check dimensionality of the noise mask
        if len(mask_noise.shape) != 3:
            raise ValueError(
                f"Input noise mask dimension: {dim}. Input dimension for the noise mask should be 3."
            )
        # Check that non-null slices are consistent between mask and mask_noise.
        for iz in range(nz):
            if not np.any(mask[..., iz]) == np.any(mask_noise[..., iz]):
                raise ValueError(
                    f"Slice {iz} is empty in either mask or mask_noise. Non-null slices should be "
                    f"consistent between mask and mask_noise.")
        # Compute mean in ROI for each z-slice, if the slice in the mask is not null
        mean_in_roi = [
            np.average(data3d[..., iz], weights=mask[..., iz])
            for iz in range(nz) if np.any(mask[..., iz])
        ]
        std_in_roi = [
            weighted_std(data3d[..., iz], weights=mask_noise[..., iz])
            for iz in range(nz) if np.any(mask_noise[..., iz])
        ]
        # Compute SNR
        snr_roi_slicewise = [m / s for m, s in zip(mean_in_roi, std_in_roi)]
        snr_roi = sum(snr_roi_slicewise) / len(snr_roi_slicewise)
        if rayleigh_correction:
            # Correcting for Rayleigh noise (see eq. A12 in Dietrich et al.)
            snr_roi *= np.sqrt((4 - np.pi) / 2)

    # Display result
    if fname_mask:
        printv('\nSNR_' + method + ' = ' + str(snr_roi) + '\n', type='info')

    # Added function for text file
    if file_name is not None:
        with open(file_name, "w") as f:
            f.write(str(snr_roi))
            printv('\nFile saved to ' + file_name)
Exemplo n.º 21
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.verbose
    set_global_loglevel(verbose=verbose)

    # initializations
    param = Param()

    param.download = int(arguments.download)
    param.path_data = arguments.path
    functions_to_test = arguments.function
    param.remove_tmp_file = int(arguments.remove_temps)
    jobs = arguments.jobs

    param.verbose = verbose

    start_time = time.time()

    # get absolute path and add slash at the end
    param.path_data = os.path.abspath(param.path_data)

    # check existence of testing data folder
    if not os.path.isdir(param.path_data) or param.download:
        downloaddata(param)

    # display path to data
    printv('\nPath to testing data: ' + param.path_data, param.verbose)

    # create temp folder that will have all results
    path_tmp = os.path.abspath(arguments.execution_folder or tmp_create())

    # go in path data (where all scripts will be run)
    curdir = os.getcwd()
    os.chdir(param.path_data)

    functions_parallel = list()
    functions_serial = list()
    if functions_to_test:
        for f in functions_to_test:
            if f in get_functions_parallelizable():
                functions_parallel.append(f)
            elif f in get_functions_nonparallelizable():
                functions_serial.append(f)
            else:
                printv(
                    'Command-line usage error: Function "%s" is not part of the list of testing functions'
                    % f,
                    type='error')
        jobs = min(jobs, len(functions_parallel))
    else:
        functions_parallel = get_functions_parallelizable()
        functions_serial = get_functions_nonparallelizable()

    if arguments.continue_from:
        first_func = arguments.continue_from
        if first_func in functions_parallel:
            functions_serial = []
            functions_parallel = functions_parallel[functions_parallel.
                                                    index(first_func):]
        elif first_func in functions_serial:
            functions_serial = functions_serial[functions_serial.
                                                index(first_func):]

    if arguments.check_filesystem and jobs != 1:
        print("Check filesystem used -> jobs forced to 1")
        jobs = 1

    print("Will run through the following tests:")
    if functions_serial:
        print("- sequentially: {}".format(" ".join(functions_serial)))
    if functions_parallel:
        print("- in parallel with {} jobs: {}".format(
            jobs, " ".join(functions_parallel)))

    list_status = []
    for name, functions in (
        ("serial", functions_serial),
        ("parallel", functions_parallel),
    ):
        if not functions:
            continue

        if any([s for (f, s) in list_status]) and arguments.abort_on_failure:
            break

        try:
            if functions == functions_parallel and jobs != 1:
                pool = multiprocessing.Pool(processes=jobs)

                results = list()
                # loop across functions and run tests
                for f in functions:
                    func_param = copy.deepcopy(param)
                    func_param.path_output = f
                    res = pool.apply_async(process_function_multiproc, (
                        f,
                        func_param,
                    ))
                    results.append(res)
            else:
                pool = None

            for idx_function, f in enumerate(functions):
                print_line('Checking ' + f)
                if functions == functions_serial or jobs == 1:
                    if arguments.check_filesystem:
                        if os.path.exists(os.path.join(path_tmp, f)):
                            shutil.rmtree(os.path.join(path_tmp, f))
                        sig_0 = fs_signature(path_tmp)

                    func_param = copy.deepcopy(param)
                    func_param.path_output = f

                    res = process_function(f, func_param)

                    if arguments.check_filesystem:
                        sig_1 = fs_signature(path_tmp)
                        fs_ok(sig_0, sig_1, exclude=(f, ))
                else:
                    res = results[idx_function].get()

                list_output, list_status_function = res
                # manage status
                if any(list_status_function):
                    if 1 in list_status_function:
                        print_fail()
                        status = (f, 1)
                    else:
                        print_warning()
                        status = (f, 99)
                    for output in list_output:
                        for line in output.splitlines():
                            print("   %s" % line)
                else:
                    print_ok()
                    if param.verbose:
                        for output in list_output:
                            for line in output.splitlines():
                                print("   %s" % line)
                    status = (f, 0)
                # append status function to global list of status
                list_status.append(status)
                if any([s for (f, s) in list_status
                        ]) and arguments.abort_on_failure:
                    break
        except KeyboardInterrupt:
            raise
        finally:
            if pool:
                pool.terminate()
                pool.join()

    print('status: ' + str([s for (f, s) in list_status]))
    if any([s for (f, s) in list_status]):
        print("Failures: {}".format(" ".join(
            [f for (f, s) in list_status if s])))

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv('Finished! Elapsed time: ' + str(int(np.round(elapsed_time))) +
           's\n')

    # come back
    os.chdir(curdir)

    # remove temp files
    if param.remove_tmp_file and arguments.execution_folder is None:
        printv('\nRemove temporary files...', 0)
        rmtree(path_tmp)

    e = 0
    if any([s for (f, s) in list_status]):
        e = 1
    # print(e)

    sys.exit(e)
Exemplo n.º 22
0
def compute_dti(fname_in, fname_bvals, fname_bvecs, prefix, method, evecs,
                file_mask, verbose):
    """
    Compute DTI.
    :param fname_in: input 4d file.
    :param bvals: bvals txt file
    :param bvecs: bvecs txt file
    :param prefix: output prefix. Example: "dti_"
    :param method: algo for computing dti
    :param evecs: bool: output diffusion tensor eigenvectors and eigenvalues
    :return: True/False
    """
    # Open file.
    from spinalcordtoolbox.image import Image
    nii = Image(fname_in)
    data = nii.data
    printv('data.shape (%d, %d, %d, %d)' % data.shape)

    # open bvecs/bvals
    from dipy.io import read_bvals_bvecs
    bvals, bvecs = read_bvals_bvecs(fname_bvals, fname_bvecs)
    from dipy.core.gradients import gradient_table
    gtab = gradient_table(bvals, bvecs)

    # mask and crop the data. This is a quick way to avoid calculating Tensors on the background of the image.
    if not file_mask == '':
        printv('Open mask file...', verbose)
        # open mask file
        nii_mask = Image(file_mask)
        mask = nii_mask.data

    # fit tensor model
    printv('Computing tensor using "' + method + '" method...', verbose)
    import dipy.reconst.dti as dti
    if method == 'standard':
        tenmodel = dti.TensorModel(gtab)
        if file_mask == '':
            tenfit = tenmodel.fit(data)
        else:
            tenfit = tenmodel.fit(data, mask)
    elif method == 'restore':
        import dipy.denoise.noise_estimate as ne
        sigma = ne.estimate_sigma(data)
        dti_restore = dti.TensorModel(gtab, fit_method='RESTORE', sigma=sigma)
        if file_mask == '':
            tenfit = dti_restore.fit(data)
        else:
            tenfit = dti_restore.fit(data, mask)

    # Compute metrics
    printv('Computing metrics...', verbose)
    # FA
    nii.data = tenfit.fa
    nii.save(prefix + 'FA.nii.gz', dtype='float32')
    # MD
    nii.data = tenfit.md
    nii.save(prefix + 'MD.nii.gz', dtype='float32')
    # RD
    nii.data = tenfit.rd
    nii.save(prefix + 'RD.nii.gz', dtype='float32')
    # AD
    nii.data = tenfit.ad
    nii.save(prefix + 'AD.nii.gz', dtype='float32')
    if evecs:
        data_evecs = tenfit.evecs
        data_evals = tenfit.evals
        # output 1st (V1), 2nd (V2) and 3rd (V3) eigenvectors as 4d data
        for idim in range(3):
            nii.data = data_evecs[:, :, :, :, idim]
            nii.save(prefix + 'V' + str(idim + 1) + '.nii.gz', dtype="float32")
            nii.data = data_evals[:, :, :, idim]
            nii.save(prefix + 'E' + str(idim + 1) + '.nii.gz', dtype="float32")

    return True
Exemplo n.º 23
0
def print_ok():
    printv("[" + bcolors.OKGREEN + "OK" + bcolors.ENDC + "]")
Exemplo n.º 24
0
def print_warning():
    printv("[" + bcolors.WARNING + "WARNING" + bcolors.ENDC + "]")
Exemplo n.º 25
0
def print_fail():
    printv("[" + bcolors.FAIL + "FAIL" + bcolors.ENDC + "]")
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    GYRO = float(42.576 * 10**6)  # gyromagnetic ratio (in Hz.T^-1)
    gradamp = []
    bigdelta = []
    smalldelta = []
    gradamp = arguments.g
    bigdelta = arguments.b
    smalldelta = arguments.d

    # printv(arguments)
    printv('\nCheck parameters:')
    printv('  gradient amplitude ..... ' + str(gradamp) + ' mT/m')
    printv('  big delta .............. ' + str(bigdelta) + ' ms')
    printv('  small delta ............ ' + str(smalldelta) + ' ms')
    printv('  gyromagnetic ratio ..... ' + str(GYRO) + ' Hz/T')
    printv('')

    bvalue = (2 * math.pi * GYRO * gradamp * 0.001 * smalldelta *
              0.001)**2 * (bigdelta * 0.001 - smalldelta * 0.001 / 3)

    printv('b-value = ' + str(bvalue / 10**6) + ' mm^2/s\n')
    return bvalue
def main(args=None):
    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
    else:
        arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    verbosity = arguments.v
    init_sct(log_level=verbosity, update=True)  # Update log level

    input_filename = arguments.i
    output_fname = arguments.o

    img = Image(input_filename)
    dtype = None

    if arguments.add is not None:
        value = arguments.add
        out = sct_labels.add(img, value)
    elif arguments.create is not None:
        labels = arguments.create
        out = sct_labels.create_labels_empty(img, labels)
    elif arguments.create_add is not None:
        labels = arguments.create_add
        out = sct_labels.create_labels(img, labels)
    elif arguments.create_seg is not None:
        labels = arguments.create_seg
        out = sct_labels.create_labels_along_segmentation(img, labels)
    elif arguments.cubic_to_point:
        out = sct_labels.cubic_to_point(img)
    elif arguments.display:
        display_voxel(img, verbosity)
        return
    elif arguments.increment:
        out = sct_labels.increment_z_inverse(img)
    elif arguments.disc is not None:
        ref = Image(arguments.disc)
        out = sct_labels.labelize_from_discs(img, ref)
    elif arguments.vert_body is not None:
        levels = arguments.vert_body
        if len(levels) == 1 and levels[0] == 0:
            levels = None  # all levels
        out = sct_labels.label_vertebrae(img, levels)
    elif arguments.vert_continuous:
        out = sct_labels.continuous_vertebral_levels(img)
        dtype = 'float32'
    elif arguments.MSE is not None:
        ref = Image(arguments.MSE)
        mse = sct_labels.compute_mean_squared_error(img, ref)
        printv(f"Computed MSE: {mse}")
        return
    elif arguments.remove_reference is not None:
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)
    elif arguments.remove_sym is not None:
        # first pass use img as source
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)

        # second pass use previous pass result as reference
        ref_out = sct_labels.remove_missing_labels(ref, out)
        ref_out.save(path=ref.absolutepath)
    elif arguments.remove is not None:
        labels = arguments.remove
        out = sct_labels.remove_labels_from_image(img, labels)
    elif arguments.keep is not None:
        labels = arguments.keep
        out = sct_labels.remove_other_labels_from_image(img, labels)
    elif arguments.create_viewer is not None:
        msg = "" if arguments.msg is None else f"{arguments.msg}\n"
        if arguments.ilabel is not None:
            input_labels_img = Image(arguments.ilabel)
            out = launch_manual_label_gui(img, input_labels_img, parse_num_list(arguments.create_viewer), msg)
        else:
            out = launch_sagittal_viewer(img, parse_num_list(arguments.create_viewer), msg)

    printv("Generating output files...")
    out.save(path=output_fname, dtype=dtype)
    display_viewer_syntax([input_filename, output_fname])

    if arguments.qc is not None:
        generate_qc(fname_in1=input_filename, fname_seg=output_fname, args=args,
                    path_qc=os.path.abspath(arguments.qc), dataset=arguments.qc_dataset,
                    subject=arguments.qc_subject, process='sct_label_utils')
Exemplo n.º 28
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    model = arguments.model
    if "," in arguments.radius:
        patch_radius = list_type(",", int)(arguments.radius)
    else:
        patch_radius = int(arguments.radius)

    file_to_denoise = arguments.i
    bval_file = arguments.b
    output_file_name = arguments.o

    path, file, ext = extract_fname(file_to_denoise)

    img = nib.load(file_to_denoise)
    bvals = np.loadtxt(bval_file)
    hdr_0 = img.get_header()
    data = img.get_data()

    printv('Applying Patch2Self Denoising...')
    den = patch2self(data,
                     bvals,
                     patch_radius=patch_radius,
                     model=model,
                     verbose=True)

    if verbose == 2:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 3)
        axial_middle = int(data.shape[2] / 2)
        middle_vol = int(data.shape[3] / 2)
        before = data[:, :, axial_middle, middle_vol].T
        ax[0].imshow(before, cmap='gray', origin='lower')
        ax[0].set_title('before')
        after = den[:, :, axial_middle, middle_vol].T
        ax[1].imshow(after, cmap='gray', origin='lower')
        ax[1].set_title('after')
        difference = np.absolute(after.astype('f8') - before.astype('f8'))
        ax[2].imshow(difference, cmap='gray', origin='lower')
        ax[2].set_title('difference')
        for i in range(3):
            ax[i].set_axis_off()
        plt.show()

    # Save files
    img_denoise = nib.Nifti1Image(den, None, hdr_0)
    diff_4d = np.absolute(den.astype('f8') - data.astype('f8'))
    img_diff = nib.Nifti1Image(diff_4d, None, hdr_0)
    if output_file_name is not None:
        output_file_name_den = output_file_name
        output_file_name_diff = add_suffix(output_file_name, "_difference")
    else:
        output_file_name_den = file + '_patch2self_denoised' + ext
        output_file_name_diff = file + '_patch2self_difference' + ext
    nib.save(img_denoise, output_file_name_den)
    nib.save(img_diff, output_file_name_diff)

    printv('\nDone! To view results, type:', verbose)
    printv('fsleyes ' + file_to_denoise + ' ' + output_file_name + ' & \n',
           verbose, 'info')
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv if argv else ['--help'])
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    fname_bvecs = arguments.bvec

    # Read bvecs
    bvecs = read_bvals_bvecs(fname_bvecs, None)
    bvecs = bvecs[0]
    # if first dimension is not equal to 3 (x,y,z), transpose bvecs file
    if not bvecs.shape[0] == 3:
        bvecs = bvecs.transpose()
    x, y, z = bvecs[0], bvecs[1], bvecs[2]

    # Get total number of directions
    n_dir = len(x)

    # Get effective number of directions
    bvecs_eff = []
    n_b0 = 0
    for i in range(0, n_dir):
        add_direction = True
        # check if b=0
        if abs(x[i]) < bzero and abs(x[i]) < bzero and abs(x[i]) < bzero:
            n_b0 += 1
            add_direction = False
        else:
            # loop across bvecs_eff
            for j in range(0, len(bvecs_eff)):
                # if bvalue already present, then do not add to bvecs_eff
                if bvecs_eff[j] == [x[i], y[i], z[i]]:
                    add_direction = False
        if add_direction:
            bvecs_eff.append([x[i], y[i], z[i]])
    n_dir_eff = len(bvecs_eff)

    # Display scatter plot
    fig = plt.figure(facecolor='white', figsize=(9, 8))
    fig.suptitle('Number of b=0: ' + str(n_b0) + ', Number of b!=0: ' + str(n_dir - n_b0) + ', Number of effective directions (without duplicates): ' + str(n_dir_eff))
    # plt.ion()

    # Display three views
    plot_2dscatter(fig_handle=fig, subplot=221, x=bvecs[0][:], y=bvecs[1][:], xlabel='X', ylabel='Y')
    plot_2dscatter(fig_handle=fig, subplot=222, x=bvecs[0][:], y=bvecs[2][:], xlabel='X', ylabel='Z')
    plot_2dscatter(fig_handle=fig, subplot=223, x=bvecs[1][:], y=bvecs[2][:], xlabel='Y', ylabel='Z')

    # 3D
    ax = fig.add_subplot(224, projection='3d')
    # ax.auto_scale_xyz([-1, 1], [-1, 1], [-1, 1])
    for i in range(0, n_dir):
        # x, y, z = bvecs[0], bvecs[1], bvecs[2]
        # if b=0, do not plot
        if not(abs(x[i]) < bzero and abs(x[i]) < bzero and abs(x[i]) < bzero):
            ax.scatter(x[i], y[i], z[i])
    ax.set_xlim3d(-1, 1)
    ax.set_ylim3d(-1, 1)
    ax.set_zlim3d(-1, 1)
    plt.title('3D view (use mouse to rotate)')
    plt.axis('off')
    # plt.draw()

    # Save image
    printv("Saving figure: bvecs.png\n")
    plt.savefig('bvecs.png')
    plt.show()
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_bvecs = arguments.bvec
    fname_bvals = arguments.bval
    # Read bvals and bvecs files (if arguments.bval is not passed, bvals will be None)
    bvals, bvecs = read_bvals_bvecs(fname_bvals, fname_bvecs)
    # if first dimension is not equal to 3 (x,y,z), transpose bvecs file
    if bvecs.shape[0] != 3:
        bvecs = bvecs.transpose()
    # if bvals file was not passed, create dummy unit bvals array (necessary fot scatter plots)
    if bvals is None:
        bvals = np.repeat(1, bvecs.shape[1])
    # multiply unit b-vectors by b-values
    x, y, z = bvecs[0] * bvals, bvecs[1] * bvals, bvecs[2] * bvals

    # Set different color for each shell (bval)
    shell_colors = {}
    # Create iterator with different colors from brg palette
    colors = iter(cm.nipy_spectral(np.linspace(0, 1, len(np.unique(bvals)))))
    for unique_bval in np.unique(bvals):
        # skip b=0
        if unique_bval < BZERO_THRESH:
            continue
        shell_colors[unique_bval] = next(colors)

    # Get total number of directions
    n_dir = len(x)

    # Get effective number of directions
    bvecs_eff = []
    n_b0 = 0
    for i in range(0, n_dir):
        add_direction = True
        # check if b=0
        if abs(x[i]) < BZERO_THRESH and abs(x[i]) < BZERO_THRESH and abs(
                x[i]) < BZERO_THRESH:
            n_b0 += 1
            add_direction = False
        else:
            # loop across bvecs_eff
            for j in range(0, len(bvecs_eff)):
                # if bvalue already present, then do not add to bvecs_eff
                if bvecs_eff[j] == [x[i], y[i], z[i]]:
                    add_direction = False
        if add_direction:
            bvecs_eff.append([x[i], y[i], z[i]])
    n_dir_eff = len(bvecs_eff)

    # Display scatter plot
    fig = plt.figure(facecolor='white', figsize=(9, 8))
    fig.suptitle('Number of b=0: ' + str(n_b0) + ', Number of b!=0: ' +
                 str(n_dir - n_b0) +
                 ', Number of effective directions (without duplicates): ' +
                 str(n_dir_eff))

    # Display three views
    plot_2dscatter(fig_handle=fig,
                   subplot=221,
                   x=x,
                   y=y,
                   xlabel='X',
                   ylabel='Y',
                   bvals=bvals,
                   colors=shell_colors)
    plot_2dscatter(fig_handle=fig,
                   subplot=222,
                   x=x,
                   y=z,
                   xlabel='X',
                   ylabel='Z',
                   bvals=bvals,
                   colors=shell_colors)
    plot_2dscatter(fig_handle=fig,
                   subplot=223,
                   x=y,
                   y=z,
                   xlabel='Y',
                   ylabel='Z',
                   bvals=bvals,
                   colors=shell_colors)

    # 3D
    ax = fig.add_subplot(224, projection='3d')
    # ax.auto_scale_xyz([-1, 1], [-1, 1], [-1, 1])
    for i in range(0, n_dir):
        # x, y, z = bvecs[0], bvecs[1], bvecs[2]
        # if b=0, do not plot
        if not (abs(x[i]) < BZERO_THRESH and abs(x[i]) < BZERO_THRESH
                and abs(x[i]) < BZERO_THRESH):
            ax.scatter(x[i],
                       y[i],
                       z[i],
                       color=shell_colors[bvals[i]],
                       alpha=0.7)
    ax.set_xlim3d(-max(bvals), max(bvals))
    ax.set_ylim3d(-max(bvals), max(bvals))
    ax.set_zlim3d(-max(bvals), max(bvals))
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.title('3D view (use mouse to rotate)')
    plt.axis('on')
    # plt.draw()

    plt.tight_layout()
    # add legend with b-values if bvals file was passed
    if arguments.bval is not None:
        create_custom_legend(fig, shell_colors, bvals)

    # Save image
    printv("Saving figure: bvecs.png\n")
    plt.savefig('bvecs.png')