def switch_case(self, branch_selector, branch_callables, name=None): """Implements a switch (branch_selector) { case ... } construct.""" with tf.name_scope('VM.switch_case'): with _control_flow_v2(): return tf.switch_case(branch_selector, branch_callables, name=name)
def distort(self, image: tf.Tensor) -> tf.Tensor: """Applies the RandAugment policy to `image`. Args: image: `Tensor` of shape [height, width, 3] representing an image. Returns: The augmented version of `image`. """ input_image_type = image.dtype if input_image_type != tf.uint8: image = tf.clip_by_value(image, 0.0, 255.0) image = tf.cast(image, dtype=tf.uint8) replace_value = [128] * 3 min_prob, max_prob = 0.2, 0.8 for _ in range(self.num_layers): op_to_select = tf.random.uniform([], maxval=len(self.available_ops) + 1, dtype=tf.int32) branch_fns = [] for (i, op_name) in enumerate(self.available_ops): prob = tf.random.uniform([], minval=min_prob, maxval=max_prob, dtype=tf.float32) func, _, args = _parse_policy_info(op_name, prob, self.magnitude, replace_value, self.cutout_const, self.translate_const) branch_fns.append(( i, # pylint:disable=g-long-lambda lambda selected_func=func, selected_args=args: selected_func(image, *selected_args))) # pylint:enable=g-long-lambda image = tf.switch_case(branch_index=op_to_select, branch_fns=branch_fns, default=lambda: tf.identity(image)) image = tf.cast(image, dtype=input_image_type) return image
def _loop_build_sub_tree(self, directions, integrator, log_slice_sample, init_energy, iter_, energy_diff_sum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step to avoid # using tf.cond index = iter_ // 2 if USE_RAGGED_TENSOR: write_index_ = self.write_instruction[index] else: write_index_ = tf.switch_case(index, self.write_instruction) write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: momentum_state_memory = MomentumStateSwap( momentum_swap=[ old.write(write_index, new) for old, new in zip( momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ old.write(write_index, new) for old, new in zip( momentum_state_memory.state_swap, next_state_parts) ]) else: momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip(momentum_state_memory.state_swap, next_state_parts) ]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) if USE_RAGGED_TENSOR: no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions, momentum_state_memory, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, int_iter, directions, momentum_state_memory, next_momentum_parts, next_state_parts) branch_excution = { x: functools.partial(f, x) for x in range(len(self.read_instruction)) } no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-tf.random.uniform(shape=[batch_size], dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim(is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, tf.ones([batch_size], dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, energy_diff_sum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def _loop_build_sub_tree( self, direction, log_slice_sample, iter_, prev_tree_state, candidate_tree_state, continue_tree_previous, trace_arrays): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) for state in prev_tree_state.state] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[tf.where(direction, ss, -ss) for direction, ss in zip( directions_expanded, self.step_size)], num_steps=self.unrolled_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step to avoid # using tf.cond index = iter_ // 2 if USE_RAGGED_TENSOR: write_index_ = self.write_instruction[index] else: write_index_ = tf.switch_case(index, self.write_instruction) write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: trace_arrays = TraceArrays( momentum_swap=[ old.write(write_index, new) for old, new in zip(trace_arrays.momentum_swap, next_momentum_parts)], state_swap=[ old.write(write_index, new) for old, new in zip(trace_arrays.state_swap, next_state_parts)]) else: trace_arrays = TraceArrays( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( trace_arrays.momentum_swap, next_momentum_parts)], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( trace_arrays.state_swap, next_state_parts)]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) if USE_RAGGED_TENSOR: no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions_expanded, trace_arrays, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, int_iter, directions_expanded, trace_arrays, next_momentum_parts, next_state_parts) branch_excution = {x: functools.partial(f, x) for x in range(len(self.read_instruction))} no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) valid_candidate = log_slice_sample <= energy # Uniform sampling on the trajectory within the subtree sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE) weight_sum = candidate_tree_state.weight + sample_weight log_accept_thresh = tf.math.log( tf.cast(sample_weight, tf.float32) / tf.cast(weight_sum, tf.float32)) log_accept_thresh = tf.where( tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform( shape=[batch_size], dtype=tf.float32, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip(next_target_grad_parts, candidate_tree_state.target_grad_parts) ], weight=weight_sum) not_divergent = log_slice_sample - energy < self.max_energy_diff continue_tree = not_divergent & no_u_turns_within_tree continue_tree_next = continue_tree_previous & continue_tree return ( iter_ + 1, next_tree_state, next_candidate_tree_state, continue_tree_next, trace_arrays, )