Ejemplo n.º 1
0
    def test_sample_n(self):
        tf.random.set_seed(10402302)
        X = UniformInteger(0, 10)
        x = X.sample([3, 3])

        self.assertTrue(np.testing.assert_array_equal(x, self.fixture),
                        "UniformInteger sample inconsistent")
Ejemplo n.º 2
0
    def x_star(m, t):
        """Draw num to delete"""
        with tf.name_scope("x_star"):
            if topology.next is not None:
                mask = (  # Mask out times prior to t
                    tf.cast(tf.range(events.shape[-2]) < t[0], events.dtype) *
                    events.dtype.max)
                m_events = tf.gather(events, m, axis=-3)
                m_inits = tf.gather(initial_state, m, axis=-2)
                # calc offset[target] + N_{target}(t) - N_{next} bound
                diff = m_events[..., topology.target] - m_events[...,
                                                                 topology.next]
                diff = tf.gather(m_inits, topology.next, axis=-1) + tf.cumsum(
                    diff, axis=-1)
                diff = diff + mask
                bound = tf.cast(tf.reduce_min(diff, axis=-1), dtype=tf.int32)
                # bound = tf.maximum(0, bound)
                bound = tf.minimum(n_max, bound)
            else:
                bound = tf.broadcast_to(n_max, m.shape)

            return UniformInteger(low=0,
                                  high=bound + 1,
                                  dtype=dtype,
                                  name="x_star")
Ejemplo n.º 3
0
    def x_star(t, delta_t):
        with tf.name_scope("x_star"):
            # Compute bounds
            # The limitations of XLA mean that we must calculate bounds for
            # intervals [t, t+delta_t) if delta_t > 0, and [t+delta_t, t) if
            # delta_t is < 0.
            t = t[..., tf.newaxis]
            delta_t = delta_t[..., tf.newaxis]
            bound_times = tf.where(
                delta_t < 0,
                t - time_interval - 1,
                t + time_interval  # [t+delta_t, t)
            )  # [t, t+delta_t)
            free_events = _abscumdiff(
                events=events,
                initial_state=initial_state,
                topology=topology,
                t=t,
                delta_t=delta_t,
                bound_times=bound_times,
                int_dtype=dtype,
            )

            # Mask out bits of the interval we don't need for our delta_t
            inf_mask = tf.cumsum(
                tf.one_hot(
                    tf.math.abs(delta_t[:, 0]),
                    d_max,
                    on_value=tf.constant(np.inf, events.dtype),
                    dtype=events.dtype,
                ))
            free_events = tf.maximum(inf_mask, free_events)
            free_events = tf.reduce_min(free_events, axis=-1)
            indices = tf.stack(
                [tf.range(events.shape[0], dtype=dtype), t[:, 0]], axis=-1)
            available_events = tf.gather_nd(target_events, indices)
            max_events = tf.minimum(free_events, available_events)
            max_events = tf.clip_by_value(max_events,
                                          clip_value_min=0,
                                          clip_value_max=n_max)
            # Draw x_star
            return UniformInteger(low=0, high=max_events + 1)
Ejemplo n.º 4
0
    def x_star(m, t):
        """Draw num to add bounded by counting process contraint"""
        if topology.prev is not None:
            mask = (  # Mask out times prior to t
                tf.cast(tf.range(events.shape[-2]) < t[0], events.dtype) *
                events.dtype.max)
            m_events = tf.gather(events, m, axis=-3)
            m_inits = tf.gather(initial_state, m, axis=-2)
            diff = m_events[..., topology.prev] - m_events[...,
                                                           topology.target]
            diff = tf.gather(m_inits, topology.target, axis=-1) + tf.cumsum(
                diff, axis=-1)
            diff = diff + mask
            bound = tf.cast(tf.reduce_min(diff, axis=-1), dtype=tf.int32)
            # bound = tf.maximum(0, bound)
            bound = tf.minimum(n_max, bound)
        else:
            bound = tf.broadcast_to(n_max, m.shape)

        return UniformInteger(low=0, high=bound, dtype=dtype)
Ejemplo n.º 5
0
 def test_log_prob(self):
     X = UniformInteger(0, 10)
     lp = X.log_prob(self.fixture)
     self.assertSequenceEqual(lp.shape, [3, 3])
     self.assertAlmostEqual(np.sum(lp), -20.723267, places=6)
Ejemplo n.º 6
0
 def delta_t():
     with tf.name_scope("delta_t"):
         d_max_bcast = tf.broadcast_to(d_max, [events.shape[-3]])
         return UniformInteger(low=-d_max_bcast, high=d_max_bcast + 1)
Ejemplo n.º 7
0
 def t():
     """Select a timepoint"""
     with tf.name_scope("t"):
         return UniformInteger(low=[t_range[0]],
                               high=[t_range[1]],
                               dtype=dtype)
Ejemplo n.º 8
0
 def m():
     """Select a metapopulation"""
     with tf.name_scope("m"):
         return UniformInteger(low=[0], high=[events.shape[0]], dtype=dtype)