예제 #1
0
파일: views.py 프로젝트: gheinrich/DIGITS
def to_pretrained(job_id):
    job = scheduler.get_job(job_id)

    if job is None:
        raise werkzeug.exceptions.NotFound('Job not found')

    epoch = -1
    # GET ?epoch=n
    if 'epoch' in flask.request.args:
        epoch = float(flask.request.args['epoch'])

    # POST ?snapshot_epoch=n (from form)
    elif 'snapshot_epoch' in flask.request.form:
        epoch = float(flask.request.form['snapshot_epoch'])

    # Write the stats of the job to json,
    # and store in tempfile (for archive)
    info = job.json_dict(verbose=False,epoch=epoch)

    task = job.train_task()
    snapshot_filename = None
    snapshot_filename = task.get_snapshot(epoch)

    # Set defaults:
    labels_path = None
    resize_mode = None

    if "labels file" in info:
        labels_path = os.path.join(task.dataset.dir(), info["labels file"])
    if "image resize mode" in info:
        resize_mode = info["image resize mode"]

    job = PretrainedModelJob(
        snapshot_filename,
        os.path.join(job.dir(), task.model_file) ,
        labels_path,
        info["framework"],
        info["image dimensions"][2],
        resize_mode,
        info["image dimensions"][0],
        info["image dimensions"][1],
        username = auth.get_username(),
        name = info["name"]
    )

    scheduler.add_job(job)

    return flask.redirect(flask.url_for('digits.views.home',tab=3)), 302
예제 #2
0
def to_pretrained(job_id):
    job = scheduler.get_job(job_id)

    if job is None:
        raise werkzeug.exceptions.NotFound('Job not found')

    epoch = -1
    # GET ?epoch=n
    if 'epoch' in flask.request.args:
        epoch = float(flask.request.args['epoch'])

    # POST ?snapshot_epoch=n (from form)
    elif 'snapshot_epoch' in flask.request.form:
        epoch = float(flask.request.form['snapshot_epoch'])

    # Write the stats of the job to json,
    # and store in tempfile (for archive)
    info = job.json_dict(verbose=False, epoch=epoch)

    task = job.train_task()
    snapshot_filename = None
    snapshot_filename = task.get_snapshot(epoch)

    # Set defaults:
    labels_path = None
    resize_mode = None

    if "labels file" in info:
        labels_path = os.path.join(task.dataset.dir(), info["labels file"])
    if "image resize mode" in info:
        resize_mode = info["image resize mode"]

    job = PretrainedModelJob(
        snapshot_filename,
        os.path.join(job.dir(), task.model_file),
        labels_path,
        info["framework"],
        info["image dimensions"][2],
        resize_mode,
        info["image dimensions"][0],
        info["image dimensions"][1],
        username=auth.get_username(),
        name=info["name"]
    )

    scheduler.add_job(job)

    return flask.redirect(flask.url_for('digits.views.home', tab=3)), 302
예제 #3
0
def create_pretrained_model(job_id,username,epoch):
    job = scheduler.get_job(job_id)

    if job is None:
        raise werkzeug.exceptions.NotFound('Job not found')

    # Write the stats of the job to json,
    # and store in tempfile (for archive)
    info = job.json_dict(verbose=False,epoch=epoch)

    task = job.train_task()
    snapshot_filename = None
    snapshot_filename = task.get_snapshot(epoch)

    # Set defaults:
    labels_path = None
    mean_path = None
    resize_mode = None

    if "labels file" in info:
        labels_path = os.path.join(task.dataset.dir(), info["labels file"])
    if "mean file" in info:
        mean_path = os.path.join(task.dataset.dir(),info["mean file"])
    if "image resize mode" in info:
        resize_mode = info["image resize mode"]


    model_file = os.path.join(job.dir(),str(task.model_file))

    # If jobs don't container model_file (too old), raise exception:
    if not os.path.isfile(model_file):
       raise werkzeug.exceptions.BadRequest('Model file not found in job dir. Job may be too old for conversion.')

    job = PretrainedModelJob(
        snapshot_filename,
        model_file,
        labels_path,
        mean_path,
        info["framework"],
        info["image dimensions"][2],
        resize_mode,
        info["image dimensions"][0],
        info["image dimensions"][1],
        username = username,
        name = info["name"]
    )

    scheduler.add_job(job)
    return job