示例#1
0
def test_change_nd_orientation(fake_4dimage_sct):
    im_src = fake_4dimage_sct.copy()
    path_tmp = tmp_create(basename="test_reorient")
    im_src.save(os.path.join(path_tmp, "src.nii"), mutable=True)

    print(im_src.orientation, im_src.data.shape)

    def orient2shape(orient):
        # test-data-specific thing
        letter2dim = dict(
            L=2,
            R=2,
            A=3,
            P=3,
            I=4,
            S=4,
        )
        return tuple([letter2dim[x] for x in orient] + [5])

    orientation = im_src.orientation
    assert orientation == "LPI"
    assert im_src.header.get_best_affine()[:3, 3].tolist() == [0, 0, 0]

    im_dst = msct_image.change_orientation(im_src, "RPI")
    assert im_dst.orientation == "RPI"
    assert im_dst.data.shape == orient2shape("RPI")
    assert im_dst.header.get_best_affine()[:3, 3].tolist() == [2 - 1, 0, 0]
示例#2
0
def test_sequences(fake_3dimage_sct):
    """
    Test correct behaviour in some Image manipulation sequences
    """

    img = fake_3dimage_sct.copy()

    path_tmp = tmp_create(basename="test_sequences")

    path_a = os.path.join(path_tmp, 'a.nii')
    path_b = os.path.join(path_tmp, 'b.nii')

    img.save(path_a)
    assert img.absolutepath is None

    img.save(path_b, mutable=True)
    assert img._path is not None
    assert img.absolutepath is not None
    assert img.absolutepath == os.path.abspath(path_b)

    img.save(path_a) \
        .change_orientation("RPI") \
        .save(path_b, mutable=True)
    assert img.absolutepath is not None
    assert img.absolutepath == os.path.abspath(path_b)
示例#3
0
def test_change_shape(fake_3dimage_sct):

    # Add dimension
    im_src = fake_3dimage_sct
    shape = tuple(list(im_src.data.shape) + [1])
    im_dst = msct_image.change_shape(im_src, shape)
    path_tmp = tmp_create(basename="test_reshape")
    src_path = os.path.join(path_tmp, "src.nii")
    dst_path = os.path.join(path_tmp, "dst.nii")
    im_src.save(src_path)
    im_dst.save(dst_path)
    im_src = msct_image.Image(src_path)
    im_dst = msct_image.Image(dst_path)
    assert im_dst.data.shape == shape

    data_src = im_src.data
    data_dst = im_dst.data

    assert (data_dst.reshape(data_src.shape) == data_src).all()

    # Remove dimension
    im_dst = im_dst.change_shape(im_src.data.shape)
    assert im_dst.data.shape == im_src.data.shape
示例#4
0
def test_more_change_orientation(fake_3dimage_sct, fake_3dimage_sct_vis):
    path_tmp = tmp_create(basename="test_reorient")
    path_tmp = "."

    im_src = fake_3dimage_sct.copy()
    im_src.save(os.path.join(path_tmp, "src.nii"), mutable=True)

    print(im_src.orientation, im_src.data.shape)

    def orient2shape(orient):
        # test-data-specific thing
        letter2dim = dict(
            L=7,
            R=7,
            A=8,
            P=8,
            I=9,
            S=9,
        )
        return tuple([letter2dim[x] for x in orient])

    orientation = im_src.orientation  # LPI
    assert im_src.header.get_best_affine()[:3, 3].tolist() == [0, 0, 0]
    im_dst = msct_image.change_orientation(im_src, "RPI")
    print(im_dst.orientation, im_dst.data.shape)
    assert im_dst.data.shape == orient2shape("RPI")
    assert im_dst.header.get_best_affine()[:3, 3].tolist() == [7 - 1, 0, 0]

    # spot check
    orientation = im_src.orientation  # LPI
    im_dst = msct_image.change_orientation(im_src, "IRP")
    print(im_dst.orientation, im_dst.data.shape)
    assert im_dst.data.shape == orient2shape("IRP")

    # to & fro
    im_dst2 = msct_image.change_orientation(im_dst, orientation)
    print(im_dst2.orientation, im_dst2.data.shape)
    assert im_dst2.orientation == im_src.orientation
    assert im_dst2.data.shape == orient2shape(orientation)
    assert (im_dst2.data == im_src.data).all()
    assert np.allclose(im_src.header.get_best_affine(),
                       im_dst2.header.get_best_affine())

    #fn = os.path.join(path_tmp, "pouet.nii")
    im_ref = fake_3dimage_sct.copy()
    im_src = fake_3dimage_sct.copy()
    orientation = im_src.orientation
    im_src.change_orientation("ASR").change_orientation(orientation)
    assert im_src.orientation == im_ref.orientation
    assert (im_dst2.data == im_src.data).all()
    assert np.allclose(im_src.header.get_best_affine(),
                       im_ref.header.get_best_affine())

    im_dst2 = msct_image.change_orientation(im_dst, orientation)
    print(im_dst2.orientation, im_dst2.data.shape)
    assert im_dst2.orientation == im_src.orientation
    assert im_dst2.data.shape == orient2shape(orientation)
    assert (im_dst2.data == im_src.data).all()
    assert np.allclose(im_src.header.get_best_affine(),
                       im_dst2.header.get_best_affine())

    # copy
    im_dst = im_src.copy().change_orientation("IRP")
    assert im_dst.data.shape == orient2shape("IRP")
    print(im_dst.orientation, im_dst.data.shape)

    print("Testing orientation persistence")
    img = im_src.copy()
    orientation = img.orientation
    fn = os.path.join(path_tmp, "pouet.nii")
    img.change_orientation("PIR").save(fn)
    assert img.data.shape == orient2shape("PIR")
    img = msct_image.Image(fn)
    assert img.orientation == "PIR"
    assert img.data.shape == orient2shape("PIR")
    print(img.orientation, img.data.shape)

    # typical pattern
    img = fake_3dimage_sct_vis.copy()
    print(img.header.get_best_affine())
    orientation = img.orientation
    path_tmp = "."
    fn = os.path.join(path_tmp, "vis.nii")
    fn2 = img.save(fn, mutable=True).change_orientation(
        "ALS", generate_path=True).save().absolutepath
    img = msct_image.Image(fn2)
    assert img.orientation == "ALS"
    assert img.data.shape == orient2shape("ALS")
    print(img.header.get_best_affine())

    fn2 = img.save(fn, mutable=True).change_orientation(
        "RAS", generate_path=True).save().absolutepath
    img = msct_image.Image(fn2)
    assert img.orientation == "RAS"
    assert img.data.shape == orient2shape("RAS")
    print(img.header.get_best_affine())

    fn2 = img.save(fn, mutable=True).change_orientation(
        "RPI", generate_path=True).save().absolutepath
    img = msct_image.Image(fn2)
    assert img.orientation == "RPI"
    assert img.data.shape == orient2shape("RPI")
    print(img.header.get_best_affine())

    fn2 = img.save(fn, mutable=True).change_orientation(
        "PLI", generate_path=True).save().absolutepath
    img = msct_image.Image(fn2)
    assert img.orientation == "PLI"
    assert img.data.shape == orient2shape("PLI")
    print(img.header.get_best_affine())

    # print(src.header)
    possibilities = [
        "ASR",
        "SRA",
        "RAS",
    ]
    possibilities = msct_image.all_refspace_strings()
    for orientation in possibilities:
        dst = msct_image.change_orientation(im_src, orientation)
        # dst.save("pouet-{}.nii".format(dst.orientation))
        print(orientation, dst.orientation, dst.data.shape, dst.dim)
        assert orientation == dst.orientation
示例#5
0
def test_change_orientation(fake_3dimage_sct, fake_3dimage_sct_vis):

    path_tmp = tmp_create(basename="test_reorient")
    path_tmp = "."

    print("Spot-checking that physical coordinates don't change")
    for shape_is in (1, 2, 3):
        shape = (1, 1, shape_is)
        print("Simple image with shape {}".format(shape))

        data = np.ones(shape, order="F")
        data[:, :, shape_is - 1] += 1
        im_src = fake_3dimage_sct_custom(data)
        im_dst = msct_image.change_orientation(im_src, "ASR")
        # Basic check
        assert im_dst.orientation == "ASR"
        # Basic data check
        assert im_dst.data.mean() == im_src.data.mean()
        # Basic header check: check that the same voxel
        # remains at the same physical position
        aff_src = im_src.header.get_best_affine()
        aff_dst = im_dst.header.get_best_affine()

        # Take the extremities "a" & "z"...
        # consider the original LPI position
        pta_src = np.array([[0, 0, 0, 1]]).T
        ptz_src = np.array([[0, 0, shape_is - 1, 1]]).T
        # and the position in ASR
        pta_dst = np.array([[0, shape_is - 1, 0, 1]]).T
        ptz_dst = np.array([[0, 0, 0, 1]]).T
        # The physical positions should be:
        posa_src = np.matmul(aff_src, pta_src)
        posa_dst = np.matmul(aff_dst, pta_dst)
        print("A at src {}".format(posa_src.T))
        print("A at dst {}".format(posa_dst.T))
        posz_src = np.matmul(aff_src, ptz_src)
        posz_dst = np.matmul(aff_dst, ptz_dst)
        # and they should be equal
        assert (posa_src == posa_dst).all()
        assert (posz_src == posz_dst).all()
        fn = "".join(str(x) for x in im_src.data.shape)
        im_src.save("{}-src.nii".format(fn))
        im_dst.save("{}-dst.nii".format(fn))

    np.random.seed(0)

    print("More checking that physical coordinates don't change")
    if 1:
        shape = (7, 8, 9)
        print("Simple image with shape {}".format(shape))

        data = np.ones(shape, order="F") * 10
        data[4, 4, 4] = 4
        data[3, 3, 3] = 3
        data[0, 0, 0] = 0
        values = (0, 3, 4)

        im_ref = fake_3dimage_sct_custom(data)
        im_ref.header.set_xyzt_units("mm", "msec")

        import scipy.linalg

        def rand_rot():
            q, _ = scipy.linalg.qr(np.random.randn(3, 3))
            if scipy.linalg.det(q) < 0:
                q[:, 0] = -q[:, 0]
            return q

        affine = im_ref.header.get_best_affine()
        affine[:3, :3] = rand_rot()
        affine[3, :3] = 0.0
        affine[:3, 3] = np.random.random((3))
        affine[3, 3] = 1.0

        affine[0, 0] *= 2
        im_ref.header.set_sform(affine, code='scanner')

        orientations = msct_image.all_refspace_strings()
        for ori_src in orientations:
            for ori_dst in orientations:
                print("{} -> {}".format(ori_src, ori_dst))
                im_src = msct_image.change_orientation(im_ref, ori_src)
                im_dst = msct_image.change_orientation(im_src, ori_dst)

                assert im_src.orientation == ori_src
                assert im_dst.orientation == ori_dst
                assert im_dst.data.mean() == im_src.data.mean()

                # Basic header check: check that the same voxel
                # remains at the same physical position
                aff_src = im_src.header.get_best_affine()
                aff_dst = im_dst.header.get_best_affine()

                data_src = np.array(im_src.data)
                data_dst = np.array(im_dst.data)

                for value in values:
                    pt_src = np.argwhere(data_src == value)[0]
                    pt_dst = np.argwhere(data_dst == value)[0]

                    pos_src = np.matmul(
                        aff_src,
                        np.hstack((pt_src, [1])).reshape((4, 1)))
                    pos_dst = np.matmul(
                        aff_dst,
                        np.hstack((pt_dst, [1])).reshape((4, 1)))
                    if 0:
                        print("P at src {}".format(pos_src.T))
                        print("P at dst {}".format(pos_dst.T))
                    assert np.allclose(pos_src, pos_dst, atol=1e-3)
示例#6
0
def install_data(url, dest_folder, keep=False):
    """
    Download a data bundle from a URL and install in the destination folder.

    :param url: URL or sequence thereof (if mirrors).
    :param dest_folder: destination directory for the data (to be created).
    :param keep: whether to keep existing data in the destination folder.
    :return: None

    .. note::
        The function tries to be smart about the data contents.

        Examples:

        a. If the archive only contains a `README.md`, and the destination folder is `${dst}`,
            `${dst}/README.md` will be created.

            Note: an archive not containing a single folder is commonly known as a "bomb" because
            it puts files anywhere in the current working directory.
            https://en.wikipedia.org/wiki/Tar_(computing)#Tarbomb

        b. If the archive contains a `${dir}/README.md`, and the destination folder is `${dst}`,
            `${dst}/README.md` will be created.

            Note: typically the package will be called `${basename}-${revision}.zip` and contain
            a root folder named `${basename}-${revision}/` under which all the other files will
            be located.
            The right thing to do in this case is to take the files from there and install them
            in `${dst}`.

        - Uses `download_data()` to retrieve the data.
        - Uses `unzip()` to extract the bundle.
    """

    if not keep and os.path.exists(dest_folder):
        logger.warning("Removing existing destination folder “%s”",
                       dest_folder)
        shutil.rmtree(dest_folder)
    os.makedirs(dest_folder, exist_ok=True)

    tmp_file = download_data(url)

    extraction_folder = tmp_create()

    unzip(tmp_file, extraction_folder)

    # Identify whether we have a proper archive or a tarbomb
    with os.scandir(extraction_folder) as it:
        has_dir = False
        nb_entries = 0
        for entry in it:
            if entry.name in ("__MACOSX", ):
                continue
            nb_entries += 1
            if entry.is_dir():
                has_dir = True

    if nb_entries == 1 and has_dir:
        # tarball with single-directory -> go under
        with os.scandir(extraction_folder) as it:
            for entry in it:
                if entry.name in ("__MACOSX", ):
                    continue
                bundle_folder = entry.path
    else:
        # bomb scenario -> stay here
        bundle_folder = extraction_folder

    # Copy over
    for cwd, ds, fs in os.walk(bundle_folder):
        ds.sort()
        fs.sort()
        ds[:] = [d for d in ds if d not in ("__MACOSX", )]
        for d in ds:
            srcpath = os.path.join(cwd, d)
            relpath = os.path.relpath(srcpath, bundle_folder)
            dstpath = os.path.join(dest_folder, relpath)
            if os.path.exists(dstpath):
                # lazy -- we assume existing is a directory, otherwise it will crash safely
                logger.debug("- d- %s", relpath)
            else:
                logger.debug("- d+ %s", relpath)
                os.makedirs(dstpath)

        for f in fs:
            srcpath = os.path.join(cwd, f)
            relpath = os.path.relpath(srcpath, bundle_folder)
            dstpath = os.path.join(dest_folder, relpath)
            if os.path.exists(dstpath):
                logger.debug("- f! %s", relpath)
                logger.warning("Updating existing “%s”", dstpath)
                os.unlink(dstpath)
            else:
                logger.debug("- f+ %s", relpath)
            shutil.copy(srcpath, dstpath)

    logger.info("Removing temporary folders...")
    shutil.rmtree(os.path.dirname(tmp_file))
    shutil.rmtree(extraction_folder)
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.verbose
    set_global_loglevel(verbose=verbose)

    # initializations
    param = Param()

    param.download = int(arguments.download)
    param.path_data = arguments.path
    functions_to_test = arguments.function
    param.remove_tmp_file = int(arguments.remove_temps)
    jobs = arguments.jobs

    param.verbose = verbose

    start_time = time.time()

    # get absolute path and add slash at the end
    param.path_data = os.path.abspath(param.path_data)

    # check existence of testing data folder
    if not os.path.isdir(param.path_data) or param.download:
        downloaddata(param)

    # display path to data
    printv('\nPath to testing data: ' + param.path_data, param.verbose)

    # create temp folder that will have all results
    path_tmp = os.path.abspath(arguments.execution_folder or tmp_create())

    # go in path data (where all scripts will be run)
    curdir = os.getcwd()
    os.chdir(param.path_data)

    functions_parallel = list()
    functions_serial = list()
    if functions_to_test:
        for f in functions_to_test:
            if f in get_functions_parallelizable():
                functions_parallel.append(f)
            elif f in get_functions_nonparallelizable():
                functions_serial.append(f)
            else:
                printv(
                    'Command-line usage error: Function "%s" is not part of the list of testing functions'
                    % f,
                    type='error')
        jobs = min(jobs, len(functions_parallel))
    else:
        functions_parallel = get_functions_parallelizable()
        functions_serial = get_functions_nonparallelizable()

    if arguments.continue_from:
        first_func = arguments.continue_from
        if first_func in functions_parallel:
            functions_serial = []
            functions_parallel = functions_parallel[functions_parallel.
                                                    index(first_func):]
        elif first_func in functions_serial:
            functions_serial = functions_serial[functions_serial.
                                                index(first_func):]

    if arguments.check_filesystem and jobs != 1:
        print("Check filesystem used -> jobs forced to 1")
        jobs = 1

    print("Will run through the following tests:")
    if functions_serial:
        print("- sequentially: {}".format(" ".join(functions_serial)))
    if functions_parallel:
        print("- in parallel with {} jobs: {}".format(
            jobs, " ".join(functions_parallel)))

    list_status = []
    for name, functions in (
        ("serial", functions_serial),
        ("parallel", functions_parallel),
    ):
        if not functions:
            continue

        if any([s for (f, s) in list_status]) and arguments.abort_on_failure:
            break

        try:
            if functions == functions_parallel and jobs != 1:
                pool = multiprocessing.Pool(processes=jobs)

                results = list()
                # loop across functions and run tests
                for f in functions:
                    func_param = copy.deepcopy(param)
                    func_param.path_output = f
                    res = pool.apply_async(process_function_multiproc, (
                        f,
                        func_param,
                    ))
                    results.append(res)
            else:
                pool = None

            for idx_function, f in enumerate(functions):
                print_line('Checking ' + f)
                if functions == functions_serial or jobs == 1:
                    if arguments.check_filesystem:
                        if os.path.exists(os.path.join(path_tmp, f)):
                            shutil.rmtree(os.path.join(path_tmp, f))
                        sig_0 = fs_signature(path_tmp)

                    func_param = copy.deepcopy(param)
                    func_param.path_output = f

                    res = process_function(f, func_param)

                    if arguments.check_filesystem:
                        sig_1 = fs_signature(path_tmp)
                        fs_ok(sig_0, sig_1, exclude=(f, ))
                else:
                    res = results[idx_function].get()

                list_output, list_status_function = res
                # manage status
                if any(list_status_function):
                    if 1 in list_status_function:
                        print_fail()
                        status = (f, 1)
                    else:
                        print_warning()
                        status = (f, 99)
                    for output in list_output:
                        for line in output.splitlines():
                            print("   %s" % line)
                else:
                    print_ok()
                    if param.verbose:
                        for output in list_output:
                            for line in output.splitlines():
                                print("   %s" % line)
                    status = (f, 0)
                # append status function to global list of status
                list_status.append(status)
                if any([s for (f, s) in list_status
                        ]) and arguments.abort_on_failure:
                    break
        except KeyboardInterrupt:
            raise
        finally:
            if pool:
                pool.terminate()
                pool.join()

    print('status: ' + str([s for (f, s) in list_status]))
    if any([s for (f, s) in list_status]):
        print("Failures: {}".format(" ".join(
            [f for (f, s) in list_status if s])))

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv('Finished! Elapsed time: ' + str(int(np.round(elapsed_time))) +
           's\n')

    # come back
    os.chdir(curdir)

    # remove temp files
    if param.remove_tmp_file and arguments.execution_folder is None:
        printv('\nRemove temporary files...', 0)
        rmtree(path_tmp)

    e = 0
    if any([s for (f, s) in list_status]):
        e = 1
    # print(e)

    sys.exit(e)