예제 #1
0
def main():
    """main function"""
    if len(sys.argv) > 1:
        util.USER_ID = sys.argv[1]

        def mk_for_userid():
            util.CACHE_DIR = os.path.join(util.CACHE_DIR, util.USER_ID) + '/'
            util.RES_DIR = os.path.join(util.RES_DIR, util.USER_ID) + '/'
            util.LOG_DIR = os.path.join(util.LOG_DIR, util.USER_ID) + '/'
            util.MODEL_DIR = os.path.join(util.MODEL_DIR, util.USER_ID) + '/'
            util.SUMMARIES_DIR = os.path.join(util.SUMMARIES_DIR,
                                              util.USER_ID) + '/'
            # util.TRAIN_YAML = util.USER_ID + ".yaml"

        mk_for_userid()
        if len(sys.argv) > 2:
            util.INFER = sys.argv[2]
    # flag = True
    util.check_tensorflow_version()
    util.check_and_mkdir()
    config = load_yaml()
    check_config(config)
    hparams = create_hparams(config)
    hparams = create_hparams(config)
    log = Log(hparams)
    hparams.logger = log.logger
    print(hparams.values())
    import train
    train.train(hparams)
예제 #2
0
파일: main.py 프로젝트: zwcdp/DeepRec-1
def main():
    """main function"""
    # flag = True
    util.check_tensorflow_version()
    util.check_and_mkdir()
    config = load_yaml()
    check_config(config)
    hparams = create_hparams(config)
    log = Log(hparams)
    hparams.logger = log.logger
    print(hparams.values())
    train.train(hparams)
예제 #3
0
파일: main.py 프로젝트: ZCROM/HST
def main():
    """main function"""
    # flag = True
    util.check_tensorflow_version()
    util.check_and_mkdir()
    #util.TRAIN_YAML = yaml
    config = load_yaml()
    check_config(config)
    hparams = create_hparams(config)
    log = Log(hparams)
    hparams.logger = log.logger
    train.train(hparams)
예제 #4
0
def main():
    """main function"""
    # flag = True
    util.check_tensorflow_version()
    util.check_and_mkdir()
    #util.TRAIN_YAML = yaml
    config = load_yaml()
    check_config(config)
    hparams = create_hparams(config)
    print(hparams.values())
    log = Log(hparams)
    hparams.logger = log.logger
    train.train(hparams)
예제 #5
0
def main():
    """main function"""
    # flag = True

    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    util.check_tensorflow_version()
    util.check_and_mkdir()
    #util.TRAIN_YAML = yaml
    config = load_yaml()
    check_config(config)

    hparams = create_hparams(config)  #(name, value)
    print(hparams.values())
    log = Log(hparams)
    hparams.logger = log.logger
    train.train(hparams)