コード例 #1
0
ファイル: dataset_utils.py プロジェクト: bbastardes/TractSeg
def cut_and_scale_img_back_to_original_img(data, t):
    '''
    Undo the transformations done with pad_and_scale_img_to_square_img

    data: 3D or 4D image
    t: transformation dict
    '''
    nr_dims = len(data.shape)
    assert (nr_dims >= 3 and nr_dims <= 4)

    # Back to old size
    # use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling
    if nr_dims == 3:
        new_data = ndimage.zoom(data, (1. / t["zoom"]), order=0)
    elif nr_dims == 4:
        new_data = img_utils.resize_first_three_dims(data, order=0, zoom=(1. / t["zoom"]))

    x_residual = 0
    y_residual = 0
    z_residual = 0

    # check if has 0.5 residual -> we have to cut 1 pixel more at the end
    if t["pad_x"] - int(t["pad_x"]) == 0.5:
        x_residual = 1
    if t["pad_y"] - int(t["pad_y"]) == 0.5:
        y_residual = 1
    if t["pad_z"] - int(t["pad_z"]) == 0.5:
        z_residual = 1

    # Cut padding
    shape = new_data.shape
    new_data = new_data[int(t["pad_x"]): shape[0] - int(t["pad_x"]) - x_residual,
                        int(t["pad_y"]): shape[1] - int(t["pad_y"]) - y_residual,
                        int(t["pad_z"]): shape[2] - int(t["pad_z"]) - z_residual]
    return new_data
コード例 #2
0
ファイル: slicer.py プロジェクト: yangsenwxy/TractSeg
def _create_slices_file(Config, subjects, filename, slice, shuffle=True):
    data_dir = join(C.HOME, Config.DATASET_FOLDER)

    dwi_slices = []
    mask_slices = []

    print("\n\nProcessing Data...")
    for s in subjects:
        print("processing dwi subject {}".format(s))

        dwi = nib.load(join(data_dir, s, Config.FEATURES_FILENAME + ".nii.gz"))
        dwi_data = dwi.get_data()
        dwi_data = np.nan_to_num(dwi_data)
        dwi_data = dataset_utils.scale_input_to_unet_shape(
            dwi_data, Config.DATASET, Config.RESOLUTION)

        #Use slices from all directions in one dataset
        for z in range(dwi_data.shape[0]):
            dwi_slices.append(dwi_data[z, :, :, :])
        for z in range(dwi_data.shape[1]):
            dwi_slices.append(dwi_data[:, z, :, :])
        for z in range(dwi_data.shape[2]):
            dwi_slices.append(dwi_data[:, :, z, :])

    dwi_slices = np.array(dwi_slices)
    random_idxs = None
    if shuffle:
        random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices))
        dwi_slices = dwi_slices[random_idxs]

    np.save(filename + "_data.npy", dwi_slices)
    del dwi_slices  #free memory

    print("\n\nProcessing Segs...")
    for s in subjects:
        print("processing seg subject {}".format(s))

        mask_data = img_utils.create_multilabel_mask(
            Config, s, labels_type=Config.LABELS_TYPE)
        if Config.RESOLUTION == "2.5mm":
            mask_data = img_utils.resize_first_three_dims(mask_data,
                                                          order=0,
                                                          zoom=0.5)
        mask_data = dataset_utils.scale_input_to_unet_shape(
            mask_data, Config.DATASET, Config.RESOLUTION)

        # Use slices from all directions in one dataset
        for z in range(dwi_data.shape[0]):
            mask_slices.append(mask_data[z, :, :, :])
        for z in range(dwi_data.shape[1]):
            mask_slices.append(mask_data[:, z, :, :])
        for z in range(dwi_data.shape[2]):
            mask_slices.append(mask_data[:, :, z, :])

    mask_slices = np.array(mask_slices)
    print("SEG TYPE: {}".format(mask_slices.dtype))
    if shuffle:
        mask_slices = mask_slices[random_idxs]

    np.save(filename + "_seg.npy", mask_slices)
コード例 #3
0
def scale_input_to_original_shape(img4d, dataset, resolution="1.25mm"):
    """
    Scale input image to original resolution and pad/cut image to make it original size.
    This is not generic but optimised for some specific datasets.

    Args:
        img4d:  (x, y, z, classes)
        dataset: HCP|HCP_32g|TRACED|Schizo
        resolution: 1.25mm|2mm|2.5mm

    Returns:
        (x_original, y_original, z_original, classes)
    """
    if resolution == "1.25mm":
        if dataset == "HCP":  # (144,144,144)
            # no resize needed
            return img_utils.pad_4d_image_left(img4d, np.array([1, 15, 1, 0]),
                                               [146, 174, 146, img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
        elif dataset == "HCP_32g":  # (144,144,144)
            # no resize needed
            return img_utils.pad_4d_image_left(img4d, np.array([1, 15, 1, 0]),
                                               [146, 174, 146, img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")
        elif dataset == "Schizo":  # (144,144,144)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([1, 15, 1, 0]),
                                                [145, 174, 145, img4d.shape[3]], pad_value=0)  # (145, 174, 145, none)
            return img_utils.resize_first_three_dims(img4d, zoom=0.62)  # (91,109,91)

    elif resolution == "2mm":
        if dataset == "HCP":  # (80,80,80)
            return img_utils.pad_4d_image_left(img4d, np.array([5, 14, 5, 0]),
                                               [90, 108, 90, img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
        elif dataset == "HCP_32g":  # (80,80,80)
            return img_utils.pad_4d_image_left(img4d, np.array([5, 14, 5, 0]),
                                               [90, 108, 90, img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
        elif dataset == "HCP_2mm":  # (80,80,80)
            return img_utils.pad_4d_image_left(img4d, np.array([5, 14, 5, 0]),
                                               [90, 108, 90, img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")

    elif resolution == "2.5mm":
        if dataset == "HCP":  # (80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]),
                                                [80, 87, 80, img4d.shape[3]], pad_value=0) # (80,87,80,none)
            return img4d[4:77,:,4:77, :] # (73, 87, 73, none)
        elif dataset == "HCP_2.5mm":  # (80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]),
                                                [80, 87, 80, img4d.shape[3]], pad_value=0)  # (80,87,80,none)
            return img4d[4:77,:,4:77,:]  # (73, 87, 73, none)
        elif dataset == "HCP_32g":  # ((80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]),
                                                [80, 87, 80, img4d.shape[3]], pad_value=0)  # (80,87,80,none)
            return img4d[4:77, :, 4:77, :]  # (73, 87, 73, none)
        elif dataset == "TRACED":  # (80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 7, 0, 0]),
                                                [80, 93, 80, img4d.shape[3]], pad_value=0)  # (80,93,80,none)
            return img4d[1:79, :, 3:78, :]  # (78,93,75,none)
コード例 #4
0
ファイル: dataset_utils.py プロジェクト: bbastardes/TractSeg
def scale_input_to_world_shape(img4d, dataset, resolution="1.25mm"):
    '''
    Scale input image to original resolution and pad/cut image to make it original size

    :param img4d: (x, y, z, userdefined)  (userdefined could be gradients or classes)
    :param resolution: "1.25mm" / "2mm" / "2.5mm"
    :return: img with original size
    '''

    if resolution == "1.25mm":
        if dataset == "HCP":  # (144,144,144)
            # no resize needed
            return img_utils.pad_4d_image_left(img4d, np.array([1, 15, 1, 0]),
                                               [146, 174, 146, img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
        elif dataset == "HCP_32g":  # (144,144,144)
            # no resize needed
            return img_utils.pad_4d_image_left(img4d, np.array([1, 15, 1, 0]),
                                               [146, 174, 146, img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")
        elif dataset == "Schizo":  # (144,144,144)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([1, 15, 1, 0]),
                                                [145, 174, 145, img4d.shape[3]], pad_value=0)  # (145, 174, 145, none)
            return img_utils.resize_first_three_dims(img4d, zoom=0.62)  # (91,109,91)

    elif resolution == "2mm":
        if dataset == "HCP":  # (80,80,80)
            return img_utils.pad_4d_image_left(img4d, np.array([5, 14, 5, 0]),
                                               [90, 108, 90, img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
        elif dataset == "HCP_32g":  # (80,80,80)
            return img_utils.pad_4d_image_left(img4d, np.array([5, 14, 5, 0]),
                                               [90, 108, 90, img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
        elif dataset == "HCP_2mm":  # (80,80,80)
            return img_utils.pad_4d_image_left(img4d, np.array([5, 14, 5, 0]),
                                               [90, 108, 90, img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")

    elif resolution == "2.5mm":
        if dataset == "HCP":  # (80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]),
                                                [80, 87, 80, img4d.shape[3]], pad_value=0) # (80,87,80,none)
            return img4d[4:77,:,4:77, :] # (73, 87, 73, none)
        elif dataset == "HCP_2.5mm":  # (80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]),
                                                [80, 87, 80, img4d.shape[3]], pad_value=0)  # (80,87,80,none)
            return img4d[4:77,:,4:77,:]  # (73, 87, 73, none)
        elif dataset == "HCP_32g":  # ((80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]),
                                                [80, 87, 80, img4d.shape[3]], pad_value=0)  # (80,87,80,none)
            return img4d[4:77, :, 4:77, :]  # (73, 87, 73, none)
        elif dataset == "TRACED":  # (80,80,80)
            img4d = img_utils.pad_4d_image_left(img4d, np.array([0, 7, 0, 0]),
                                                [80, 93, 80, img4d.shape[3]], pad_value=0)  # (80,93,80,none)
            return img4d[1:79, :, 3:78, :]  # (78,93,75,none)
コード例 #5
0
def pad_and_scale_img_to_square_img(data, target_size=144, nr_cpus=-1):
    '''
    Expects 3D or 4D image as input.

    Does
    1. Pad image with 0 to make it square
        (if uneven padding -> adds one more px "behind" img; but resulting img shape will be correct)
    2. Scale image to UNet size (144, 144, 144)
    '''
    nr_dims = len(data.shape)
    assert (nr_dims >= 3 and nr_dims <= 4)

    shape = data.shape
    biggest_dim = max(shape)

    # Pad to make square
    if nr_dims == 4:
        new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim,
                            shape[3])).astype(data.dtype)
    else:
        new_img = np.zeros(
            (biggest_dim, biggest_dim, biggest_dim)).astype(data.dtype)
    pad1 = (biggest_dim - shape[0]) / 2.
    pad2 = (biggest_dim - shape[1]) / 2.
    pad3 = (biggest_dim - shape[2]) / 2.
    new_img[int(pad1):int(pad1) + shape[0],
            int(pad2):int(pad2) + shape[1],
            int(pad3):int(pad3) + shape[2]] = data

    # Scale to right size
    zoom = float(target_size) / biggest_dim
    if nr_dims == 4:
        #use order=0, otherwise does not work for peak images (results would be wrong)
        new_img = img_utils.resize_first_three_dims(new_img,
                                                    order=0,
                                                    zoom=zoom,
                                                    nr_cpus=nr_cpus)
    else:
        new_img = ndimage.zoom(new_img, zoom, order=0)

    transformation = {
        "original_shape": shape,
        "pad_x": pad1,
        "pad_y": pad2,
        "pad_z": pad3,
        "zoom": zoom
    }

    return new_img, transformation
コード例 #6
0
ファイル: slicer.py プロジェクト: yangsenwxy/TractSeg
def save_fusion_nifti_as_npy():

    #Can leave this always the same (for 270g and 32g)
    class Config:
        DATASET = "HCP"
        RESOLUTION = "1.25mm"
        FEATURES_FILENAME = "270g_125mm_peaks"
        LABELS_TYPE = np.int16
        LABELS_FILENAME = "bundle_masks"
        DATASET_FOLDER = "HCP"

    DIFFUSION_FOLDER = "32g_25mm"
    subjects = get_all_subjects()

    print("\n\nProcessing Data...")
    for s in subjects:
        print("processing data subject {}".format(s))
        start_time = time.time()
        data = nib.load(
            join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER,
                 s + "_probmap.nii.gz")).get_data()
        print("Done Loading")
        data = np.nan_to_num(data)
        data = dataset_utils.scale_input_to_unet_shape(data, Config.DATASET,
                                                       Config.RESOLUTION)
        # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
        data = data[:-1, :, :-1, :]
        exp_utils.make_dir(
            join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
        np.save(
            join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                 DIFFUSION_FOLDER + "_xyz.npy"), data)
        print("Took {}s".format(time.time() - start_time))

        print("processing seg subject {}".format(s))
        start_time = time.time()
        # seg = ImgUtils.create_multilabel_mask(Config, s, labels_type=Config.LABELS_TYPE)
        seg = nib.load(
            join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s,
                 Config.LABELS_FILENAME + ".nii.gz")).get_data()
        if Config.RESOLUTION == "2.5mm":
            seg = img_utils.resize_first_three_dims(seg, order=0, zoom=0.5)
        seg = dataset_utils.scale_input_to_unet_shape(seg, Config.DATASET,
                                                      Config.RESOLUTION)
        np.save(
            join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                 "bundle_masks.npy"), seg)
        print("Took {}s".format(time.time() - start_time))
コード例 #7
0
ファイル: slicer.py プロジェクト: yangsenwxy/TractSeg
def create_one_3D_file():
    '''
    Create one big file which contains all 3D Images (not slices).
    '''
    class Config:
        DATASET = "HCP"
        RESOLUTION = "1.25mm"
        FEATURES_FILENAME = "270g_125mm_peaks"
        LABELS_TYPE = np.int16
        DATASET_FOLDER = "HCP"

    data_all = []
    seg_all = []

    print("\n\nProcessing Data...")
    for s in get_all_subjects():
        print("processing data subject {}".format(s))
        data = nib.load(
            join(C.HOME, Config.DATASET_FOLDER, s,
                 Config.FEATURES_FILENAME + ".nii.gz")).get_data()
        data = np.nan_to_num(data)
        data = dataset_utils.scale_input_to_unet_shape(data, Config.DATASET,
                                                       Config.RESOLUTION)
    data_all.append(np.array(data))
    np.save("data.npy", data_all)
    del data_all  # free memory

    print("\n\nProcessing Segs...")
    for s in get_all_subjects():
        print("processing seg subject {}".format(s))
        seg = img_utils.create_multilabel_mask(Config,
                                               s,
                                               labels_type=Config.LABELS_TYPE)
        if Config.RESOLUTION == "2.5mm":
            seg = img_utils.resize_first_three_dims(seg, order=0, zoom=0.5)
        seg = dataset_utils.scale_input_to_unet_shape(seg, Config.DATASET,
                                                      Config.RESOLUTION)
    seg_all.append(np.array(seg))
    print("SEG TYPE: {}".format(seg_all.dtype))
    np.save("seg.npy", seg_all)
コード例 #8
0
def scale_input_to_unet_shape(img4d, dataset, resolution="1.25mm"):
    """
    Scale input image to right isotropic resolution and pad/cut image to make it square to fit UNet input shape.
    This is not generic but optimised for some specific datasets.

    Args:
        img4d: (x, y, z, classes)
        dataset: HCP|HCP_32g|TRACED|Schizo
        resolution: 1.25mm|2mm|2.5mm

    Returns:
        img with dim 1mm: (144,144,144,none) or 2mm: (80,80,80,none) or 2.5mm: (80,80,80,none)
        (note: 2.5mm padded with more zeros to reach 80,80,80)
    """
    if resolution == "1.25mm":
        if dataset == "HCP":  # (145,174,145)
            # no resize needed
            return img4d[1:, 15:159, 1:]  # (144,144,144)
        elif dataset == "HCP_32g":  # (73,87,73)
            img4d = img_utils.resize_first_three_dims(img4d, zoom=2)  # (146,174,146,none)
            img4d = img4d[:-1,:,:-1]  # remove one voxel that came from upsampling   # (145,174,145)
            return img4d[1:, 15:159, 1:]  # (144,144,144)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")
        elif dataset == "Schizo":  # (91,109,91)
            img4d = img_utils.resize_first_three_dims(img4d, zoom=1.60)  # (146,174,146)
            return img4d[1:145, 15:159, 1:145]                                # (144,144,144)

    elif resolution == "2mm":
        if dataset == "HCP":  # (145,174,145)
            img4d = img_utils.resize_first_three_dims(img4d, zoom=0.62)  # (90,108,90)
            return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
        elif dataset == "HCP_32g":  # (145,174,145)
            img4d = img_utils.resize_first_three_dims(img4d, zoom=0.62)  # (90,108,90)
            return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
        elif dataset == "HCP_2mm":  # (90,108,90)
            # no resize needed
            return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")
        elif dataset == "Schizo":  # (91,109,91)
            return img4d[:, 9:100, :]                                # (91,91,91)

    elif resolution == "2.5mm":
        if dataset == "HCP":  # (145,174,145)
            img4d = img_utils.resize_first_three_dims(img4d, zoom=0.5)  # (73,87,73,none)
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
            bg = bg + img4d[0,0,0,:]
            bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
            return bg  # (80,80,80)
        elif dataset == "HCP_2.5mm":  # (73,87,73,none)
            # no resize needed
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
            bg = bg + img4d[0,0,0,:]
            bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
            return bg  # (80,80,80)
        elif dataset == "HCP_32g":  # (73,87,73,none)
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
            bg = bg + img4d[0, 0, 0, :]
            bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
            return bg  # (80,80,80)
        elif dataset == "TRACED":  # (78,93,75)
            # no resize needed
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            bg = bg + img4d[0, 0, 0, :]  # make bg have same value as bg from original img
            bg[1:79, :, 3:78, :] = img4d[:, 7:87, :, :]
            return bg  # (80,80,80)
コード例 #9
0
def scale_input_to_unet_shape(img4d, dataset, resolution="1.25mm"):
    '''
    Scale input image to right isotropic resolution and pad/cut image to make it square to fit UNet input shape

    :param img4d: (x, y, z, userdefined)  (userdefined could be gradients or classes)
    :param resolution: "1.25mm" / "2mm" / "2.5mm"     results in UNet input shape of (144,144,144) or (80,80,80)
    :return: img with dim 1mm: (144,144,144,none) or 2mm: (80,80,80,none) or 2.5mm: (80,80,80,none)
                (note: 2.5mm padded with more zeros to reach 80,80,80)
    '''

    if resolution == "1.25mm":
        if dataset == "HCP":  # (145,174,145)
            # no resize needed
            return img4d[1:, 15:159, 1:]  # (144,144,144)
        elif dataset == "HCP_32g":  # (73,87,73)
            # return img4d[1:, 15:159, 1:]  # (144,144,144) #OLD when HCP_32g was still 125mm
            img4d = img_utils.resize_first_three_dims(
                img4d, zoom=2)  # (146,174,146,none)
            img4d = img4d[:-1, :, :
                          -1]  #remove one voxel that came from upsampling   #(145,174,145)
            return img4d[1:, 15:159, 1:]  # (144,144,144)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError(
                "resolution '1.25mm' not supported for dataset 'TRACED'")
        elif dataset == "Schizo":  # (91,109,91)
            img4d = img_utils.resize_first_three_dims(
                img4d, zoom=1.60)  # (146,174,146)
            return img4d[1:145, 15:159, 1:145]  # (144,144,144)

    elif resolution == "2mm":
        if dataset == "HCP":  # (145,174,145)
            img4d = img_utils.resize_first_three_dims(img4d,
                                                      zoom=0.62)  # (90,108,90)
            return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
        elif dataset == "HCP_32g":  # (145,174,145)
            img4d = img_utils.resize_first_three_dims(img4d,
                                                      zoom=0.62)  # (90,108,90)
            return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
        elif dataset == "HCP_2mm":  # (90,108,90)
            # no resize needed
            return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
        elif dataset == "TRACED":  # (78,93,75)
            raise ValueError(
                "resolution '2mm' not supported for dataset 'TRACED'")
        elif dataset == "Schizo":  # (91,109,91)
            return img4d[:, 9:100, :]  # (91,91,91)

    elif resolution == "2.5mm":
        if dataset == "HCP":  # (145,174,145)
            img4d = img_utils.resize_first_three_dims(
                img4d, zoom=0.5)  # (73,87,73,none)
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
            bg = bg + img4d[0, 0, 0, :]
            bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
            return bg  # (80,80,80)
        elif dataset == "HCP_2.5mm":  # (73,87,73,none)
            # no resize needed
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
            bg = bg + img4d[0, 0, 0, :]
            bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
            return bg  # (80,80,80)
        elif dataset == "HCP_32g":  # (73,87,73,none)
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
            bg = bg + img4d[0, 0, 0, :]
            bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
            return bg  # (80,80,80)
        elif dataset == "TRACED":  # (78,93,75)
            # no resize needed
            bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
            bg = bg + img4d[
                0, 0, 0, :]  # make bg have same value as bg from original img
            bg[1:79, :, 3:78, :] = img4d[:, 7:87, :, :]
            return bg  # (80,80,80)