Exemple #1
0
 def tracer(call, *args, **kwargs):
   if not args or not isinstance(args[0], Entity):
     return ed.traceable(call)(*args, **kwargs)
   entity_handle = args[0]
   entity_name = kwargs.get('name')
   if entity_name in entity_handles:
     entity_name = f'{entity_name}_{id(entity_handle)}'
   entity_handles[entity_name] = entity_handle
   return ed.traceable(call)(*args, **kwargs)
 def log_prob_tracer(rv_constructor, *args, **kwargs):
     nonlocal rv_index
     if rv_index >= num_rvs:
         raise RuntimeError(
             "function created {} random variables the first time it was called,"
             " but created more the second time".format(num_rvs))
     field_name, observed_value = observed_output_values_by_rv_order[
         rv_index]
     rv_index += 1
     rv = rv_constructor(*args, **kwargs)
     logp = rv.distribution.log_prob(observed_value)
     log_probs[field_name] = logp
     kwargs["value"] = observed_value
     # be nice to higher tracers
     return ed.traceable(rv_constructor)(*args, **kwargs)
Exemple #3
0
 def double(f, *args, **kwargs):
     return 2. * ed.traceable(f)(*args, **kwargs)
Exemple #4
0
 def set_xy(f, *args, **kwargs):
   if kwargs.get("name") == "x":
     kwargs["value"] = 1.
   if kwargs.get("name") == "y":
     kwargs["value"] = 0.42
   return ed.traceable(f)(*args, **kwargs)
Exemple #5
0
 def trivial_tracer(fn, *args, **kwargs):
   # A tracer that does nothing.
   return ed.traceable(fn)(*args, **kwargs)