Пример #1
0
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      # Padding the step_size so it is compatable with the states
      step_size = self.step_size
      if len(step_size) == 1:
        step_size = step_size * len(init_state)
      if len(step_size) != len(init_state):
        raise ValueError('Expected either one step size or {} (size of '
                         '`init_state`), but found {}'.format(
                             len(init_state), len(step_size)))

      dummy_momentum = [tf.ones_like(state) for state in init_state]

      def _init(shape_and_dtype):
        """Allocate TensorArray for storing state and momentum."""
        return [  # pylint: disable=g-complex-comprehension
            ps.zeros(
                ps.concat([[max(self._write_instruction) + 1], s], axis=0),
                dtype=d) for (s, d) in shape_and_dtype
        ]

      get_shapes_and_dtypes = lambda x: [(ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                                         for x_ in x]
      momentum_state_memory = MomentumStateSwap(
          momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)),
          state_swap=_init(get_shapes_and_dtypes(init_state)))
      [
          _,
          _,
          current_target_log_prob,
          current_grads_log_prob,
      ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum,
                                     init_state)

      return NUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          momentum_state_memory=momentum_state_memory,
          step_size=step_size,
          log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                         name='log_accept_ratio'),
          leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                        dtype=TREE_COUNT_DTYPE,
                                        name='leapfrogs_taken'),
          is_accepted=tf.zeros_like(current_target_log_prob,
                                    dtype=tf.bool,
                                    name='is_accepted'),
          reach_max_depth=tf.zeros_like(current_target_log_prob,
                                        dtype=tf.bool,
                                        name='reach_max_depth'),
          has_divergence=tf.zeros_like(current_target_log_prob,
                                       dtype=tf.bool,
                                       name='has_divergence'),
          energy=compute_hamiltonian(current_target_log_prob, dummy_momentum),
          # Allow room for one_step's seed.
          seed=samplers.zeros_seed(),
      )
Пример #2
0
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      # Padding the step_size so it is compatable with the states
      step_size = self.step_size
      if len(step_size) == 1:
        step_size = step_size * len(init_state)
        self._step_size = step_size
      if len(step_size) != len(init_state):
        raise ValueError('Expected either one step size or {} (size of '
                         '`init_state`), but found {}'.format(
                             len(init_state), len(step_size)))
      dummy_momentum = [tf.ones_like(state) for state in init_state]
      [
          _,
          _,
          current_target_log_prob,
          current_grads_log_prob,
      ] = leapfrog_impl.process_args(self.target_log_prob_fn,
                                     dummy_momentum,
                                     init_state)
      batch_size = prefer_static.size(current_target_log_prob)

      return NUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          leapfrogs_computed=tf.zeros([], dtype=tf.int32,
                                      name='leapfrogs_computed'),
          is_accepted=tf.zeros([batch_size], dtype=tf.bool,
                               name='is_accepted'),
          reach_max_depth=tf.zeros([batch_size], dtype=tf.bool,
                                   name='is_accepted'),
          )
Пример #3
0
    def bootstrap_results(self, init_state):
        """Creates initial `previous_kernel_results` using a supplied `state`."""
        with tf.name_scope(self.name + '.bootstrap_results'):
            if not tf.nest.is_nested(init_state):
                init_state = [init_state]
            dummy_momentum = [tf.ones_like(state) for state in init_state]

            [
                _,
                _,
                current_target_log_prob,
                current_grads_log_prob,
            ] = leapfrog_impl.process_args(self.target_log_prob_fn,
                                           dummy_momentum, init_state)

            # Confirm that the step size is compatible with the state parts.
            _ = _prepare_step_size(self.step_size,
                                   current_target_log_prob.dtype,
                                   len(init_state))

            return NUTSKernelResults(
                target_log_prob=current_target_log_prob,
                grads_target_log_prob=current_grads_log_prob,
                step_size=tf.nest.map_structure(
                    lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                        x,
                        dtype=current_target_log_prob.dtype,
                        name='step_size'),
                    self.step_size),
                log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                               name='log_accept_ratio'),
                leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                              dtype=TREE_COUNT_DTYPE,
                                              name='leapfrogs_taken'),
                is_accepted=tf.zeros_like(current_target_log_prob,
                                          dtype=tf.bool,
                                          name='is_accepted'),
                reach_max_depth=tf.zeros_like(current_target_log_prob,
                                              dtype=tf.bool,
                                              name='reach_max_depth'),
                has_divergence=tf.zeros_like(current_target_log_prob,
                                             dtype=tf.bool,
                                             name='has_divergence'),
                energy=compute_hamiltonian(
                    current_target_log_prob,
                    dummy_momentum,
                    shard_axis_names=self.experimental_shard_axis_names),
                # Allow room for one_step's seed.
                seed=samplers.zeros_seed(),
            )
Пример #4
0
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      dummy_momentum = [tf.ones_like(state) for state in init_state]

      def _init(shape_and_dtype):
        """Allocate TensorArray for storing state and momentum."""
        return [  # pylint: disable=g-complex-comprehension
            ps.zeros(
                ps.concat([[max(self._write_instruction) + 1], s], axis=0),
                dtype=d) for (s, d) in shape_and_dtype
        ]

      get_shapes_and_dtypes = lambda x: [(ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                                         for x_ in x]
      momentum_state_memory = MomentumStateSwap(
          momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)),
          state_swap=_init(get_shapes_and_dtypes(init_state)))
      [
          _,
          _,
          current_target_log_prob,
          current_grads_log_prob,
      ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum,
                                     init_state)

      # Confirm that the step size is compatible with the state parts.
      _ = _prepare_step_size(
          self.step_size, current_target_log_prob.dtype, len(init_state))

      return NUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          momentum_state_memory=momentum_state_memory,
          step_size=tf.nest.map_structure(
              lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                  x,
                  dtype=current_target_log_prob.dtype,
                  name='step_size'),
              self.step_size),
          log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                         name='log_accept_ratio'),
          leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                        dtype=TREE_COUNT_DTYPE,
                                        name='leapfrogs_taken'),
          is_accepted=tf.zeros_like(current_target_log_prob,
                                    dtype=tf.bool,
                                    name='is_accepted'),
          reach_max_depth=tf.zeros_like(current_target_log_prob,
                                        dtype=tf.bool,
                                        name='reach_max_depth'),
          has_divergence=tf.zeros_like(current_target_log_prob,
                                       dtype=tf.bool,
                                       name='has_divergence'),
          energy=compute_hamiltonian(
              current_target_log_prob, dummy_momentum,
              shard_axis_names=self.experimental_shard_axis_names),
          # Allow room for one_step's seed.
          seed=samplers.zeros_seed(),
      )
Пример #5
0
    def bootstrap_results(self, init_state):
        """Creates initial `previous_kernel_results` using a supplied `state`."""
        with tf.name_scope(self.name + '.bootstrap_results'):
            if not tf.nest.is_nested(init_state):
                init_state = [init_state]
            # Padding the step_size so it is compatable with the states
            step_size = self.step_size
            if len(step_size) == 1:
                step_size = step_size * len(init_state)
            if len(step_size) != len(init_state):
                raise ValueError(
                    'Expected either one step size or {} (size of '
                    '`init_state`), but found {}'.format(
                        len(init_state), len(step_size)))

            dummy_momentum = [tf.ones_like(state) for state in init_state]

            def _init(shape_and_dtype):
                """Allocate TensorArray for storing state and momentum."""
                if USE_TENSORARRAY:
                    return [  # pylint: disable=g-complex-comprehension
                        tf.TensorArray(dtype=d,
                                       size=self.max_tree_depth + 1,
                                       element_shape=s,
                                       clear_after_read=False)
                        for (s, d) in shape_and_dtype
                    ]
                else:
                    return [  # pylint: disable=g-complex-comprehension
                        tf.zeros(tf.TensorShape([self.max_tree_depth + 1
                                                 ]).concatenate(s),
                                 dtype=d) for (s, d) in shape_and_dtype
                    ]

            get_shapes_and_dtypes = lambda x: [(x_.shape, x_.dtype)
                                               for x_ in x]
            momentum_state_memory = MomentumStateSwap(
                momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)),
                state_swap=_init(get_shapes_and_dtypes(init_state)))
            [
                _,
                _,
                current_target_log_prob,
                current_grads_log_prob,
            ] = leapfrog_impl.process_args(self.target_log_prob_fn,
                                           dummy_momentum, init_state)
            batch_size = prefer_static.size(current_target_log_prob)

            return NUTSKernelResults(
                target_log_prob=current_target_log_prob,
                grads_target_log_prob=current_grads_log_prob,
                momentum_state_memory=momentum_state_memory,
                step_size=step_size,
                log_accept_ratio=tf.zeros([batch_size],
                                          dtype=current_target_log_prob.dtype,
                                          name='log_accept_ratio'),
                leapfrogs_taken=tf.zeros([batch_size],
                                         dtype=TREE_COUNT_DTYPE,
                                         name='leapfrogs_taken'),
                is_accepted=tf.zeros([batch_size],
                                     dtype=tf.bool,
                                     name='is_accepted'),
                reach_max_depth=tf.zeros([batch_size],
                                         dtype=tf.bool,
                                         name='reach_max_depth'),
                has_divergence=tf.zeros([batch_size],
                                        dtype=tf.bool,
                                        name='has_divergence'),
            )