Exemplo n.º 1
0
 def test_choose_from(self):
     options = [
         [{
             'value': 1.
         }, (2., 3)],
         [{
             'value': 2.
         }, (3., 4)],
         [{
             'value': 3.
         }, (4., 5)],
     ]
     first_option = util.choose_from(tf.constant(0), options)
     second_option = util.choose_from(tf.constant(1), options)
     third_option = util.choose_from(2, options)
     negative_option = util.choose_from(tf.constant(-10), options)
     large_option = util.choose_from(tf.constant(10), options)
     self.assertAllEqualNested(first_option, options[0], check_types=True)
     self.assertAllEqualNested(second_option, options[1], check_types=True)
     self.assertAllEqualNested(third_option, options[2], check_types=True)
     self.assertAllEqualNested(negative_option,
                               options[0],
                               check_types=True)
     self.assertAllEqualNested(large_option, options[2], check_types=True)
Exemplo n.º 2
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        previous_inner_results = previous_kernel_results.inner_results
        previous_step = previous_kernel_results.step
        num_adaptation_steps = tf.cast(self.num_adaptation_steps,
                                       dtype=tf.int32)
        first_window_size, slow_window_size, last_window_size = _get_window_sizes(
            num_adaptation_steps)

        def first_fast_window_update():
            dmma_results = previous_inner_results
            dassa_results = dmma_results.inner_results._replace(
                num_adaptation_steps=first_window_size + slow_window_size)
            return dmma_results._replace(
                inner_results=dassa_results,
                # Skip mass matrix adaptation.
                num_estimation_steps=tf.constant(-1, dtype=tf.int32))

        def first_slow_window_update():
            dmma_results = previous_inner_results
            # Start mass matrix adaptation.
            return dmma_results._replace(step=tf.constant(0, dtype=tf.int32),
                                         num_estimation_steps=slow_window_size)

        def slow_window_update():
            curr_slow_window_size = (previous_step - first_window_size +
                                     slow_window_size)
            # Reset mass matrix adaptation.
            dmma_results = self.inner_kernel._bootstrap_from_inner_results(  # pylint: disable=protected-access
                current_state, previous_inner_results.inner_results)
            # Reset step size adaptation.
            dassa_inner_results = self.inner_kernel.inner_kernel.step_size_setter_fn(
                dmma_results.inner_results.inner_results,
                dmma_results.inner_results.new_step_size)
            dassa_results = self.inner_kernel.inner_kernel._bootstrap_from_inner_results(  # pylint: disable=protected-access
                current_state, dassa_inner_results)
            dassa_results = dassa_results._replace(
                num_adaptation_steps=curr_slow_window_size)
            return dmma_results._replace(
                inner_results=dassa_results,
                num_estimation_steps=curr_slow_window_size)

        def last_window_update():
            dmma_results = previous_inner_results
            # Reset step size adaptation.
            dassa_inner_results = self.inner_kernel.inner_kernel.step_size_setter_fn(
                dmma_results.inner_results.inner_results,
                dmma_results.inner_results.new_step_size)
            dassa_results = self.inner_kernel.inner_kernel._bootstrap_from_inner_results(  # pylint: disable=protected-access
                current_state, dassa_inner_results)
            dassa_results = dassa_results._replace(
                num_adaptation_steps=last_window_size)
            return dmma_results._replace(inner_results=dassa_results)

        is_first_fast_window_start = tf.equal(previous_step,
                                              tf.constant(0, dtype=tf.int32))
        is_first_slow_window_start = tf.equal(previous_step, first_window_size)
        # Currently, we use 4 slow windows in the function _get_window_sizes.
        num_slow_windows = 4
        is_slow_window_start = tf.reduce_any(
            tf.equal(
                previous_step, first_window_size + slow_window_size *
                tf.constant([2**i - 1 for i in range(1, num_slow_windows)],
                            dtype=tf.int32)))
        is_last_window_start = tf.equal(
            previous_step,
            first_window_size + (2**num_slow_windows - 1) * slow_window_size)
        option = (tf.cast(is_first_fast_window_start, dtype=tf.int32) +
                  tf.cast(is_first_slow_window_start, dtype=tf.int32) * 2 +
                  tf.cast(is_slow_window_start, dtype=tf.int32) * 3 +
                  tf.cast(is_last_window_start, dtype=tf.int32) * 4)
        previous_inner_results = mcmc_util.choose_from(option, [
            previous_inner_results,
            first_fast_window_update(),
            first_slow_window_update(),
            slow_window_update(),
            last_window_update()
        ])
        new_state, new_inner_results = self.inner_kernel.one_step(
            current_state, previous_inner_results, seed=seed)
        return new_state, previous_kernel_results._replace(
            inner_results=new_inner_results, step=previous_step + 1)