Esempio n. 1
0
    def Train(self):
        fixedRandomSeed = None
        trainToValidationChance = 0.2
        includeEvaluationInValidation = True
        stepMultiplier = None
        stepCount = 1000
        showInputs = False
        augmentationLevel = 0
        detNMSThresh = 0.35
        rpnNMSThresh = 0.55
        trainDir = os.path.join(os.curdir, self.__mParams["train_dir"])
        evalDir = None
        inModelPath = os.path.join(os.curdir, self.__mParams["input_model"])
        outModelPath = os.path.join(os.curdir, self.__mParams["output_model"])
        blankInput = self.__mParams["blank_mrcnn"] == "true"
        maxdim = 1024

        if "eval_dir" in self.__mParams:
            evalDir = os.path.join(os.curdir, self.__mParams["eval_dir"])

        if "image_size" in self.__mParams:
            maxdim = int(self.__mParams["image_size"])

        if "train_to_val_seed" in self.__mParams:
            fixedRandomSeed = self.__mParams["train_to_val_seed"]

        if "train_to_val_ratio" in self.__mParams:
            trainToValidationChance = float(
                self.__mParams["train_to_val_ratio"])

        if "use_eval_in_val" in self.__mParams:
            includeEvaluationInValidation = self.__mParams[
                "use_eval_in_val"] == "true"

        if "step_ratio" in self.__mParams:
            stepMultiplier = float(self.__mParams["step_ratio"])

        if "step_num" in self.__mParams:
            stepCount = int(self.__mParams["step_num"])

        if "show_inputs" in self.__mParams:
            showInputs = self.__mParams["show_inputs"] == "true"

        if "random_augmentation_level" in self.__mParams:
            augmentationLevel = int(
                self.__mParams["random_augmentation_level"])

        if "detection_nms_threshold" in self.__mParams:
            detNMSThresh = float(self.__mParams["detection_nms_threshold"])

        if "rpn_nms_threshold" in self.__mParams:
            rpnNMSThresh = float(self.__mParams["rpn_nms_threshold"])

        rnd = random.Random()
        rnd.seed(fixedRandomSeed)
        trainImagesAndMasks = {}
        validationImagesAndMasks = {}

        # iterate through train set
        imagesDir = os.path.join(trainDir, "images")
        masksDir = os.path.join(trainDir, "masks")

        # splitting train data into train and validation
        imageFileList = [
            f for f in os.listdir(imagesDir)
            if os.path.isfile(os.path.join(imagesDir, f))
        ]
        for imageFile in imageFileList:
            baseName = os.path.splitext(os.path.basename(imageFile))[0]
            imagePath = os.path.join(imagesDir, imageFile)
            maskPath = os.path.join(masksDir, baseName + ".tiff")
            if not os.path.isfile(imagePath) or not os.path.isfile(maskPath):
                continue
            if rnd.random() > trainToValidationChance:
                trainImagesAndMasks[imagePath] = maskPath
            else:
                validationImagesAndMasks[imagePath] = maskPath

        # adding evaluation data into validation
        if includeEvaluationInValidation and evalDir is not None:

            # iterate through test set
            imagesDir = os.path.join(evalDir, "images")
            masksDir = os.path.join(evalDir, "masks")

            imageFileList = [
                f for f in os.listdir(imagesDir)
                if os.path.isfile(os.path.join(imagesDir, f))
            ]
            for imageFile in imageFileList:
                baseName = os.path.splitext(os.path.basename(imageFile))[0]
                imagePath = os.path.join(imagesDir, imageFile)
                maskPath = os.path.join(masksDir, baseName + ".tiff")
                if not os.path.isfile(imagePath) or not os.path.isfile(
                        maskPath):
                    continue
                validationImagesAndMasks[imagePath] = maskPath

        if len(trainImagesAndMasks) < 1:
            raise ValueError("Empty train image list")

        #just to be non-empty
        if len(validationImagesAndMasks) < 1:
            for key, value in trainImagesAndMasks.items():
                validationImagesAndMasks[key] = value
                break

        # Training dataset
        dataset_train = mask_rcnn_additional.NucleiDataset()
        dataset_train.initialize(pImagesAndMasks=trainImagesAndMasks,
                                 pAugmentationLevel=augmentationLevel)
        dataset_train.prepare()

        # Validation dataset
        dataset_val = mask_rcnn_additional.NucleiDataset()
        dataset_val.initialize(pImagesAndMasks=validationImagesAndMasks,
                               pAugmentationLevel=0)
        dataset_val.prepare()

        print("training images (with augmentation):", dataset_train.num_images)
        print("validation images (with augmentation):", dataset_val.num_images)

        config = mask_rcnn_additional.NucleiConfig()
        config.IMAGE_MAX_DIM = maxdim
        config.IMAGE_MIN_DIM = maxdim
        config.STEPS_PER_EPOCH = stepCount
        if stepMultiplier is not None:
            steps = int(float(dataset_train.num_images) * stepMultiplier)
            config.STEPS_PER_EPOCH = steps

        config.VALIDATION_STEPS = dataset_val.num_images
        config.DETECTION_NMS_THRESHOLD = detNMSThresh
        config.RPN_NMS_THRESHOLD = rpnNMSThresh
        config.__init__()
        # show config
        config.display()

        # show setup
        for a in dir(self):
            if not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")

        if showInputs:
            # Load and display random samples
            image_ids = numpy.random.choice(dataset_train.image_ids, 20)
            for imageId in image_ids:
                image = dataset_train.load_image(imageId)
                mask, class_ids = dataset_train.load_mask(imageId)
                # visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)
                visualize.display_instances(
                    image=image,
                    masks=mask,
                    class_ids=class_ids,
                    title=dataset_train.image_reference(imageId),
                    boxes=utils.extract_bboxes(mask),
                    class_names=dataset_train.class_names)

        # Create model in training mode
        mdl = model.MaskRCNN(mode="training",
                             config=config,
                             model_dir=os.path.dirname(outModelPath))

        if blankInput:
            mdl.load_weights(inModelPath,
                             by_name=True,
                             exclude=[
                                 "mrcnn_class_logits", "mrcnn_bbox_fc",
                                 "mrcnn_bbox", "mrcnn_mask"
                             ])
        else:
            mdl.load_weights(inModelPath, by_name=True)

        allcount = 0
        for epochgroup in self.__mParams["epoch_groups"]:
            epochs = int(epochgroup["epochs"])
            if epochs < 1:
                continue
            allcount += epochs
            mdl.train(dataset_train,
                      dataset_val,
                      learning_rate=float(epochgroup["learning_rate"]),
                      epochs=allcount,
                      layers=epochgroup["layers"])

        mdl.keras_model.save_weights(outModelPath)
    def Segment(self,
                pImage,
                pPaddingRatio=0.0,
                pDilationSElem=None,
                pCavityFilling=False,
                pPredictSize=None):

        rebuild = self.__mModel is None

        if pPredictSize is not None:
            maxdim = pPredictSize
            temp = maxdim / 2**6
            if temp != int(temp):
                maxdim = (int(temp) + 1) * 2**6

            if maxdim != self.__mLastMaxDim:
                self.__mLastMaxDim = maxdim
                rebuild = True

        if rebuild:
            import model
            import keras.backend
            keras.backend.clear_session()
            print("Max dim changed (", str(self.__mLastMaxDim),
                  "), rebuilding model")

            self.__mConfig = mask_rcnn_additional.NucleiConfig()
            self.__mConfig.DETECTION_MIN_CONFIDENCE = self.__mConfidence
            self.__mConfig.DETECTION_NMS_THRESHOLD = self.__NMSThreshold
            self.__mConfig.IMAGE_MAX_DIM = self.__mLastMaxDim
            self.__mConfig.IMAGE_MIN_DIM = self.__mLastMaxDim
            self.__mConfig.DETECTION_MAX_INSTANCES = self.__mMaxDetNum
            self.__mConfig.__init__()

            self.__mModel = model.MaskRCNN(mode="inference",
                                           config=self.__mConfig,
                                           model_dir=self.__mModelDir)
            self.__mModel.load_weights(self.__mModelPath, by_name=True)

        image = kutils.RCNNConvertInputImage(pImage)
        offsetX = 0
        offsetY = 0
        width = image.shape[1]
        height = image.shape[0]

        if pPaddingRatio > 0.0:
            image, (offsetX, offsetY) = kutils.PadImageR(image, pPaddingRatio)

        results = self.__mModel.detect([image], verbose=0)

        r = results[0]
        masks = r['masks']
        scores = r['scores']

        if masks.shape[0] != image.shape[0] or masks.shape[1] != image.shape[1]:
            print("Invalid prediction")
            return numpy.zeros((height, width), numpy.uint16), \
                   numpy.zeros((height, width, 0), numpy.uint8),\
                   numpy.zeros(0, numpy.float)

        count = masks.shape[2]
        if count < 1:
            return numpy.zeros((height, width), numpy.uint16), \
                   numpy.zeros((height, width, 0), numpy.uint8),\
                   numpy.zeros(0, numpy.float)

        if pPaddingRatio > 0.0:
            newMasks = numpy.zeros((height, width, count), numpy.uint8)

            for i in range(count):
                newMasks[:, :, i] = masks[offsetY:(offsetY + height),
                                          offsetX:(offsetX + width), i]

            masks = newMasks

        if pDilationSElem is not None:
            for i in range(count):
                masks[:, :, i] = cv2.dilate(masks[:, :, i],
                                            kernel=pDilationSElem)

        if pCavityFilling:
            for i in range(count):
                temp = cv2.bitwise_not(masks[:, :, i])
                temp, _, _ = cv2.findContours(temp, cv2.RETR_CCOMP,
                                              cv2.CHAIN_APPROX_SIMPLE)
                masks[:, :, i] = cv2.bitwise_not(temp)


#            for i in range(count):
#                masks[:, :, i] = scipy.ndimage.binary_fill_holes(masks[:, :, i])

        for i in range(count):
            masks[:, :, i] = numpy.where(masks[:, :, i] == 0, 0, 255)

        return kutils.MergeMasks(masks), masks, scores
class Segmentation:
    __mModel = None
    __mConfig = None
    __mModelDir = ""
    __mModelPath = ""
    __mLastMaxDim = mask_rcnn_additional.NucleiConfig().IMAGE_MAX_DIM
    __mConfidence = 0.5
    __NMSThreshold = 0.35
    '''
    @param pModelDir clustering Mask_RCNN model path
    '''
    def __init__(self,
                 pModelPath,
                 pConfidence=0.5,
                 pNMSThreshold=0.35,
                 pMaxDetNum=512):
        if not os.path.isfile(pModelPath):
            sys.exit("Invalid model path: " + pModelPath)

        self.__mConfidence = pConfidence
        self.__NMSThreshold = pNMSThreshold
        self.__mModelPath = pModelPath
        self.__mModelDir = os.path.dirname(pModelPath)
        self.__mMaxDetNum = pMaxDetNum

    def Segment(self, pImage, pPredictSize=None):

        rebuild = self.__mModel is None

        if pPredictSize is not None:
            maxdim = pPredictSize
            temp = maxdim / 2**6
            if temp != int(temp):
                maxdim = (int(temp) + 1) * 2**6

            if maxdim != self.__mLastMaxDim:
                self.__mLastMaxDim = maxdim
                rebuild = True

        if rebuild:
            import mrcnn_model
            import keras.backend
            keras.backend.clear_session()
            print("Max dim changed (", str(self.__mLastMaxDim),
                  "), rebuilding model")

            self.__mConfig = mask_rcnn_additional.NucleiConfig()
            self.__mConfig.DETECTION_MIN_CONFIDENCE = self.__mConfidence
            self.__mConfig.DETECTION_NMS_THRESHOLD = self.__NMSThreshold
            self.__mConfig.IMAGE_MAX_DIM = self.__mLastMaxDim
            self.__mConfig.IMAGE_MIN_DIM = self.__mLastMaxDim
            self.__mConfig.DETECTION_MAX_INSTANCES = self.__mMaxDetNum
            self.__mConfig.__init__()

            self.__mModel = mrcnn_model.MaskRCNN(mode="inference",
                                                 config=self.__mConfig,
                                                 model_dir=self.__mModelDir)
            self.__mModel.load_weights(self.__mModelPath, by_name=True)

        image = kutils.RCNNConvertInputImage(pImage)
        offsetX = 0
        offsetY = 0
        width = image.shape[1]
        height = image.shape[0]

        results = self.__mModel.detect([image], verbose=0)

        r = results[0]
        masks = r['masks']
        scores = r['scores']

        if masks.shape[0] != image.shape[0] or masks.shape[1] != image.shape[1]:
            print("Invalid prediction")
            return numpy.zeros((height, width), numpy.uint16), \
                   numpy.zeros((height, width, 0), numpy.uint8),\
                   numpy.zeros(0, numpy.float)

        count = masks.shape[2]
        if count < 1:
            return numpy.zeros((height, width), numpy.uint16), \
                   numpy.zeros((height, width, 0), numpy.uint8),\
                   numpy.zeros(0, numpy.float)

        for i in range(count):
            masks[:, :, i] = numpy.where(masks[:, :, i] == 0, 0, 255)

        return kutils.MergeMasks(masks), masks, scores

    def Run(self, imagesDir, outputDir, maxdimValues, subsize_x, subsize_y):

        os.makedirs(name=outputDir, exist_ok=True)
        imageFiles = [
            f for f in os.listdir(imagesDir)
            if os.path.isfile(os.path.join(imagesDir, f))
        ]
        imcount = len(imageFiles)
        for index, imageFile in enumerate(imageFiles):
            print("Image:", str(index + 1), "/", str(imcount), "(", imageFile,
                  ")")

            baseName = os.path.splitext(os.path.basename(imageFile))[0]
            imagePath = os.path.join(imagesDir, imageFile)
            image = skimage.io.imread(imagePath)
            if len(image.shape) > 2:
                if image.shape[0] < image.shape[2]:
                    new_image = numpy.zeros(
                        (image.shape[1], image.shape[2], 3), numpy.uint16)
                    for k in range(image.shape[0]):
                        new_image[:, :, k] = image[k, :, :]
                    for k in range(image.shape[0], 3):
                        new_image[:, :, k] = image[image.shape[0] - 1, :, :]
                    image = new_image
            elif len(image.shape) == 2:
                new_image = numpy.zeros((image.shape[0], image.shape[1], 3),
                                        numpy.uint16)
                for k in range(3):
                    new_image[:, :, k] = image
                image = new_image

            image_size_x = image.shape[1]
            image_size_y = image.shape[0]

            mask_allScales_allImageParts = numpy.zeros(
                (len(maxdimValues), image_size_y, image_size_x), numpy.uint16)
            index = 0
            totalNucleiCount = numpy.zeros(len(maxdimValues))

            for maxdim in maxdimValues:
                x_minIterator = 0
                y_minIterator = 0
                x_maxIterator = subsize_x
                y_maxIterator = subsize_y
                overlap_x = 0
                overlap_y = 0
                if image_size_x < subsize_x:
                    x_maxIterator = image_size_x
                else:
                    overlap_x = math.floor(
                        (math.ceil(image_size_x / subsize_x) * subsize_x -
                         image_size_x) / math.floor(image_size_x / subsize_x))
                if image_size_y < subsize_y:
                    y_maxIterator = image_size_y
                else:
                    overlap_y = math.floor(
                        (math.ceil(image_size_y / subsize_y) * subsize_y -
                         image_size_y) / math.floor(image_size_y / subsize_y))
                done = False

                while done == False:
                    current_image = image[y_minIterator:y_maxIterator,
                                          x_minIterator:x_maxIterator, :]

                    mask, masks, scores = self.Segment(pImage=current_image,
                                                       pPredictSize=maxdim)
                    currentNucleiCount = masks.shape[2]

                    x_min_combined = x_minIterator
                    x_range_start = 0
                    if image_size_x < subsize_x:
                        x_range_end = image_size_x
                    else:
                        x_range_end = subsize_x
                    if x_minIterator > 0:
                        x_range_start = math.floor(overlap_x / 2)

                    if x_maxIterator != image_size_x:
                        x_range_end = subsize_x - math.ceil(overlap_x / 2)

                    y_min_combined = y_minIterator
                    y_range_start = 0
                    if image_size_y < subsize_y:
                        y_range_end = image_size_y
                    else:
                        y_range_end = subsize_y
                    if y_minIterator > 0:
                        y_range_start = math.floor(overlap_y / 2)
                    if y_maxIterator != image_size_y:
                        y_range_end = subsize_y - math.ceil(overlap_y / 2)

                    nucleiIds = numpy.zeros(
                        numpy.int16(currentNucleiCount) + 1, numpy.uint16)

                    if x_min_combined > 0:
                        for y in range(y_range_start, y_range_end):
                            if mask[y, x_range_start] > 0:
                                if mask_allScales_allImageParts[
                                        index, y + y_min_combined,
                                        x_min_combined + x_range_start -
                                        1] > 0:
                                    nucleiIds[mask[
                                        y,
                                        x_range_start]] = mask_allScales_allImageParts[
                                            index, y + y_min_combined,
                                            x_min_combined + x_range_start - 1]

                    if y_min_combined > 0:
                        for x in range(x_range_start, x_range_end):
                            if mask[y_range_start, x] > 0:
                                if mask_allScales_allImageParts[
                                        index, y_min_combined + y_range_start -
                                        1, x + x_min_combined] > 0:
                                    nucleiIds[mask[
                                        y_range_start,
                                        x]] = mask_allScales_allImageParts[
                                            index,
                                            y_min_combined + y_range_start - 1,
                                            x + x_min_combined]

                    for y in range(y_range_start, y_range_end):
                        for x in range(x_range_start, x_range_end):
                            if mask[y, x] > 0:
                                if nucleiIds[mask[y, x]] > 0:
                                    mask_allScales_allImageParts[
                                        index, y + y_min_combined, x +
                                        x_min_combined] = nucleiIds[mask[y, x]]
                                else:
                                    mask_allScales_allImageParts[
                                        index, y + y_min_combined,
                                        x + x_min_combined] = mask[
                                            y, x] + totalNucleiCount[index]

                    totalNucleiCount[index] += currentNucleiCount

                    x_minIterator_memory = x_minIterator
                    x_minIterator = x_maxIterator - overlap_x
                    x_maxIterator = x_minIterator + subsize_x
                    if x_maxIterator > image_size_x:
                        if x_maxIterator < image_size_x + overlap_x:
                            delta_x = x_maxIterator - image_size_x
                            x_minIterator -= delta_x
                            x_maxIterator -= delta_x
                        else:
                            x_minIterator = x_minIterator_memory

                    if x_minIterator_memory == x_minIterator:
                        x_minIterator = 0
                        x_maxIterator = subsize_x

                        y_minIterator_memory = y_minIterator
                        y_minIterator = y_maxIterator - overlap_y
                        y_maxIterator = y_minIterator + subsize_y

                        if y_maxIterator > image_size_y:
                            if y_maxIterator < image_size_y + overlap_y:
                                delta_y = y_maxIterator - image_size_y
                                y_minIterator -= delta_y
                                y_maxIterator -= delta_y
                            else:
                                y_minIterator = y_minIterator_memory

                        if y_minIterator_memory == y_minIterator:
                            done = True

                index = index + 1

            skimage.io.imsave(os.path.join(outputDir, baseName + ".tiff"),
                              mask_allScales_allImageParts)
    def Segment(self, pImage, pPredictSize=None):

        rebuild = self.__mModel is None

        if pPredictSize is not None:
            maxdim = pPredictSize
            temp = maxdim / 2**6
            if temp != int(temp):
                maxdim = (int(temp) + 1) * 2**6

            if maxdim != self.__mLastMaxDim:
                self.__mLastMaxDim = maxdim
                rebuild = True

        if rebuild:
            import mrcnn_model
            import keras.backend
            keras.backend.clear_session()
            print("Max dim changed (", str(self.__mLastMaxDim),
                  "), rebuilding model")

            self.__mConfig = mask_rcnn_additional.NucleiConfig()
            self.__mConfig.DETECTION_MIN_CONFIDENCE = self.__mConfidence
            self.__mConfig.DETECTION_NMS_THRESHOLD = self.__NMSThreshold
            self.__mConfig.IMAGE_MAX_DIM = self.__mLastMaxDim
            self.__mConfig.IMAGE_MIN_DIM = self.__mLastMaxDim
            self.__mConfig.DETECTION_MAX_INSTANCES = self.__mMaxDetNum
            self.__mConfig.__init__()

            self.__mModel = mrcnn_model.MaskRCNN(mode="inference",
                                                 config=self.__mConfig,
                                                 model_dir=self.__mModelDir)
            self.__mModel.load_weights(self.__mModelPath, by_name=True)

        image = kutils.RCNNConvertInputImage(pImage)
        offsetX = 0
        offsetY = 0
        width = image.shape[1]
        height = image.shape[0]

        results = self.__mModel.detect([image], verbose=0)

        r = results[0]
        masks = r['masks']
        scores = r['scores']

        if masks.shape[0] != image.shape[0] or masks.shape[1] != image.shape[1]:
            print("Invalid prediction")
            return numpy.zeros((height, width), numpy.uint16), \
                   numpy.zeros((height, width, 0), numpy.uint8),\
                   numpy.zeros(0, numpy.float)

        count = masks.shape[2]
        if count < 1:
            return numpy.zeros((height, width), numpy.uint16), \
                   numpy.zeros((height, width, 0), numpy.uint8),\
                   numpy.zeros(0, numpy.float)

        for i in range(count):
            masks[:, :, i] = numpy.where(masks[:, :, i] == 0, 0, 255)

        return kutils.MergeMasks(masks), masks, scores
    def Train(self):
        fixedRandomSeed = 0
        trainToValidationChance = 0.2
        includeEvaluationInValidation = True
        stepMultiplier = 1.0
        stepCount = 1000
        showInputs = False
        augmentationLevel = 0
        detNMSThresh = 0.35
        rpnNMSThresh = 0.55
        trainDir = os.path.join(os.curdir, self.__mParams["train_dir"])
        evalDir = os.path.join(os.curdir, self.__mParams["eval_dir"])
        inModelPath = os.path.join(os.curdir, self.__mParams["input_model"])
        os.makedirs(name=self.__mParams["output_dir"], exist_ok=True)
        outModelPath = os.path.join(self.__mParams["output_dir"],
                                    self.__mParams["model_name"] + ".h5")

        blankInput = True

        if "eval_dir" in self.__mParams:
            evalDir = os.path.join(os.curdir, self.__mParams["eval_dir"])

        if "image_size" in self.__mParams:
            maxdim = self.__mParams["image_size"]

        if "train_to_val_ratio" in self.__mParams:
            trainToValidationChance = self.__mParams["train_to_val_ratio"]

        if "step_num" in self.__mParams:
            stepCount = self.__mParams["step_num"]

        if "show_inputs" in self.__mParams:
            showInputs = self.__mParams["show_inputs"]

        if "random_augmentation_level" in self.__mParams:
            augmentationLevel = self.__mParams["random_augmentation_level"]

        if "detection_nms_threshold" in self.__mParams:
            detNMSThresh = self.__mParams["detection_nms_threshold"]

        if "rpn_nms_threshold" in self.__mParams:
            rpnNMSThresh = self.__mParams["rpn_nms_threshold"]

        rnd = random.Random()
        rnd.seed(fixedRandomSeed)
        trainImagesAndMasks = {}
        validationImagesAndMasks = {}

        # iterate through train set
        imagesDir = os.path.join(trainDir, "images")
        masksDir = os.path.join(trainDir, "masks")

        # adding evaluation data into validation
        if includeEvaluationInValidation and evalDir is not None:

            # iterate through test set
            imagesValDir = os.path.join(evalDir, "images")
            masksValDir = os.path.join(evalDir, "masks")

            imageValFileList = [
                f for f in os.listdir(imagesValDir)
                if os.path.isfile(os.path.join(imagesValDir, f))
            ]
            for imageFile in imageValFileList:
                baseName = os.path.splitext(os.path.basename(imageFile))[0]
                imagePath = os.path.join(imagesValDir, imageFile)
                if os.path.exists(os.path.join(masksValDir,
                                               baseName + ".png")):
                    maskPath = os.path.join(masksValDir, baseName + ".png")
                elif os.path.exists(
                        os.path.join(masksValDir, baseName + ".tif")):
                    maskPath = os.path.join(masksValDir, baseName + ".tif")
                elif os.path.exists(
                        os.path.join(masksValDir, baseName + ".tiff")):
                    maskPath = os.path.join(masksValDir, baseName + ".tiff")
                else:
                    sys.exit(
                        "The image " + imageFile +
                        " does not have a corresponding mask file ending with png, tif or tiff"
                    )
                if not os.path.isfile(imagePath) or not os.path.isfile(
                        maskPath):
                    continue
                validationImagesAndMasks[imagePath] = maskPath
            imageFileList = [
                f for f in os.listdir(imagesDir)
                if os.path.isfile(os.path.join(imagesDir, f))
            ]
            for imageFile in imageFileList:
                baseName = os.path.splitext(os.path.basename(imageFile))[0]
                imagePath = os.path.join(imagesDir, imageFile)
                if os.path.exists(os.path.join(masksDir, baseName + ".png")):
                    maskPath = os.path.join(masksDir, baseName + ".png")
                elif os.path.exists(os.path.join(masksDir, baseName + ".tif")):
                    maskPath = os.path.join(masksDir, baseName + ".tif")
                elif os.path.exists(os.path.join(masksDir,
                                                 baseName + ".tiff")):
                    maskPath = os.path.join(masksDir, baseName + ".tiff")
                else:
                    sys.exit(
                        "The image " + imageFile +
                        " does not have a corresponding mask file ending with png, tif or tiff"
                    )
                if not os.path.isfile(imagePath) or not os.path.isfile(
                        maskPath):
                    continue
                trainImagesAndMasks[imagePath] = maskPath
        # splitting train data into train and validation
        else:
            imageFileList = [
                f for f in os.listdir(imagesDir)
                if os.path.isfile(os.path.join(imagesDir, f))
            ]
            for imageFile in imageFileList:
                baseName = os.path.splitext(os.path.basename(imageFile))[0]
                imagePath = os.path.join(imagesDir, imageFile)
                if os.path.exists(os.path.join(masksDir, baseName + ".png")):
                    maskPath = os.path.join(masksDir, baseName + ".png")
                elif os.path.exists(os.path.join(masksDir, baseName + ".tif")):
                    maskPath = os.path.join(masksDir, baseName + ".tif")
                elif os.path.exists(os.path.join(masksDir,
                                                 baseName + ".tiff")):
                    maskPath = os.path.join(masksDir, baseName + ".tiff")
                else:
                    sys.exit(
                        "The image " + imageFile +
                        " does not have a corresponding mask file ending with png, tif or tiff"
                    )
                if not os.path.isfile(imagePath) or not os.path.isfile(
                        maskPath):
                    continue
                if rnd.random() > trainToValidationChance:
                    trainImagesAndMasks[imagePath] = maskPath
                else:
                    validationImagesAndMasks[imagePath] = maskPath

        if len(trainImagesAndMasks) < 1:
            sys.exit("Empty train image list")

        #just to be non-empty
        if len(validationImagesAndMasks) < 1:
            for key, value in trainImagesAndMasks.items():
                validationImagesAndMasks[key] = value
                break

        # Training dataset
        dataset_train = mask_rcnn_additional.NucleiDataset()
        dataset_train.initialize(pImagesAndMasks=trainImagesAndMasks,
                                 pAugmentationLevel=augmentationLevel)
        dataset_train.prepare()

        # Validation dataset
        dataset_val = mask_rcnn_additional.NucleiDataset()
        dataset_val.initialize(pImagesAndMasks=validationImagesAndMasks,
                               pAugmentationLevel=0)
        dataset_val.prepare()

        print("training images (with augmentation):", dataset_train.num_images)
        print("validation images (with augmentation):", dataset_val.num_images)

        config = mask_rcnn_additional.NucleiConfig()
        config.IMAGE_MAX_DIM = maxdim
        config.IMAGE_MIN_DIM = maxdim
        config.STEPS_PER_EPOCH = stepCount
        if stepMultiplier is not None:
            steps = int(float(dataset_train.num_images) * stepMultiplier)
            config.STEPS_PER_EPOCH = steps

        config.VALIDATION_STEPS = dataset_val.num_images
        config.DETECTION_NMS_THRESHOLD = detNMSThresh
        config.RPN_NMS_THRESHOLD = rpnNMSThresh
        config.MAX_GT_INSTANCES = 512
        config.BATCH_SIZE = 5000
        config.__init__()

        # Create model in training mode
        mdl = mrcnn_model.MaskRCNN(mode="training",
                                   config=config,
                                   model_dir=os.path.dirname(outModelPath))

        if blankInput:
            mdl.load_weights(inModelPath,
                             by_name=True,
                             exclude=[
                                 "mrcnn_class_logits", "mrcnn_bbox_fc",
                                 "mrcnn_bbox", "mrcnn_mask"
                             ])
        else:
            mdl.load_weights(inModelPath, by_name=True)

        allcount = 0
        logdir = "logs/scalars/" + self.__mParams["model_name"]
        for epochgroup in self.__mParams["epoch_groups"]:
            epochs = int(epochgroup["epochs"])
            if epochs < 1:
                continue
            allcount += epochs
            mdl.train(dataset_train,
                      dataset_val,
                      learning_rate=float(epochgroup["learning_rate"]),
                      epochs=allcount,
                      layers=epochgroup["layers"])

        mdl.keras_model.save_weights(outModelPath)