def outer(): lifted, updater = lift.lift_with_state(inner.init) params, state = lifted(None) self.assertEmpty(params) out, state = inner.apply(params, state, None) updater.update(state) return out, state
def test_empty_lift_with_state(self, ignore_update): f = transform.transform_with_state(lambda: None) init_fn, updater = lift.lift_with_state(f.init) params, state = init_fn(None) self.assertEmpty(params) self.assertEmpty(state) if ignore_update: updater.ignore_update() else: updater.update({})
def f(): g = transform.transform_with_state(lambda: None) _, updater = lift.lift_with_state(g.init) transform.transform_with_state(lambda: updater_fn(updater)).init( None)
def __call__(self): lifted, updater = lift.lift_with_state(inner.init) params, state = lifted(None) out, state = inner.apply(params, state, None) updater.update(state) return out, state
def test_used_multiple_times(self, update_fn1, update_fn2): f = transform.transform_with_state(lambda: None) updater = lift.lift_with_state(f.init)[1] update_fn1(updater) with self.assertRaisesRegex(ValueError, "must only be used once"): update_fn2(updater)
def f() -> lift.LiftWithStateUpdater: f = transform.transform_with_state(lambda: None) return lift.lift_with_state(f.init)[1]