示例#1
0
def _get_entities(story):
  """Traces a story function to capture entity constructors.

  This function ingests a story, runs it, and captures all
  instantiations of descendants of the Entity class, returning a dictionary of
  entity_name: object handle pairs. Name conflicts are resolved by
  by appending the object id to entity name.

  Args:
    story: an argumentless callable which leads to the creation of objects
      inheriting from Entity.

  Returns:
    A list of simulation variables and a dictionary of name: handle pairs.
  """
  entity_handles = {}

  def tracer(call, *args, **kwargs):
    if not args or not isinstance(args[0], Entity):
      return ed.traceable(call)(*args, **kwargs)
    entity_handle = args[0]
    entity_name = kwargs.get('name')
    if entity_name in entity_handles:
      entity_name = f'{entity_name}_{id(entity_handle)}'
    entity_handles[entity_name] = entity_handle
    return ed.traceable(call)(*args, **kwargs)

  with ed.trace(tracer):
    sim_vars = story()

  return sim_vars, entity_handles
    def transformed_fn(observed_output, *observed_inputs):
        cleaned_inputs = list(
            map(
                functools.partial(data.remove_data_index,
                                  data_index_field=_OBSERVATION_INDEX_FIELD),
                observed_inputs))
        observed_output_values_by_rv_order = _observed_output_values_by_rv_order(
            fn, observed_output, *cleaned_inputs)
        num_rvs = len(observed_output_values_by_rv_order)

        rv_index = 0
        log_probs = {}

        def log_prob_tracer(rv_constructor, *args, **kwargs):
            nonlocal rv_index
            if rv_index >= num_rvs:
                raise RuntimeError(
                    "function created {} random variables the first time it was called,"
                    " but created more the second time".format(num_rvs))
            field_name, observed_value = observed_output_values_by_rv_order[
                rv_index]
            rv_index += 1
            rv = rv_constructor(*args, **kwargs)
            logp = rv.distribution.log_prob(observed_value)
            log_probs[field_name] = logp
            kwargs["value"] = observed_value
            # be nice to higher tracers
            return ed.traceable(rv_constructor)(*args, **kwargs)

        with ed.trace(log_prob_tracer):
            _ = fn(*cleaned_inputs)

        return Value(**log_probs)
def _observed_output_values_by_rv_order(fn, observed_output, *observed_inputs):
    """Returns a sequence of observed values in order of RandomVariable creation.

  For example, suppose `fn(x, y) = {"a": rv1, "b": rv2}` where
  ```
    rv1 = F(x, y)
    rv2 = G(rv1)
  ```
  Then,
  ```
    _observed_values_by_rv_order(fn, {"a": ov1, "b": ov2}, x, y) = [ov1, ov2]
  ```
  Here, `fn` constructs two `ed.RandomVariable` objects, first `rv1` and then
  `rv2`. Because `rv1` was constructed first, `ov1` appears first in the
  returned sequence.

  Args:
    fn: A `Variable` value function; see `ValueDef.fn`.
    observed_output: A `Value` containing field names matching the value
      returned by `fn`.
    *observed_inputs: The input arguments of `fn`.

  Returns:
    A sequence of (field_name, `observed_output`) pairs of field values
    corresponding to the `ed.RandomVariable` field values output by `fn`,
    arranged in the order in which `fn` constructs those `ed.RandomVariable`
    objects together with the names of the fields they get assigned to.

  Raises:
    RuntimeError: If `fn` creates any `ed.RandomVariable` objects that are not
      exposed as fields of its output `Value`.
  """

    rvs_in_order_of_construction = []

    def index_random_variables(rv_constructor, *args, **kwargs):
        rv = rv_constructor(*args, **kwargs)
        rvs_in_order_of_construction.append(rv)
        return rv

    with ed.trace(index_random_variables):
        temporary_output = fn(*observed_inputs)

    rv_to_output_value = {
        value: (field_name, observed_output.get(field_name))
        for field_name, value in temporary_output.as_dict.items()
        if isinstance(value, ed.RandomVariable)
    }
    unobserved = [
        rv for rv in rvs_in_order_of_construction
        if rv not in rv_to_output_value
    ]
    if unobserved:
        raise RuntimeError(
            "unobserved random variables; log-probability cannot be computed: {}"
            .format(unobserved))

    return [rv_to_output_value[rv] for rv in rvs_in_order_of_construction]
示例#4
0
 def testTrace(self, cls, value, kwargs):
   def tracer(f, *fargs, **fkwargs):
     name = fkwargs.get("name", None)
     if name == "rv2":
       fkwargs["value"] = value
     return f(*fargs, **fkwargs)
   rv1 = cls(value=value, name="rv1", **kwargs)
   with ed.trace(tracer):
     rv2 = cls(name="rv2", **kwargs)
   self.assertEqual(rv1, value)
   self.assertEqual(rv2, value)
示例#5
0
  def testTraceForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.traceable(f)(*args, **kwargs)

    def set_xy(f, *args, **kwargs):
      if kwargs.get("name") == "x":
        kwargs["value"] = 1.
      if kwargs.get("name") == "y":
        kwargs["value"] = 0.42
      return ed.traceable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.trace(set_xy):
      with ed.trace(double):
        z = model()

    value = 2. * 1. + 2. * 0.42
    self.assertAlmostEqual(z, value, places=5)
    def model_mean_eval(self, inference_data, inference_model):
        def take_mean(f, *args, **kwargs):
            """Tracer which sets each random variable's value to its mean."""
            rv = f(*args, **kwargs)
            rv._value = rv.distribution.mean()
            return rv

        #import tensorflow_probability.edward2 as ed
        from edward2 import trace
        with trace(take_mean):
            y_pred = inference_model(inference_data)

        #Edward Trace returns a Distribution object from the model: Output Lambda
        return y_pred.mean()
示例#7
0
    def testTraceNonForwarding(self):
        def double(f, *args, **kwargs):
            self.assertEqual("yes", "no")
            return 2. * f(*args, **kwargs)

        def set_xy(f, *args, **kwargs):
            if kwargs.get("name") == "x":
                kwargs["value"] = 1.
            if kwargs.get("name") == "y":
                kwargs["value"] = 0.42
            return f(*args, **kwargs)

        def model():
            x = ed.Normal(loc=0., scale=1., name="x")
            y = ed.Normal(loc=x, scale=1., name="y")
            return x + y

        with ed.trace(double):
            with ed.trace(set_xy):
                z = model()

        value = 1. + 0.42
        z_value = self.evaluate(z)
        self.assertAlmostEqual(z_value, value, places=5)
示例#8
0
    def testTapeInnerForwarding(self):
        def double(f, *args, **kwargs):
            return 2. * ed.traceable(f)(*args, **kwargs)

        def model():
            x = ed.Normal(loc=0., scale=1., name="x")
            y = ed.Normal(loc=x, scale=1., name="y")
            return x + y

        with ed.trace(double):
            with ed.tape() as model_tape:
                output = model()

        self.assertEqual(list(model_tape.keys()), ["x", "y"])
        expected_value = model_tape["x"] + model_tape["y"]
        actual_value = output
        self.assertEqual(expected_value, actual_value)
示例#9
0
  def testTraceException(self):
    def f():
      raise NotImplementedError()
    def tracer(f, *fargs, **fkwargs):
      return f(*fargs, **fkwargs)

    with ed.get_next_tracer() as top_tracer:
      old_tracer = top_tracer

    with self.assertRaises(NotImplementedError):
      with ed.trace(tracer):
        f()

    with ed.get_next_tracer() as top_tracer:
      new_tracer = top_tracer

    self.assertEqual(old_tracer, new_tracer)
示例#10
0
    def testTapeOuterForwarding(self):
        def double(f, *args, **kwargs):
            return 2. * ed.traceable(f)(*args, **kwargs)

        def model():
            x = ed.Normal(loc=0., scale=1., name="x")
            y = ed.Normal(loc=x, scale=1., name="y")
            return x + y

        with ed.tape() as model_tape:
            with ed.trace(double):
                output = model()

        expected_value, actual_value = self.evaluate(
            [2. * model_tape["x"] + 2. * model_tape["y"], output])
        self.assertEqual(list(six.iterkeys(model_tape)), ["x", "y"])
        self.assertEqual(expected_value, actual_value)
示例#11
0
    def model_mean_eval(self, inference_data, inference_model):
        def take_mean(f, *args, **kwargs):
            """Tracer which sets each random variable's value to its mean."""
            rv = f(*args, **kwargs)
            rv._value = rv.distribution.mean()
            return rv

        #import tensorflow_probability.edward2 as ed
        from edward2 import trace
        with trace(take_mean):
            model_outputs = inference_model(inference_data)

        #Edward Trace returns a Distribution object from the model: Output Lambda for regression part
        #Classification object is numpy array
        y_reg_output = model_outputs[0].mean()
        y_cla_output = model_outputs[1]

        return y_reg_output, y_cla_output
示例#12
0
    def testDenseMean(self, layer):
        """Tests that forward pass can use other values, e.g., posterior mean."""
        tf.keras.backend.set_learning_phase(0)  # test time

        def take_mean(f, *args, **kwargs):
            """Sets random variable value to its mean."""
            rv = f(*args, **kwargs)
            rv._value = rv.distribution.mean()
            return rv

        inputs = np.random.rand(5, 3, 7).astype(np.float32)
        model = layer(4, activation=tf.nn.relu, use_bias=False)
        outputs1 = tf.convert_to_tensor(model(inputs))
        with ed.trace(take_mean):
            outputs2 = tf.convert_to_tensor(model(inputs))
        self.assertEqual(outputs1.shape, (5, 3, 4))
        self.assertNotAllClose(outputs1, outputs2)
        if layer != ed.layers.DenseDVI:
            self.assertAllClose(outputs2, np.zeros((5, 3, 4)), atol=1e-4)
示例#13
0
 def transformed_model():
   with ed.trace(trivial_tracer):
     model()