def export(context, model, output_path): """Convert a pytext model snapshot to a caffe2 model.""" config = parse_config(Mode.TRAIN, context.obj.load_config()) model = model or config.save_snapshot_path output_path = output_path or config.export_caffe2_path print(f"Exporting {model} to {output_path}") export_saved_model_to_caffe2(model, output_path)
def predict(context, exported_model): """Start a repl executing examples against a caffe2 model.""" config = parse_config(Mode.TRAIN, context.obj.load_config()) print(f"Loading model from {exported_model or config.export_caffe2_path}") predictor = create_predictor(config, exported_model) print(f"Model loaded, reading example JSON from stdin") for line in sys.stdin.readlines(): input = json.loads(line) predictions = predictor(input) pprint.pprint(predictions)
def test_load_all_configs(self): """ Try an load all the json files in pytext to make sure we didn't break the config API. """ print() for filename in glob.iglob("pytext/**/*.json", recursive=True): if filename in EXCLUDE_JSON: continue print("--- loading:", filename) with open(filename) as file: config_json = json.load(file) config = parse_config(config_json) self.assertIsNotNone(config)
def train(context): """Train a model and save the best snapshot.""" config = parse_config(context.obj.load_config()) print("\n===Starting training...") if config.distributed_world_size == 1: train_model(config) else: train_model_distributed(config) print("\n=== Starting testing...") test_model_from_snapshot_path( config.save_snapshot_path, config.use_cuda_if_available, config.task.data_handler.test_path, )
def load_config(): # Cache the config object so it can be accessed multiple times if not hasattr(context.obj, "config"): if config_module: context.obj.config = import_module(config_module).config else: if config_file: with open(config_file) as file: config = json.load(file) elif config_json: config = json.loads(config_json) else: click.echo("No config file specified, reading from stdin") config = json.load(sys.stdin) context.obj.config = parse_config(config) return context.obj.config
def train(context): """Train a model and save the best snapshot.""" config_json = context.obj.load_config() config = parse_config(Mode.TRAIN, config_json) print("\n===Starting training...") if config.distributed_world_size == 1: train_model(config) else: train_model_distributed(config) print("\n=== Starting testing...") test_config = TestConfig( load_snapshot_path=config.save_snapshot_path, test_path=config.task.data_handler.test_path, use_cuda_if_available=config.use_cuda_if_available, ) test_model(test_config)
def load_config(): # Cache the config object so it can be accessed multiple times if not hasattr(context.obj, "config"): if config_module: context.obj.config = import_module(config_module).config else: if config_file: with PathManager.open(config_file) as file: config = json.load(file) elif config_json: config = json.loads(config_json) else: eprint("No config file specified, reading from stdin") config = json.load(sys.stdin) # before parsing the config, include the custom components for path in config.get("include_dirs", []): add_include(path.rstrip("/")) context.obj.config = parse_config(config) return context.obj.config
def DISABLED_test_load_all_configs(self): """ Try an load all the json files in pytext to make sure we didn't break the config API. """ print() exclude_json_path = {*[get_absolute_path(p) for p in EXCLUDE_JSON]} exclude_json_dir = {*[get_absolute_path(p) for p in EXCLUDE_DIRS]} for filename in glob.iglob("./**/*.json", recursive=True): filepath = get_absolute_path(filename) if filepath in exclude_json_path: continue if any(filepath.startswith(prefix) for prefix in exclude_json_dir): continue print("--- loading:", filepath) with open(filepath) as file: config_json = json.load(file) config = parse_config(config_json) self.assertIsNotNone(config)
def test(context, model_snapshot, test_path, use_cuda): """Test a trained model snapshot. If model-snapshot is provided, the models and configuration will then be loaded from the snapshot rather than any passed config file. Otherwise, a config file will be loaded. """ if model_snapshot: print(f"Loading model snapshot and config from {model_snapshot}") if use_cuda is None: raise Exception( "if --model-snapshot is set --use-cuda/--no-cuda must be set") else: print(f"No model snapshot provided, loading from config") config = parse_config(context.obj.load_config()) model_snapshot = config.save_snapshot_path use_cuda = config.use_cuda_if_available print(f"Configured model snapshot {model_snapshot}") print("\n=== Starting testing...") test_model_from_snapshot_path(model_snapshot, use_cuda, test_path)
def test(context): """Test a trained model snapshot.""" config_json = context.obj.load_config() config = parse_config(Mode.TEST, config_json) print("\n=== Starting testing...") test_model(config)