Beispiel #1
0
def ImageInMultiROIOut2D(root_folder,
                         input_shape,
                         batch_size=8,
                         hierarchical_level=0,
                         augment_param={}):
    input_number, output_number = _GetInputOutputNumber(root_folder)
    assert (output_number == 1)
    case_list = os.listdir(root_folder)

    input_list = [[] for index in range(input_number)]
    output_list = [[] for index in range(hierarchical_level + 1)]

    param_generator = AugmentParametersGenerator()
    augmentor = DataAugmentor2D()

    while True:
        shuffle(case_list)
        for case in case_list:
            case_path = os.path.join(root_folder, case)
            if not case_path.endswith('.h5'):
                continue

            input_data_list, output_data_list = [], []
            file = h5py.File(case_path, 'r')
            for input_number_index in range(input_number):
                temp_data = np.asarray(file['input_' +
                                            str(input_number_index)])
                if temp_data.ndim == 2:
                    temp_data = temp_data[..., np.newaxis]
                input_data_list.append(temp_data)

            one_roi = np.asarray(file['output_0'])

            param_generator.RandomParameters(augment_param)
            augmentor.SetParameter(param_generator.GetRandomParametersDict())

            input_data_list = _AugmentDataList2D(input_data_list, augmentor)
            one_roi = _AugmentDataList2D([one_roi], augmentor)[0]

            input_data_list = _CropDataList2D(input_data_list, input_shape)
            one_roi = _CropDataList2D([one_roi], input_shape)[0]

            one_roi_list = [one_roi]
            for index in np.arange(1, hierarchical_level + 1):
                temp_roi = deepcopy(
                    cv2.resize(one_roi,
                               (one_roi.shape[0] // np.power(2, index),
                                one_roi.shape[1] // np.power(2, index)),
                               interpolation=cv2.INTER_LINEAR))
                one_roi_list.insert(0, temp_roi[..., np.newaxis])

            _AddOneSample(input_list, input_data_list)
            _AddOneSample(output_list, one_roi_list)

            if len(input_list[0]) >= batch_size:
                inputs = _MakeKerasFormat(input_list)
                outputs = _MakeKerasFormat(output_list)
                yield inputs, outputs
                input_list = [[] for index in range(input_number)]
                output_list = [[] for index in range(hierarchical_level + 1)]
Beispiel #2
0
def ImageInImageOut2D(root_folder, input_shape, batch_size=8, augment_param={}):
    input_number, output_number = _GetInputOutputNumber(root_folder)
    case_list = os.listdir(root_folder)

    input_list = [[] for index in range(input_number)]
    output_list = [[] for index in range(output_number)]

    param_generator = AugmentParametersGenerator()
    augmentor = DataAugmentor2D()
    crop = ImageProcess2D()

    while True:
        shuffle(case_list)
        for case in case_list:
            case_path = os.path.join(root_folder, case)
            if not case_path.endswith('.h5'):
                continue

            input_data_list, output_data_list = [], []
            try:
                file = h5py.File(case_path, 'r')
                for input_number_index in range(input_number):
                    temp_data = np.asarray(file['input_' + str(input_number_index)])
                    if temp_data.ndim == 2:
                        temp_data = temp_data[..., np.newaxis]
                    input_data_list.append(temp_data)
                for output_number_index in range(output_number):
                    temp_data = np.asarray(file['output_' + str(output_number_index)])
                    if temp_data.ndim == 2:
                        temp_data = temp_data[..., np.newaxis]
                    output_data_list.append(temp_data)
                file.close()
            except Exception as e:
                print(case_path)
                print(e.__str__())
                continue

            param_generator.RandomParameters(augment_param)
            augmentor.SetParameter(param_generator.GetRandomParametersDict())

            input_data_list = _AugmentDataList2D(input_data_list, augmentor)
            output_data_list = _AugmentDataList2D(output_data_list, augmentor)

            input_data_list = crop.CropDataList2D(input_data_list, input_shape)
            output_data_list = crop.CropDataList2D(output_data_list, input_shape)

            _AddOneSample(input_list, input_data_list)
            _AddOneSample(output_list, output_data_list)

            if len(input_list[0]) >= batch_size:
                inputs = _MakeKerasFormat(input_list)
                outputs = _MakeKerasFormat(output_list)
                return inputs, outputs
def AugmentScript():
    from MeDIT.SaveAndLoad import LoadNiiData
    from MeDIT.Visualization import Imshow3DArray, DrawBoundaryOfBinaryMask
    from MeDIT.Normalize import Normalize01
    import numpy as np
    import time

    t2_image, _, t2 = LoadNiiData(r'H:/data/TZ roi/BIAN ZHONG BEN/t2.nii',
                                  dtype=np.float32)
    roi_image, _, roi = LoadNiiData(
        r'H:/data/TZ roi/BIAN ZHONG BEN/prostate_roi_TrumpetNet.nii',
        dtype=np.uint8)

    Imshow3DArray(t2, ROI=roi)

    t2_slice = t2[..., 10]
    roi_slice = roi[..., 10]

    # DrawBoundaryOfBinaryMask(t2_slice, roi_slice)
    import matplotlib.pyplot as plt
    # plt.imshow(np.concatenate((Normalize01(t2_slice), Normalize01(roi_slice)), axis=1), cmap='gray')
    # plt.show()

    from MeDIT.DataAugmentor import DataAugmentor2D, AugmentParametersGenerator
    param_dict = {
        'stretch_x': 0.1,
        'stretch_y': 0.1,
        'shear': 0.1,
        'rotate_z_angle': 20,
        'horizontal_flip': True
    }

    augment_generator = AugmentParametersGenerator()
    augmentor = DataAugmentor2D()

    while True:
        augment_generator.RandomParameters(param_dict)
        transform_param = augment_generator.GetRandomParametersDict()
        print(transform_param)

        augment_t2 = augmentor.Execute(t2_slice,
                                       aug_parameter=transform_param,
                                       interpolation_method='linear')
        augment_roi = augmentor.Execute(roi_slice,
                                        aug_parameter=transform_param,
                                        interpolation_method='linear')
Beispiel #4
0
def ImageInImageOut2D(root_folder, input_shape, batch_size=8, augment_param={}):
    input_number, output_number = GetInputOutputNumber(root_folder)
    case_list = os.listdir(root_folder)

    input_list = [[] for index in range(input_number)]
    output_list = [[] for index in range(output_number)]

    param_generator = AugmentParametersGenerator()
    augmentor = DataAugmentor2D()

    while True:
        shuffle(case_list)
        for case in case_list:
            case_path = os.path.join(root_folder, case)
            if not case_path.endswith('.h5'):
                continue

            input_data_list, output_data_list = [], []
            file = h5py.File(case_path, 'r')
            for input_number_index in range(input_number):
                input_data_list.append(file['input_' + str(input_number_index)])
            for output_number_index in range(output_number):
                output_data_list.append(file['output_' + str(output_number_index)])

            param_generator.RandomParameters(augment_param)
            augmentor.SetParameter(param_generator.GetRandomParametersDict())

            input_data_list = AugmentDataList2D(input_data_list, augmentor)
            output_data_list = AugmentDataList2D(output_data_list, augmentor)

            input_data_list = CropDataList2D(input_data_list, input_shape)
            output_data_list = CropDataList2D(output_data_list, input_shape)

            AddOneSample(input_list, input_data_list)
            AddOneSample(output_list, output_data_list)

            if len(input_list[0]) >= batch_size:
                inputs = MakeKerasFormat(input_list)
                outputs = MakeKerasFormat(output_list)
                yield inputs, outputs
                input_list = [[] for index in range(input_number)]
                output_list = [[] for index in range(output_number)]
Beispiel #5
0
def GenerateMultiInputOneOutput_From2DMultiSliceTo2D(root_folder,
                                                     input_shape,
                                                     batch_size=8,
                                                     augment_config='',
                                                     is_yield=True):
    if augment_config:
        with open(augment_config, 'r') as file:
            random_params = json.load(file)
        param_generator = AugmentParametersGenerator()
        aug_generator = DataAugmentor2D()

    case_list = os.listdir(root_folder)

    shuffle(case_list)
    one_path = os.path.join(root_folder, case_list[0])
    info = LoadH5InfoForGenerate(one_path)

    if info['input_number'] <= 1:
        print('Need Multi Input ', one_path)
        return

    input_list = [[] for temp in range(info['input_number'])]
    one_output = LoadH5(one_path, tag='output_0')
    output_array = np.zeros(
        (batch_size, input_shape[0], input_shape[1], one_output.shape[-1]))
    current_batch = 0

    while True:
        shuffle(case_list)
        for case in case_list:
            case_path = os.path.join(root_folder, case)

            if augment_config:
                param_generator.RandomParameters(random_params)
                aug_generator.SetParameter(
                    param_generator.GetRandomParametersDict())

            for input_number in range(info['input_number']):
                data = LoadH5(case_path, tag='input_{:d}'.format(input_number))

                crop_data = np.zeros(
                    (input_shape[0], input_shape[1], data.shape[-1]))
                for slice_index in range(data.shape[-1]):
                    if augment_config:
                        aug_data = aug_generator.Execute(
                            data[..., slice_index],
                            interpolation_method='linear')
                        crop_data[..., slice_index] = ExtractPatch(
                            aug_data, patch_size=input_shape)[0]
                    else:
                        crop_data[..., slice_index] = ExtractPatch(
                            data[..., slice_index], patch_size=input_shape)[0]
                input_list[input_number].append(crop_data)

            output_data = LoadH5(case_path, tag='output_0')
            if augment_config:
                aug_data = aug_generator.Execute(np.squeeze(output_data),
                                                 interpolation_method='linear')
                output_array[current_batch, ...,
                             0] = ExtractPatch(aug_data,
                                               patch_size=input_shape)[0]
            else:
                output_array[current_batch, ...,
                             0] = ExtractPatch(np.squeeze(output_data),
                                               patch_size=input_shape)[0]

            current_batch += 1
            if current_batch >= batch_size:
                input_list = [np.asarray(temp) for temp in input_list]
                return input_list, output_array
                current_batch = 0
                # yield input_list, output_array
                input_list = [[] for temp in range(info['input_number'])]
                output_array = np.zeros((batch_size, input_shape[0],
                                         input_shape[1], one_output.shape[-1]))
Beispiel #6
0
def AugmentTrain(train_folder, batch_size):
    file_list = os.listdir(train_folder)
    image_list = []
    label_list = []
    for i in range(len(file_list)):
        # path
        data_path = os.path.join(train_folder, file_list[i])
        h5_file = h5py.File(data_path, 'r')
        image = np.asarray(h5_file['input_0'], dtype=np.float32)
        label = np.asarray(h5_file['output_0'], dtype=np.uint8)

        # augmentation param
        param_dict = {
            'stretch_x': 0.1,
            'stretch_y': 0.1,
            'shear': 0.1,
            'rotate_z_angle': 20,
            'horizontal_flip': True
        }

        # 设增强函数为augment_generator
        augment_generator = AugmentParametersGenerator()

        # 2D的数据增强
        augmentor = DataAugmentor2D()
        augment_generator.RandomParameters(param_dict)
        transform_param = augment_generator.GetRandomParametersDict()
        augment_t2 = augmentor.Execute(image,
                                       aug_parameter=transform_param,
                                       interpolation_method='linear')
        augment_roi = augmentor.Execute(label,
                                        aug_parameter=transform_param,
                                        interpolation_method='linear')

        # cut
        if np.shape(augment_t2) == (440, 440):
            cropImage = image[100:340, 100:340]
            cropRoi = label[100:340, 100:340]
            cropaugmentImage = augment_t2[100:340, 100:340]
            cropaugmentRoi = augment_roi[100:340, 100:340]

        else:
            cropImage = image[60:300, 60:300]
            cropRoi = label[60:300, 60:300]
            cropaugmentImage = augment_t2[60:300, 60:300]
            cropaugmentRoi = augment_roi[60:300, 60:300]

        # one_hot
        roi_onehot = to_categorical(cropRoi)
        augment_roi_onehot = to_categorical(cropaugmentRoi)

        # show
        # plt.imshow(np.concatenate((Normalize01(cropImage), Normalize01(cropRoi)), axis=1), cmap='gray')
        # plt.show()
        # plt.imshow(np.concatenate((Normalize01(cropaugmentImage), Normalize01(cropaugmentRoi)), axis=1), cmap='gray')
        # plt.show()

        reshape = (240, 240, 1)
        cropImage = cropImage.reshape(reshape)
        cropaugmentImage = cropaugmentImage.reshape(reshape)

        # add data into list
        image_list.append(cropImage)
        label_list.append(roi_onehot)

        image_list.append(cropaugmentImage)
        label_list.append(augment_roi_onehot)

        if len(image_list) >= batch_size:
            yield np.asarray(image_list), np.asarray(label_list)
            image_list = []
            label_list = []