コード例 #1
0
ファイル: trace_test.py プロジェクト: colinsongf/google-qanet
 def testTrace(self, cls, value, kwargs):
   def tracer(f, *fargs, **fkwargs):
     name = fkwargs.get("name", None)
     if name == "rv2":
       fkwargs["value"] = value
     return f(*fargs, **fkwargs)
   rv1 = cls(value=value, name="rv1", **kwargs)
   with ed.trace(tracer):
     rv2 = cls(name="rv2", **kwargs)
   rv1_value, rv2_value = self.evaluate([rv1.value, rv2.value])
   self.assertEqual(rv1_value, value)
   self.assertEqual(rv2_value, value)
コード例 #2
0
ファイル: trace_test.py プロジェクト: colinsongf/google-qanet
  def testTraceForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.traceable(f)(*args, **kwargs)

    def set_xy(f, *args, **kwargs):
      if kwargs.get("name") == "x":
        kwargs["value"] = 1.
      if kwargs.get("name") == "y":
        kwargs["value"] = 0.42
      return ed.traceable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.trace(set_xy):
      with ed.trace(double):
        z = model()

    value = 2. * 1. + 2. * 0.42
    z_value = self.evaluate(z)
    self.assertAlmostEqual(z_value, value, places=5)
コード例 #3
0
ファイル: trace_test.py プロジェクト: colinsongf/google-qanet
  def testTapeInnerForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.traceable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.trace(double):
      with ed.tape() as model_tape:
        output = model()

    expected_value, actual_value = self.evaluate([
        model_tape["x"] + model_tape["y"], output])
    self.assertEqual(list(six.iterkeys(model_tape)), ["x", "y"])
    self.assertEqual(expected_value, actual_value)
コード例 #4
0
ファイル: trace_test.py プロジェクト: colinsongf/google-qanet
  def testTraceException(self):
    def f():
      raise NotImplementedError()
    def tracer(f, *fargs, **fkwargs):
      return f(*fargs, **fkwargs)

    with ed.get_next_tracer() as top_tracer:
      old_tracer = top_tracer

    with self.assertRaises(NotImplementedError):
      with ed.trace(tracer):
        f()

    with ed.get_next_tracer() as top_tracer:
      new_tracer = top_tracer

    self.assertEqual(old_tracer, new_tracer)
コード例 #5
0
ファイル: trace_test.py プロジェクト: colinsongf/google-qanet
 def transformed_model():
   with ed.trace(trivial_tracer):
     model()