Exemple #1
0
 def outer():
     lifted, updater = lift.lift_with_state(inner.init)
     params, state = lifted(None)
     self.assertEmpty(params)
     out, state = inner.apply(params, state, None)
     updater.update(state)
     return out, state
Exemple #2
0
 def test_empty_lift_with_state(self, ignore_update):
     f = transform.transform_with_state(lambda: None)
     init_fn, updater = lift.lift_with_state(f.init)
     params, state = init_fn(None)
     self.assertEmpty(params)
     self.assertEmpty(state)
     if ignore_update:
         updater.ignore_update()
     else:
         updater.update({})
Exemple #3
0
 def f():
     g = transform.transform_with_state(lambda: None)
     _, updater = lift.lift_with_state(g.init)
     transform.transform_with_state(lambda: updater_fn(updater)).init(
         None)
Exemple #4
0
 def __call__(self):
     lifted, updater = lift.lift_with_state(inner.init)
     params, state = lifted(None)
     out, state = inner.apply(params, state, None)
     updater.update(state)
     return out, state
Exemple #5
0
 def test_used_multiple_times(self, update_fn1, update_fn2):
     f = transform.transform_with_state(lambda: None)
     updater = lift.lift_with_state(f.init)[1]
     update_fn1(updater)
     with self.assertRaisesRegex(ValueError, "must only be used once"):
         update_fn2(updater)
Exemple #6
0
 def f() -> lift.LiftWithStateUpdater:
     f = transform.transform_with_state(lambda: None)
     return lift.lift_with_state(f.init)[1]