コード例 #1
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def rqst_gan_types():
     db = get_db()
     ganlist = get_gan_types(db)
     emit("projects-get-gans", ganlist, namespace="/projects")
コード例 #2
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
    def request_available_inferenced_imgs(pid, expid):
        db = get_db()
        img_list = get_output_imgs(db, expid, "INFERENCED")

        emit("inference-get-inferenced-imgs", img_list, namespace='/inference')
コード例 #3
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
    def rqst_del_img(expid, path):
        db = get_db()
        delImgPath(db, expid, path)
        img_list = get_output_imgs(db, expid, "INFERENCED")

        emit("inference-get-inferenced-imgs", img_list, namespace='/inference')
コード例 #4
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def rqst_del_plt_setting(
         pid, expid, plt_values):  # plt_values = [plt_stat_name, plt_id]
     db = get_db()
     del_plt_stat(db, expid, plt_values[0], plt_values[1])
     plt_settings_list = getPlotStats(db, expid)
     emit("plt-get-plt-settings", plt_settings_list, namespace="/plot")
コード例 #5
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def request_available_models(pid, expid):
     db = get_db()
     model_data_list = get_models(db, pid, expid)
     emit("inference-get-available-models",
          model_data_list,
          namespace='/inference')
コード例 #6
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def plotting(msg):
     updateplot(socketio, get_db())
コード例 #7
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def plt_rqst_plt_settings(pid, expid):
     db = get_db()
     plt_settings_list = getPlotStats(db, expid)
     print("plt settings list ", plt_settings_list)
     emit("plt-get-plt-settings", plt_settings_list, namespace="/plot")
コード例 #8
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def load_img_paths(pid, expid):
     print("data load imgs")
     db = get_db()
     img_path_list = getImagePaths(db, expid, "INPUTDATA")
     emit('data-get-img-paths', img_path_list, namespace='/data')
     print("Emitted dataload imgs")
コード例 #9
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def data_request_gan_gen_images(pid, expid):
     db = get_db()
     img_list = get_output_imgs(db, expid, "GENDATA")
     emit('data-get-gen-images', img_list, namespace='/data')
コード例 #10
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def summary_request_editable(pid, expid):
     db = get_db()
     output_dict = get_exp_default_para_info(db, expid)
     print(output_dict)
     emit("summary-get-exp-info", output_dict, namespace='/summary')
     print("emited editable summary")
コード例 #11
0
ファイル: events.py プロジェクト: vlbthambawita/GANExFlask
 def request_default_exp_para(pid):
     db = get_db()
     default_exp_para_list = get_default_exp_para(db, pid)
     emit("get-exp-default-para",
          default_exp_para_list,
          namespace='/experiments')
コード例 #12
0
ファイル: plots.py プロジェクト: vlbthambawita/GANExFlask
def plots(pid, expid):
    db = get_db()
    #col_trainstat = db["trainstats"].find({"expid":expid})
    statlist = getTrainStatsList(db, expid)

    return render_template('run/plots.html', pid=pid, expid=expid, statlist=statlist) #trainstat=col_trainstat
コード例 #13
0
def runexp(pid, expid):

    db = get_db()
    status = getExpState(db, expid)
    print("status:", status)

    if status == None:
        flash("status error")

    if request.method == "POST":
        print("POST request received")
        if request.form["runexp_btn"] == "train":

            try:
                # get the GAN class
                (ganDir, ganFile, ganClass) = getGANInfo(db, expid)
                print("gan Dir=", ganDir)
                print("gan file=", ganFile)
                print("gan class=", ganClass)

                # import gan from gan file
                #* my_module = importlib.import_module("GANEX.fastGAN.{}".format(ganFile))
                #* gan = eval("my_module.{}(db, pid, expid)".format(ganClass))
                gan = create_gan_object(db, pid, expid, ganDir, ganFile,
                                        ganClass)
                # gan.run("BTN_TRAIN")
                gan.run()

                setExpState(db, expid, "RETRAIN")

                #run(get_db(),pid, expid, status)
                print("Training")
                #setExpState(db, expid, "RETRAIN")
                #status = getExpState(db, expid)
            except Exception as e:
                flash(str(e))

        elif request.form["runexp_btn"] == "re-train":
            try:
                # get the GAN class
                (ganDir, ganFile, ganClass) = getGANInfo(db, expid)
                print("gan file=", ganFile)
                print("gan class=", ganClass)

                # import gan from gan file
                #* my_module = importlib.import_module("GANEX.fastGAN.{}".format(ganFile))
                #* gan = eval("my_module.{}(db, pid, expid)".format(ganClass))
                gan = create_gan_object(db, pid, expid, ganDir, ganFile,
                                        ganClass)
                # * gan.run("BTN_RETRAIN")
                gan.rerun()
                setExpState(db, expid, "RETRAIN")

                #run(get_db(),pid, expid, status)
                print("Training")
                #setExpState(db, expid, "RETRAIN")
                #status = getExpState(db, expid)
            except Exception as e:
                flash(str(e))

            print("Retraining")

        elif request.form["runexp_btn"] == "reset":

            print("Reset")
            delTrainStats(db, expid)
            setExpState(db, expid, "TRAIN")
            status = getExpState(db, expid)

            # reset default exp para
            exp_para_list = get_default_exp_para(db, pid)
            for exp_para in exp_para_list:
                temp_dict = {exp_para["para_key"]: exp_para["para_value"]}
                update_exp_info(db, expid, temp_dict)
                print("exp para :", exp_para)

    return render_template('run/runexp.html',
                           pid=pid,
                           expid=expid,
                           status=status)
コード例 #14
0
def create(pid):
    exp_form = CreateExperiment_form()
    db = get_db()

    #
    col_gans = (db["gantypes"].find({},{"_id":0}))
    col_exp = db["experiments"] # experiments table
    col_exp.create_index([("name", pymongo.ASCENDING), ("pid", pymongo.ASCENDING)], unique=True) # name unique index

    all_exps = col_exp.find({"pid":pid})

    gan_types = []

    for g in col_gans:
        print("gggg=", g)
        gan_types.append((g["name"], g["name"]))
        
    print(gan_types)
    exp_form.ganType.choices = gan_types
    error = None
    #all_projects = col.find({})

    # load default hyper params
    all_hyperparams = list(get_default_hyperparams(db, pid))

    
    default_para_list = list(get_default_hyperparams(db, pid))

    #print(all_projects)

    #all_projects = list(all_projects)

    # test_project = {"p1":"test1", "p2":"test2"}

    # if this for loop print outputs -
    # web page will not print outputs
    #for p in all_projects:
     #   print(p)

     
    if exp_form.validate_on_submit():

        try:
            if error is None:
                exp_name = exp_form.expName.data
                exp_gan = exp_form.ganType.data
                exp_pro_path = db.projects.find_one({"_id":ObjectId(pid)})["path"]
                print(exp_pro_path)
                print("exp gan=", exp_gan)

                #paths
                exp_path = os.path.join(exp_pro_path, exp_name)
                exp_models_path = os.path.join(exp_pro_path, exp_name + "/models")
                exp_output_path = os.path.join(exp_pro_path, exp_name + "/output")

                os.mkdir(exp_path)
                os.mkdir(exp_models_path)
                os.mkdir(exp_output_path)


                # initialize exp inforamtion
                exp_dict = {"name":exp_name, "type":exp_gan, "pid": pid, "status": "TRAIN", 
                            "path":exp_path, "models_path":exp_models_path, "output_path": exp_output_path , "iters": 0,
                            "current_epoch": 0, "dataloader_size": 0}

                # modify exp_dict with default parameters
                exp_para_list = get_default_exp_para(db, pid)
                print("exp para list=", exp_para_list)

                for exp_para in exp_para_list:
                    print("exp para000=====", exp_para) 
                    exp_dict.update({exp_para["para_key"]: exp_para["para_value"]})

                # insert exp dict
                x = col_exp.insert_one(exp_dict)

                # add additional data to experiments collection

                #addInfoToExp

                # initialize train settings
                dict_settings = {"num_epochs": 0, "checkpoint_interval": 0, "checkpoint_type":"EPOCH"}
                set_train_settings(db, str(x.inserted_id), dict_settings)

                print(x.inserted_id) #out.inserted_id
                # flash(x.inserted_id) # remove this one, if redirect the page
                #return redirect(url_for('experiments.create'))



        except Exception as e:
            flash(e)

    return render_template('experiments/create.html', form=exp_form, pid=pid, exps=all_exps)