Exemple #1
0
 def switch_case(self, branch_selector, branch_callables, name=None):
     """Implements a switch (branch_selector) { case ... } construct."""
     with tf.name_scope('VM.switch_case'):
         with _control_flow_v2():
             return tf.switch_case(branch_selector,
                                   branch_callables,
                                   name=name)
Exemple #2
0
    def distort(self, image: tf.Tensor) -> tf.Tensor:
        """Applies the RandAugment policy to `image`.

    Args:
      image: `Tensor` of shape [height, width, 3] representing an image.

    Returns:
      The augmented version of `image`.
    """
        input_image_type = image.dtype

        if input_image_type != tf.uint8:
            image = tf.clip_by_value(image, 0.0, 255.0)
            image = tf.cast(image, dtype=tf.uint8)

        replace_value = [128] * 3
        min_prob, max_prob = 0.2, 0.8

        for _ in range(self.num_layers):
            op_to_select = tf.random.uniform([],
                                             maxval=len(self.available_ops) +
                                             1,
                                             dtype=tf.int32)

            branch_fns = []
            for (i, op_name) in enumerate(self.available_ops):
                prob = tf.random.uniform([],
                                         minval=min_prob,
                                         maxval=max_prob,
                                         dtype=tf.float32)
                func, _, args = _parse_policy_info(op_name, prob,
                                                   self.magnitude,
                                                   replace_value,
                                                   self.cutout_const,
                                                   self.translate_const)
                branch_fns.append((
                    i,
                    # pylint:disable=g-long-lambda
                    lambda selected_func=func, selected_args=args:
                    selected_func(image, *selected_args)))
                # pylint:enable=g-long-lambda

            image = tf.switch_case(branch_index=op_to_select,
                                   branch_fns=branch_fns,
                                   default=lambda: tf.identity(image))

        image = tf.cast(image, dtype=input_image_type)
        return image
    def _loop_build_sub_tree(self, directions, integrator, log_slice_sample,
                             init_energy, iter_, energy_diff_sum_previous,
                             leapfrogs_taken, prev_tree_state,
                             candidate_tree_state, continue_tree_previous,
                             not_divergent_previous, momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            # Save state and momentum at odd step, check U turn at even step.
            # Note that here we also write to a Placeholder at even step to avoid
            # using tf.cond
            index = iter_ // 2
            if USE_RAGGED_TENSOR:
                write_index_ = self.write_instruction[index]
            else:
                write_index_ = tf.switch_case(index, self.write_instruction)

            write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_,
                                   self.max_tree_depth)

            if USE_TENSORARRAY:
                momentum_state_memory = MomentumStateSwap(
                    momentum_swap=[
                        old.write(write_index, new) for old, new in zip(
                            momentum_state_memory.momentum_swap,
                            next_momentum_parts)
                    ],
                    state_swap=[
                        old.write(write_index, new) for old, new in zip(
                            momentum_state_memory.state_swap, next_state_parts)
                    ])
            else:
                momentum_state_memory = MomentumStateSwap(
                    momentum_swap=[
                        tf.tensor_scatter_nd_update(old, [[write_index]],
                                                    [new])
                        for old, new in zip(
                            momentum_state_memory.momentum_swap,
                            next_momentum_parts)
                    ],
                    state_swap=[
                        tf.tensor_scatter_nd_update(old, [[write_index]],
                                                    [new]) for old, new in
                        zip(momentum_state_memory.state_swap, next_state_parts)
                    ])
            batch_size = prefer_static.size(next_target)
            has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool)

            if USE_RAGGED_TENSOR:
                no_u_turns_within_tree = tf.cond(
                    tf.equal(iter_ % 2, 0),
                    lambda: has_not_u_turn_at_even_step,
                    lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                        self.read_instruction, iter_ // 2, directions,
                        momentum_state_memory, next_momentum_parts,
                        next_state_parts))
            else:
                f = lambda int_iter: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                    self.read_instruction, int_iter, directions,
                    momentum_state_memory, next_momentum_parts,
                    next_state_parts)
                branch_excution = {
                    x: functools.partial(f, x)
                    for x in range(len(self.read_instruction))
                }
                no_u_turns_within_tree = tf.cond(
                    tf.equal(iter_ % 2,
                             0), lambda: has_not_u_turn_at_even_step,
                    lambda: tf.switch_case(iter_ // 2, branch_excution))

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            energy = tf.where(tf.math.is_nan(energy),
                              tf.constant(-np.inf, dtype=energy.dtype), energy)
            energy_diff = energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=[batch_size],
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(is_sample_accepted,
                                                     prefer_static.rank(s0)),
                        s0, s1) for s0, s1 in zip(next_state_parts,
                                                  candidate_tree_state.state)
                ],
                target=tf.where(is_sample_accepted, next_target,
                                candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(
                            is_sample_accepted, prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                tf.ones([batch_size], dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.)
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Exemple #4
0
  def _loop_build_sub_tree(
      self, direction, log_slice_sample,
      iter_, prev_tree_state, candidate_tree_state,
      continue_tree_previous, trace_arrays):
    """Base case in tree doubling."""
    with tf.name_scope('loop_build_sub_tree'):
      # Take one leapfrog step in the direction v and check divergence
      directions_expanded = [
          _expand_dims_under_batch_dim(direction, prefer_static.rank(state))
          for state in prev_tree_state.state]
      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn,
          step_sizes=[tf.where(direction, ss, -ss)
                      for direction, ss in zip(
                          directions_expanded, self.step_size)],
          num_steps=self.unrolled_leapfrog_steps)
      [
          next_momentum_parts,
          next_state_parts,
          next_target,
          next_target_grad_parts
      ] = integrator(prev_tree_state.momentum,
                     prev_tree_state.state,
                     prev_tree_state.target,
                     prev_tree_state.target_grad_parts)

      next_tree_state = TreeDoublingState(
          momentum=next_momentum_parts,
          state=next_state_parts,
          target=next_target,
          target_grad_parts=next_target_grad_parts)

      # Save state and momentum at odd step, check U turn at even step.
      # Note that here we also write to a Placeholder at even step to avoid
      # using tf.cond
      index = iter_ // 2
      if USE_RAGGED_TENSOR:
        write_index_ = self.write_instruction[index]
      else:
        write_index_ = tf.switch_case(index, self.write_instruction)

      write_index = tf.where(tf.equal(iter_ % 2, 0),
                             write_index_, self.max_tree_depth)

      if USE_TENSORARRAY:
        trace_arrays = TraceArrays(
            momentum_swap=[
                old.write(write_index, new) for old, new in
                zip(trace_arrays.momentum_swap, next_momentum_parts)],
            state_swap=[
                old.write(write_index, new) for old, new in
                zip(trace_arrays.state_swap, next_state_parts)])
      else:
        trace_arrays = TraceArrays(
            momentum_swap=[
                tf.tensor_scatter_nd_update(old, [[write_index]], [new])
                for old, new in zip(
                    trace_arrays.momentum_swap, next_momentum_parts)],
            state_swap=[
                tf.tensor_scatter_nd_update(old, [[write_index]], [new])
                for old, new in zip(
                    trace_arrays.state_swap, next_state_parts)])
      batch_size = prefer_static.size(next_target)
      has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool)

      if USE_RAGGED_TENSOR:
        no_u_turns_within_tree = tf.cond(
            tf.equal(iter_ % 2, 0),
            lambda: has_not_u_turn_at_even_step,
            lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                self.read_instruction, iter_ // 2, directions_expanded,
                trace_arrays, next_momentum_parts, next_state_parts))
      else:
        f = lambda int_iter: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
            self.read_instruction, int_iter, directions_expanded, trace_arrays,
            next_momentum_parts, next_state_parts)
        branch_excution = {x: functools.partial(f, x)
                           for x in range(len(self.read_instruction))}
        no_u_turns_within_tree = tf.cond(
            tf.equal(iter_ % 2, 0),
            lambda: has_not_u_turn_at_even_step,
            lambda: tf.switch_case(iter_ // 2, branch_excution))

      energy = compute_hamiltonian(next_target, next_momentum_parts)
      valid_candidate = log_slice_sample <= energy

      # Uniform sampling on the trajectory within the subtree
      sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE)
      weight_sum = candidate_tree_state.weight + sample_weight
      log_accept_thresh = tf.math.log(
          tf.cast(sample_weight, tf.float32) /
          tf.cast(weight_sum, tf.float32))
      log_accept_thresh = tf.where(
          tf.math.is_nan(log_accept_thresh),
          tf.zeros([], log_accept_thresh.dtype),
          log_accept_thresh)
      u = tf.math.log1p(-tf.random.uniform(
          shape=[batch_size],
          dtype=tf.float32,
          seed=self._seed_stream()))
      is_sample_accepted = u <= log_accept_thresh

      next_candidate_tree_state = TreeDoublingStateCandidate(
          state=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  _expand_dims_under_batch_dim(
                      is_sample_accepted, prefer_static.rank(s0)), s0, s1)
              for s0, s1 in zip(next_state_parts,
                                candidate_tree_state.state)
          ],
          target=tf.where(is_sample_accepted,
                          next_target,
                          candidate_tree_state.target),
          target_grad_parts=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  _expand_dims_under_batch_dim(
                      is_sample_accepted, prefer_static.rank(grad0)),
                  grad0, grad1)
              for grad0, grad1 in zip(next_target_grad_parts,
                                      candidate_tree_state.target_grad_parts)
          ],
          weight=weight_sum)

      not_divergent = log_slice_sample - energy < self.max_energy_diff
      continue_tree = not_divergent & no_u_turns_within_tree
      continue_tree_next = continue_tree_previous & continue_tree

      return (
          iter_ + 1,
          next_tree_state,
          next_candidate_tree_state,
          continue_tree_next,
          trace_arrays,
      )