Exemple #1
0
def cut(impath, outdir, cancer, cptac2=False):
    try:
        os.mkdir(outdir)
    except FileExistsError:
        pass
    # load standard image for normalization
    std = staintools.read_image("../colorstandard.png")
    std = staintools.LuminosityStandardizer.standardize(std)
    # cut tiles with coordinates in the name (exclude white)
    start_time = time.time()
    if cptac2:
        CPTAClist = image_ids_in_2(impath, can=cancer)
        CPTACpp = pd.DataFrame(CPTAClist, columns=['id', 'dir', 'sld', 'scan'])
        CPTACpp[['id', 'dir', 'sld']].to_csv(outdir + '/sum.csv', index=False)
    else:
        CPTAClist = image_ids_in(impath)
        CPTACpp = pd.DataFrame(CPTAClist, columns=['id', 'dir', 'sld'])
        CPTACpp.to_csv(outdir + '/sum.csv', index=False)

    for i in CPTAClist:
        try:
            os.mkdir("{}/{}".format(outdir, i[1]))
        except FileExistsError:
            pass
        try:
            os.mkdir("{}/{}/{}".format(outdir, i[1], i[2]))
        except FileExistsError:
            continue
        outfolder = "{}/{}/{}".format(outdir, i[1], i[2])
        for m in range(1, 4):
            if cptac2 and i[3] == "Scanned by WashU":
                print("Alert: 40X")
                if m == 0:
                    tff = 2
                    level = 0
                elif m == 1:
                    tff = 1
                    level = 1
                elif m == 2:
                    tff = 2
                    level = 1
                elif m == 3:
                    tff = 1
                    level = 2
            else:
                if m == 0:
                    tff = 1
                    level = 0
                elif m == 1:
                    tff = 2
                    level = 0
                elif m == 2:
                    tff = 1
                    level = 1
                elif m == 3:
                    tff = 2
                    level = 1
            otdir = "{}/level{}".format(outfolder, str(m))
            try:
                os.mkdir(otdir)
            except FileExistsError:
                pass
            try:
                n_x, n_y, raw_img, ct = Slicer.tile(image_file=impath + '/' +
                                                    i[0],
                                                    outdir=otdir,
                                                    level=level,
                                                    std_img=std,
                                                    ft=tff)
            except Exception as err:
                print(type(err))
                print(err)
                pass
            if len(os.listdir(otdir)) < 2:
                shutil.rmtree(otdir, ignore_errors=True)
        if len(os.listdir(outfolder)) < 3:
            shutil.rmtree(outfolder, ignore_errors=True)
            print(outfolder + ' has less than 3 levels. Deleted!')

    print("--- %s seconds ---" % (time.time() - start_time))
import staintools
from matplotlib import pyplot as plt
# Read data
target = staintools.read_image("from.png")
to_transform = staintools.read_image("test.png")

# Standardize brightness (This step is optional but can improve the tissue mask calculation)
standardizer = staintools.BrightnessStandardizer()
target = standardizer.transform(target)
to_transform = standardizer.transform(to_transform)

# Stain normalize
normalizer = staintools.StainNormalizer(method='vahadane')
normalizer.fit(target)
transformed = normalizer.transform(to_transform)

plt.imshow(transformed)
plt.show()
#
# colorizer = staintools.ReinhardColorNormalizer()
# colorizer.fit(target)


def transform(image, target):
    # Read data
    target = staintools.read_image("from.png")
    to_transform = staintools.read_image("test.png")

    # Standardize brightness (This step is optional but can improve the tissue mask calculation)
    standardizer = staintools.BrightnessStandardizer()
    target = standardizer.transform(target)
Exemple #3
0
def cutter(img, outdirr, imgdir, dp=None, resolution=None):
    try:
        os.mkdir(outdirr)
    except(FileExistsError):
        pass
    import panoptes
    # load standard image for normalization
    std = staintools.read_image("{}/colorstandard.png".format(panoptes.__path__[0]))
    std = staintools.LuminosityStandardizer.standardize(std)
    if resolution == 20:
        for m in range(1, 4):
            level = int(m / 2)
            tff = int(m % 2 + 1)
            otdir = "{}/level{}".format(outdirr, str(m))
            try:
                os.mkdir(otdir)
            except(FileExistsError):
                pass
            try:
                numx, numy, raw, tct = Slicer.tile(image_file=img, outdir=otdir,
                                                   level=level, std_img=std, ft=tff, dp=dp, path_to_slide=imgdir)
            except Exception as e:
                print('Error!')
                pass
    elif resolution == 40:
        for m in range(1, 4):
            level = int(m / 3 + 1)
            tff = int(m / level)
            otdir = "{}/level{}".format(outdirr, str(m))
            try:
                os.mkdir(otdir)
            except(FileExistsError):
                pass
            try:
                numx, numy, raw, tct = Slicer.tile(image_file=img, outdir=otdir,
                                                   level=level, std_img=std, ft=tff, dp=dp, path_to_slide=imgdir)
            except Exception as e:
                print('Error!')
                pass
    else:
        if "TCGA" in img:
            for m in range(1, 4):
                level = int(m / 3 + 1)
                tff = int(m / level)
                otdir = "{}/level{}".format(outdirr, str(m))
                try:
                    os.mkdir(otdir)
                except(FileExistsError):
                    pass
                try:
                    numx, numy, raw, tct = Slicer.tile(image_file=img, outdir=otdir,
                                                       level=level, std_img=std, ft=tff, dp=dp, path_to_slide=imgdir)
                except Exception as e:
                    print('Error!')
                    pass
        else:
            for m in range(1, 4):
                level = int(m / 2)
                tff = int(m % 2 + 1)
                otdir = "{}/level{}".format(outdirr, str(m))
                try:
                    os.mkdir(otdir)
                except(FileExistsError):
                    pass
                try:
                    numx, numy, raw, tct = Slicer.tile(image_file=img, outdir=otdir,
                                                       level=level, std_img=std, ft=tff, dp=dp, path_to_slide=imgdir)
                except Exception as e:
                    print('Error!')
                    pass
Exemple #4
0
def get_stain_normalizer(path='/path/to/reference/image', method='vahadane'):
    target = staintools.read_image(path)
    target = staintools.LuminosityStandardizer.standardize(target)
    normalizer = staintools.StainNormalizer(method=method)
    normalizer.fit(target)
    return normalizer
Exemple #5
0
def cut():
    # load standard image for normalization
    std = staintools.read_image("../colorstandard.png")
    std = staintools.LuminosityStandardizer.standardize(std)
    CPTACpath = '../images/CPTAC/'
    TCGApath = '../images/TCGA/'
    ref = pd.read_csv('../dummy_His_MUT_joined.csv', header=0)
    refls = ref['name'].tolist()
    # cut tiles with coordinates in the name (exclude white)
    start_time = time.time()
    CPTAClist = image_ids_in(CPTACpath, 'CPTAC')
    TCGAlist = image_ids_in(TCGApath, 'TCGA')

    CPTACpp = pd.DataFrame(CPTAClist, columns=['id', 'dir', 'sld'])
    CPTACcc = CPTACpp['dir'].value_counts()
    CPTACcc = CPTACcc[CPTACcc > 1].index.tolist()
    print(CPTACcc)

    TCGApp = pd.DataFrame(TCGAlist, columns=['id', 'dir', 'sld'])
    TCGAcc = TCGApp['dir'].value_counts()
    TCGAcc = TCGAcc[TCGAcc > 1].index.tolist()
    print(TCGAcc)

    # CPTAC
    for i in CPTAClist:
        matchrow = ref.loc[ref['name'] == i[1]]
        if matchrow.empty:
            continue
        try:
            os.mkdir("../tiles/{}".format(i[1]))
        except (FileExistsError):
            pass
        for m in range(4):
            if m == 0:
                tff = 1
                level = 0
            elif m == 1:
                tff = 2
                level = 0
            elif m == 2:
                tff = 1
                level = 1
            elif m == 3:
                tff = 2
                level = 1
            otdir = "../tiles/{}/level{}".format(i[1], str(m))
            try:
                os.mkdir(otdir)
            except (FileExistsError):
                pass
            try:
                n_x, n_y, raw_img, ct = Slicer.tile(image_file='CPTAC/' + i[0],
                                                    outdir=otdir,
                                                    level=level,
                                                    std_img=std,
                                                    dp=i[2],
                                                    ft=tff)
            except (IndexError):
                pass
            if len(os.listdir(otdir)) < 2:
                shutil.rmtree(otdir, ignore_errors=True)
        # else:
        #     print("pass: {}".format(str(i)))

    # TCGA
    for i in TCGAlist:
        matchrow = ref.loc[ref['name'] == i[1]]
        if matchrow.empty:
            continue
        try:
            os.mkdir("../tiles/{}".format(i[1]))
        except (FileExistsError):
            pass
        for m in range(4):
            if m == 0:
                tff = 2
                level = 0
            elif m == 1:
                tff = 1
                level = 1
            elif m == 2:
                tff = 2
                level = 1
            elif m == 3:
                tff = 1
                level = 2
            otdir = "../tiles/{}/level{}".format(i[1], str(m))
            try:
                os.mkdir(otdir)
            except (FileExistsError):
                pass
            try:
                n_x, n_y, raw_img, ct = Slicer.tile(image_file='TCGA/' + i[0],
                                                    outdir=otdir,
                                                    level=level,
                                                    std_img=std,
                                                    dp=i[2],
                                                    ft=tff)
            except Exception as e:
                print('Error!')
                pass

            if len(os.listdir(otdir)) < 2:
                shutil.rmtree(otdir, ignore_errors=True)

    print("--- %s seconds ---" % (time.time() - start_time))
    subfolders = [f.name for f in os.scandir('../tiles/') if f.is_dir()]
    for w in subfolders:
        if w not in refls:
            print(w)
@author: SENSETIME\yuxian
"""
import os
import staintools
stain_normalizer = staintools.StainNormalizer(method='vahadane')
standardizer = staintools.BrightnessStandardizer()  #Brightness Standardization

#patch_dir = 'PATCHES_NORMAL_TRAIN'
#patch_dir = 'PATCHES_TUMOR_TRAIN'
#patch_dir = 'PATCHES_NORMAL_VALID'
patch_dir = 'PATCHES_TUMOR_VALID'

img_dir = '/mnt/lustre/yuxian/Code/NCRF-master/Data/1024/' + patch_dir
save_dir = '/mnt/lustre/yuxian/Code/NCRF-master/Data/Patch_Stain_Norm/1024/SN/' + patch_dir
img_files = os.listdir(img_dir)

num = 0
total = len(img_files)

for img_file in img_files:
    if '.jpeg' in img_file:
        num += 1
        print('%d/%d with img name %s' % (num, total, img_file))

        Original_img = staintools.read_image(os.path.join(img_dir, img_file))
        '''光照信息'''
        img_standard = standardizer.transform(Original_img)

        img_standard_normalized = stain_normalizer.transform(img_standard)

        img_standard_normalized.save(os.path.join(save_dir, img_file))
Exemple #7
0
def main():  # run as main
    import warnings
    import argparse
    parser = argparse.ArgumentParser(description='Neutrophil MIL data prep')
    parser.add_argument('--scn_dir',
                        type=str,
                        default='.',
                        help='directory of scn files.')
    parser.add_argument('--png_dir',
                        type=str,
                        default=None,
                        help='output directory of png files.')
    parser.add_argument('--tfr_dir',
                        type=str,
                        default=None,
                        help='output directory of TFRecords.')
    parser.add_argument('--dic_dir',
                        type=str,
                        default=None,
                        help='output directory of tile dictionaries.')
    parser.add_argument('--tile_size',
                        type=int,
                        default=299,
                        help='tile size pix')
    parser.add_argument('--overlap',
                        type=int,
                        default=49,
                        help='tile size pix')
    parser.add_argument('--standard',
                        type=str,
                        default=None,
                        help='standard image for color normalization')
    parser.add_argument('--slide_lab',
                        type=str,
                        default=None,
                        help='txt file containing slide level labels')

    args = parser.parse_args()
    for arg in vars(args):
        print(arg, getattr(args, arg))

    # create output directories
    for dirs in (args.png_dir, args.dic_dir, args.tfr_dir):
        if dirs:
            try:
                os.makedirs(dirs)
            except FileExistsError:
                pass

    if args.standard:  # standard image for color normalization
        std = staintools.read_image(args.standard)
        std = staintools.LuminosityStandardizer.standardize(std)
    else:
        std = np.asarray([])

    # read slide level file as dictionary, if provided
    if args.slide_lab:
        s_lab = {}
        with open(args.slide_lab) as l:
            for line in l:
                line = line.replace('"', '').strip()
                (key, val) = line.split(',')
                s_lab[key] = val

    # list all scn files
    scn_ls = os.listdir(args.scn_dir)
    scn_ls = list(filter(lambda x: (x[-4:] == '.scn'), scn_ls))

    # iterate through all scn files
    for f in scn_ls:
        f_id = f.split('.')[0]
        print('Slide id: {}'.format(f_id))

        if args.png_dir:  # whether output png files
            out_dir = args.png_dir + '/' + f_id
            try:
                os.mkdir(out_dir)
            except FileExistsError:
                pass

        else:
            out_dir = None

        # read slide and get tiles using multiple processing.
        # save png files if asked; save tile info and image arrays in lists.

        n_x, n_y, lowres, residue_x, residue_y, imglist, imlocpd, ct =\
            tile(f, f_id, out_dir=out_dir, std_img=std, path_to_slide=args.scn_dir,
                 tile_size=args.tile_size, overlap=args.overlap)
        print('number columns:{}'.format(n_x))
        print('number of rows: {}'.format(n_y))
        print('total number of tiles in slide {}: {}'.format(f_id, ct))

        dims = list(map(np.shape, imglist))

        print(dims[0])
        print(type(imglist[0]))
        print(imglist[0].shape)
        assert (all(x == dims[0]
                    for x in dims)), "Images are of different dimensions"

        if args.dic_dir:  # if save tile info in csv
            imlocpd.to_csv(args.dic_dir + '/' + f_id + "_dict.csv",
                           index=False)
            print('Tile info saved in: ' + args.dic_dir + '/' + f_id +
                  "_dict.csv")

        if args.tfr_dir:  # if output TFRecords files for future training

            tf_fn = args.tfr_dir + '/' + f_id + '.tfrecords'
            writer = tf.python_io.TFRecordWriter(tf_fn)
            try:
                s_lab  # if slide level labels are available
                try:  # get slide level label from input dictionary
                    lab = int(s_lab[f_id])
                    print('slide {} label: {}'.format(f_id, str(lab)))
                except KeyError:  # if slide id not in the label list
                    lab = 999  # numeric code for missing value
                    print('slide {} has no label: coded as {}'.format(
                        f_id, str(lab)))
            except NameError:  # no labels provided
                lab = 999

            for i in range(len(imglist)):
                feature = {
                    'dim':
                    _bytes_feature(
                        tf.compat.as_bytes(np.asarray(dims[i]).tostring())),
                    'image':
                    _bytes_feature(tf.compat.as_bytes(imglist[i].tostring())),
                    'label':
                    _int64_feature(lab)
                }
                example = tf.train.Example(features=tf.train.Features(
                    feature=feature))
                writer.write(example.SerializeToString())
            writer.close()
 def __init__(self, target_fname, method):
     target = staintools.read_image(target_fname)
     self.normalizer = staintools.StainNormalizer(method=method)
     self.normalizer.fit(target)
 def stain_norm_func(self, target_image_path):
     target = staintools.read_image(target_image_path)
     target = staintools.LuminosityStandardizer.standardize(target)
     normalizer = staintools.StainNormalizer(method='vahadane')
     normalizer.fit(target)
     return normalizer
Exemple #10
0
import staintools
import datetime

# Set up
METHOD = 'vahadane'
STANDARDIZE_BRIGHTNESS = True
RESULTS_DIR = './results ' + str(datetime.datetime.now()) + '/'

# Read the images
i1 = staintools.read_image("./data/i1.png")
i2 = staintools.read_image("./data/i2.png")
i3 = staintools.read_image("./data/i3.png")
i4 = staintools.read_image("./data/i4.png")
i5 = staintools.read_image("./data/i5.png")

# Plot
stack = staintools.make_image_stack([i1, i2, i3, i4, i5])
titles = ["Target"] + ["Original"] * 4
staintools.plot_image_stack(stack, width=5, title_list=titles, \
                            save_name=RESULTS_DIR + 'original-images.png', show=0)

# =========================
# Brightness standarization
# (Can skip but can help with tissue mask detection)
# =========================

if STANDARDIZE_BRIGHTNESS:

    # Standardize brightness
    standardizer = staintools.BrightnessStandardizer()
    i1 = standardizer.transform(i1)
def slide_process_C8(model, slide, patch_n_w_l0, patch_n_h_l0, p_s, m_p_s):

    #INITIALIZE STAIN NORMALIZER
    st = staintools.read_image('images/standard_he_stain_small.jpg')
    stain_norm.fit(st)

    #CREATE CHUNK FOR MAP WITH PREDICTIONS
    wsi_map_preds = np.zeros((patch_n_h_l0, patch_n_w_l0, 3), dtype=np.float32)

    #Start loop
    for hi in range(patch_n_h_l0):
        h = hi * p_s + 1
        if (hi == 0):
            h = 0
        print("Current cycle ", hi + 1, " of ", patch_n_h_l0)
        for wi in range(patch_n_w_l0):
            w = wi * p_s + 1
            if (wi == 0):
                w = 0

            #Generate patch
            work_patch = slide.read_region((w, h), 0, (p_s, p_s))
            work_patch = work_patch.convert('RGB')

            #Resize to model patch size (depends on target magnification)
            work_patch = work_patch.resize((m_p_s, m_p_s), Image.ANTIALIAS)

            #Patch image to array
            wp_temp = np.array(work_patch)

            #Control: 1. Is image black? (background of Mirax images is typical black)
            #Control: 2. Tissue detector.
            if (det.tissue_detector(wp_temp) == True):
                #stain normalization
                if (det.blue_detector(wp_temp) == True):
                    wp_temp = standardizer.transform(wp_temp)
                    wp_temp = stain_norm.transform(wp_temp)

                    im_sn = Image.fromarray(wp_temp)

                    wp_temp = np.float32(wp_temp)

                    #PREPROCESSING
                    wp_temp = np.expand_dims(wp_temp, axis=0)
                    wp_temp /= 255.

                    #prediction from model
                    preds = model.predict(wp_temp)

                    #record predictions into map using function
                    record_map_preds(wsi_map_preds, hi, wi, preds)

                    #if patch in gray zone
                    #(Tumor Class probability is between 0.2 and 0.8)
                    #Analyse through C8 algorithm and update predictions
                    if (preds[0, 2]) >= 0.2 and (preds[0, 2]) < 0.8:
                        preds_C8 = gateway_median(model, im_sn)
                        #update predictions after C8 analysis
                        record_map_preds(wsi_map_preds, hi, wi, preds_C8)

    return (wsi_map_preds)  #returns mathematical map with predictions
Exemple #12
0
def gen_imgs(samples,
             crop_size,
             batch_size,
             type,
             shuffle=True,
             color_norm=False,
             target="./data/svs_patches/01_01_0091_12800_22528.png"):
    '''
    :param samples: a dataframe which contains the top left coordinates of all patches which contain at least 50% tissue from all images
    :param batch_size: an int stands for size of the batch
    :param shuffle: an option whether shuffle samples
    :param color_norm: an options whether do color normalization
    :param target: the path of the base image to do color normalization
    :return: np.arrary of X_train and y_train
    '''

    save_svs_patches = './data/svs_patches'
    save_patches = './data/' + type + '_patches'
    num_samples = len(samples)

    while 1:
        if shuffle:
            samples = samples.sample(frac=1)
        # select a sub-dataframe with size of batch size
        for offset in range(0, num_samples, batch_size):
            batch_samples = samples.iloc[offset:offset + batch_size]
            images = []
            masks = []
            id_list = list(batch_samples['id'])
            x_list = list(batch_samples['x'])
            y_list = list(batch_samples['y'])

            for i in range(batch_size):
                a = id_list[i]
                print('********** Crop in ' + str(a) + ' **********')
                x = int(x_list[i])
                y = int(y_list[i])
                print(str(x + 512) + '.....' + str(y + 512))
                slide_patch, mask_patch = crop(a, (x, y), crop_size, type)[0:2]

                # color normalization
                if color_norm:
                    target = staintools.read_image(target)
                    target = staintools.LuminosityStandardizer.standardize(
                        target)
                    slide_patch = staintools.LuminosityStandardizer.standardize(
                        slide_patch)
                    normalizer = staintools.StainNormalizer(method='vahadane')
                    normalizer.fit(target)
                    slide_patch = normalizer.transform(slide_patch)

                # save patches
                if not os.path.exists(save_svs_patches):
                    os.mkdir(save_svs_patches)
                imsave(
                    osp.join(save_svs_patches,
                             str(a) + '_' + str(x) + '_' + str(y) + '.png'),
                    slide_patch)
                if not os.path.exists(save_patches):
                    os.mkdir(save_patches)
                imsave(
                    osp.join(
                        save_patches,
                        str(a) + '_' + str(x) + '_' + str(y) + '_' + type +
                        '.png'), mask_patch)

                images.append(slide_patch)
                masks.append(mask_patch)

                batch_samples = pd.DataFrame(batch_samples)

                X_train = np.array(images)
                y_train = np.array(masks)
                print(np.shape(y_train))
                y_train = to_categorical(y_train, num_classes=2).reshape(
                    y_train.shape[0], 512, 512, 2)

            yield X_train, y_train
Exemple #13
0
def gen_imgs_random(id_list,
                    crop_size,
                    batch_size,
                    type,
                    color_norm=False,
                    target="./data/svs_patches/01_01_0091_12800_22528.png"):
    '''
    :param id_list: a list contains all images ids, all id has the following format: 01_01_0083
    :param batch_size: an int stands for size of the batch
    :param crop_size: a tuple stands for the size for each patch
    :param color_norm: an options whether do the color normalization
    :param target: the path of the base image to do color normalization
    :return: np.arrary of X_train and y_train
    '''

    save_svs_patches = './data/svs_patches_random'
    save_mask_patches = './data/' + str(type) + '_patches_random'

    while 1:
        images = []
        masks = []
        if not os.path.exists(save_svs_patches):
            os.mkdir(save_svs_patches)
        if not os.path.exists(save_mask_patches):
            os.mkdir(save_mask_patches)

        # produce a sample with a fit batch size
        counter = 0
        while counter < batch_size:
            img_id = random.choice(id_list)
            slide_path = './data/OriginalImage/' + str(img_id) + '.svs'
            if not os.path.exists(slide_path):
                slide_path = './data/OriginalImage/' + str(img_id) + '.SVS'
            slide = openslide.open_slide(slide_path)
            mask_path = './data/' + type.capitalize() + 'Mask/' + str(
                img_id) + '_' + type + '.tif'
            mask = io.imread(mask_path)

            # inisilize the top left coordinate for each patch
            shape = np.shape(mask)
            start_x = np.random.randint(shape[1] - crop_size[1])
            start_y = np.random.randint(shape[0] - crop_size[0])

            # if the patch is already cropped, drop this patch
            slide_patch_save_path = osp.join(
                save_svs_patches,
                str(img_id) + '_' + str(start_x) + '_' + str(start_y) + '.png')
            if not os.path.exists(slide_patch_save_path):
                croped_slide_img = slide.read_region((start_x, start_y), 0,
                                                     crop_size)
                croped_slide_img = np.array(croped_slide_img)
                '''
                # convert the patch from RGBA to grey scale in order to drop the patch which contains much backgrpound
                img_grey = cv2.cvtColor(croped_slide_img, cv2.COLOR_RGBA2GRAY)
                if len(np.unique(img_grey)) != 1:
                    img_grey = np.array(img_grey)
                    threshold = threshold_otsu(img_grey)
                    # drop the patch where tissue is less than 50%
                    if np.sum(img_grey < threshold) > 0.5 * crop_size[0] * crop_size[1]:
                '''
                # if option of color normalization is true, do color normalization
                if color_norm:
                    target = staintools.read_image(target)
                    target = staintools.LuminosityStandardizer.standardize(
                        target)
                    croped_slide_img = staintools.LuminosityStandardizer.standardize(
                        croped_slide_img)
                    normalizer = staintools.StainNormalizer(method='vahadane')
                    normalizer.fit(target)
                    croped_slide_img = normalizer.transform(croped_slide_img)

                # save patches
                croped_mask_img = mask[start_y:start_y + crop_size[0],
                                       start_x:start_x + crop_size[1]]
                croped_mask_img_2 = croped_mask_img * 255
                imsave(slide_patch_save_path, croped_slide_img)
                imsave(
                    osp.join(
                        save_mask_patches,
                        str(img_id) + '_' + str(start_x) + '_' + str(start_y) +
                        '_' + type + '.png'), croped_mask_img_2)

                images.append(croped_slide_img)
                masks.append(croped_mask_img)
                X_train = np.array(images)
                y_train = np.array(masks)
                y_train = to_categorical(y_train, num_classes=2).reshape(
                    y_train.shape[0], crop_size[0], crop_size[1], 2)

                counter += 1

        yield X_train, y_train