def test_state_lazy_load(self, tmpdir): exp = Experiment(directory=tmpdir, state_class=EmptyState) s = EmptyState(experiment_object=exp) exp.root = s s.save() # Generate a new child state from state `s` r = EmptyState.new_state(s) r.save() exp.save() # Reload the two-state experiment # By default, slim=True exp = Experiment.restore(directory=tmpdir) for node, state in exp.graph.node_map.items(): assert state.slim_loaded # deflated with state.lazy_load(): assert not state.slim_loaded # fully loaded assert state.slim_loaded # deflated # Check behavior change as slim is set to False when exp is restored exp = Experiment.restore(directory=tmpdir, slim=False) for node, state in exp.graph.node_map.items(): assert not state.slim_loaded # deflated with state.lazy_load(): assert not state.slim_loaded # fully loaded # Note: lazy_load deflates the state even if it was initially # fully loaded! assert state.slim_loaded # deflated
def test_same_recipe(self, tmpdir): exp = Experiment(directory=tmpdir, state_class=EmptyState) exp_args = {"a": 10.0, "B": "notused"} root = exp.spawn_new_tree(**exp_args) op = OpRecipe(mul, 0.4, stochastic=False) new_state_a = op(root) new_state_b = op(root) new_state_c = op(root) # make sure this creates three new dask delayed states assert len(exp.leaves) == 3 for leafID, leaf in exp.leaves.items(): assert type(leaf) == Delayed # Since this is the same op on the root, these three ops would result # in 3 identical states. Check, therefore, that only one state is # created exp.run() exp = Experiment.restore(directory=tmpdir, state_class=EmptyState) assert len(exp.graph.node_map) == 2 # root + new state # Test actual restoring from cache by hand by replicating what # happens in `run_recipe` exp1 = Experiment(directory=tmpdir, state_class=EmptyState) root1 = exp1.spawn_new_tree(**exp_args) assert root1.get().from_cache # same root new_state1 = root1.get().new_state(op) assert new_state1.restore()
def test_new_state(self, tmpdir): """Generating a new state from a previous one using new_state should generate the right connection between the two states, which can be inspected through the setting of a parent_sha and then in the way the experiment graph is drawn when the StaticExperimentTree is reloaded. """ exp = Experiment(directory=tmpdir, state_class=EmptyState) s = EmptyState(experiment_object=exp) # Set it as the root of the experiment exp.root = s s.save() # Generate a new child state from state `s` r = EmptyState.new_state(s) assert r.parent_sha == s.sha() assert r.experiment_object == exp r.save() exp.save() exp = Experiment.restore(directory=tmpdir) # reload experiment # Test that the graph looks as expected with the connection assert len(exp.graph.nodes) == 2 assert exp.graph.edge_map[s.sha()] == set([r.sha()])
def test_function_safe_op(self, tmpdir): """Regardless of whether the op in the function fails or succeeds, the state it acts on gets deflated. """ exp = Experiment(directory=tmpdir, state_class=EmptyState) exp_args = {"a": 1, "B": 2} root = exp.spawn_new_tree(**exp_args) exp.run() exp = Experiment.restore(directory=tmpdir, state_class=EmptyState) assert exp.root.slim_loaded badfunction = Function(lambda s: print("d = {}".format(s.d))) with pytest.raises(AttributeError): s = badfunction._safe_op(exp.root) assert exp.root.slim_loaded goodfunction = Function(lambda s: print("a = {}".format(s.a))) s = goodfunction._safe_op(exp.root) assert exp.root.slim_loaded
def test_tag_filtering(self, tmpdir): exp = Experiment(directory=tmpdir, state_class=EmptyState) exp_args = {"a": 1.0, "B": 2.0, "c": 3.0} root = exp.spawn_new_tree(**exp_args) op_add = OpRecipe(add, 1.2) with exp.tag("ops"): with exp.tag("phase:mul"): x1 = OpRecipe(mul, 0.4)(root) x2 = OpRecipe(mul, 0.5)(root) with exp.tag("phase:add"): y1 = op_add(x1) y2 = op_add(x2) exp.run() exp = Experiment.restore(directory=tmpdir, state_class=EmptyState) assert len(exp.graph.nodes.filter("op*")) == 4 assert ( len( exp.graph.nodes.filter("phase:mul") | exp.graph.nodes.filter("phase:add") ) == 4 ) assert len(exp.graph.nodes.filter("!phase:mul")) == 3 assert ( len( exp.graph.nodes.filter("ops") & exp.graph.nodes.filter("!phase:add") ) == 2 ) # Cannot compose other objects with a nodeset with pytest.raises(TypeError): exp.graph.nodes.filter("phase:mul") | "hi" with pytest.raises(TypeError): exp.graph.nodes.filter("phase:*") & "!hi"