Example #1
0
def export(context, model, output_path):
    """Convert a pytext model snapshot to a caffe2 model."""
    config = parse_config(Mode.TRAIN, context.obj.load_config())
    model = model or config.save_snapshot_path
    output_path = output_path or config.export_caffe2_path
    print(f"Exporting {model} to {output_path}")
    export_saved_model_to_caffe2(model, output_path)
Example #2
0
def predict(context, exported_model):
    """Start a repl executing examples against a caffe2 model."""
    config = parse_config(Mode.TRAIN, context.obj.load_config())
    print(f"Loading model from {exported_model or config.export_caffe2_path}")
    predictor = create_predictor(config, exported_model)

    print(f"Model loaded, reading example JSON from stdin")
    for line in sys.stdin.readlines():
        input = json.loads(line)
        predictions = predictor(input)
        pprint.pprint(predictions)
Example #3
0
 def test_load_all_configs(self):
     """
         Try an load all the json files in pytext to make sure we didn't
         break the config API.
     """
     print()
     for filename in glob.iglob("pytext/**/*.json", recursive=True):
         if filename in EXCLUDE_JSON:
             continue
         print("--- loading:", filename)
         with open(filename) as file:
             config_json = json.load(file)
             config = parse_config(config_json)
             self.assertIsNotNone(config)
Example #4
0
def train(context):
    """Train a model and save the best snapshot."""
    config = parse_config(context.obj.load_config())
    print("\n===Starting training...")
    if config.distributed_world_size == 1:
        train_model(config)
    else:
        train_model_distributed(config)
    print("\n=== Starting testing...")
    test_model_from_snapshot_path(
        config.save_snapshot_path,
        config.use_cuda_if_available,
        config.task.data_handler.test_path,
    )
Example #5
0
 def load_config():
     # Cache the config object so it can be accessed multiple times
     if not hasattr(context.obj, "config"):
         if config_module:
             context.obj.config = import_module(config_module).config
         else:
             if config_file:
                 with open(config_file) as file:
                     config = json.load(file)
             elif config_json:
                 config = json.loads(config_json)
             else:
                 click.echo("No config file specified, reading from stdin")
                 config = json.load(sys.stdin)
             context.obj.config = parse_config(config)
     return context.obj.config
Example #6
0
def train(context):
    """Train a model and save the best snapshot."""
    config_json = context.obj.load_config()
    config = parse_config(Mode.TRAIN, config_json)
    print("\n===Starting training...")
    if config.distributed_world_size == 1:
        train_model(config)
    else:
        train_model_distributed(config)
    print("\n=== Starting testing...")
    test_config = TestConfig(
        load_snapshot_path=config.save_snapshot_path,
        test_path=config.task.data_handler.test_path,
        use_cuda_if_available=config.use_cuda_if_available,
    )
    test_model(test_config)
Example #7
0
 def load_config():
     # Cache the config object so it can be accessed multiple times
     if not hasattr(context.obj, "config"):
         if config_module:
             context.obj.config = import_module(config_module).config
         else:
             if config_file:
                 with PathManager.open(config_file) as file:
                     config = json.load(file)
             elif config_json:
                 config = json.loads(config_json)
             else:
                 eprint("No config file specified, reading from stdin")
                 config = json.load(sys.stdin)
             # before parsing the config, include the custom components
             for path in config.get("include_dirs", []):
                 add_include(path.rstrip("/"))
             context.obj.config = parse_config(config)
     return context.obj.config
Example #8
0
 def DISABLED_test_load_all_configs(self):
     """
         Try an load all the json files in pytext to make sure we didn't
         break the config API.
     """
     print()
     exclude_json_path = {*[get_absolute_path(p) for p in EXCLUDE_JSON]}
     exclude_json_dir = {*[get_absolute_path(p) for p in EXCLUDE_DIRS]}
     for filename in glob.iglob("./**/*.json", recursive=True):
         filepath = get_absolute_path(filename)
         if filepath in exclude_json_path:
             continue
         if any(filepath.startswith(prefix) for prefix in exclude_json_dir):
             continue
         print("--- loading:", filepath)
         with open(filepath) as file:
             config_json = json.load(file)
             config = parse_config(config_json)
             self.assertIsNotNone(config)
Example #9
0
def test(context, model_snapshot, test_path, use_cuda):
    """Test a trained model snapshot.

    If model-snapshot is provided, the models and configuration will then be loaded from
    the snapshot rather than any passed config file.
    Otherwise, a config file will be loaded.
    """
    if model_snapshot:
        print(f"Loading model snapshot and config from {model_snapshot}")
        if use_cuda is None:
            raise Exception(
                "if --model-snapshot is set --use-cuda/--no-cuda must be set")
    else:
        print(f"No model snapshot provided, loading from config")
        config = parse_config(context.obj.load_config())
        model_snapshot = config.save_snapshot_path
        use_cuda = config.use_cuda_if_available
        print(f"Configured model snapshot {model_snapshot}")
    print("\n=== Starting testing...")
    test_model_from_snapshot_path(model_snapshot, use_cuda, test_path)
Example #10
0
def test(context):
    """Test a trained model snapshot."""
    config_json = context.obj.load_config()
    config = parse_config(Mode.TEST, config_json)
    print("\n=== Starting testing...")
    test_model(config)