def test_data_index_field(self): x = data.data_variable( name="x", spec=ValueSpec(a=FieldSpec()), data_sequence=TimeSteps(), output_fn=lambda t: Value(a=t * t)) r = PythonRuntime(network=Network(variables=[x])) self.assertEqual( r.execute(num_steps=3)["x"].as_dict, { "a": 9, data.DEFAULT_DATA_INDEX_FIELD: 3 }) y = data.data_variable( name="y", spec=ValueSpec(b=FieldSpec()), data_sequence=TimeSteps(), output_fn=lambda t: Value(b=t * t), data_index_field="twiddle_dum") r = PythonRuntime(network=Network(variables=[y])) self.assertEqual( r.execute(num_steps=3)["y"].as_dict, { "b": 9, "twiddle_dum": 3 })
def test_sliced_value(self): x = data.data_variable( name="x", spec=ValueSpec(a=FieldSpec(), b=FieldSpec()), data_sequence=SlicedValue(value=Value(a=[1, 2, 3], b=[4, 5, 6]))) r = PythonRuntime(network=Network(variables=[x])) v0 = r.execute(num_steps=0)["x"] v1 = r.execute(num_steps=1)["x"] v2 = r.execute(num_steps=2)["x"] self.assertEqual(v0.get("a"), 1) self.assertEqual(v0.get("b"), 4) self.assertEqual(v1.get("a"), 2) self.assertEqual(v1.get("b"), 5) self.assertEqual(v2.get("a"), 3) self.assertEqual(v2.get("b"), 6) x = data.data_variable( name="x", spec=ValueSpec(a=FieldSpec(), b=FieldSpec()), data_sequence=SlicedValue( value=Value(a=[1, 2, 3], b=[4, 5, 6]), slice_fn=lambda x, i: x[-1 - i])) r = PythonRuntime(network=Network(variables=[x])) v0 = r.execute(num_steps=0)["x"] v1 = r.execute(num_steps=1)["x"] v2 = r.execute(num_steps=2)["x"] self.assertEqual(v0.get("a"), 3) self.assertEqual(v0.get("b"), 6) self.assertEqual(v1.get("a"), 2) self.assertEqual(v1.get("b"), 5) self.assertEqual(v2.get("a"), 1) self.assertEqual(v2.get("b"), 4) x = data.data_variable( name="x", spec=ValueSpec(c=FieldSpec()), data_sequence=SlicedValue(value=Value(a=[1, 2, 3], b=[4, 5, 6])), output_fn=lambda val: Value(c=val.get("a") + val.get("b"))) r = PythonRuntime(network=Network(variables=[x])) self.assertEqual(r.execute(num_steps=0)["x"].get("c"), 5) self.assertEqual(r.execute(num_steps=1)["x"].get("c"), 7) self.assertEqual(r.execute(num_steps=2)["x"].get("c"), 9)
def test_tf_dataset(self, graph_compile): dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) x = data.data_variable( name="x", spec=ValueSpec(a=FieldSpec()), data_sequence=TFDataset(dataset=dataset), output_fn=lambda t: Value(a=t * t)) r = TFRuntime(network=Network(variables=[x]), graph_compile=graph_compile) self.assertEqual(r.execute(num_steps=0)["x"].get("a"), 1) self.assertEqual(r.execute(num_steps=1)["x"].get("a"), 4) self.assertEqual(r.execute(num_steps=2)["x"].get("a"), 9)
def test_time_steps(self): x = data.data_variable( name="x", spec=ValueSpec(a=FieldSpec()), data_sequence=TimeSteps(), output_fn=lambda t: Value(a=t * t)) r = PythonRuntime(network=Network(variables=[x])) self.assertEqual(r.execute(num_steps=0)["x"].get("a"), 0) self.assertEqual(r.execute(num_steps=1)["x"].get("a"), 1) self.assertEqual(r.execute(num_steps=2)["x"].get("a"), 4) self.assertEqual(r.execute(num_steps=3)["x"].get("a"), 9)
def chained_rv_test_network(self): # Creates variables to simulate the sequence # z[0] = (0., 1.) # z[t][0] = Normal(loc=z[t-1][0], scale=1) # z[t][1] = Normal(loc=z[t][0] + 1., scale=2) obs_0 = tf.constant([0., 1., 2., 3.]) obs_1 = tf.constant([1., 2., 3., 4.]) o = data.data_variable( name="o", spec=ValueSpec(a0=FieldSpec(), a1=FieldSpec()), data_sequence=data.SlicedValue(value=Value(a0=obs_0, a1=obs_1))) z = Variable(name="z", spec=ValueSpec(a0=FieldSpec(), a1=FieldSpec())) z.initial_value = variable.value(lambda: Value( a0=ed.Deterministic(loc=0.), a1=ed.Deterministic(loc=1.))) def v(prev): a0 = ed.Normal(loc=prev.get("a0"), scale=1.) a1 = ed.Normal(loc=a0 + 1., scale=2.) return Value(a0=a0, a1=a1) z.value = variable.value(v, (z.previous, )) return z, o, obs_0, obs_1
def test_log_probs_from_direct_output(self): z, _, _, _ = self.chained_rv_test_network() online_lp_vars = log_probability.log_prob_variables_from_direct_output( [z]) tf_runtime = runtime.TFRuntime(network=network_lib.Network( variables=[z] + online_lp_vars)) online_lp_traj = tf_runtime.trajectory(4) self.assertSetEqual(set(online_lp_traj.keys()), set(["z", "z_log_prob"])) o = data.data_variable( name="o", spec=ValueSpec(a0=FieldSpec(), a1=FieldSpec()), data_sequence=data.SlicedValue(value=online_lp_traj["z"])) offline_lp_vars = log_probability.log_prob_variables_from_observation( [z], [o]) tf_runtime = runtime.TFRuntime(network=network_lib.Network( variables=[o] + offline_lp_vars)) offline_lp_traj = tf_runtime.trajectory(4) self.assertAllClose(online_lp_traj["z_log_prob"].get("a0"), offline_lp_traj["z_log_prob"].get("a0")) self.assertAllClose(online_lp_traj["z_log_prob"].get("a1"), offline_lp_traj["z_log_prob"].get("a1"))
def replay_variables(variables, value_trajectory): """Trajectory replay variables for log probability computation. Given a sequence of variables and a trajectory of observed values of these variables, this function constructs a sequence of observation variables with corresponding to the simulation variables replaying their logged values. Args: variables: A sequence of `Variable`s defining a dynamic Bayesian network (DBN). value_trajectory: A trajectory generated from `TFRuntime.trajectory`. Returns: A sequence of `Variable`. """ observations = [] for var in variables: obs = data.data_variable( name=var.name + " obs", spec=var.spec, data_sequence=data.SlicedValue(value=value_trajectory[var.name]), data_index_field=_OBSERVATION_INDEX_FIELD) observations.append(obs) return observations
def test_smoke(self): o = data.data_variable(name="o", spec=ValueSpec(a=FieldSpec()), data_sequence=data.SlicedValue(value=Value( a=tf.constant([0., 1., 2., 3.])))) # This computes the log-probability of a sequence # x[0] = 0. # x[t] = Normal(loc=x[t-1], scale=1) # against the observation # o = [0., 1., 2., 3.] x = Variable(name="x", spec=ValueSpec(a=FieldSpec())) x.initial_value = variable.value( lambda: Value(a=ed.Deterministic(loc=0.))) x.value = variable.value( lambda x_prev: Value(a=ed.Normal(loc=x_prev.get("a"), scale=1.)), (x.previous, )) self.assertAllClose( 0., log_probability.log_probability(variables=[x], observation=[o], num_steps=0)) self.assertAllClose( -1.4189385, log_probability.log_probability(variables=[x], observation=[o], num_steps=1)) self.assertAllClose( -2.837877, log_probability.log_probability(variables=[x], observation=[o], num_steps=2)) self.assertAllClose( -4.2568154, log_probability.log_probability(variables=[x], observation=[o], num_steps=3)) # This is an example of a field value that is not a random variable (y.t). # This computes the log-probability of a sequence # y[t] = Normal(loc=t, scale=1) # against the observation # o = [0., 1., 2., 3.] y = data.data_variable( name="y", spec=ValueSpec(a=FieldSpec()), data_sequence=data.TimeSteps(), output_fn=lambda t: Value(a=ed.Normal(loc=float(t), scale=1.))) self.assertAllClose( -0.918939, log_probability.log_probability(variables=[y], observation=[o], num_steps=0)) self.assertAllClose( -1.837877, log_probability.log_probability(variables=[y], observation=[o], num_steps=1)) self.assertAllClose( -2.756815, log_probability.log_probability(variables=[y], observation=[o], num_steps=2)) self.assertAllClose( -3.675754, log_probability.log_probability(variables=[y], observation=[o], num_steps=3))