def distort(self, image): """Applies the RandAugment policy to `image`. Args: image: `Tensor` of shape `h, w, 3` representing an image. Returns: The augmented version of `image`. """ input_image_type = image.dtype image = tf.clip_by_value( image, tf.cast(0, image.dtype), tf.cast(255, image.dtype)) image = tf.cast(image, dtype=tf.uint8) replace_value = [128] * 3 prob = tf.random.uniform([], 0.2, 0.8, tf.float32) for _ in range(self.num_layers): op_to_select = tf.random.uniform( [], minval=0, maxval=len(self.available_ops), dtype=tf.int32) branch_fns = [] for (i, op_name) in enumerate(self.available_ops): func, _, args = _parse_policy_info(op_name, prob, self.magnitude, replace_value, self.cutout_const, self.translate_const) def branch_fn(selected_func=func, selected_args=args): return tf.cond(tf.random.uniform([], 0., 1., prob.dtype) <= prob, lambda: selected_func(image, *selected_args), lambda: image) branch_fns.append((i, branch_fn)) image = tf.switch_case(branch_index=op_to_select, branch_fns=branch_fns) image = tf.cast(image, dtype=input_image_type) return image
def __init__(self, hparams, mode, features): """Create the model. Args: hparams: Hyperparameter configurations. mode: TRAIN | EVAL | INFER features: a dict of input features. """ # Set params self._set_params_initializer(hparams, mode, features) if self.mode == contrib_learn.ModeKeys.INFER: self.build_train_graph(hparams, 0) else: src_len = tf.reduce_max(self.features["source_sequence_length"]) tgt_len = tf.reduce_max(self.features["target_sequence_length"]) max_len = tf.maximum(src_len, tgt_len) max_len = tf.maximum(max_len, 12) gradients, global_norm = tf.switch_case( tf.cast((max_len - 12) / 6, tf.int32), { 0: lambda: self.build_train_graph(hparams, 18), 1: lambda: self.build_train_graph(hparams, 24), 2: lambda: self.build_train_graph(hparams, 30), 3: lambda: self.build_train_graph(hparams, 36), 4: lambda: self.build_train_graph(hparams, 42), }, default=lambda: self.build_train_graph(hparams, 0)) self.learning_rate = tf.constant(hparams.learning_rate) self.learning_rate = self._get_learning_rate_warmup(hparams) self.learning_rate = self._get_learning_rate_decay(hparams) if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) elif hparams.optimizer == "adam": opt = tf.train.AdamOptimizer(self.learning_rate) else: raise ValueError("Unknown optimizer type %s" % hparams.optimizer) gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm, global_norm) gradients = [ (tf.cast(tf.tpu.cross_replica_sum(tf.cast(g, tf.bfloat16)), tf.float32), v) for g, v in zip(gradients, tf.trainable_variables()) ] self.update = opt.apply_gradients(gradients, global_step=self.global_step)
def apply_with_random_selector(x: tf.Tensor, func: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], num_cases: int, selected: Optional[int] = None) -> tf.Tensor: """Computes func(x, sel), with sel sampled from [0...num_cases-1]. Args: x: input Tensor. func: Python function to apply. num_cases: Python int32, number of cases to sample sel from. selected: Python int32, optional value to use as the selected index. Returns: The result of func(x, sel), where func receives the value of the selector as a python integer, but sel is sampled dynamically. """ if selected is None: selected = tf.random.uniform([], maxval=num_cases, dtype=tf.int32) branches = [lambda i=case: func(x, i) for case in range(num_cases)] return tf.switch_case(selected, branches)