def predict(argv):
    parser = ArgumentParser(prog="Extra-Trees Object Counter Predictor")

    # Cytomine
    parser.add_argument('--cytomine_host', dest='cytomine_host',
                        default='demo.cytomine.be', help="The Cytomine host")
    parser.add_argument('--cytomine_public_key', dest='cytomine_public_key',
                        help="The Cytomine public key")
    parser.add_argument('--cytomine_private_key', dest='cytomine_private_key',
                        help="The Cytomine private key")
    parser.add_argument('--cytomine_base_path', dest='cytomine_base_path',
                        default='/api/', help="The Cytomine base path")
    parser.add_argument('--cytomine_working_path', dest='cytomine_working_path',
                        default=None, help="The working directory (eg: /tmp)")
    parser.add_argument('--cytomine_id_software', dest='cytomine_software', type=int,
                        help="The Cytomine software identifier")
    parser.add_argument('--cytomine_id_project', dest='cytomine_project', type=int,
                        help="The Cytomine project identifier")

    # Objects
    parser.add_argument('--cytomine_object_term', dest='cytomine_object_term', type=int,
                        help="The Cytomine identifier of object term")

    # Post-processing
    parser.add_argument('--post_threshold', dest='post_threshold', type=float,
                        help="Post-processing discarding threshold")
    parser.add_argument('--post_sigma', dest='post_sigma', type=float,
                        help="Std-dev of Gauss filter applied to smooth prediction")
    parser.add_argument('--post_min_dist', dest='post_min_dist', type=int,
                        help="Minimum distance between two peaks")

    # ROI
    parser.add_argument('--annotation', dest='annotation', type=str, action='append', default=[])
    parser.add_argument('--image', dest='image', type=str, action='append', default=[])

    # Execution
    parser.add_argument('--n_jobs', dest='n_jobs', type=int, default=1, help="Number of jobs")
    parser.add_argument('--verbose', '-v', dest='verbose', type=int, default=0, help="Level of verbosity")
    parser.add_argument('--model_id_job', dest='model_id_job', type=str, default=None, help="Model job ID")
    parser.add_argument('--model_file', dest="model_file", type=str, default=None, help="Model file")

    params, other = parser.parse_known_args(argv)
    if params.cytomine_working_path is None:
        params.cytomine_working_path = os.path.join(tempfile.gettempdir(), "cytomine")
    make_dirs(params.cytomine_working_path)

    params.model_id_job = str2int(params.model_id_job)
    params.image = [str2int(i) for i in params.image]
    params.annotation = [str2int(i) for i in params.annotation]

    # Initialize logger
    logger = StandardOutputLogger(params.verbose)
    for key, val in sorted(vars(params).iteritems()):
        logger.info("[PARAMETER] {}: {}".format(key, val))

    # Start job
    with CytomineJob(params.cytomine_host,
                     params.cytomine_public_key,
                     params.cytomine_private_key,
                     params.cytomine_software,
                     params.cytomine_project,
                     parameters=vars(params),
                     working_path=params.cytomine_working_path,
                     base_path=params.cytomine_base_path,
                     verbose=(params.verbose >= Logger.DEBUG)) as job:
        cytomine = job
        cytomine.update_job_status(job.job, status_comment="Starting...", progress=0)

        cytomine.update_job_status(job.job, status_comment="Loading model...", progress=1)
        logger.i("Loading model...")
        if params.model_file:
            model_file = params.model_file
        else:
            model_job = cytomine.get_job(params.model_id_job)
            model_file = os.path.join(params.cytomine_working_path, "models", str(model_job.software),
                                      "{}.pkl".format(model_job.id))
        with open(model_file, 'rb') as f:
            estimator = pickle.load(f)
            predict_params = vars(params).copy()
            predict_params.pop("image", None)
            predict_params.pop("annotation", None)
            estimator.set_params(**predict_params)

        cytomine.update_job_status(job.job, status_comment="Dumping annotations/images to predict...", progress=3)
        logger.i("Dumping annotations/images to predict...")
        if params.annotation[0] is not None:
            annots = [cytomine.get_annotation(id) for id in params.annotation]
            annots_collection = AnnotationCollection()
            annots_collection._data = annots
            crops = cytomine.dump_annotations(annotations=annots_collection,
                                              dest_path=os.path.join(params.cytomine_working_path, "crops",
                                                                     str(params.cytomine_project)),
                                              desired_zoom=0,
                                              get_image_url_func=Annotation.get_annotation_alpha_crop_url)
            X = crops.data()
        elif params.image[0] is not None:
            image_instances = [cytomine.get_image_instance(id) for id in params.image]
            image_instances = cytomine.dump_project_images(id_project=params.cytomine_project,
                                                           dest_path="/imageinstances/",
                                                           image_instances=image_instances,
                                                           max_size=True)
            X = image_instances
        else:
            X = []

        logger.d("X size: {} samples".format(len(X)))

        for i, x in enumerate(X):
            logger.i("Predicting ID {}...".format(x.id))
            cytomine.update_job_status(job.job, status_comment="Predicting ID {}...".format(x.id),
                                       progress=5 + np.ceil(i / len(X)) * 95)
            y = estimator.predict([x.filename])
            y = estimator.postprocessing([y], **estimator.filter_sk_params(estimator.postprocessing))

            logger.i("Uploading annotations...")
            cytomine.update_job_status(job.job, status_comment="Uploading annotations...")
            upload_annotations(cytomine, x, y, term=params.cytomine_object_term)

        logger.i("Finished.")
        cytomine.update_job_status(job.job, status_comment="Finished.", progress=100)
Esempio n. 2
0
def train(argv):
    parser = ArgumentParser(prog="Extra-Trees Object Counter Model Builder")

    # Cytomine
    parser.add_argument('--cytomine_host', dest='cytomine_host',
                        default='demo.cytomine.be', help="The Cytomine host")
    parser.add_argument('--cytomine_public_key', dest='cytomine_public_key',
                        help="The Cytomine public key")
    parser.add_argument('--cytomine_private_key', dest='cytomine_private_key',
                        help="The Cytomine private key")
    parser.add_argument('--cytomine_base_path', dest='cytomine_base_path',
                        default='/api/', help="The Cytomine base path")
    parser.add_argument('--cytomine_working_path', dest='cytomine_working_path',
                        default=None, help="The working directory (eg: /tmp)")
    parser.add_argument('--cytomine_id_software', dest='cytomine_software', type=int,
                        help="The Cytomine software identifier")
    parser.add_argument('--cytomine_id_project', dest='cytomine_project', type=int,
                        help="The Cytomine project identifier")
    parser.add_argument('--cytomine_force_download', dest='cytomine_force_download', type=bool, default=True,
                        help="Force download from Cytomine or not")

    # Objects
    parser.add_argument('--cytomine_object_term', dest='cytomine_object_term', type=int,
                        help="The Cytomine identifier of object term")
    parser.add_argument('--cytomine_object_user', dest='cytomine_object_user', type=int,
                        help="The Cytomine identifier of object owner")
    parser.add_argument('--cytomine_object_reviewed_only', dest='cytomine_object_reviewed_only', type=bool,
                        help="Whether objects have to be reviewed or not")

    # ROI
    parser.add_argument('--cytomine_roi_term', dest='cytomine_roi_term', type=int, default=None,
                        help="The Cytomine identifier of region of interest term")
    parser.add_argument('--cytomine_roi_user', dest='cytomine_roi_user', type=int,
                        help="The Cytomine identifier of ROI owner")
    parser.add_argument('--cytomine_roi_reviewed_only', dest='cytomine_roi_reviewed_only', type=bool,
                        help="Whether ROIs have to be reviewed or not")

    # Pre-processing
    parser.add_argument('--mean_radius', dest='mean_radius', type=int, required=True,
                        help="The mean radius of object to detect")
    parser.add_argument('--pre_transformer', dest='pre_transformer',
                        default=None, choices=['edt', 'euclidean_distance_transform', 'density', None, 'None'],
                        help="Scoremap transformer (None, edt, euclidean_distance_transform, density)")
    parser.add_argument('--pre_alpha', dest='pre_alpha', action='append', type=int,
                        help="Exponential decrease rate of distance (if EDT)")

    # Subwindows
    parser.add_argument('--sw_input_size', dest='sw_input_size', action='append', type=int,
                        help="Size of input subwindow")
    parser.add_argument('--sw_output_size', dest='sw_output_size', action='append', type=int,
                        help="Size of output subwindow (ignored for FCRN)")
    parser.add_argument('--sw_extr_mode', dest='sw_extr_mode', choices=['random', 'sliding', 'scoremap_constrained'],
                        help="Mode of extraction (random, scoremap_constrained)")
    parser.add_argument('--sw_extr_score_thres', dest='sw_extr_score_thres', action='append', type=float,
                        help="Minimum threshold to be foreground in subwindows extraction"
                             "(if 'scoremap_constrained' mode)")
    parser.add_argument('--sw_extr_ratio', dest='sw_extr_ratio', action='append', type=float,
                        help="Ratio of background subwindows extracted in subwindows "
                             "extraction (if 'scoremap_constrained' mode)")
    parser.add_argument('--sw_extr_npi', dest="sw_extr_npi", action='append', type=int,
                        help="Number of extracted subwindows per image (if 'random' mode)")
    parser.add_argument('--sw_colorspace', dest="sw_colorspace", type=str, default='RGB__rgb',
                        help="List of colorspace features")

    # Forest
    parser.add_argument('--forest_method', dest='forest_method', type=str,
                        action='append', choices=['ET-clf', 'ET-regr', 'RF-clf', 'RF-regr'],
                        help="Type of forest method")
    parser.add_argument('--forest_n_estimators', dest='forest_n_estimators', action='append', type=int,
                        help="Number of trees in forest")
    parser.add_argument('--forest_min_samples_split', dest='forest_min_samples_split', action='append', type=int,
                        help="Minimum number of samples for further splitting")
    parser.add_argument('--forest_max_features', dest='forest_max_features', action='append',
                        help="Max features")

    # Dataset augmentation
    parser.add_argument('--augmentation', dest='augmentation', type=bool)
    parser.add_argument('--aug_rotation_range', dest='rotation_range', type=float)
    parser.add_argument('--aug_width_shift_range', dest='width_shift_range', type=float)
    parser.add_argument('--aug_height_shift_range', dest='height_shift_range', type=float)
    parser.add_argument('--aug_zoom_range', dest='zoom_range', type=float)
    parser.add_argument('--aug_fill_mode', dest='fill_mode', type=str)
    parser.add_argument('--aug_horizontal_flip', dest='horizontal_flip', type=bool)
    parser.add_argument('--aug_vertical_flip', dest='vertical_flip', type=bool)
    parser.add_argument('--aug_featurewise_center', dest='featurewise_center', type=bool)
    parser.add_argument('--aug_featurewise_std_normalization', dest='featurewise_std_normalization', type=bool)

    # Execution
    parser.add_argument('--n_jobs', dest='n_jobs', type=int, default=1, help="Number of jobs")
    parser.add_argument('--verbose', '-v', dest='verbose', default=0, help="Level of verbosity")

    params, other = parser.parse_known_args(argv)
    if params.cytomine_working_path is None:
        params.cytomine_working_path = os.path.join(tempfile.gettempdir(), "cytomine")
    make_dirs(params.cytomine_working_path)

    params.pre_transformer = check_default(params.pre_transformer, None, return_list=False)
    params.pre_alpha = check_default(params.pre_alpha, 5)
    params.forest_method = check_default(params.forest_method, 'ET-regr')
    params.forest_n_estimators = check_default(params.forest_n_estimators, 1)
    params.forest_min_samples_split = check_default(params.forest_min_samples_split, 2)
    params.forest_max_features = check_default(params.forest_max_features, 'sqrt')
    params.forest_max_features = check_max_features(params.forest_max_features)
    params.sw_input_size = check_default(params.sw_input_size, 4)
    params.sw_input_size = [(s, s) for s in params.sw_input_size]
    params.sw_output_size = check_default(params.sw_output_size, 1)
    params.sw_output_size = [(s, s) for s in params.sw_output_size]
    params.sw_extr_mode = check_default(params.sw_extr_mode, 'scoremap_constrained', return_list=False)
    params.sw_extr_ratio = check_default(params.sw_extr_ratio, 0.5)
    params.sw_extr_score_thres = check_default(params.sw_extr_score_thres, 0.4)
    params.sw_extr_npi = check_default(params.sw_extr_npi, 100)
    params.sw_colorspace = params.sw_colorspace.split(' ')

    params.augmentation = check_default(params.augmentation, False, return_list=False)
    if params.augmentation:
        params.rotation_range = check_default(params.rotation_range, 30., return_list=False)
        params.width_shift_range = check_default(params.width_shift_range, 0.3, return_list=False)
        params.height_shift_range = check_default(params.height_shift_range, 0.3, return_list=False)
        params.zoom_range = check_default(params.zoom_range, 0.3, return_list=False)
        params.fill_mode = check_default(params.fill_mode, 'constant', return_list=False)
        params.horizontal_flip = check_default(params.horizontal_flip, True, return_list=False)
        params.vertical_flip = check_default(params.vertical_flip, True, return_list=False)
        params.featurewise_center = check_default(params.featurewise_center, False, return_list=False)
        params.featurewise_std_normalization = check_default(params.featurewise_std_normalization, False,
                                                             return_list=False)
    else:
        params.rotation_range = 0.
        params.width_shift_range = 0.
        params.height_shift_range = 0.
        params.zoom_range = 0.
        params.fill_mode = 'reflect'
        params.horizontal_flip = False
        params.vertical_flip = False
        params.featurewise_center = False
        params.featurewise_std_normalization = False

    params = params_remove_list(params)

    # Initialize logger
    logger = StandardOutputLogger(params.verbose)
    for key, val in sorted(vars(params).iteritems()):
        logger.info("[PARAMETER] {}: {}".format(key, val))

    # Initialize Cytomine client
    cytomine = Cytomine(
        params.cytomine_host,
        params.cytomine_public_key,
        params.cytomine_private_key,
        working_path=params.cytomine_working_path,
        base_path=params.cytomine_base_path,
        verbose=(params.verbose >= Logger.DEBUG)
    )

    # Start job
    with CytomineJob(cytomine,
                     params.cytomine_software,
                     params.cytomine_project,
                     parameters=vars(params_remove_none(params))) as job:
        cytomine.update_job_status(job.job, status_comment="Starting...", progress=0)

        cytomine.update_job_status(job.job, status_comment="Loading training set...", progress=1)
        X, y = get_dataset(cytomine, params.cytomine_working_path, params.cytomine_project, params.cytomine_object_term,
                           params.cytomine_roi_term, params.cytomine_object_user, params.cytomine_object_reviewed_only,
                           params.cytomine_roi_user, params.cytomine_roi_reviewed_only, params.cytomine_force_download)
        logger.d("X size: {} samples".format(len(X)))
        logger.d("y size: {} samples".format(len(y)))

        cytomine.update_job_status(job.job, status_comment="Training forest...", progress=5)
        estimator = CellCountRandomizedTrees(logger=logger, **vars(params))
        estimator.fit(np.asarray(X), np.asarray(y))

        cytomine.update_job_status(job.job, status_comment="Saving (best) model", progress=95)
        model_path = os.path.join(params.cytomine_working_path, "models", str(params.cytomine_software))
        model_file = os.path.join(model_path, "{}.pkl".format(job.job.id))
        make_dirs(model_path)
        estimator.save(model_file)

        cytomine.update_job_status(job.job, status_comment="Finished.", progress=100)
Esempio n. 3
0
def train(argv):
    parser = ArgumentParser(prog="CNN Object Counter Model Builder")

    # Cytomine
    parser.add_argument('--cytomine_host', dest='cytomine_host',
                        default='demo.cytomine.be', help="The Cytomine host")
    parser.add_argument('--cytomine_public_key', dest='cytomine_public_key',
                        help="The Cytomine public key")
    parser.add_argument('--cytomine_private_key', dest='cytomine_private_key',
                        help="The Cytomine private key")
    parser.add_argument('--cytomine_base_path', dest='cytomine_base_path',
                        default='/api/', help="The Cytomine base path")
    parser.add_argument('--cytomine_working_path', dest='cytomine_working_path',
                        default=None, help="The working directory (eg: /tmp)")
    parser.add_argument('--cytomine_id_software', dest='cytomine_software', type=int,
                        help="The Cytomine software identifier")
    parser.add_argument('--cytomine_id_project', dest='cytomine_project', type=int,
                        help="The Cytomine project identifier")
    parser.add_argument('--cytomine_force_download', dest='cytomine_force_download', type=str, default=True,
                        help="Force download from Cytomine or not")

    # Objects
    parser.add_argument('--cytomine_object_term', dest='cytomine_object_term', type=int,
                        help="The Cytomine identifier of object term")
    parser.add_argument('--cytomine_object_user', dest='cytomine_object_user', type=str,
                        help="The Cytomine identifier of object owner")
    parser.add_argument('--cytomine_object_reviewed_only', dest='cytomine_object_reviewed_only', type=str,
                        help="Whether objects have to be reviewed or not")

    # ROI
    parser.add_argument('--cytomine_roi_term', dest='cytomine_roi_term', type=int, default=None,
                        help="The Cytomine identifier of region of interest term")
    parser.add_argument('--cytomine_roi_user', dest='cytomine_roi_user', type=str,
                        help="The Cytomine identifier of ROI owner")
    parser.add_argument('--cytomine_roi_reviewed_only', dest='cytomine_roi_reviewed_only', type=str,
                        help="Whether ROIs have to be reviewed or not")

    # Pre-processing
    parser.add_argument('--pre_transformer', dest='pre_transformer',
                        default='density', choices=['edt', 'euclidean_distance_transform', 'density', None, 'None'],
                        help="Scoremap transformer (None, edt, euclidean_distance_transform, density)")
    parser.add_argument('--pre_alpha', dest='pre_alpha', type=int, default=3,
                        help="Exponential decrease rate of distance (if EDT)")

    # Subwindows for training
    parser.add_argument('--sw_input_size', dest='sw_input_size', type=int, default=128,
                        help="Size of input subwindow")
    parser.add_argument('--sw_colorspace', dest="sw_colorspace", type=str, default='RGB__rgb',
                        help="List of colorspace features")
    parser.add_argument('--sw_extr_npi', dest="sw_extr_npi", type=int, default=100,
                        help="Number of extracted subwindows per image (if 'random' mode)")

    # CNN
    parser.add_argument('--cnn_architecture', '--architecture', dest='cnn_architecture',
                        type=str, choices=['FCRN-A', 'FCRN-B'], default='FCRN-A')
    parser.add_argument('--cnn_initializer', '--initializer', dest='cnn_initializer', type=str, default='orthogonal')
    parser.add_argument('--cnn_batch_normalization', '--batch_normalization', dest='cnn_batch_normalization', type=str,
                        default=True)
    parser.add_argument('--cnn_learning_rate', '--learning_rate', '--lr', dest='cnn_learning_rate', type=float,
                        default=0.01)
    parser.add_argument('--cnn_momentum', '--momentum', dest='cnn_momentum', type=float, default=0.9)
    parser.add_argument('--cnn_nesterov', '--nesterov', dest='cnn_nesterov', type=str, default=True)
    parser.add_argument('--cnn_decay', '--decay', dest='cnn_decay', type=float, default=0.0)
    parser.add_argument('--cnn_epochs', '--epochs', dest='cnn_epochs', type=int, default=24)
    parser.add_argument('--cnn_batch_size', '--batch_size', dest='cnn_batch_size', type=int, default=16)

    # Dataset augmentation
    parser.add_argument('--augmentation', dest='augmentation', type=str, default=True)
    parser.add_argument('--aug_rotation_range', dest='rotation_range', type=float, default=0.)
    parser.add_argument('--aug_width_shift_range', dest='width_shift_range', type=float, default=0.)
    parser.add_argument('--aug_height_shift_range', dest='height_shift_range', type=float, default=0.)
    parser.add_argument('--aug_zoom_range', dest='zoom_range', type=float, default=0.)
    parser.add_argument('--aug_fill_mode', dest='fill_mode', type=str, default="reflect")
    parser.add_argument('--aug_horizontal_flip', dest='horizontal_flip', type=bool, default=False)
    parser.add_argument('--aug_vertical_flip', dest='vertical_flip', type=bool, default=False)
    parser.add_argument('--aug_featurewise_center', dest='featurewise_center', type=bool, default=False)
    parser.add_argument('--aug_featurewise_std_normalization', dest='featurewise_std_normalization', type=bool,
                        default=False)

    # Execution
    parser.add_argument('--n_jobs', dest='n_jobs', type=int, default=1, help="Number of jobs")
    parser.add_argument('--verbose', '-v', dest='verbose', type=int, default=0, help="Level of verbosity")

    params, other = parser.parse_known_args(argv)
    if params.cytomine_working_path is None:
        params.cytomine_working_path = os.path.join(tempfile.gettempdir(), "cytomine")
    make_dirs(params.cytomine_working_path)

    params.cytomine_force_download = str2bool(params.cytomine_force_download)
    params.cytomine_object_reviewed_only = str2bool(params.cytomine_object_reviewed_only)
    params.cytomine_roi_reviewed_only = str2bool(params.cytomine_roi_reviewed_only)
    params.cnn_batch_normalization = str2bool(params.cnn_batch_normalization)
    params.cnn_nesterov = str2bool(params.cnn_nesterov)
    params.augmentation = str2bool(params.augmentation)

    d = 8. if params.cnn_architecture == 'FCRN-A' else 4.
    params.sw_size = (int(np.ceil(params.sw_input_size / d) * d), int(np.ceil(params.sw_input_size / d) * d))
    params.sw_input_size = params.sw_size
    params.sw_output_size = params.sw_size
    params.sw_colorspace = params.sw_colorspace.split(' ')
    params.sw_extr_mode = 'random'
    params.cnn_regularizer = None
    params.mean_radius = 2
    params.k_factor = 100

    if params.augmentation:
        params.rotation_range = check_default(params.rotation_range, 30., return_list=False)
        params.width_shift_range = check_default(params.width_shift_range, 0.3, return_list=False)
        params.height_shift_range = check_default(params.height_shift_range, 0.3, return_list=False)
        params.zoom_range = check_default(params.zoom_range, 0.3, return_list=False)
        params.fill_mode = check_default(params.fill_mode, 'constant', return_list=False)
        params.horizontal_flip = check_default(params.horizontal_flip, True, return_list=False)
        params.vertical_flip = check_default(params.vertical_flip, True, return_list=False)
        params.featurewise_center = check_default(params.featurewise_center, False, return_list=False)
        params.featurewise_std_normalization = check_default(params.featurewise_std_normalization, False,
                                                             return_list=False)

    # Initialize logger
    logger = StandardOutputLogger(params.verbose)
    for key, val in sorted(vars(params).iteritems()):
        logger.info("[PARAMETER] {}: {}".format(key, val))

    # Start job
    with CytomineJob(params.cytomine_host,
                     params.cytomine_public_key,
                     params.cytomine_private_key,
                     params.cytomine_software,
                     params.cytomine_project,
                     parameters=vars(params),
                     working_path=params.cytomine_working_path,
                     base_path=params.cytomine_base_path,
                     verbose=(params.verbose >= Logger.DEBUG)) as job:
        cytomine = job
        cytomine.update_job_status(job.job, status_comment="Starting...", progress=0)

        cytomine.update_job_status(job.job, status_comment="Loading training set...", progress=1)
        logger.i("Loading training set...")
        X, y = get_dataset(cytomine, params.cytomine_working_path, params.cytomine_project, params.cytomine_object_term,
                           params.cytomine_roi_term, params.cytomine_object_user, params.cytomine_object_reviewed_only,
                           params.cytomine_roi_user, params.cytomine_roi_reviewed_only, params.cytomine_force_download)
        logger.d("X size: {} samples".format(len(X)))
        logger.d("y size: {} samples".format(len(y)))

        # Rename parameters
        params.architecture = params.cnn_architecture
        params.initializer = params.cnn_initializer
        params.regularizer = params.cnn_regularizer
        params.batch_normalization = params.cnn_batch_normalization
        params.learning_rate = params.cnn_learning_rate
        params.momentum = params.cnn_momentum
        params.nesterov = params.cnn_nesterov
        params.decay = params.cnn_decay
        params.epochs = params.cnn_epochs
        params.batch_size = params.cnn_batch_size

        model_path = os.path.join(params.cytomine_working_path, "models", str(params.cytomine_software))
        model_file = os.path.join(model_path, "{}.h5".format(job.job.id))
        make_dirs(model_path)

        # Callbacks
        # checkpoint_callback = ModelCheckpoint(model_file, monitor='loss', save_best_only=True)
        lr_callback = LearningRateScheduler(lr_scheduler)
        callbacks = [lr_callback]

        logger.i("Training FCRN...")
        cytomine.update_job_status(job.job, status_comment="Training FCRN...", progress=5)
        estimator = FCRN(FCRN.build_fcrn, callbacks, **vars(params))
        estimator.fit(np.asarray(X), np.asarray(y))

        logger.i("Saving model...")
        cytomine.update_job_status(job.job, status_comment="Saving (best) model", progress=95)
        estimator.save(model_file)

        logger.i("Finished.")
        cytomine.update_job_status(job.job, status_comment="Finished.", progress=100)