def Train(self): fixedRandomSeed = None trainToValidationChance = 0.2 includeEvaluationInValidation = True stepMultiplier = None stepCount = 1000 showInputs = False augmentationLevel = 0 detNMSThresh = 0.35 rpnNMSThresh = 0.55 trainDir = os.path.join(os.curdir, self.__mParams["train_dir"]) evalDir = None inModelPath = os.path.join(os.curdir, self.__mParams["input_model"]) outModelPath = os.path.join(os.curdir, self.__mParams["output_model"]) blankInput = self.__mParams["blank_mrcnn"] == "true" maxdim = 1024 if "eval_dir" in self.__mParams: evalDir = os.path.join(os.curdir, self.__mParams["eval_dir"]) if "image_size" in self.__mParams: maxdim = int(self.__mParams["image_size"]) if "train_to_val_seed" in self.__mParams: fixedRandomSeed = self.__mParams["train_to_val_seed"] if "train_to_val_ratio" in self.__mParams: trainToValidationChance = float( self.__mParams["train_to_val_ratio"]) if "use_eval_in_val" in self.__mParams: includeEvaluationInValidation = self.__mParams[ "use_eval_in_val"] == "true" if "step_ratio" in self.__mParams: stepMultiplier = float(self.__mParams["step_ratio"]) if "step_num" in self.__mParams: stepCount = int(self.__mParams["step_num"]) if "show_inputs" in self.__mParams: showInputs = self.__mParams["show_inputs"] == "true" if "random_augmentation_level" in self.__mParams: augmentationLevel = int( self.__mParams["random_augmentation_level"]) if "detection_nms_threshold" in self.__mParams: detNMSThresh = float(self.__mParams["detection_nms_threshold"]) if "rpn_nms_threshold" in self.__mParams: rpnNMSThresh = float(self.__mParams["rpn_nms_threshold"]) rnd = random.Random() rnd.seed(fixedRandomSeed) trainImagesAndMasks = {} validationImagesAndMasks = {} # iterate through train set imagesDir = os.path.join(trainDir, "images") masksDir = os.path.join(trainDir, "masks") # splitting train data into train and validation imageFileList = [ f for f in os.listdir(imagesDir) if os.path.isfile(os.path.join(imagesDir, f)) ] for imageFile in imageFileList: baseName = os.path.splitext(os.path.basename(imageFile))[0] imagePath = os.path.join(imagesDir, imageFile) maskPath = os.path.join(masksDir, baseName + ".tiff") if not os.path.isfile(imagePath) or not os.path.isfile(maskPath): continue if rnd.random() > trainToValidationChance: trainImagesAndMasks[imagePath] = maskPath else: validationImagesAndMasks[imagePath] = maskPath # adding evaluation data into validation if includeEvaluationInValidation and evalDir is not None: # iterate through test set imagesDir = os.path.join(evalDir, "images") masksDir = os.path.join(evalDir, "masks") imageFileList = [ f for f in os.listdir(imagesDir) if os.path.isfile(os.path.join(imagesDir, f)) ] for imageFile in imageFileList: baseName = os.path.splitext(os.path.basename(imageFile))[0] imagePath = os.path.join(imagesDir, imageFile) maskPath = os.path.join(masksDir, baseName + ".tiff") if not os.path.isfile(imagePath) or not os.path.isfile( maskPath): continue validationImagesAndMasks[imagePath] = maskPath if len(trainImagesAndMasks) < 1: raise ValueError("Empty train image list") #just to be non-empty if len(validationImagesAndMasks) < 1: for key, value in trainImagesAndMasks.items(): validationImagesAndMasks[key] = value break # Training dataset dataset_train = mask_rcnn_additional.NucleiDataset() dataset_train.initialize(pImagesAndMasks=trainImagesAndMasks, pAugmentationLevel=augmentationLevel) dataset_train.prepare() # Validation dataset dataset_val = mask_rcnn_additional.NucleiDataset() dataset_val.initialize(pImagesAndMasks=validationImagesAndMasks, pAugmentationLevel=0) dataset_val.prepare() print("training images (with augmentation):", dataset_train.num_images) print("validation images (with augmentation):", dataset_val.num_images) config = mask_rcnn_additional.NucleiConfig() config.IMAGE_MAX_DIM = maxdim config.IMAGE_MIN_DIM = maxdim config.STEPS_PER_EPOCH = stepCount if stepMultiplier is not None: steps = int(float(dataset_train.num_images) * stepMultiplier) config.STEPS_PER_EPOCH = steps config.VALIDATION_STEPS = dataset_val.num_images config.DETECTION_NMS_THRESHOLD = detNMSThresh config.RPN_NMS_THRESHOLD = rpnNMSThresh config.__init__() # show config config.display() # show setup for a in dir(self): if not callable(getattr(self, a)): print("{:30} {}".format(a, getattr(self, a))) print("\n") if showInputs: # Load and display random samples image_ids = numpy.random.choice(dataset_train.image_ids, 20) for imageId in image_ids: image = dataset_train.load_image(imageId) mask, class_ids = dataset_train.load_mask(imageId) # visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names) visualize.display_instances( image=image, masks=mask, class_ids=class_ids, title=dataset_train.image_reference(imageId), boxes=utils.extract_bboxes(mask), class_names=dataset_train.class_names) # Create model in training mode mdl = model.MaskRCNN(mode="training", config=config, model_dir=os.path.dirname(outModelPath)) if blankInput: mdl.load_weights(inModelPath, by_name=True, exclude=[ "mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask" ]) else: mdl.load_weights(inModelPath, by_name=True) allcount = 0 for epochgroup in self.__mParams["epoch_groups"]: epochs = int(epochgroup["epochs"]) if epochs < 1: continue allcount += epochs mdl.train(dataset_train, dataset_val, learning_rate=float(epochgroup["learning_rate"]), epochs=allcount, layers=epochgroup["layers"]) mdl.keras_model.save_weights(outModelPath)
def Segment(self, pImage, pPaddingRatio=0.0, pDilationSElem=None, pCavityFilling=False, pPredictSize=None): rebuild = self.__mModel is None if pPredictSize is not None: maxdim = pPredictSize temp = maxdim / 2**6 if temp != int(temp): maxdim = (int(temp) + 1) * 2**6 if maxdim != self.__mLastMaxDim: self.__mLastMaxDim = maxdim rebuild = True if rebuild: import model import keras.backend keras.backend.clear_session() print("Max dim changed (", str(self.__mLastMaxDim), "), rebuilding model") self.__mConfig = mask_rcnn_additional.NucleiConfig() self.__mConfig.DETECTION_MIN_CONFIDENCE = self.__mConfidence self.__mConfig.DETECTION_NMS_THRESHOLD = self.__NMSThreshold self.__mConfig.IMAGE_MAX_DIM = self.__mLastMaxDim self.__mConfig.IMAGE_MIN_DIM = self.__mLastMaxDim self.__mConfig.DETECTION_MAX_INSTANCES = self.__mMaxDetNum self.__mConfig.__init__() self.__mModel = model.MaskRCNN(mode="inference", config=self.__mConfig, model_dir=self.__mModelDir) self.__mModel.load_weights(self.__mModelPath, by_name=True) image = kutils.RCNNConvertInputImage(pImage) offsetX = 0 offsetY = 0 width = image.shape[1] height = image.shape[0] if pPaddingRatio > 0.0: image, (offsetX, offsetY) = kutils.PadImageR(image, pPaddingRatio) results = self.__mModel.detect([image], verbose=0) r = results[0] masks = r['masks'] scores = r['scores'] if masks.shape[0] != image.shape[0] or masks.shape[1] != image.shape[1]: print("Invalid prediction") return numpy.zeros((height, width), numpy.uint16), \ numpy.zeros((height, width, 0), numpy.uint8),\ numpy.zeros(0, numpy.float) count = masks.shape[2] if count < 1: return numpy.zeros((height, width), numpy.uint16), \ numpy.zeros((height, width, 0), numpy.uint8),\ numpy.zeros(0, numpy.float) if pPaddingRatio > 0.0: newMasks = numpy.zeros((height, width, count), numpy.uint8) for i in range(count): newMasks[:, :, i] = masks[offsetY:(offsetY + height), offsetX:(offsetX + width), i] masks = newMasks if pDilationSElem is not None: for i in range(count): masks[:, :, i] = cv2.dilate(masks[:, :, i], kernel=pDilationSElem) if pCavityFilling: for i in range(count): temp = cv2.bitwise_not(masks[:, :, i]) temp, _, _ = cv2.findContours(temp, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) masks[:, :, i] = cv2.bitwise_not(temp) # for i in range(count): # masks[:, :, i] = scipy.ndimage.binary_fill_holes(masks[:, :, i]) for i in range(count): masks[:, :, i] = numpy.where(masks[:, :, i] == 0, 0, 255) return kutils.MergeMasks(masks), masks, scores
class Segmentation: __mModel = None __mConfig = None __mModelDir = "" __mModelPath = "" __mLastMaxDim = mask_rcnn_additional.NucleiConfig().IMAGE_MAX_DIM __mConfidence = 0.5 __NMSThreshold = 0.35 ''' @param pModelDir clustering Mask_RCNN model path ''' def __init__(self, pModelPath, pConfidence=0.5, pNMSThreshold=0.35, pMaxDetNum=512): if not os.path.isfile(pModelPath): sys.exit("Invalid model path: " + pModelPath) self.__mConfidence = pConfidence self.__NMSThreshold = pNMSThreshold self.__mModelPath = pModelPath self.__mModelDir = os.path.dirname(pModelPath) self.__mMaxDetNum = pMaxDetNum def Segment(self, pImage, pPredictSize=None): rebuild = self.__mModel is None if pPredictSize is not None: maxdim = pPredictSize temp = maxdim / 2**6 if temp != int(temp): maxdim = (int(temp) + 1) * 2**6 if maxdim != self.__mLastMaxDim: self.__mLastMaxDim = maxdim rebuild = True if rebuild: import mrcnn_model import keras.backend keras.backend.clear_session() print("Max dim changed (", str(self.__mLastMaxDim), "), rebuilding model") self.__mConfig = mask_rcnn_additional.NucleiConfig() self.__mConfig.DETECTION_MIN_CONFIDENCE = self.__mConfidence self.__mConfig.DETECTION_NMS_THRESHOLD = self.__NMSThreshold self.__mConfig.IMAGE_MAX_DIM = self.__mLastMaxDim self.__mConfig.IMAGE_MIN_DIM = self.__mLastMaxDim self.__mConfig.DETECTION_MAX_INSTANCES = self.__mMaxDetNum self.__mConfig.__init__() self.__mModel = mrcnn_model.MaskRCNN(mode="inference", config=self.__mConfig, model_dir=self.__mModelDir) self.__mModel.load_weights(self.__mModelPath, by_name=True) image = kutils.RCNNConvertInputImage(pImage) offsetX = 0 offsetY = 0 width = image.shape[1] height = image.shape[0] results = self.__mModel.detect([image], verbose=0) r = results[0] masks = r['masks'] scores = r['scores'] if masks.shape[0] != image.shape[0] or masks.shape[1] != image.shape[1]: print("Invalid prediction") return numpy.zeros((height, width), numpy.uint16), \ numpy.zeros((height, width, 0), numpy.uint8),\ numpy.zeros(0, numpy.float) count = masks.shape[2] if count < 1: return numpy.zeros((height, width), numpy.uint16), \ numpy.zeros((height, width, 0), numpy.uint8),\ numpy.zeros(0, numpy.float) for i in range(count): masks[:, :, i] = numpy.where(masks[:, :, i] == 0, 0, 255) return kutils.MergeMasks(masks), masks, scores def Run(self, imagesDir, outputDir, maxdimValues, subsize_x, subsize_y): os.makedirs(name=outputDir, exist_ok=True) imageFiles = [ f for f in os.listdir(imagesDir) if os.path.isfile(os.path.join(imagesDir, f)) ] imcount = len(imageFiles) for index, imageFile in enumerate(imageFiles): print("Image:", str(index + 1), "/", str(imcount), "(", imageFile, ")") baseName = os.path.splitext(os.path.basename(imageFile))[0] imagePath = os.path.join(imagesDir, imageFile) image = skimage.io.imread(imagePath) if len(image.shape) > 2: if image.shape[0] < image.shape[2]: new_image = numpy.zeros( (image.shape[1], image.shape[2], 3), numpy.uint16) for k in range(image.shape[0]): new_image[:, :, k] = image[k, :, :] for k in range(image.shape[0], 3): new_image[:, :, k] = image[image.shape[0] - 1, :, :] image = new_image elif len(image.shape) == 2: new_image = numpy.zeros((image.shape[0], image.shape[1], 3), numpy.uint16) for k in range(3): new_image[:, :, k] = image image = new_image image_size_x = image.shape[1] image_size_y = image.shape[0] mask_allScales_allImageParts = numpy.zeros( (len(maxdimValues), image_size_y, image_size_x), numpy.uint16) index = 0 totalNucleiCount = numpy.zeros(len(maxdimValues)) for maxdim in maxdimValues: x_minIterator = 0 y_minIterator = 0 x_maxIterator = subsize_x y_maxIterator = subsize_y overlap_x = 0 overlap_y = 0 if image_size_x < subsize_x: x_maxIterator = image_size_x else: overlap_x = math.floor( (math.ceil(image_size_x / subsize_x) * subsize_x - image_size_x) / math.floor(image_size_x / subsize_x)) if image_size_y < subsize_y: y_maxIterator = image_size_y else: overlap_y = math.floor( (math.ceil(image_size_y / subsize_y) * subsize_y - image_size_y) / math.floor(image_size_y / subsize_y)) done = False while done == False: current_image = image[y_minIterator:y_maxIterator, x_minIterator:x_maxIterator, :] mask, masks, scores = self.Segment(pImage=current_image, pPredictSize=maxdim) currentNucleiCount = masks.shape[2] x_min_combined = x_minIterator x_range_start = 0 if image_size_x < subsize_x: x_range_end = image_size_x else: x_range_end = subsize_x if x_minIterator > 0: x_range_start = math.floor(overlap_x / 2) if x_maxIterator != image_size_x: x_range_end = subsize_x - math.ceil(overlap_x / 2) y_min_combined = y_minIterator y_range_start = 0 if image_size_y < subsize_y: y_range_end = image_size_y else: y_range_end = subsize_y if y_minIterator > 0: y_range_start = math.floor(overlap_y / 2) if y_maxIterator != image_size_y: y_range_end = subsize_y - math.ceil(overlap_y / 2) nucleiIds = numpy.zeros( numpy.int16(currentNucleiCount) + 1, numpy.uint16) if x_min_combined > 0: for y in range(y_range_start, y_range_end): if mask[y, x_range_start] > 0: if mask_allScales_allImageParts[ index, y + y_min_combined, x_min_combined + x_range_start - 1] > 0: nucleiIds[mask[ y, x_range_start]] = mask_allScales_allImageParts[ index, y + y_min_combined, x_min_combined + x_range_start - 1] if y_min_combined > 0: for x in range(x_range_start, x_range_end): if mask[y_range_start, x] > 0: if mask_allScales_allImageParts[ index, y_min_combined + y_range_start - 1, x + x_min_combined] > 0: nucleiIds[mask[ y_range_start, x]] = mask_allScales_allImageParts[ index, y_min_combined + y_range_start - 1, x + x_min_combined] for y in range(y_range_start, y_range_end): for x in range(x_range_start, x_range_end): if mask[y, x] > 0: if nucleiIds[mask[y, x]] > 0: mask_allScales_allImageParts[ index, y + y_min_combined, x + x_min_combined] = nucleiIds[mask[y, x]] else: mask_allScales_allImageParts[ index, y + y_min_combined, x + x_min_combined] = mask[ y, x] + totalNucleiCount[index] totalNucleiCount[index] += currentNucleiCount x_minIterator_memory = x_minIterator x_minIterator = x_maxIterator - overlap_x x_maxIterator = x_minIterator + subsize_x if x_maxIterator > image_size_x: if x_maxIterator < image_size_x + overlap_x: delta_x = x_maxIterator - image_size_x x_minIterator -= delta_x x_maxIterator -= delta_x else: x_minIterator = x_minIterator_memory if x_minIterator_memory == x_minIterator: x_minIterator = 0 x_maxIterator = subsize_x y_minIterator_memory = y_minIterator y_minIterator = y_maxIterator - overlap_y y_maxIterator = y_minIterator + subsize_y if y_maxIterator > image_size_y: if y_maxIterator < image_size_y + overlap_y: delta_y = y_maxIterator - image_size_y y_minIterator -= delta_y y_maxIterator -= delta_y else: y_minIterator = y_minIterator_memory if y_minIterator_memory == y_minIterator: done = True index = index + 1 skimage.io.imsave(os.path.join(outputDir, baseName + ".tiff"), mask_allScales_allImageParts)
def Segment(self, pImage, pPredictSize=None): rebuild = self.__mModel is None if pPredictSize is not None: maxdim = pPredictSize temp = maxdim / 2**6 if temp != int(temp): maxdim = (int(temp) + 1) * 2**6 if maxdim != self.__mLastMaxDim: self.__mLastMaxDim = maxdim rebuild = True if rebuild: import mrcnn_model import keras.backend keras.backend.clear_session() print("Max dim changed (", str(self.__mLastMaxDim), "), rebuilding model") self.__mConfig = mask_rcnn_additional.NucleiConfig() self.__mConfig.DETECTION_MIN_CONFIDENCE = self.__mConfidence self.__mConfig.DETECTION_NMS_THRESHOLD = self.__NMSThreshold self.__mConfig.IMAGE_MAX_DIM = self.__mLastMaxDim self.__mConfig.IMAGE_MIN_DIM = self.__mLastMaxDim self.__mConfig.DETECTION_MAX_INSTANCES = self.__mMaxDetNum self.__mConfig.__init__() self.__mModel = mrcnn_model.MaskRCNN(mode="inference", config=self.__mConfig, model_dir=self.__mModelDir) self.__mModel.load_weights(self.__mModelPath, by_name=True) image = kutils.RCNNConvertInputImage(pImage) offsetX = 0 offsetY = 0 width = image.shape[1] height = image.shape[0] results = self.__mModel.detect([image], verbose=0) r = results[0] masks = r['masks'] scores = r['scores'] if masks.shape[0] != image.shape[0] or masks.shape[1] != image.shape[1]: print("Invalid prediction") return numpy.zeros((height, width), numpy.uint16), \ numpy.zeros((height, width, 0), numpy.uint8),\ numpy.zeros(0, numpy.float) count = masks.shape[2] if count < 1: return numpy.zeros((height, width), numpy.uint16), \ numpy.zeros((height, width, 0), numpy.uint8),\ numpy.zeros(0, numpy.float) for i in range(count): masks[:, :, i] = numpy.where(masks[:, :, i] == 0, 0, 255) return kutils.MergeMasks(masks), masks, scores
def Train(self): fixedRandomSeed = 0 trainToValidationChance = 0.2 includeEvaluationInValidation = True stepMultiplier = 1.0 stepCount = 1000 showInputs = False augmentationLevel = 0 detNMSThresh = 0.35 rpnNMSThresh = 0.55 trainDir = os.path.join(os.curdir, self.__mParams["train_dir"]) evalDir = os.path.join(os.curdir, self.__mParams["eval_dir"]) inModelPath = os.path.join(os.curdir, self.__mParams["input_model"]) os.makedirs(name=self.__mParams["output_dir"], exist_ok=True) outModelPath = os.path.join(self.__mParams["output_dir"], self.__mParams["model_name"] + ".h5") blankInput = True if "eval_dir" in self.__mParams: evalDir = os.path.join(os.curdir, self.__mParams["eval_dir"]) if "image_size" in self.__mParams: maxdim = self.__mParams["image_size"] if "train_to_val_ratio" in self.__mParams: trainToValidationChance = self.__mParams["train_to_val_ratio"] if "step_num" in self.__mParams: stepCount = self.__mParams["step_num"] if "show_inputs" in self.__mParams: showInputs = self.__mParams["show_inputs"] if "random_augmentation_level" in self.__mParams: augmentationLevel = self.__mParams["random_augmentation_level"] if "detection_nms_threshold" in self.__mParams: detNMSThresh = self.__mParams["detection_nms_threshold"] if "rpn_nms_threshold" in self.__mParams: rpnNMSThresh = self.__mParams["rpn_nms_threshold"] rnd = random.Random() rnd.seed(fixedRandomSeed) trainImagesAndMasks = {} validationImagesAndMasks = {} # iterate through train set imagesDir = os.path.join(trainDir, "images") masksDir = os.path.join(trainDir, "masks") # adding evaluation data into validation if includeEvaluationInValidation and evalDir is not None: # iterate through test set imagesValDir = os.path.join(evalDir, "images") masksValDir = os.path.join(evalDir, "masks") imageValFileList = [ f for f in os.listdir(imagesValDir) if os.path.isfile(os.path.join(imagesValDir, f)) ] for imageFile in imageValFileList: baseName = os.path.splitext(os.path.basename(imageFile))[0] imagePath = os.path.join(imagesValDir, imageFile) if os.path.exists(os.path.join(masksValDir, baseName + ".png")): maskPath = os.path.join(masksValDir, baseName + ".png") elif os.path.exists( os.path.join(masksValDir, baseName + ".tif")): maskPath = os.path.join(masksValDir, baseName + ".tif") elif os.path.exists( os.path.join(masksValDir, baseName + ".tiff")): maskPath = os.path.join(masksValDir, baseName + ".tiff") else: sys.exit( "The image " + imageFile + " does not have a corresponding mask file ending with png, tif or tiff" ) if not os.path.isfile(imagePath) or not os.path.isfile( maskPath): continue validationImagesAndMasks[imagePath] = maskPath imageFileList = [ f for f in os.listdir(imagesDir) if os.path.isfile(os.path.join(imagesDir, f)) ] for imageFile in imageFileList: baseName = os.path.splitext(os.path.basename(imageFile))[0] imagePath = os.path.join(imagesDir, imageFile) if os.path.exists(os.path.join(masksDir, baseName + ".png")): maskPath = os.path.join(masksDir, baseName + ".png") elif os.path.exists(os.path.join(masksDir, baseName + ".tif")): maskPath = os.path.join(masksDir, baseName + ".tif") elif os.path.exists(os.path.join(masksDir, baseName + ".tiff")): maskPath = os.path.join(masksDir, baseName + ".tiff") else: sys.exit( "The image " + imageFile + " does not have a corresponding mask file ending with png, tif or tiff" ) if not os.path.isfile(imagePath) or not os.path.isfile( maskPath): continue trainImagesAndMasks[imagePath] = maskPath # splitting train data into train and validation else: imageFileList = [ f for f in os.listdir(imagesDir) if os.path.isfile(os.path.join(imagesDir, f)) ] for imageFile in imageFileList: baseName = os.path.splitext(os.path.basename(imageFile))[0] imagePath = os.path.join(imagesDir, imageFile) if os.path.exists(os.path.join(masksDir, baseName + ".png")): maskPath = os.path.join(masksDir, baseName + ".png") elif os.path.exists(os.path.join(masksDir, baseName + ".tif")): maskPath = os.path.join(masksDir, baseName + ".tif") elif os.path.exists(os.path.join(masksDir, baseName + ".tiff")): maskPath = os.path.join(masksDir, baseName + ".tiff") else: sys.exit( "The image " + imageFile + " does not have a corresponding mask file ending with png, tif or tiff" ) if not os.path.isfile(imagePath) or not os.path.isfile( maskPath): continue if rnd.random() > trainToValidationChance: trainImagesAndMasks[imagePath] = maskPath else: validationImagesAndMasks[imagePath] = maskPath if len(trainImagesAndMasks) < 1: sys.exit("Empty train image list") #just to be non-empty if len(validationImagesAndMasks) < 1: for key, value in trainImagesAndMasks.items(): validationImagesAndMasks[key] = value break # Training dataset dataset_train = mask_rcnn_additional.NucleiDataset() dataset_train.initialize(pImagesAndMasks=trainImagesAndMasks, pAugmentationLevel=augmentationLevel) dataset_train.prepare() # Validation dataset dataset_val = mask_rcnn_additional.NucleiDataset() dataset_val.initialize(pImagesAndMasks=validationImagesAndMasks, pAugmentationLevel=0) dataset_val.prepare() print("training images (with augmentation):", dataset_train.num_images) print("validation images (with augmentation):", dataset_val.num_images) config = mask_rcnn_additional.NucleiConfig() config.IMAGE_MAX_DIM = maxdim config.IMAGE_MIN_DIM = maxdim config.STEPS_PER_EPOCH = stepCount if stepMultiplier is not None: steps = int(float(dataset_train.num_images) * stepMultiplier) config.STEPS_PER_EPOCH = steps config.VALIDATION_STEPS = dataset_val.num_images config.DETECTION_NMS_THRESHOLD = detNMSThresh config.RPN_NMS_THRESHOLD = rpnNMSThresh config.MAX_GT_INSTANCES = 512 config.BATCH_SIZE = 5000 config.__init__() # Create model in training mode mdl = mrcnn_model.MaskRCNN(mode="training", config=config, model_dir=os.path.dirname(outModelPath)) if blankInput: mdl.load_weights(inModelPath, by_name=True, exclude=[ "mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask" ]) else: mdl.load_weights(inModelPath, by_name=True) allcount = 0 logdir = "logs/scalars/" + self.__mParams["model_name"] for epochgroup in self.__mParams["epoch_groups"]: epochs = int(epochgroup["epochs"]) if epochs < 1: continue allcount += epochs mdl.train(dataset_train, dataset_val, learning_rate=float(epochgroup["learning_rate"]), epochs=allcount, layers=epochgroup["layers"]) mdl.keras_model.save_weights(outModelPath)