def route_block_pred(x, nexpert, route, expert): blocks = [x] for (c, f, p) in route: new_block = block(blocks[-1], c, f, p) blocks.append(new_block) flat = tf.reshape(blocks[-1], [1, f]) route = tf.squeeze(dense_block(flat, [f, nexpert])) pi = tf.distributions.Categorical(logits=route) train_idx = tf.squeeze(pi.sample(1)) test_idx = tf.cast(tf.argmax(route), dtype=tf.int32) experts = [] for e in range(nexpert): expert_blocks = [x] for (c, f, p) in expert: new_block = block(expert_blocks[-1], c, f, p) expert_blocks.append(new_block) flat = tf.reshape(expert_blocks[-1], [1, f]) pred = dense_block(flat, [f, 100]) experts.append(pred) branch_fns = {} for e in range(nexpert): branch_fns[e] = lambda: experts[e] train_out = tf.switch_case(branch_index=train_idx, branch_fns=branch_fns) test_out = tf.switch_case(branch_index=test_idx, branch_fns=branch_fns) return train_out, test_out, train_idx, test_idx
def switch_case(index, values, default=None, name='switch_case'): """Switch/Case over integer indices. Parameters ---------- index : () tensor_like Index to switch. values : dict(index: callable) Pairs of index/function to execute. default : callable, optional Default function to call if all branches fail. name : str, default='switch_case' Name for the operation. Returns ------- value : tensor or array """ if tf.is_tensor(index): return tf.switch_case(index, values, default, name) else: for key, item in values.items(): if index == key: return item() raise IndexError('{} is not a key from `values`.'.format(index))
def _apply_one_layer(self, image): """Applies one level of augmentation to the image.""" level = self._get_level() branch_fns = [] for augment_op_name in IMAGENET_AUG_OPS: augment_fn = augment_ops.NAME_TO_FUNC[augment_op_name] level_to_args_fn = LEVEL_TO_ARG[augment_op_name] def _branch_fn(image=image, augment_fn=augment_fn, level_to_args_fn=level_to_args_fn): args = [image] + list(level_to_args_fn(level)) return augment_fn(*args) branch_fns.append(_branch_fn) branch_index = tf.random.uniform( shape=[], maxval=len(branch_fns), dtype=tf.int32) aug_image = tf.switch_case(branch_index, branch_fns, default=lambda: image) if self.prob_to_apply is not None: return tf.cond( tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply, lambda: aug_image, lambda: image) else: return aug_image
def _apply_one_layer(self, data): """Applies one level of augmentation to the data.""" level = self._get_level() branch_fns = [] for augment_op_name in IMAGENET_AUG_OPS: augment_fn = NAME_TO_FUNC[augment_op_name] level_to_args_fn = LEVEL_TO_ARG[augment_op_name] def _branch_fn(data=data, augment_fn=augment_fn, level_to_args_fn=level_to_args_fn): args = [data] + list(level_to_args_fn(level)) fuc_args = inspect.getfullargspec(augment_fn).args if 'replace' in fuc_args and 'replace' == fuc_args[-1]: # Make sure replace is the final argument args.append(self.replace) return augment_fn(*args) branch_fns.append(_branch_fn) branch_index = tf.random.uniform(shape=[], maxval=len(branch_fns), dtype=tf.int32) aug_data = tf.switch_case(branch_index, branch_fns, default=lambda: data) if self.prob_to_apply is not None: return tf.cond( tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply, lambda: aug_data, lambda: data) else: return aug_data
def _rotate(): """Rotation. These will be rotated: image, rbox, entity_id_mask, TODO(longshangbang): rotate vertices. Returns: The rotated tensors of the above fields. """ k = tf.random.uniform([], 1, 4, dtype=tf.int32) h, w, _ = utilities.resolve_shape(data['image']) # Image rotated_img = tf.image.rot90(data['image'], k=k, name='image_rot90k') # Box rotate_box_op = functools.partial(utilities.rotate_rboxes90, rboxes=data['groundtruth_boxes'], image_width=w, image_height=h) rotated_boxes = tf.switch_case( k - 1, # Indices start with 1. branch_fns=[ lambda: rotate_box_op(rotation_count=1), lambda: rotate_box_op(rotation_count=2), lambda: rotate_box_op(rotation_count=3) ]) # Mask rotated_mask = tf.image.rot90(data['entity_id_mask'], k=k, name='mask_rot90k') return rotated_img, rotated_boxes, rotated_mask
def random_resize_pad(images, height, width): methods = { 0: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod. NEAREST_NEIGHBOR), 1: lambda: tf.image.resize_with_pad( images, height, width, method=tf.image.ResizeMethod.BICUBIC), 2: lambda: tf.image.resize_with_pad( images, height, width, method=tf.image.ResizeMethod.AREA), 3: lambda: tf.image.resize_with_pad( images, height, width, method=tf.image.ResizeMethod.LANCZOS3), 4: lambda: tf.image.resize_with_pad( images, height, width, method=tf.image.ResizeMethod.LANCZOS5), 5: lambda: tf.image.resize_with_pad( images, height, width, method=tf.image.ResizeMethod.MITCHELLCUBIC), 6: lambda: tf.image.resize_with_pad( images, height, width, method=tf.image.ResizeMethod.GAUSSIAN), } return tf.switch_case(tf.cast( tf.random.uniform([], 0, 1.0) * len(methods), tf.int32), branch_fns=methods)
def application(): Kyc() b = generate_cibil_score() upload_docs = upload_bank_statement(b) """ 650<cibil<750 :- 10 to 15k loan 750<cibil<850 :- 25 to 40k 850 <cibil :- 40k to 80 k """ tf.switch_case(b, branch_fns={ 0: case0(b, upload_docs), 1: case1(b, upload_docs), 2: case2(b, upload_docs), 3: loan_not_applicable(b) })
def switch_case(self, branch_selector, branch_callables, name=None): """Implements a switch (branch_selector) { case ... } construct.""" with tf.compat.v2.name_scope('VM.switch_case'): with _control_flow_v2(): return tf.switch_case(branch_selector, branch_callables, name=name)
def apply_with_random_selector(x: tf.Tensor, func: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], num_cases: int, seed: tf.Tensor, selected: Optional[int] = None) -> tf.Tensor: """Computes func(x, sel), with sel sampled from [0...num_cases-1]. Args: x: input Tensor. func: Python function to apply. num_cases: Python int32, number of cases to sample sel from. seed: the random seed to use. selected: Python int32, optional value to use as the selected index. Returns: The result of func(x, sel), where func receives the value of the selector as a python integer, but sel is sampled dynamically. """ if selected is None: selected = tf.random.stateless_uniform([], maxval=num_cases, dtype=tf.int32, seed=seed) branches = [lambda i=case: func(x, i) for case in range(num_cases)] return tf.switch_case(selected, branches)
def _apply_ops(self, data, op_indices, op_args, prob_to_apply=None): for idx in range(self.num_layers): op_index, op_level = op_indices[idx], op_args[idx] """Applies one augmentation op to the data.""" branch_fns = [] for augment_op_name in IMAGENET_AUG_OPS: augment_fn = NAME_TO_FUNC[augment_op_name] level_to_args_fn = LEVEL_TO_ARG[augment_op_name] def _branch_fn(data=data, augment_fn=augment_fn, level_to_args_fn=level_to_args_fn): # Add image shape kwargs for translate augment args = [data] + list( level_to_args_fn(op_level, image_size=data.shape[0])) fuc_args = inspect.getfullargspec(augment_fn).args if 'replace' in fuc_args and 'replace' == fuc_args[-1]: # Make sure replace is the final argument args.append(self.replace) return augment_fn(*args) branch_fns.append(_branch_fn) aug_data = tf.switch_case(op_index, branch_fns, default=lambda: data) if prob_to_apply is not None: data = tf.cond( tf.random.uniform(shape=[], dtype=tf.float32) < prob_to_apply, lambda: aug_data, lambda: data) else: data = aug_data return data
def call(self, x): hardwts, index = gumbel_softmax(self.alpha, tau=1.0, hard=True, return_index=True) branch_fns = [self._create_branch_fn(x, i, hardwts) for i in range(3)] return tf.switch_case(index, branch_fns)
def to_tensor(self, elem: tfds.features.FeaturesDict) -> tf.Tensor: index = tf.dtypes.cast(elem[self.label], tf.int32) fns = {key: lambda input_=input_: input_.to_tensor(elem) for key, input_ in self.mapping.items()} tensor = tf.switch_case(branch_index=index, branch_fns=fns, default=lambda: 'NONE') return tensor
def distort_each_ops(data): op_to_select = tf.random.uniform([], maxval=len(available_ops), dtype=tf.int32) # tf.print("ops ",op_to_select) ret = tf.switch_case( op_to_select, branch_fns={ 0: lambda: cls.autocontrast(data["image"]), 1: lambda: cls.equalize(data["image"]), 2: lambda: cls.invert(data["image"]), 3: lambda: cls.rotate(data["image"], *(cls._rotate_level_to_arg(level))), 4: lambda: cls.posterize(data["image"] , int((level/cls._MAX_LEVEL)*4)), 5: lambda: cls.solarize(data["image"], int((level/cls._MAX_LEVEL)* 256)), 6: lambda: cls.solarize_add(data["image"],int((level/cls._MAX_LEVEL)* 110)), 7: lambda: cls.color(data["image"],*(cls._enhance_level_to_arg(level))), 8: lambda: cls.contrast(data["image"],*(cls._enhance_level_to_arg(level))), 9: lambda: cls.brightness(data["image"], *(cls._enhance_level_to_arg(level))), 10: lambda: cls.sharpness(data["image"], *(cls._enhance_level_to_arg(level))), 11: lambda: cls.shear_x(data["image"] , *(cls._shear_level_to_arg(level))), 12: lambda: cls.shear_y(data["image"] , *(cls._shear_level_to_arg(level))), 13: lambda: cls.translate_x(data["image"], *(cls._translate_level_to_arg(level, translate_const))), 14: lambda: cls.translate_y(data["image"], *(cls._translate_level_to_arg(level, translate_const))), 15: lambda: cls.cutout(data["image"], int((level/cls._MAX_LEVEL) * cutout_const)), }, default= lambda: cls.autocontrast(data["image"], *()) ) return ret
def inference(actor_ids, run_ids, env_outputs, raw_rewards): # Reset the actors that had their first run or crashed. previous_run_ids = actor_run_ids.read(actor_ids) actor_run_ids.replace(actor_ids, run_ids) reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0] actors_needing_reset = tf.gather(actor_ids, reset_indices) if tf.not_equal(tf.shape(actors_needing_reset)[0], 0): tf.print('Actor ids needing reset:', actors_needing_reset) actor_infos.reset(actors_needing_reset) store.reset(actors_needing_reset) initial_agent_states = agent.initial_state( tf.shape(actors_needing_reset)[0]) first_agent_states.replace(actors_needing_reset, initial_agent_states) agent_states.replace(actors_needing_reset, initial_agent_states) actions.reset(actors_needing_reset) # Update steps and return. actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards)) done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0]) info_queue.enqueue_many(actor_infos.read(done_ids)) actor_infos.reset(done_ids) actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.)) # Inference. prev_actions = actions.read(actor_ids) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(actor_ids) def make_inference_fn(inference_device): def device_specific_inference_fn(): with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args)) return agent_inference(input_, prev_agent_states) return device_specific_inference_fn # Distribute the inference calls among the inference cores. branch_index = inference_iteration.assign_add(1) % len(inference_devices) agent_outputs, curr_agent_states = tf.switch_case(branch_index, { i: make_inference_fn(inference_device) for i, inference_device in enumerate(inference_devices) }) # Append the latest outputs to the unroll and insert completed unrolls in # queue. completed_ids, unrolls = store.append( actor_ids, (prev_actions, env_outputs, agent_outputs)) unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls) unroll_queue.enqueue_many(unrolls) first_agent_states.replace(completed_ids, agent_states.read(completed_ids)) # Update current state. agent_states.replace(actor_ids, curr_agent_states) actions.replace(actor_ids, agent_outputs.action) # Return environment actions to actors. return agent_outputs.action
def color_jitter_rand(self, x): def change_brightness(): return x if self.brightness == 0 else tf.image.random_brightness( x, self.brightness) def change_contrast(): return x if self.contrast[ 0] + self.contrast[1] == 0 else tf.image.random_contrast( x, self.contrast[0], self.contrast[1]) def change_saturation(): return x if self.saturation[0] + self.saturation[ 1] == 0 else tf.image.random_saturation( x, self.saturation[0], self.saturation[1]) def change_hue(): return x if self.hue == 0 else tf.image.random_hue(x, self.hue) perm = tf.random.shuffle(tf.range(4)) for i in range(4): index = tf.gather(perm, i) x = tf.switch_case(index, branch_fns={ 0: change_brightness, 1: change_contrast, 2: change_saturation, 3: change_hue }) if self.clip[0] + self.clip[1] != 0: x = tf.clip_by_value(x, self.clip[0], self.clip[1]) return x
def tuneLabel(label): def f0(): return tf.constant(0.0, dtype=tf.float16) def f1(): return tf.constant(1.0, dtype=tf.float16) def f2(): return tf.constant(2.0, dtype=tf.float16) def f3(): return tf.constant(3.0, dtype=tf.float16) def f4(): return tf.constant(4.0, dtype=tf.float16) def f5(): return tf.constant(5.5, dtype=tf.float16) return tf.switch_case(tf.cast(label,dtype=tf.int32), branch_fns=[f0, f1, f2, f3, f4, f5])
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.TypedJaxpr], linear: Sequence[bool]): del linear # tf.cond needs lambdas with no arguments. tf_branches = [functools.partial(_interpret_jaxpr, jaxpr, *operands) for jaxpr in branches] return tf.switch_case(index, tf_branches)
def _cond(index: TfVal, *operands: TfValOrUnit, branches: Sequence[core.TypedJaxpr], linear: Sequence[bool]) -> Sequence[TfValOrUnit]: del linear # tf.cond needs lambdas with no arguments. branches_tf = [functools.partial(_interpret_jaxpr, jaxpr, *operands) for jaxpr in branches] res_tf: Sequence[TfVal] = tf.switch_case(index, branches_tf) return _tfval_add_unit(res_tf, branches[0].out_avals)
def distort(self, image: tf.Tensor) -> tf.Tensor: """Applies the RandAugment policy to `image`. Args: image: `Tensor` of shape [height, width, 3] representing an image. Returns: The augmented version of `image`. """ input_image_type = image.dtype if input_image_type != tf.uint8: image = tf.clip_by_value(image, 0.0, 255.0) image = tf.cast(image, dtype=tf.uint8) replace_value = [128] * 3 min_prob, max_prob = 0.2, 0.8 aug_image = image for _ in range(self.num_layers): op_to_select = tf.random.uniform([], maxval=len(self.available_ops) + 1, dtype=tf.int32) branch_fns = [] for (i, op_name) in enumerate(self.available_ops): prob = tf.random.uniform([], minval=min_prob, maxval=max_prob, dtype=tf.float32) func, _, args = _parse_policy_info(op_name, prob, self.magnitude, replace_value, self.cutout_const, self.translate_const) branch_fns.append(( i, # pylint:disable=g-long-lambda lambda selected_func=func, selected_args=args: selected_func(image, *selected_args))) # pylint:enable=g-long-lambda aug_image = tf.switch_case(branch_index=op_to_select, branch_fns=branch_fns, default=lambda: tf.identity(image)) if self.prob_to_apply is not None: aug_image = tf.cond( tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply, lambda: tf.identity(aug_image), lambda: tf.identity(image)) image = aug_image image = tf.cast(image, dtype=input_image_type) return image
def route_block(x, nexpert, nidx, route, expert): blocks = [x] for (c, f, p) in route: new_block = block(blocks[-1], c, f, p) blocks.append(new_block) flat = tf.reshape(blocks[-1], [1, f]) dense = tf.squeeze(dense_block(flat, [f, nidx * nexpert])) route = tf.reshape(dense, [nidx, nexpert]) pis = [None] * nidx train_idx = [None] * nidx for i in range(nidx): pis[i] = tf.distributions.Categorical(logits=route[p]) train_idx[i] = tf.squeeze(pis[i].sample(1)) test_idx = tf.cast(tf.argmax(route, axis=0), dtype=tf.int32) experts = [] for e in range(nexpert): expert_blocks = [x] for (c, f, p) in expert: new_block = block(expert_blocks[-1], c, f, p) expert_blocks.append(new_block) pred = expert_blocks[-1] experts.append(pred) branch_fns = {} for e in range(nexpert): branch_fns[e] = lambda: experts[e] train_out = [None] * nidx test_out = [None] * nidx for i in range(nidx): train_out[i] = tf.switch_case(branch_index=train_idx[i], branch_fns=branch_fns) test_out[i] = tf.switch_case(branch_index=test_idx[i], branch_fns=branch_fns) train_out = tf.concat(train_out, axis=3) test_out = tf.concat(test_out, axis=3) entropy = return train_out, test_out, train_idx, test_idx
def _apply_fn(image, ii): tmp_ii = function_order[ii] image = tf.switch_case( tmp_ii, { 0: lambda: brightness_fn(image), 1: lambda: contrast_fn(image), 2: lambda: saturation_fn(image) }) ii = ii + 1 return image, ii
def rot90(image, bboxes, k=(0, 1, 2, 3)): """ Rotate image and bounding boxes counter-clockwise by random multiple of 90 degrees. :param image: 3-D Tensor of shape [height, width, channels] :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] :param k: array with multiples of 90 to choose from :return: (rotated image, rotated bounding boxes) """ with tf.name_scope("rot90"): selected_k = tf.math.floormod(tf.random.shuffle(k)[0], 4) image = tf.image.rot90(image, k=selected_k) rotate_bboxes = [ lambda: bboxes, lambda: tf.stack( [ tf.math.subtract(1.0, bboxes[:, 3]), bboxes[:, 0], tf.math.subtract(1.0, bboxes[:, 1]), bboxes[:, 2] ], axis=1, ), lambda: tf.math.subtract( 1.0, tf.stack( [ bboxes[:, 2], bboxes[:, 3], bboxes[:, 0], bboxes[:, 1], ], axis=1, ), ), lambda: tf.stack( [ bboxes[:, 1], tf.math.subtract(1.0, bboxes[:, 2]), bboxes[:, 3], tf.math.subtract(1.0, bboxes[:, 0]) ], axis=1, ), ] bboxes = tf.cond( tf.greater(tf.shape(bboxes)[0], 0), lambda: tf.switch_case(selected_k, rotate_bboxes), lambda: bboxes, ) return image, bboxes
def func(image): return tf.switch_case(i, [ lambda: tf.clip_by_value( tf.image.random_brightness(image, brightness), 0, 1), lambda: tf.clip_by_value( tf.image.random_contrast(image, 1 - contrast, 1 + contrast ), 0, 1), lambda: tf.clip_by_value( tf.image.random_saturation(image, 1 - saturation, 1 + saturation), 0, 1), lambda: tf.clip_by_value(tf.image.random_hue(image, hue), 0, 1 ), ])
def build(self, framework: str) -> None: if framework == 'tf': self.prob_fn = tfp.distributions.Uniform(low=0, high=len(self.ops)) self.invoke_fn = lambda idx, data, state: tf.switch_case( idx, [lambda: op.forward(data, state) for op in self.ops]) elif framework == 'torch': self.prob_fn = torch.distributions.uniform.Uniform(low=0, high=len( self.ops)) self.invoke_fn = lambda idx, data, state: self.ops[idx].forward( data, state) else: raise ValueError("unrecognized framework: {}".format(framework))
def inference_eval(actor_ids, run_ids, env_outputs, raw_rewards): # Reset the actors that had their first run or crashed. previous_run_ids = actor_run_ids.read(actor_ids) actor_run_ids.replace(actor_ids, run_ids) reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0] actors_needing_reset = tf.gather(actor_ids, reset_indices) if tf.not_equal(tf.shape(actors_needing_reset)[0], 0): tf.print('Actor ids needing reset:', actors_needing_reset) initial_agent_states = agent.initial_state( tf.shape(actors_needing_reset)[0]) agent_states.replace(actors_needing_reset, initial_agent_states) actions.reset(actors_needing_reset) # Inference. prev_actions = parametric_action_distribution.postprocess( actions.read(actor_ids)) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(actor_ids) def make_inference_fn(inference_device): def device_specific_inference_fn(): with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args), is_training=False, postprocess_action=False) return agent_inference(*input_, prev_agent_states) return device_specific_inference_fn # Distribute the inference calls among the inference cores. branch_index = inference_iteration_eval.assign_add(1) % len( inference_devices) agent_outputs, curr_agent_states = tf.switch_case( branch_index, { i: make_inference_fn(inference_device) for i, inference_device in enumerate(inference_devices) }) # Update current state. agent_states.replace(actor_ids, curr_agent_states) actions.replace(actor_ids, agent_outputs.action) # Return environment actions to actors. return parametric_action_distribution.postprocess(agent_outputs.action)
def augment(batch): which_aug = tf.random.uniform(shape=(), dtype=tf.dtypes.int32, minval=0, maxval=8) aug_options = { 0: lambda: batch, 1: lambda: tf.transpose(batch, perm=[0, 2, 1, 3]), 2: lambda: tf.image.flip_up_down(batch), 3: lambda: tf.image.flip_left_right(tf.image.rot90(batch, k=1)), 4: lambda: tf.image.flip_left_right(batch), 5: lambda: tf.image.rot90(batch, k=1), 6: lambda: tf.image.rot90(batch, k=2), 7: lambda: tf.image.rot90(batch, k=3), } batch = tf.switch_case(which_aug, aug_options) return batch
def combined_labels(self, nb_lesions): def f1(): return tf.constant(0) def f2(): return tf.constant(1) def f3(): return tf.constant(2) return tf.switch_case(nb_lesions, branch_fns={ 0: f1, 1: f2, 2: f2, 3: f2, 4: f2 })
def isup_to_smoothed_labels(label): label = tf.cast(label, dtype=tf.int32) # label smothing that accounts for order def case0(): return tf.constant([2/3, 2/9, 1/9, 0, 0, 0],dtype=tf.float32) def case1(): return tf.constant([1/6, 2/3, 1/9, 1/18, 0, 0],dtype=tf.float32) def case2(): return tf.constant([1/18, 1/9, 2/3, 1/9, 1/18, 0],dtype=tf.float32) def case3(): return tf.constant([0, 1/18, 1/9, 2/3, 1/9, 1/18],dtype=tf.float32) def case4(): return tf.constant([0, 0, 1/18, 1/9, 2/3, 1/6],dtype=tf.float32) def case5(): return tf.constant([0, 0, 0, 1/9, 2/9, 2/3],dtype=tf.float32) return tf.switch_case( label, { 0: case0, 1: case1, 2: case2, 3: case3, 4: case4, 5: case5 })
def apply_augmentation_op(image, op_index, op_level, prob_to_apply): """Applies one augmentation op to the image.""" branch_fns = [] for augment_op_name in IMAGENET_AUG_OPS: augment_fn = augment_ops.NAME_TO_FUNC[augment_op_name] level_to_args_fn = LEVEL_TO_ARG[augment_op_name] def _branch_fn(image=image, augment_fn=augment_fn, level_to_args_fn=level_to_args_fn): args = [image] + list(level_to_args_fn(op_level)) return augment_fn(*args) branch_fns.append(_branch_fn) aug_image = tf.switch_case(op_index, branch_fns, default=lambda: image) if prob_to_apply is not None: return tf.cond( tf.random.uniform(shape=[], dtype=tf.float32) < prob_to_apply, lambda: aug_image, lambda: image) else: return aug_image
def apply_augmentation_op(data, op_index, op_level, prob_to_apply): """Applies one augmentation op to the data.""" branch_fns = [] for augment_op_name in AUG_OPS: augment_fn = NAME_TO_FUNC[augment_op_name] level_to_args_fn = LEVEL_TO_ARG[augment_op_name] def _branch_fn(data=data, augment_fn=augment_fn, level_to_args_fn=level_to_args_fn): args = [data] + list(level_to_args_fn(op_level)) return augment_fn(*args) branch_fns.append(_branch_fn) aug_data = tf.switch_case(op_index, branch_fns, default=lambda: data) if prob_to_apply is not None: return tf.cond( tf.random.uniform(shape=[], dtype=tf.float32) < prob_to_apply, lambda: aug_data, lambda: data) else: return aug_data