def wrapped(*args, **kwargs): latents = harvest.reap( conditional(f, input_names), tag=RANDOM_VARIABLE)(*args, **kwargs) outputs = [latents[name] for name in output_names] latents = { name: harvest.sow(value, tag=RANDOM_VARIABLE, name=name, mode='strict') for name, value in latents.items() if name not in output_names } if single_output: outputs = outputs[0] return primitive.tie_in(latents, outputs)
def test_sow_happens_in_forward_pass(self): def f(x, y): return x, harvest.sow(x, name='x', tag='foo') * y vals = harvest.reap(core.inverse(f), tag='foo')(1., 1.) self.assertDictEqual(vals, dict(x=1.))
def joint_sample(f: Program) -> Program: """Returns a program that outputs a dictionary of latent random variable samples.""" return harvest.reap(f, tag=RANDOM_VARIABLE)
def init_f(init_key, *args, **kwargs): return harvest.reap( f_, tag=module.VARIABLE, exclusive=True)(init_key, *args, **kwargs)