Beispiel #1
0
 def test_setattr(self):
     mod = deferred.Deferred(ExampleModule)
     mod()
     new_w = jnp.ones_like(mod.w)
     mod.w = new_w
     self.assertIs(mod.w, new_w)
     self.assertIs(mod.target.w, new_w)  # pytype: disable=attribute-error
Beispiel #2
0
 def test_only_computes_target_once(self):
     target = ExampleModule()
     targets = [target]
     mod = deferred.Deferred(targets.pop)  # pytype: disable=wrong-arg-types
     for _ in range(10):
         # If target was recomputed more than once pop should fail.
         self.assertIs(mod.target, target)
         self.assertEmpty(targets)
Beispiel #3
0
 def test_setattr_on_target(self):
     mod = deferred.Deferred(ExampleModule)
     mod()
     w = jnp.ones_like(mod.w)
     mod.w = None
     # Assigning to the target directly should reflect in the parent.
     mod.target.w = w
     self.assertIs(mod.w, w)
     self.assertIs(mod.target.w, w)
Beispiel #4
0
 def test_str(self):
     m = ExampleModule()
     d = deferred.Deferred(lambda: m)
     self.assertEqual("Deferred(%s)" % m, str(d))
Beispiel #5
0
 def test_alternative_forward_call_type_error(self):
     mod = deferred.Deferred(AlternativeForwardModule,
                             call_methods=("forward", ))
     msg = "'AlternativeForwardModule' object is not callable"
     with self.assertRaisesRegex(TypeError, msg):
         mod()
Beispiel #6
0
 def test_alternative_forward(self):
     mod = deferred.Deferred(AlternativeForwardModule,
                             call_methods=("forward", ))
     self.assertEqual(mod.forward(), 42)
Beispiel #7
0
 def test_delattr(self):
     mod = deferred.Deferred(ExampleModule)
     mod()
     self.assertTrue(hasattr(mod.target, "w"))
     del mod.w
     self.assertFalse(hasattr(mod.target, "w"))
Beispiel #8
0
 def test_getattr(self):
     mod = deferred.Deferred(ExampleModule)
     mod()
     self.assertIs(mod.w, mod.target.w)  # pytype: disable=attribute-error
Beispiel #9
0
 def test_attr_forwarding_fails_before_construction(self):
     mod = deferred.Deferred(ExampleModule)
     with self.assertRaises(AttributeError):
         getattr(mod, "foo")
Beispiel #10
0
 def test_target(self):
     target = ExampleModule()
     mod = deferred.Deferred(lambda: target)
     self.assertIs(mod.target, target)
Beispiel #11
0
 def __init__(self, name="outer"):
     super().__init__(name=name)
     self.deferred = deferred.Deferred(ExampleModule)
Beispiel #12
0
 def test_deferred_naming_name_scope(self):
     with module.name_scope("foo"):
         d = deferred.Deferred(ExampleModule)
     mod = d.target
     self.assertEqual(mod.module_name, "foo/example_module")