def test_config_no_interpolation(d): """Test that interpolation is correctly preserved. The parametrized value is the final divider (${a.b} vs. ${a:b}). Both should now work and be valid. The double {{ }} in the config strings are required to prevent the references from being interpreted as an actual f-string variable. """ c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}""" config = Config().from_str(c_str, interpolate=False) assert not config.is_interpolated assert config["c"]["d"] == f"${{a{d}b}}" assert config["c"]["e"] == f'"hello${{a{d}b}}"' assert config["c"]["f"] == "${a}" config2 = Config().from_str(config.to_str(), interpolate=True) assert config2.is_interpolated assert config2["c"]["d"] == 1 assert config2["c"]["e"] == "hello1" assert config2["c"]["f"] == {"b": 1} config3 = config.interpolate() assert config3.is_interpolated assert config3["c"]["d"] == 1 assert config3["c"]["e"] == "hello1" assert config3["c"]["f"] == {"b": 1} # Bad non-serializable value cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}} with pytest.raises(ConfigValidationError): Config(cfg).interpolate()
def substitute_project_variables( config: Dict[str, Any], overrides: Dict[str, Any] = SimpleFrozenDict(), key: str = "vars", env_key: str = "env", ) -> Dict[str, Any]: """Interpolate variables in the project file using the config system. config (Dict[str, Any]): The project config. overrides (Dict[str, Any]): Optional config overrides. key (str): Key containing variables in project config. env_key (str): Key containing environment variable mapping in project config. RETURNS (Dict[str, Any]): The interpolated project config. """ config.setdefault(key, {}) config.setdefault(env_key, {}) # Substitute references to env vars with their values for config_var, env_var in config[env_key].items(): config[env_key][config_var] = _parse_override( os.environ.get(env_var, "")) # Need to put variables in the top scope again so we can have a top-level # section "project" (otherwise, a list of commands in the top scope wouldn't) # be allowed by Thinc's config system cfg = Config({ "project": config, key: config[key], env_key: config[env_key] }) cfg = Config().from_str(cfg.to_str(), overrides=overrides) interpolated = cfg.interpolate() return dict(interpolated["project"])
def test_config_roundtrip_disk(): cfg = Config().from_str(OPTIMIZER_CFG) with make_tempdir() as path: cfg_path = path / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_overrides(): overrides_nested = {"nlp": {"lang": "de", "pipeline": ["tagger"]}} overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]} # load_model from config with overrides passed directly to Config config = Config().from_str(nlp_config_string, overrides=overrides_dot) nlp = load_model_from_config(config, auto_fill=True) assert isinstance(nlp, German) assert nlp.pipe_names == ["tagger"] # Serialized roundtrip with config passed in base_config = Config().from_str(nlp_config_string) base_nlp = load_model_from_config(base_config, auto_fill=True) assert isinstance(base_nlp, English) assert base_nlp.pipe_names == ["tok2vec", "tagger"] with make_tempdir() as d: base_nlp.to_disk(d) nlp = spacy.load(d, config=overrides_nested) assert isinstance(nlp, German) assert nlp.pipe_names == ["tagger"] with make_tempdir() as d: base_nlp.to_disk(d) nlp = spacy.load(d, config=overrides_dot) assert isinstance(nlp, German) assert nlp.pipe_names == ["tagger"] with make_tempdir() as d: base_nlp.to_disk(d) nlp = spacy.load(d) assert isinstance(nlp, English) assert nlp.pipe_names == ["tok2vec", "tagger"]
def test_config_deep_merge(): config = {"a": "hello", "b": {"c": "d"}} defaults = {"a": "world", "b": {"c": "e", "f": "g"}} merged = Config(defaults).merge(config) assert len(merged) == 2 assert merged["a"] == "hello" assert merged["b"] == {"c": "d", "f": "g"} config = {"a": "hello", "b": {"@test": "x", "foo": 1}} defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2} assert merged["c"] == 100 config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100} defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@test": "x", "foo": 1} assert merged["c"] == 100 # Test that leaving out the factory just adds to existing config = {"a": "hello", "b": {"foo": 1}, "c": 100} defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2} assert merged["c"] == 100
def test_config_no_interpolation_registry(): config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}""" config = Config().from_str(config_str, interpolate=False) assert not config.is_interpolated assert config["b"]["evil"] == "${a:bad}" assert config["c"]["d"] == "${b}" resolved, filled = my_registry.resolve(config) assert resolved["b"] == "scratch!" assert resolved["c"]["d"] == "scratch!" assert filled["b"]["evil"] == "${a:bad}" assert filled["b"]["cute"] is True assert filled["c"]["d"] == "${b}" interpolated = filled.interpolate() assert interpolated.is_interpolated assert interpolated["b"]["evil"] is True assert interpolated["c"]["d"] == interpolated["b"] config = Config().from_str(config_str, interpolate=True) assert config.is_interpolated resolved, filled = my_registry.resolve(config) assert resolved["b"] == "scratch!" assert resolved["c"]["d"] == "scratch!" assert filled["b"]["evil"] is True assert filled["c"]["d"] == filled["b"] # Resolving a non-interpolated filled config config = Config().from_str(config_str, interpolate=False) assert not config.is_interpolated filled = my_registry.fill_config(config) assert not filled.is_interpolated assert filled["c"]["d"] == "${b}" resolved = my_registry.make_from_config(filled) assert resolved["c"]["d"] == "scratch!"
def test_config_auto_fill_extra_fields(): config = Config({"nlp": {"lang": "en"}, "training": {}}) assert load_model_from_config(config, auto_fill=True) config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}}) nlp = load_model_from_config(config, auto_fill=True, validate=False) assert "extra" not in nlp.config["training"] # Make sure the config generated is valid load_model_from_config(nlp.config)
def test_config_from_str_invalid_section(): config_str = """[a]\nb = null\n\n[a.b]\nc = 1""" with pytest.raises(ConfigValidationError): Config().from_str(config_str) config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1""" with pytest.raises(ConfigValidationError): Config().from_str(config_str)
def test_deepcopy_config(): config = Config({"a": 1, "b": {"c": 2, "d": 3}}) copied = config.copy() # Same values but not same object assert config == copied assert config is not copied # Check for error if value can't be pickled/deepcopied config = Config({"a": 1, "b": numpy}) with pytest.raises(ValueError): config.copy()
def test_cant_expand_undefined_block(cfg, is_valid): """Test that you can't expand a block that hasn't been created yet. This comes up when you typo a name, and if we allow expansion of undefined blocks, it's very hard to create good errors for those typos. """ if is_valid: Config().from_str(cfg) else: with pytest.raises(ConfigValidationError): Config().from_str(cfg)
def test_config_optional_sections(): config = Config().from_str(nlp_config_string) config = DEFAULT_CONFIG.merge(config) assert "pretraining" not in config filled = registry.fill(config, schema=ConfigSchema, validate=False) # Make sure that optional "pretraining" block doesn't default to None, # which would (rightly) cause error because it'd result in a top-level # key that's not a section (dict). Note that the following roundtrip is # also how Config.interpolate works under the hood. new_config = Config().from_str(filled.to_str()) assert new_config["pretraining"] == {}
def test_config_to_str_invalid_defaults(): """Test that an error is raised if a config contains top-level keys without a section that would otherwise be interpreted as [DEFAULT] (which causes the values to be included in *all* other sections). """ cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}} with pytest.raises(ConfigValidationError): Config(cfg).to_str() config_str = "[DEFAULT]\none = 1" with pytest.raises(ConfigValidationError): Config().from_str(config_str)
def main(path: Optional[Path] = None, out_dir: Optional[Path] = None): if prefer_gpu(): print("Using gpu!") use_pytorch_for_gpu_memory() # You can edit the CONFIG string within the file, or copy it out to # a separate file and pass in the path. if path is None: config = Config().from_str(CONFIG) else: config = Config().from_disk(path) # make_from_config constructs objects whenever you have blocks with an @ key. # In the optimizer block we write @optimizers = "Adam.v1". This tells Thinc # to use registry.optimizers to fetch the "Adam.v1" function. You can # register your own functions as well and build up trees of objects. C = thinc.registry.make_from_config(config) words_per_subbatch = C["training"]["words_per_subbatch"] n_epoch = C["training"]["n_epoch"] batch_size = C["training"]["batch_size"] model = C["model"] optimizer = C["optimizer"] calculate_loss = SequenceCategoricalCrossentropy() (train_X, train_Y), (dev_X, dev_Y) = ml_datasets.ud_ancora_pos_tags() # Convert the outputs to cupy (if we're using that) train_Y = list(map(model.ops.asarray, train_Y)) dev_Y = list(map(model.ops.asarray, dev_Y)) # Pass in a small batch of data, to fill in missing shapes model.initialize(X=train_X[:5], Y=train_Y[:5]) for epoch in range(n_epoch): # Transformers often learn best with large batch sizes -- larger than # fits in GPU memory. But you don't have to backprop the whole batch # at once. Here we consider the "logical" batch size (number of examples # per update) separately from the physical batch size. batches = model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True) for outer_batch in tqdm.tqdm(batches, leave=False): # For the physical batch size, what we care about is the number # of words (considering padding too). We also want to sort by # length, for efficiency. for batch in minibatch_by_words(outer_batch, words_per_subbatch): inputs, truths = zip(*batch) guesses, backprop = model(inputs, is_train=True) backprop(calculate_loss.get_grad(guesses, truths)) # At the end of the batch, we call the optimizer with the accumulated # gradients, and advance the learning rate schedules. model.finish_update(optimizer) optimizer.step_schedules() # You might want to evaluate more often than once per epoch; that's up # to you. score = evaluate_sequences(model, dev_X, dev_Y, 128) print(epoch, f"{score:.3f}") if out_dir: model.to_disk(out_dir / f"{epoch}.bin")
def test_config_serialize_custom_sort(section_order, expected_str, expected_keys): cfg = { "j": {"k": 6}, "a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}}, "h": {"i": 5}, } cfg_str = Config(cfg).to_str() assert Config(cfg, section_order=section_order).to_str() == expected_str keys = list(Config(section_order=section_order).from_str(cfg_str).keys()) assert keys == expected_keys keys = list(Config(cfg, section_order=section_order).keys()) assert keys == expected_keys
def test_config_is_interpolated(): """Test that a config object correctly reports whether it's interpolated.""" config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}""" config = Config().from_str(config_str, interpolate=False) assert not config.is_interpolated config = config.merge(Config({"x": {"y": "z"}})) assert not config.is_interpolated config = Config(config) assert not config.is_interpolated config = config.interpolate() assert config.is_interpolated config = config.merge(Config().from_str(config_str, interpolate=False)) assert not config.is_interpolated
def test_config_deep_merge_variables(): config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}""" defaults_str = """[a]\nx = 100\n\n[d]\ny = 500""" config = Config().from_str(config_str, interpolate=False) defaults = Config().from_str(defaults_str) merged = defaults.merge(config) assert merged["a"] == {"b": 1, "c": 2, "x": 100} assert merged["d"] == {"e": "${a:b}", "y": 500} assert merged.interpolate()["d"] == {"e": 1, "y": 500} # With variable in defaults: overwritten by new value config = Config().from_str("""[a]\nb= 1\nc = 2""") defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False) merged = defaults.merge(config) assert merged["a"]["c"] == 2
def test_positional_args_to_from_string(): cfg = """[a]\nb = 1\n* = ["foo","bar"]""" assert Config().from_str(cfg).to_str() == cfg cfg = """[a]\nb = 1\n\n[a.*.foo]\ntest = 1\n\n[a.*.bar]\ntest = 2""" assert Config().from_str(cfg).to_str() == cfg @my_registry.cats("catsie.v666") def catsie_666(*args, meow=False): return args cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]""" filled = my_registry.fill_config(Config().from_str(cfg)).to_str() assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false""" assert my_registry.make_from_config(Config().from_str(cfg)) == {"a": ("foo", "bar")} cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1""" filled = my_registry.fill_config(Config().from_str(cfg)).to_str() assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1""" assert my_registry.make_from_config(Config().from_str(cfg)) == {"a": ({"x": 1},)} @my_registry.cats("catsie.v777") def catsie_777(y: int = 1): return "meow" * y cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\"""" filled = my_registry.fill_config(Config().from_str(cfg)).to_str() expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1""" assert filled == expected cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3""" result = my_registry.make_from_config(Config().from_str(cfg)) assert result == {"a": ("meowmeowmeow",)}
def test_config_to_str_simple_promises(): """Test that references to function registries without arguments are serialized inline as dict.""" config_str = """[section]\nsubsection = {"@registry":"value"}""" config = Config().from_str(config_str) assert config["section"]["subsection"]["@registry"] == "value" assert config.to_str() == config_str
def test_issue5551(textcat_config): """Test that after fixing the random seed, the results of the pipeline are truly identical""" component = "textcat" pipe_cfg = Config().from_str(textcat_config) results = [] for i in range(3): fix_random_seed(0) nlp = English() text = "Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g." annots = {"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}} pipe = nlp.add_pipe(component, config=pipe_cfg, last=True) for label in set(annots["cats"]): pipe.add_label(label) # Train nlp.initialize() doc = nlp.make_doc(text) nlp.update([Example.from_dict(doc, annots)]) # Store the result of each iteration result = pipe.model.predict([doc]) results.append(result[0]) # All results should be the same because of the fixed seed assert len(results) == 3 ops = get_current_ops() assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1]), decimal=5) assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2]), decimal=5)
def test_replace_listeners(): orig_config = Config().from_str(cfg_string) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})] nlp.initialize(lambda: examples) tok2vec = nlp.get_pipe("tok2vec") tagger = nlp.get_pipe("tagger") assert isinstance(tagger.model.layers[0], Tok2VecListener) assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0] assert (nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2") assert (nlp.config["components"]["tagger"]["model"]["tok2vec"] ["@architectures"] == "spacy.Tok2VecListener.v1") nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"]) assert not isinstance(tagger.model.layers[0], Tok2VecListener) t2v_cfg = nlp.config["components"]["tok2vec"]["model"] assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2" assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg with pytest.raises(ValueError): nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"]) with pytest.raises(ValueError): nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"]) with pytest.raises(ValueError): nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"]) with pytest.raises(ValueError): nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
def test_tok2vec_listener(): orig_config = Config().from_str(cfg_string) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) assert nlp.pipe_names == ["tok2vec", "tagger"] tagger = nlp.get_pipe("tagger") tok2vec = nlp.get_pipe("tok2vec") tagger_tok2vec = tagger.model.get_ref("tok2vec") assert isinstance(tok2vec, Tok2Vec) assert isinstance(tagger_tok2vec, Tok2VecListener) train_examples = [] for t in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) for tag in t[1]["tags"]: tagger.add_label(tag) # Check that the Tok2Vec component finds it listeners assert tok2vec.listeners == [] optimizer = nlp.initialize(lambda: train_examples) assert tok2vec.listeners == [tagger_tok2vec] for i in range(5): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) doc = nlp("Running the pipeline as a whole.") doc_tensor = tagger_tok2vec.predict([doc])[0] assert_equal(doc.tensor, doc_tensor) # TODO: should this warn or error? nlp.select_pipes(disable="tok2vec") assert nlp.pipe_names == ["tagger"] nlp("Running the pipeline with the Tok2Vec component disabled.")
def test_serialize_config_language_specific(): """Test that config serialization works as expected with language-specific factories.""" name = "test_serialize_config_language_specific" @English.factory(name, default_config={"foo": 20}) def custom_factory(nlp: Language, name: str, foo: int): return lambda doc: doc nlp = Language() assert not nlp.has_factory(name) nlp = English() assert nlp.has_factory(name) nlp.add_pipe(name, config={"foo": 100}, name="bar") pipe_config = nlp.config["components"]["bar"] assert pipe_config["foo"] == 100 assert pipe_config["factory"] == name with make_tempdir() as d: nlp.to_disk(d) nlp2 = spacy.load(d) assert nlp2.has_factory(name) assert nlp2.pipe_names == ["bar"] assert nlp2.get_pipe_meta("bar").factory == name pipe_config = nlp2.config["components"]["bar"] assert pipe_config["foo"] == 100 assert pipe_config["factory"] == name config = Config().from_str(nlp2.config.to_str()) config["nlp"]["lang"] = "de" with pytest.raises(ValueError): # German doesn't have a factory, only English does load_model_from_config(config)
def test_serialize_config_missing_pipes(): config = Config().from_str(nlp_config_string) config["components"].pop("tok2vec") assert "tok2vec" in config["nlp"]["pipeline"] assert "tok2vec" not in config["components"] with pytest.raises(ValueError): load_model_from_config(config, auto_fill=True)
def test_tok2vec_listeners_textcat(): orig_config = Config().from_str(cfg_string_multi_textcat) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) assert nlp.pipe_names == ["tok2vec", "textcat_multilabel", "tagger"] tagger = nlp.get_pipe("tagger") textcat = nlp.get_pipe("textcat_multilabel") tok2vec = nlp.get_pipe("tok2vec") tagger_tok2vec = tagger.model.get_ref("tok2vec") textcat_tok2vec = textcat.model.get_ref("tok2vec") assert isinstance(tok2vec, Tok2Vec) assert isinstance(tagger_tok2vec, Tok2VecListener) assert isinstance(textcat_tok2vec, Tok2VecListener) train_examples = [] for t in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) optimizer = nlp.initialize(lambda: train_examples) for i in range(50): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) docs = list(nlp.pipe(["Eat blue ham", "I like green eggs"])) cats0 = docs[0].cats assert cats0["preference"] < 0.1 assert cats0["imperative"] > 0.9 cats1 = docs[1].cats assert cats1["preference"] > 0.1 assert cats1["imperative"] < 0.9 assert [t.tag_ for t in docs[0]] == ["V", "J", "N"] assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
def test_issue7055(): """Test that fill-config doesn't turn sourced components into factories.""" source_cfg = { "nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger"]}, "components": { "tok2vec": {"factory": "tok2vec"}, "tagger": {"factory": "tagger"}, }, } source_nlp = English.from_config(source_cfg) with make_tempdir() as dir_path: # We need to create a loadable source pipeline source_path = dir_path / "test_model" source_nlp.to_disk(source_path) base_cfg = { "nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]}, "components": { "tok2vec": {"source": str(source_path)}, "tagger": {"source": str(source_path)}, "ner": {"factory": "ner"}, }, } base_cfg = Config(base_cfg) base_path = dir_path / "base.cfg" base_cfg.to_disk(base_path) output_path = dir_path / "config.cfg" fill_config(output_path, base_path, silent=True) filled_cfg = load_config(output_path) assert filled_cfg["components"]["tok2vec"]["source"] == str(source_path) assert filled_cfg["components"]["tagger"]["source"] == str(source_path) assert filled_cfg["components"]["ner"]["factory"] == "ner" assert "model" in filled_cfg["components"]["ner"]
def test_transformer_pipeline_textcat(): """Test that a pipeline with just a transformer+textcat runs and trains properly. This used to throw an error because of shape inference issues - cf https://github.com/explosion/spaCy/issues/6401""" orig_config = Config().from_str(cfg_string) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) assert nlp.pipe_names == ["transformer", "textcat"] train_examples = [] for text, annotations in TRAIN_DATA: train_examples.append( Example.from_dict(nlp.make_doc(text), annotations)) optimizer = nlp.initialize(get_examples=lambda: train_examples) for i in range(2): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) doc = nlp("We're interested at underwater basket weaving.") cats1 = doc.cats # ensure IO goes OK with make_tempdir() as d: file_path = d / "trained_nlp" nlp.to_disk(file_path) nlp2 = spacy.load(file_path) doc2 = nlp2("We're interested at underwater basket weaving.") cats2 = doc2.cats assert cats1 == cats2
def main(pytorch: bool = False, gpu_id: int = -1): global CONFIG fix_random_seed(0) if gpu_id >= 0: require_gpu(gpu_id) print("Set GPU", gpu_id) backends = {"pytorch": pytorch} for name, use_backend in backends.items(): if not use_backend: print(f"Skipping {name}") continue set_backend(name, gpu_id) C = registry.resolve(Config().from_str(CONFIG)) model = C["model"] X, Y = get_dummy_data(**C["data"]) print("Copy to device") X = [model.ops.asarray(x) for x in X] Y = [model.ops.asarray(y) for y in Y] print("Begin init", len(X)) model.initialize(X=X[:5]) print("Pre-batch") n_words = sum(len(x) for x in X) X = [ model.layers[0].predict(batch) for batch in model.ops.minibatch(16, X) ] model.layers.pop(0) print("Start") start_time = timer() end_time = timer() print(name, n_words, end_time - start_time)
def test_config_validate_literal(parser_config_string): nlp = English() config = Config().from_str(parser_config_string) config["model"]["state_type"] = "nonsense" with pytest.raises(ConfigValidationError): nlp.add_pipe("parser", config=config) config["model"]["state_type"] = "ner" nlp.add_pipe("parser", config=config)
def test_config_to_str_order(): """Test that Config.to_str orders the sections.""" config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}} expected = ( "[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5" ) config = Config(config) assert config.to_str() == expected
def test_read_config(): byte_string = EXAMPLE_CONFIG.encode("utf8") cfg = Config().from_bytes(byte_string) assert cfg["optimizer"]["beta1"] == 0.9 assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1 assert cfg["pipeline"]["parser"]["factory"] == "parser" assert cfg["pipeline"]["parser"]["model"]["tok2vec"]["width"] == 128