Example #1
0
    def distort(self, image):
        """Applies the RandAugment policy to `image`.

        Args:
          image: `Tensor` of shape `h, w, 3` representing an image.

        Returns:
          The augmented version of `image`.
        """
        input_image_type = image.dtype
        image = tf.clip_by_value(
            image, tf.cast(0, image.dtype), tf.cast(255, image.dtype))
        image = tf.cast(image, dtype=tf.uint8)
        replace_value = [128] * 3
        prob = tf.random.uniform([], 0.2, 0.8, tf.float32)
        for _ in range(self.num_layers):
            op_to_select = tf.random.uniform(
                [], minval=0, maxval=len(self.available_ops), dtype=tf.int32)

            branch_fns = []
            for (i, op_name) in enumerate(self.available_ops):
                func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
                                                   replace_value, self.cutout_const,
                                                   self.translate_const)

                def branch_fn(selected_func=func, selected_args=args):
                    return tf.cond(tf.random.uniform([], 0., 1., prob.dtype) <= prob,
                                   lambda: selected_func(image, *selected_args),
                                   lambda: image)

                branch_fns.append((i, branch_fn))
            image = tf.switch_case(branch_index=op_to_select, branch_fns=branch_fns)

        image = tf.cast(image, dtype=input_image_type)
        return image
Example #2
0
    def __init__(self, hparams, mode, features):
        """Create the model.

    Args:
      hparams: Hyperparameter configurations.
      mode: TRAIN | EVAL | INFER
      features: a dict of input features.
    """
        # Set params
        self._set_params_initializer(hparams, mode, features)
        if self.mode == contrib_learn.ModeKeys.INFER:
            self.build_train_graph(hparams, 0)
        else:
            src_len = tf.reduce_max(self.features["source_sequence_length"])
            tgt_len = tf.reduce_max(self.features["target_sequence_length"])
            max_len = tf.maximum(src_len, tgt_len)
            max_len = tf.maximum(max_len, 12)
            gradients, global_norm = tf.switch_case(
                tf.cast((max_len - 12) / 6, tf.int32), {
                    0: lambda: self.build_train_graph(hparams, 18),
                    1: lambda: self.build_train_graph(hparams, 24),
                    2: lambda: self.build_train_graph(hparams, 30),
                    3: lambda: self.build_train_graph(hparams, 36),
                    4: lambda: self.build_train_graph(hparams, 42),
                },
                default=lambda: self.build_train_graph(hparams, 0))
            self.learning_rate = tf.constant(hparams.learning_rate)
            self.learning_rate = self._get_learning_rate_warmup(hparams)
            self.learning_rate = self._get_learning_rate_decay(hparams)
            if hparams.optimizer == "sgd":
                opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            elif hparams.optimizer == "adam":
                opt = tf.train.AdamOptimizer(self.learning_rate)
            else:
                raise ValueError("Unknown optimizer type %s" %
                                 hparams.optimizer)

            gradients, _ = tf.clip_by_global_norm(gradients,
                                                  hparams.max_gradient_norm,
                                                  global_norm)

            gradients = [
                (tf.cast(tf.tpu.cross_replica_sum(tf.cast(g, tf.bfloat16)),
                         tf.float32), v)
                for g, v in zip(gradients, tf.trainable_variables())
            ]

            self.update = opt.apply_gradients(gradients,
                                              global_step=self.global_step)
def apply_with_random_selector(x: tf.Tensor,
                               func: Callable[[tf.Tensor, tf.Tensor],
                                              tf.Tensor],
                               num_cases: int,
                               selected: Optional[int] = None) -> tf.Tensor:
    """Computes func(x, sel), with sel sampled from [0...num_cases-1].

  Args:
    x: input Tensor.
    func: Python function to apply.
    num_cases: Python int32, number of cases to sample sel from.
    selected: Python int32, optional value to use as the selected index.

  Returns:
    The result of func(x, sel), where func receives the value of the
    selector as a python integer, but sel is sampled dynamically.
  """
    if selected is None:
        selected = tf.random.uniform([], maxval=num_cases, dtype=tf.int32)
    branches = [lambda i=case: func(x, i) for case in range(num_cases)]
    return tf.switch_case(selected, branches)