def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs([ config.callbacks.tensorboard_log_dir, config.callbacks.checkpoint_dir ]) print('Create the data generator.') data_loader = SimpleMnistDataLoader(config) print('Create the model.') model = SimpleMnistModel(config) print('Create the trainer') trainer = SimpleMnistModelTrainer(model.model, data_loader.get_train_data(), config) print('Start training the model.') trainer.train()
def main_train(): """ 训练模型 :return: """ print '[INFO] 解析配置...' parser = None config = None try: args, parser = get_train_args() config = process_config(args.config) except Exception as e: print '[Exception] 配置无效, %s' % e if parser: parser.print_help() print '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json' exit(0) # config = process_config('configs/simple_mnist_config.json') print '[INFO] 加载数据...' dl = SimpleMnistDL(config=config) print '[INFO] 构造网络...' model = SimpleMnistModel(config=config) print '[INFO] 训练网络...' trainer = SimpleMnistTrainer( model=model.model, data=[dl.get_train_data(), dl.get_test_data()], config=config) trainer.train() print '[INFO] 训练完成...'
def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs([config.tensorboard_log_dir, config.checkpoint_dir, "val_test"]) print('Create the data generator.') if hasattr(config, "data_set"): if config.data_set == "face_data_77": data_loader = FaceLandmark77DataLoader(config) else: data_loader = SimpleMnistDataLoader(config) else: data_loader = SimpleMnistDataLoader(config) print('Create the model.') if hasattr(config, "model_name"): if config.model_name == "mobile_net": model = MobileNetV2Model(config) else: model = SimpleMnistModel(config) else: model = SimpleMnistModel(config) print(model.model.input_names) print([out.op.name for out in model.model.outputs]) return if hasattr(config, "best_checkpoint"): model.load(config.best_checkpoint) frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.model.outputs]) ckpt_path = Path(config.best_checkpoint) tf.train.write_graph(frozen_graph, str(ckpt_path.parent), ckpt_path.with_suffix(".pb").name, as_text=False)
def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs( [config.tensorboard_log_dir, config.checkpoint_dir, "val_test"]) print('Create the data generator.') if hasattr(config, "data_set"): if config.data_set == "face_data_77": data_loader = FaceLandmark77DataLoader(config) else: data_loader = SimpleMnistDataLoader(config) else: data_loader = SimpleMnistDataLoader(config) print('Create the model.') if hasattr(config, "model_name"): if config.model_name == "mobile_net": model = MobileNetV2Model(config) else: model = SimpleMnistModel(config) else: model = SimpleMnistModel(config) print("Create the Evaluater.") evaluator = FaceLandmarkEvaluater(model, data_loader, config) print("Start evaluate.") evaluator.evaluate()
def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs( [config.tensorboard_log_dir, config.checkpoint_dir, "val_test"]) print('Create the data generator.') if hasattr(config, "data_set"): if config.data_set == "face_data_77": data_loader = FaceLandmark77DataLoader(config) else: data_loader = SimpleMnistDataLoader(config) else: data_loader = SimpleMnistDataLoader(config) print('Create the model.') if hasattr(config, "model_name"): if config.model_name == "mobile_net": model = MobileNetV2Model(config) else: model = SimpleMnistModel(config) else: model = SimpleMnistModel(config) print(model.model.input_names) print([out.op.name for out in model.model.outputs]) return if hasattr(config, "best_checkpoint"): model.load(config.best_checkpoint) frozen_graph = freeze_session( K.get_session(), output_names=[out.op.name for out in model.model.outputs]) ckpt_path = Path(config.best_checkpoint) tf.train.write_graph(frozen_graph, str(ckpt_path.parent), ckpt_path.with_suffix(".pb").name, as_text=False)
def main_train(): """ 训练模型 :return: """ print('[INFO] 解析配置...') parser = None config = None # try: # args, parser = get_train_args() # config = process_config(args.config) # except Exception as e: # print('[Exception] 配置无效, %s' % e) # if parser: # parser.print_help() # print('[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json') # exit(0) config = process_config('configs/simple_mnist_config.json') np.random.seed(47) # 固定随机数 print('[INFO] 加载数据...') dl = SimpleMnistDL(config=config) print('[INFO] 构造网络...') model = SimpleMnistModel(config=config) print('[INFO] 训练网络...') trainer = SimpleMnistTrainer( model=model.model, data=[dl.get_train_data(), dl.get_test_data()], config=config) trainer.train() print('[INFO] 训练完成...')