def testTrace(self, cls, value, kwargs): def tracer(f, *fargs, **fkwargs): name = fkwargs.get("name", None) if name == "rv2": fkwargs["value"] = value return f(*fargs, **fkwargs) rv1 = cls(value=value, name="rv1", **kwargs) with ed.trace(tracer): rv2 = cls(name="rv2", **kwargs) rv1_value, rv2_value = self.evaluate([rv1.value, rv2.value]) self.assertEqual(rv1_value, value) self.assertEqual(rv2_value, value)
def testTraceForwarding(self): def double(f, *args, **kwargs): return 2. * ed.traceable(f)(*args, **kwargs) def set_xy(f, *args, **kwargs): if kwargs.get("name") == "x": kwargs["value"] = 1. if kwargs.get("name") == "y": kwargs["value"] = 0.42 return ed.traceable(f)(*args, **kwargs) def model(): x = ed.Normal(loc=0., scale=1., name="x") y = ed.Normal(loc=x, scale=1., name="y") return x + y with ed.trace(set_xy): with ed.trace(double): z = model() value = 2. * 1. + 2. * 0.42 z_value = self.evaluate(z) self.assertAlmostEqual(z_value, value, places=5)
def testTapeInnerForwarding(self): def double(f, *args, **kwargs): return 2. * ed.traceable(f)(*args, **kwargs) def model(): x = ed.Normal(loc=0., scale=1., name="x") y = ed.Normal(loc=x, scale=1., name="y") return x + y with ed.trace(double): with ed.tape() as model_tape: output = model() expected_value, actual_value = self.evaluate([ model_tape["x"] + model_tape["y"], output]) self.assertEqual(list(six.iterkeys(model_tape)), ["x", "y"]) self.assertEqual(expected_value, actual_value)
def testTraceException(self): def f(): raise NotImplementedError() def tracer(f, *fargs, **fkwargs): return f(*fargs, **fkwargs) with ed.get_next_tracer() as top_tracer: old_tracer = top_tracer with self.assertRaises(NotImplementedError): with ed.trace(tracer): f() with ed.get_next_tracer() as top_tracer: new_tracer = top_tracer self.assertEqual(old_tracer, new_tracer)
def transformed_model(): with ed.trace(trivial_tracer): model()