コード例 #1
0
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)
    #utils.load_plugins()

    if 'TV_DIR_RUNS' in os.environ:
        os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
                                                 'KittiBox')
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    #train.maybe_download_and_extract(hypes)
    logging.info("Start training")
    train.do_training(hypes)
コード例 #2
0
ファイル: train.py プロジェクト: xiaoai119/my_fcn
def main():
    with open('../config/fcn8_seg.json', 'r') as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)

    utils.set_dirs(hypes, '../config/fcn8_seg.json')
    utils._add_paths_to_sys(hypes)
    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    logging.info("Start training")
    print('start')
    train.do_training(hypes)
    print('end')
コード例 #3
0
def main(_):
    utils.set_gpus_to_use()

    sys.path.append("submodules/tensorflow-fcn")
    sys.path.append("submodules/TensorVision")

    import tensorvision.train
    import tensorflow_fcn.utils

    # try:
    #     import tensorvision.train
    #     import tensorflow_fcn.utils
    # except ImportError:
    #     logging.error("Could not import the submodules.")
    #     logging.error("Please execute:"
    #                   "'git submodule update --init --recursive'")
    #     exit(1)

    if tf.app.flags.FLAGS.hypes is None:
        logging.error("No hype file is given.")
        logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)
    utils.load_plugins()

    if tf.app.flags.FLAGS.mod is not None:
        import ast
        mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
        dict_merge(hypes, mod_dict)

    if 'TV_DIR_RUNS' in os.environ:
        os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
                                                 'KittiSeg')
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)
    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    logging.info("Start training")
    train.do_training(hypes)
コード例 #4
0
ファイル: train.py プロジェクト: heyealex/KittiSeg
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    if tf.app.flags.FLAGS.hypes is None:
        logging.error("No hype file is given.")
        logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)
        hypes['dist'] = FLAGS.dist
        if FLAGS.layers:
            hypes['arch']['layers'] = FLAGS.layers
        if FLAGS.lr:
            hypes['solver']['learning_rate'] = FLAGS.lr
        if FLAGS.optimizer:
            hypes['solver']['opt'] = FLAGS.optimizer
    utils.load_plugins()

    if tf.app.flags.FLAGS.mod is not None:
        import ast
        mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
        dict_merge(hypes, mod_dict)

    if 'TV_DIR_RUNS' in os.environ:
        os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
                                                 'KittiSeg')
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)
    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    train.do_training(hypes)
コード例 #5
0
ファイル: train.py プロジェクト: afcarl/MediSeg
def main(_):
    utils.set_gpus_to_use()

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)
    utils.load_plugins()

    if 'TV_DIR_RUNS' in os.environ:
        os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
                                                 'MediSeg')
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    train.maybe_download_and_extract(hypes)
    logging.info("Start training")
    train.do_training(hypes)
コード例 #6
0
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    if tf.app.flags.FLAGS.hypes is None:
        logging.error("No hype file is given.")
        logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)
    utils.load_plugins()

    if tf.app.flags.FLAGS.mod is not None:
        import ast
        mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
        dict_merge(hypes, mod_dict)

    os.environ["TV_DIR_DATA"] = "../../SemSeg_DATA/DATA"
    os.environ["TV_DIR_RUNS"] = "../../SemSeg_DATA/RUNS"

    # print(os.environ["TV_DIR_DATA"])

    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)
    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    logging.info("Start training")
    train.do_training(hypes)
コード例 #7
0
def train_and_get_results(to_drop, hypes, encoder_path):
    ga_content = {'encoder_name': 'fcn8_vgg',
                  'encoder_path': encoder_path,
                  'drop': to_drop}

    with open(hypes['ga_data'], 'w') as f:
        commentjson.dump(ga_content, f)

    tf.reset_default_graph() 
    train.do_training(hypes)
    # thanks to https://stackoverflow.com/questions/39327032/how-to-get-the-latest-file-in-a-folder-using-python
    runs = glob.glob('RUNS/*')
    latest_dir = max(runs, key=os.path.getctime)
    log = os.path.join(latest_dir, 'output.log')
    with open(log) as f:
        lines = list(reversed(f.read().splitlines()))
    
    duration = float(lines[1].split(':')[-1].strip())
    ap = float(lines[2].split(':')[-1].strip())
    maxf1 = float(lines[4].split(':')[-1].strip())
    return duration, maxf1, ap 
コード例 #8
0
ファイル: train.py プロジェクト: new-2017/KittiSeg
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    if tf.app.flags.FLAGS.hypes is None:
        logging.error("No hype file is given.")
        logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)
    utils.load_plugins()

    if tf.app.flags.FLAGS.mod is not None:
        import ast
        mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
        dict_merge(hypes, mod_dict)

    if 'TV_DIR_RUNS' in os.environ:
        os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
                                                 'KittiSeg')
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)
    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    logging.info("Start training")
    train.do_training(hypes)
コード例 #9
0
ファイル: train.py プロジェクト: rodolfolotte/KittiSeg
def main(_):
    logging.info(
        "Initializing GPUs, plugins and creating the essential folders")
    utils.set_gpus_to_use()

    if FLAGS.hypes is None:
        logging.error("No hypes are given.")
        logging.error("Usage: python train.py --hypes hypes.json")
        logging.error("   tf: tv-train --hypes hypes.json")
        exit(1)

    with open(FLAGS.hypes) as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)

    if FLAGS.mod is not None:
        import ast
        mod_dict = ast.literal_eval(FLAGS.mod)
        dict_merge(hypes, mod_dict)

    logging.info("Loading plugins")
    utils.load_plugins()

    logging.info("Set dirs")
    utils.set_dirs(hypes, FLAGS.hypes)

    logging.info("Add paths to sys")
    utils._add_paths_to_sys(hypes)

    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)

    tf.reset_default_graph()

    logging.info("Start training")
    train.do_training(hypes)