def test_sample_n(self): tf.random.set_seed(10402302) X = UniformInteger(0, 10) x = X.sample([3, 3]) self.assertTrue(np.testing.assert_array_equal(x, self.fixture), "UniformInteger sample inconsistent")
def x_star(m, t): """Draw num to delete""" with tf.name_scope("x_star"): if topology.next is not None: mask = ( # Mask out times prior to t tf.cast(tf.range(events.shape[-2]) < t[0], events.dtype) * events.dtype.max) m_events = tf.gather(events, m, axis=-3) m_inits = tf.gather(initial_state, m, axis=-2) # calc offset[target] + N_{target}(t) - N_{next} bound diff = m_events[..., topology.target] - m_events[..., topology.next] diff = tf.gather(m_inits, topology.next, axis=-1) + tf.cumsum( diff, axis=-1) diff = diff + mask bound = tf.cast(tf.reduce_min(diff, axis=-1), dtype=tf.int32) # bound = tf.maximum(0, bound) bound = tf.minimum(n_max, bound) else: bound = tf.broadcast_to(n_max, m.shape) return UniformInteger(low=0, high=bound + 1, dtype=dtype, name="x_star")
def x_star(t, delta_t): with tf.name_scope("x_star"): # Compute bounds # The limitations of XLA mean that we must calculate bounds for # intervals [t, t+delta_t) if delta_t > 0, and [t+delta_t, t) if # delta_t is < 0. t = t[..., tf.newaxis] delta_t = delta_t[..., tf.newaxis] bound_times = tf.where( delta_t < 0, t - time_interval - 1, t + time_interval # [t+delta_t, t) ) # [t, t+delta_t) free_events = _abscumdiff( events=events, initial_state=initial_state, topology=topology, t=t, delta_t=delta_t, bound_times=bound_times, int_dtype=dtype, ) # Mask out bits of the interval we don't need for our delta_t inf_mask = tf.cumsum( tf.one_hot( tf.math.abs(delta_t[:, 0]), d_max, on_value=tf.constant(np.inf, events.dtype), dtype=events.dtype, )) free_events = tf.maximum(inf_mask, free_events) free_events = tf.reduce_min(free_events, axis=-1) indices = tf.stack( [tf.range(events.shape[0], dtype=dtype), t[:, 0]], axis=-1) available_events = tf.gather_nd(target_events, indices) max_events = tf.minimum(free_events, available_events) max_events = tf.clip_by_value(max_events, clip_value_min=0, clip_value_max=n_max) # Draw x_star return UniformInteger(low=0, high=max_events + 1)
def x_star(m, t): """Draw num to add bounded by counting process contraint""" if topology.prev is not None: mask = ( # Mask out times prior to t tf.cast(tf.range(events.shape[-2]) < t[0], events.dtype) * events.dtype.max) m_events = tf.gather(events, m, axis=-3) m_inits = tf.gather(initial_state, m, axis=-2) diff = m_events[..., topology.prev] - m_events[..., topology.target] diff = tf.gather(m_inits, topology.target, axis=-1) + tf.cumsum( diff, axis=-1) diff = diff + mask bound = tf.cast(tf.reduce_min(diff, axis=-1), dtype=tf.int32) # bound = tf.maximum(0, bound) bound = tf.minimum(n_max, bound) else: bound = tf.broadcast_to(n_max, m.shape) return UniformInteger(low=0, high=bound, dtype=dtype)
def test_log_prob(self): X = UniformInteger(0, 10) lp = X.log_prob(self.fixture) self.assertSequenceEqual(lp.shape, [3, 3]) self.assertAlmostEqual(np.sum(lp), -20.723267, places=6)
def delta_t(): with tf.name_scope("delta_t"): d_max_bcast = tf.broadcast_to(d_max, [events.shape[-3]]) return UniformInteger(low=-d_max_bcast, high=d_max_bcast + 1)
def t(): """Select a timepoint""" with tf.name_scope("t"): return UniformInteger(low=[t_range[0]], high=[t_range[1]], dtype=dtype)
def m(): """Select a metapopulation""" with tf.name_scope("m"): return UniformInteger(low=[0], high=[events.shape[0]], dtype=dtype)