def train_loop(myhypes=None): utils.set_gpus_to_use() try: import tensorvision.train import tensorflow_fcn.utils except ImportError: logging.error("Could not import the submodules.") logging.error("Please execute:" "'git submodule update --init --recursive'") exit(1) if tf.app.flags.FLAGS.hypes is None: logging.error("No hype file is given.") logging.info("Usage: python train.py --hypes hypes/KittiClass.json") exit(1) with open(myhypes, 'r') as f: logging.info("f: %s", f) hypes = commentjson.load(f) utils.load_plugins() if tf.app.flags.FLAGS.mod is not None: import ast mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod) dict_merge(hypes, mod_dict) if 'TV_DIR_RUNS' in os.environ: os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg') utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes) utils._add_paths_to_sys(hypes) train.maybe_download_and_extract(hypes) logging.info("Initialize training folder") # TODO initialize the train folder and copy some arg files to it--------------------------------yu train.initialize_training_folder(hypes) logging.info("Start training") train.do_training(hypes, trainable_scopes=FLAGS.trainable_scopes, exclude_scopes=FLAGS.checkpoint_exclude_scopes, checkpoint_path=FLAGS.checkpoint_path)
def main(_): """Run main function.""" if FLAGS.hypes is None: logging.error("No hypes are given.") logging.error("Usage: tv-train --hypes hypes.json") exit(1) with open(tf.app.flags.FLAGS.hypes, 'r') as f: logging.info("f: %s", f) hypes = json.load(f) utils.set_gpus_to_use() utils.load_plugins() utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes) logging.info("Initialize training folder") initialize_training_folder(hypes) maybe_download_and_extract(hypes) logging.info("Start training") do_training(hypes)