def _get_entities(story): """Traces a story function to capture entity constructors. This function ingests a story, runs it, and captures all instantiations of descendants of the Entity class, returning a dictionary of entity_name: object handle pairs. Name conflicts are resolved by by appending the object id to entity name. Args: story: an argumentless callable which leads to the creation of objects inheriting from Entity. Returns: A list of simulation variables and a dictionary of name: handle pairs. """ entity_handles = {} def tracer(call, *args, **kwargs): if not args or not isinstance(args[0], Entity): return ed.traceable(call)(*args, **kwargs) entity_handle = args[0] entity_name = kwargs.get('name') if entity_name in entity_handles: entity_name = f'{entity_name}_{id(entity_handle)}' entity_handles[entity_name] = entity_handle return ed.traceable(call)(*args, **kwargs) with ed.trace(tracer): sim_vars = story() return sim_vars, entity_handles
def transformed_fn(observed_output, *observed_inputs): cleaned_inputs = list( map( functools.partial(data.remove_data_index, data_index_field=_OBSERVATION_INDEX_FIELD), observed_inputs)) observed_output_values_by_rv_order = _observed_output_values_by_rv_order( fn, observed_output, *cleaned_inputs) num_rvs = len(observed_output_values_by_rv_order) rv_index = 0 log_probs = {} def log_prob_tracer(rv_constructor, *args, **kwargs): nonlocal rv_index if rv_index >= num_rvs: raise RuntimeError( "function created {} random variables the first time it was called," " but created more the second time".format(num_rvs)) field_name, observed_value = observed_output_values_by_rv_order[ rv_index] rv_index += 1 rv = rv_constructor(*args, **kwargs) logp = rv.distribution.log_prob(observed_value) log_probs[field_name] = logp kwargs["value"] = observed_value # be nice to higher tracers return ed.traceable(rv_constructor)(*args, **kwargs) with ed.trace(log_prob_tracer): _ = fn(*cleaned_inputs) return Value(**log_probs)
def _observed_output_values_by_rv_order(fn, observed_output, *observed_inputs): """Returns a sequence of observed values in order of RandomVariable creation. For example, suppose `fn(x, y) = {"a": rv1, "b": rv2}` where ``` rv1 = F(x, y) rv2 = G(rv1) ``` Then, ``` _observed_values_by_rv_order(fn, {"a": ov1, "b": ov2}, x, y) = [ov1, ov2] ``` Here, `fn` constructs two `ed.RandomVariable` objects, first `rv1` and then `rv2`. Because `rv1` was constructed first, `ov1` appears first in the returned sequence. Args: fn: A `Variable` value function; see `ValueDef.fn`. observed_output: A `Value` containing field names matching the value returned by `fn`. *observed_inputs: The input arguments of `fn`. Returns: A sequence of (field_name, `observed_output`) pairs of field values corresponding to the `ed.RandomVariable` field values output by `fn`, arranged in the order in which `fn` constructs those `ed.RandomVariable` objects together with the names of the fields they get assigned to. Raises: RuntimeError: If `fn` creates any `ed.RandomVariable` objects that are not exposed as fields of its output `Value`. """ rvs_in_order_of_construction = [] def index_random_variables(rv_constructor, *args, **kwargs): rv = rv_constructor(*args, **kwargs) rvs_in_order_of_construction.append(rv) return rv with ed.trace(index_random_variables): temporary_output = fn(*observed_inputs) rv_to_output_value = { value: (field_name, observed_output.get(field_name)) for field_name, value in temporary_output.as_dict.items() if isinstance(value, ed.RandomVariable) } unobserved = [ rv for rv in rvs_in_order_of_construction if rv not in rv_to_output_value ] if unobserved: raise RuntimeError( "unobserved random variables; log-probability cannot be computed: {}" .format(unobserved)) return [rv_to_output_value[rv] for rv in rvs_in_order_of_construction]
def testTrace(self, cls, value, kwargs): def tracer(f, *fargs, **fkwargs): name = fkwargs.get("name", None) if name == "rv2": fkwargs["value"] = value return f(*fargs, **fkwargs) rv1 = cls(value=value, name="rv1", **kwargs) with ed.trace(tracer): rv2 = cls(name="rv2", **kwargs) self.assertEqual(rv1, value) self.assertEqual(rv2, value)
def testTraceForwarding(self): def double(f, *args, **kwargs): return 2. * ed.traceable(f)(*args, **kwargs) def set_xy(f, *args, **kwargs): if kwargs.get("name") == "x": kwargs["value"] = 1. if kwargs.get("name") == "y": kwargs["value"] = 0.42 return ed.traceable(f)(*args, **kwargs) def model(): x = ed.Normal(loc=0., scale=1., name="x") y = ed.Normal(loc=x, scale=1., name="y") return x + y with ed.trace(set_xy): with ed.trace(double): z = model() value = 2. * 1. + 2. * 0.42 self.assertAlmostEqual(z, value, places=5)
def model_mean_eval(self, inference_data, inference_model): def take_mean(f, *args, **kwargs): """Tracer which sets each random variable's value to its mean.""" rv = f(*args, **kwargs) rv._value = rv.distribution.mean() return rv #import tensorflow_probability.edward2 as ed from edward2 import trace with trace(take_mean): y_pred = inference_model(inference_data) #Edward Trace returns a Distribution object from the model: Output Lambda return y_pred.mean()
def testTraceNonForwarding(self): def double(f, *args, **kwargs): self.assertEqual("yes", "no") return 2. * f(*args, **kwargs) def set_xy(f, *args, **kwargs): if kwargs.get("name") == "x": kwargs["value"] = 1. if kwargs.get("name") == "y": kwargs["value"] = 0.42 return f(*args, **kwargs) def model(): x = ed.Normal(loc=0., scale=1., name="x") y = ed.Normal(loc=x, scale=1., name="y") return x + y with ed.trace(double): with ed.trace(set_xy): z = model() value = 1. + 0.42 z_value = self.evaluate(z) self.assertAlmostEqual(z_value, value, places=5)
def testTapeInnerForwarding(self): def double(f, *args, **kwargs): return 2. * ed.traceable(f)(*args, **kwargs) def model(): x = ed.Normal(loc=0., scale=1., name="x") y = ed.Normal(loc=x, scale=1., name="y") return x + y with ed.trace(double): with ed.tape() as model_tape: output = model() self.assertEqual(list(model_tape.keys()), ["x", "y"]) expected_value = model_tape["x"] + model_tape["y"] actual_value = output self.assertEqual(expected_value, actual_value)
def testTraceException(self): def f(): raise NotImplementedError() def tracer(f, *fargs, **fkwargs): return f(*fargs, **fkwargs) with ed.get_next_tracer() as top_tracer: old_tracer = top_tracer with self.assertRaises(NotImplementedError): with ed.trace(tracer): f() with ed.get_next_tracer() as top_tracer: new_tracer = top_tracer self.assertEqual(old_tracer, new_tracer)
def testTapeOuterForwarding(self): def double(f, *args, **kwargs): return 2. * ed.traceable(f)(*args, **kwargs) def model(): x = ed.Normal(loc=0., scale=1., name="x") y = ed.Normal(loc=x, scale=1., name="y") return x + y with ed.tape() as model_tape: with ed.trace(double): output = model() expected_value, actual_value = self.evaluate( [2. * model_tape["x"] + 2. * model_tape["y"], output]) self.assertEqual(list(six.iterkeys(model_tape)), ["x", "y"]) self.assertEqual(expected_value, actual_value)
def model_mean_eval(self, inference_data, inference_model): def take_mean(f, *args, **kwargs): """Tracer which sets each random variable's value to its mean.""" rv = f(*args, **kwargs) rv._value = rv.distribution.mean() return rv #import tensorflow_probability.edward2 as ed from edward2 import trace with trace(take_mean): model_outputs = inference_model(inference_data) #Edward Trace returns a Distribution object from the model: Output Lambda for regression part #Classification object is numpy array y_reg_output = model_outputs[0].mean() y_cla_output = model_outputs[1] return y_reg_output, y_cla_output
def testDenseMean(self, layer): """Tests that forward pass can use other values, e.g., posterior mean.""" tf.keras.backend.set_learning_phase(0) # test time def take_mean(f, *args, **kwargs): """Sets random variable value to its mean.""" rv = f(*args, **kwargs) rv._value = rv.distribution.mean() return rv inputs = np.random.rand(5, 3, 7).astype(np.float32) model = layer(4, activation=tf.nn.relu, use_bias=False) outputs1 = tf.convert_to_tensor(model(inputs)) with ed.trace(take_mean): outputs2 = tf.convert_to_tensor(model(inputs)) self.assertEqual(outputs1.shape, (5, 3, 4)) self.assertNotAllClose(outputs1, outputs2) if layer != ed.layers.DenseDVI: self.assertAllClose(outputs2, np.zeros((5, 3, 4)), atol=1e-4)
def transformed_model(): with ed.trace(trivial_tracer): model()