def tracer(call, *args, **kwargs): if not args or not isinstance(args[0], Entity): return ed.traceable(call)(*args, **kwargs) entity_handle = args[0] entity_name = kwargs.get('name') if entity_name in entity_handles: entity_name = f'{entity_name}_{id(entity_handle)}' entity_handles[entity_name] = entity_handle return ed.traceable(call)(*args, **kwargs)
def log_prob_tracer(rv_constructor, *args, **kwargs): nonlocal rv_index if rv_index >= num_rvs: raise RuntimeError( "function created {} random variables the first time it was called," " but created more the second time".format(num_rvs)) field_name, observed_value = observed_output_values_by_rv_order[ rv_index] rv_index += 1 rv = rv_constructor(*args, **kwargs) logp = rv.distribution.log_prob(observed_value) log_probs[field_name] = logp kwargs["value"] = observed_value # be nice to higher tracers return ed.traceable(rv_constructor)(*args, **kwargs)
def double(f, *args, **kwargs): return 2. * ed.traceable(f)(*args, **kwargs)
def set_xy(f, *args, **kwargs): if kwargs.get("name") == "x": kwargs["value"] = 1. if kwargs.get("name") == "y": kwargs["value"] = 0.42 return ed.traceable(f)(*args, **kwargs)
def trivial_tracer(fn, *args, **kwargs): # A tracer that does nothing. return ed.traceable(fn)(*args, **kwargs)