Beispiel #1
0
def make_embed(project_name):
    proj = db.session.query(Project).filter_by(name=project_name).first()
    if proj is None:
        return jsonify(error=f"project {project_name} doesn't exist"), 400

    model0ExistOrNot = os.path.exists(
        f"./projects/{project_name}/models/0/best_model.pth")
    current_app.logger.info(
        f'Model 0 (autoencoder) exists = {model0ExistOrNot}')
    if not model0ExistOrNot:
        return jsonify(
            error=
            "Embedding is not available unless at least a base model is trained. Please make patches and train AE"
        ), 400

    if proj.train_ae_time is None and proj.iteration == 0:
        error_message = f'The base model 0 of project {project_name} was overwritten when Retrain Model 0 started.\n ' \
                        f'Please wait until the Retrain Model 0 finishes. '
        current_app.logger.warn(error_message)
        return jsonify(error=error_message), 400

    current_app.logger.info('Checking if the embeddings are the most recent.')

    # get config options:
    batchsize = config.getint('make_embed', 'batchsize', fallback=32)
    patchsize = config.getint('make_embed', 'patchsize', fallback=256)
    numimgs = request.args.get('numimgs', default=-1, type=int)
    modelid = request.args.get('modelid',
                               default=get_latest_modelid(project_name),
                               type=int)
    outdir = f"./projects/{project_name}/models/{modelid}"

    latest_modelID = get_latest_modelid(project_name)

    if modelid < 0 or modelid > latest_modelID:
        return jsonify(
            error=
            f"Your selected Embed Model ID is {modelid}. The last model ID is {latest_modelID}. A valid Model ID ranges from 0 to {latest_modelID}."
        ), 400

    # get the command:
    full_command = [
        sys.executable, "make_embed.py", project_name, f"-o{outdir}",
        f"-p{patchsize}", f"-b{batchsize}", f"-m{numimgs}"
    ]
    current_app.logger.info(f'Full command = {str(full_command)}')

    # update the embedding iteration:
    # current_app.logger.info('Updating the embedding iteration to the model iteration:')
    # proj.embed_iteration = proj.iteration
    db.session.commit()

    # run the command asynchronously:
    command_name = "make_embed"
    return pool_run_script(project_name,
                           command_name,
                           full_command,
                           callback=make_embed_callback)
Beispiel #2
0
def plotembed(project_name):
    current_app.logger.info('Plotting patch embedding:')
    project = Project.query.filter_by(name=project_name).first()
    if not project:
        current_app.logger.error('No project found.')
        return render_template("error.html")

    latest_modelid = get_latest_modelid(project_name)
    selected_modelid = request.args.get('modelid',
                                        default=latest_modelid,
                                        type=int)
    if selected_modelid > latest_modelid or selected_modelid < 0:
        error_message = f"Your selected View Embed Model ID is {selected_modelid}. A valid Model ID ranges from 0 to {latest_modelid}."
        current_app.logger.error(error_message)
        return render_template("embed.html",
                               project_name=project_name,
                               data="None",
                               project_iteration=project.iteration,
                               current_modelId=selected_modelid,
                               error_message=error_message)

    return render_template("embed.html",
                           project_name=project_name,
                           project_iteration=project.iteration,
                           current_modelId=selected_modelid)
Beispiel #3
0
def get_embed_csv(project_name):
    project = Project.query.filter_by(name=project_name).first()

    latest_modelid = get_latest_modelid(project_name)
    selected_modelid = request.args.get('modelid',
                                        default=latest_modelid,
                                        type=int)
    fname = f"./projects/{project_name}/models/{selected_modelid}/embedding.csv"

    if selected_modelid > latest_modelid or selected_modelid < 0:
        error_message = f"Your selected View Embed Model ID is {selected_modelid}. A valid Model ID ranges from 0 to {latest_modelid}."
        current_app.logger.error(error_message)
        return jsonify(error=error_message), 400

    if not os.path.exists(fname):
        error_message = f'No embedding data available to render for Model {selected_modelid}.'
        current_app.logger.error(error_message)
        return jsonify(error=error_message), 400

    folder, filename = os.path.split(fname)
    response = send_from_directory(folder, filename)

    response.headers[
        'Cache-Control'] = 'no-store, no-cache, must-revalidate, post-check=0, pre-check=0, max-age=0'
    response.headers['Pragma'] = 'no-cache'
    response.headers['Expires'] = '-1'
    return response
Beispiel #4
0
def get_model(project_name):
    modelid = request.args.get('model',
                               get_latest_modelid(project_name),
                               type=int)
    model_path = f"./projects/{project_name}/models/{modelid}/"
    return send_from_directory(model_path,
                               "best_model.pth",
                               as_attachment=True)
Beispiel #5
0
def get_prediction(project_name, image_name):
    current_app.logger.info(
        f'Getting prediction for project {project_name} and image {image_name}'
    )

    project = Project.query.filter_by(name=project_name).first()
    curr_image = Image.query.filter_by(projId=project.id,
                                       name=image_name).first()
    if curr_image is None:
        jsonify(error=f"Image {image_name} does not exist"), 400

    modelid = request.args.get('model',
                               get_latest_modelid(project_name),
                               type=int)
    current_app.logger.info(f'Model id = {str(modelid)}')

    if modelid <= 0:
        current_app.logger.warn(
            f"No DL model trained for {project_name} -- {image_name} -- {modelid}"
        )
        return jsonify(
            error="No AI model trained, so no AI results available yet."), 400

    upload_folder = f"./projects/{project_name}/pred/{modelid}"
    fname = image_name.replace(".png", "_pred.png")
    full_fname = f"{upload_folder}/{fname}"
    current_app.logger.info('Full filename for prediction = ' + full_fname)

    print('Generating new prediction image:')
    batchsize = config.getint('get_prediction', 'batchsize', fallback=32)
    patchsize = config.getint('get_prediction', 'patchsize', fallback=256)

    # run the command:
    full_command = [
        sys.executable, "make_output_unet_cmd.py", f"-s{batchsize}",
        f"-p{patchsize}",
        f"-m./projects/{project_name}/models/{modelid}/best_model.pth",
        f"-o./projects/{project_name}/pred/{modelid}",
        f"./projects/{project_name}/{image_name}", "--force"
    ]

    command_name = "generate_prediction"
    return pool_get_image(project_name,
                          command_name,
                          full_command,
                          full_fname,
                          imageid=curr_image.id)
Beispiel #6
0
def get_superpixels(project_name, image_name):
    current_app.logger.info(
        f'Getting superpixel for project {project_name} and image {image_name}'
    )
    latest_modelid = get_latest_modelid(project_name)

    force = request.args.get('force', False, type=bool)

    modelidreq = request.args.get('superpixel_run_id',
                                  latest_modelid,
                                  type=int)
    current_app.logger.info(f'Model id = {str(modelidreq)}')
    if modelidreq > latest_modelid:
        return jsonify(
            error=
            f"Requested ModelID {modelidreq} greater than available models {latest_modelid}"
        ), 400

    project = Project.query.filter_by(name=project_name).first()
    curr_image = Image.query.filter_by(projId=project.id,
                                       name=image_name).first()
    superpixel_modelid = curr_image.superpixel_modelid
    current_app.logger.info(
        f'The current superpixel_modelid of {image_name} = {str(superpixel_modelid)}'
    )

    upload_folder = f"./projects/{project_name}/superpixels"
    spixel_fname = image_name.replace(".png", "_superpixels.png")
    full_fname = f"{upload_folder}/{spixel_fname}"
    current_app.logger.info('Full filename for superpixel = ' + full_fname)

    batchsize = config.getint('superpixel', 'batchsize', fallback=32)
    patchsize = config.getint('superpixel', 'patchsize', fallback=256)
    approxcellsize = config.getint('superpixel', 'approxcellsize', fallback=20)
    compactness = config.getfloat('superpixel', 'compactness', fallback=.01)
    command_to_use = config.get("superpixel",
                                'command_to_use',
                                fallback="make_superpixel.py")

    if modelidreq < 0:
        # We are using simple method, since we have no dl model
        current_app.logger.warn(
            f"No DL model trained for {project_name} -- {image_name} -- {modelidreq}, will use simple method"
        )
        command_to_use = "make_superpixel.py"

    full_command = [
        sys.executable, command_to_use, f"-p{patchsize}", f"-x{batchsize}",
        f"-c{compactness}", f"-a{approxcellsize}",
        f"-m./projects/{project_name}/models/{modelidreq}/best_model.pth",
        f"-s./projects/{project_name}/superpixels/",
        f"-o./projects/{project_name}/superpixels_boundary/",
        f"./projects/{project_name}/{image_name}", "--force"
    ]

    current_app.logger.info(
        f'We are running {command_to_use} to generate superpixels for IMAGE {image_name} in PROJECT {project_name} '
    )
    current_app.logger.info(f'Superpixel command = {full_command}')

    command_name = "generate_superpixel"

    if modelidreq > superpixel_modelid or force:
        try:
            os.remove(full_fname)
        except:
            pass

    return pool_get_image(project_name,
                          command_name,
                          full_command,
                          full_fname,
                          imageid=curr_image.id,
                          callback=get_superpixels_callback)
Beispiel #7
0
def retrain_dl(project_name):
    proj = Project.query.filter_by(name=project_name).first()
    if proj is None:
        return jsonify(error=f"project {project_name} doesn't exist"), 400
    current_app.logger.info(
        f'About to train a new transfer model for {project_name}')

    frommodelid = request.args.get('frommodelid', default=0, type=int)

    if (frommodelid == -1):
        frommodelid = get_latest_modelid(project_name)

    if frommodelid > proj.iteration or not os.path.exists(
            f"./projects/{project_name}/models/{frommodelid}/best_model.pth"):
        return jsonify(
            error=f"Deep learning model {frommodelid} doesn't exist"), 400

    if proj.train_ae_time is None and frommodelid == 0:
        error_message = f'The base model 0 of project {project_name} was overwritten when Retrain Model 0 started.\n ' \
                        f'Please wait until the Retrain Model 0 finishes. '
        current_app.logger.warn(error_message)
        return jsonify(error=error_message), 400

    # todo: make sure there's actually a model in that subdirectory since errors still create the dir before the model is ready
    new_modelid = get_latest_modelid(project_name) + 1
    output_model_path = f"./projects/{project_name}/models/{new_modelid}/"
    current_app.logger.info(f'New model path = {output_model_path}')

    # store the list of test and training images in text files:
    test_file_path = f"projects/{project_name}/test_imgs.txt"
    train_file_path = f"projects/{project_name}/train_imgs.txt"

    current_app.logger.info('Populating project files:')
    populate_training_files(project_name, train_file_path, test_file_path)

    # check if enough data exists:
    empty_training = not os.path.exists(test_file_path) or os.stat(
        test_file_path).st_size == 0
    empty_testing = not os.path.exists(test_file_path) or os.stat(
        test_file_path).st_size == 0
    if empty_training or empty_testing:  # TODO can improve this by simply counting ROIs in the db
        error_message = f'Not enough training/test images for project {project_name}. You need at least 1 of each.'
        current_app.logger.warn(error_message)
        return jsonify(error=error_message), 400

    # get config properties:
    num_epochs = config.getint('train_tl', 'numepochs', fallback=1000)
    num_epochs_earlystop = config.getint('train_tl',
                                         'num_epochs_earlystop',
                                         fallback=-1)
    num_min_epochs = config.getint('train_tl', 'num_min_epochs', fallback=300)
    batch_size = config.getint('train_tl', 'batchsize', fallback=32)
    patch_size = config.getint('train_tl', 'patchsize', fallback=256)
    num_workers = config.getint('train_tl', 'numworkers', fallback=0)
    edge_weight = config.getfloat('train_tl', 'edgeweight', fallback=2)
    pclass_weight = config.getfloat('train_tl', 'pclass_weight', fallback=.5)
    fillbatch = config.getboolean('train_tl', 'fillbatch', fallback=False)

    # query P/N pixel count from database for ppixel_train npixel_train ppixel_test npixel_test
    if pclass_weight == -1:
        proj_ppixel = db.session.query(db.func.sum(
            Image.ppixel)).filter_by(projId=proj.id).scalar()
        proj_npixel = db.session.query(db.func.sum(
            Image.npixel)).filter_by(projId=proj.id).scalar()
        total = proj_npixel + proj_ppixel
        pclass_weight = 1 - proj_ppixel / total

    # get the command to retrain the model:
    full_command = [
        sys.executable, "train_model.py", f"-p{patch_size}",
        f"-e{edge_weight}", f"-n{num_epochs}", f"-s{num_epochs_earlystop}",
        f"-l{num_min_epochs}", f"-b{batch_size}", f"-o{output_model_path}",
        f"-w{pclass_weight}", f"-r{num_workers}",
        f"-m./projects/{project_name}/models/{frommodelid}/best_model.pth",
        f"./projects/{project_name}"
    ]

    if (fillbatch):
        full_command.append("--fillbatch")

    current_app.logger.info(f'Training command = {full_command}')

    # run the script asynchronously:
    command_name = "retrain_dl"
    return pool_run_script(project_name,
                           command_name,
                           full_command,
                           callback=retrain_dl_callback)