def _decode_areas(parsed_tensors):
    xmin = parsed_tensors['image/object/bbox/xmin']
    xmax = parsed_tensors['image/object/bbox/xmax']
    ymin = parsed_tensors['image/object/bbox/ymin']
    ymax = parsed_tensors['image/object/bbox/ymax']
    return tf.cond(
        tf.greater(tf.shape(parsed_tensors['image/object/area'])[0],
                   0), lambda: parsed_tensors['image/object/area'], lambda:
        (xmax - xmin) * (ymax - ymin))
示例#2
0
def _stability_limit_tensor(total_count, dtype):
  limit = tf.cast(BATES_TOTAL_COUNT_STABILITY_LIMITS[dtype], dtype)
  return tf.cond(
      tf.math.reduce_any(total_count > limit),
      # pylint: disable=g-long-lambda
      lambda: tf.print(
          'WARNING: Bates PDF/CDF is unstable for `total_count` >', limit,
          output_stream=sys.stderr),
      tf.no_op)
示例#3
0
def cond(pred, true_fn, false_fn):
  """A version of tf.cond that tries to evaluate the condition."""
  v = get_static_value(pred)
  if v is None:
    return tf.cond(pred, true_fn, false_fn)
  if v:
    return true_fn()
  else:
    return false_fn()
示例#4
0
  def update_step(self, gradient, variable):
    """Update step given gradient and the associated model variable."""
    var_dtype = variable.dtype
    lr = tf.cast(self.learning_rate, var_dtype)
    local_step = tf.cast(self.iterations + 1, var_dtype)
    next_step = tf.cast(self.iterations + 2, var_dtype)
    decay = tf.cast(0.96, var_dtype)
    beta_1 = tf.cast(self.beta_1, var_dtype)
    beta_2 = tf.cast(self.beta_2, var_dtype)
    u_t = beta_1 * (1. - 0.5 * (tf.pow(decay, local_step)))
    u_t_1 = beta_1 * (1. - 0.5 * (tf.pow(decay, next_step)))
    def get_cached_u_product():
      return self._u_product

    def compute_new_u_product():
      u_product_t = self._u_product * u_t
      self._u_product.assign(u_product_t)
      self._u_product_counter += 1
      return u_product_t

    u_product_t = tf.cond(
        self._u_product_counter == (self.iterations + 2),
        true_fn=get_cached_u_product,
        false_fn=compute_new_u_product)
    u_product_t_1 = u_product_t * u_t_1
    beta_2_power = tf.pow(beta_2, local_step)

    var_key = self._var_key(variable)
    m = self._momentums[self._index_dict[var_key]]
    v = self._velocities[self._index_dict[var_key]]

    if isinstance(gradient, tf.IndexedSlices):
      # Sparse gradients.
      m.assign_add(-m * (1 - beta_1))
      m.scatter_add(
          tf.IndexedSlices(gradient.values * (1 - beta_1),
                           gradient.indices))
      v.assign_add(-v * (1 - beta_2))
      v.scatter_add(
          tf.IndexedSlices(
              tf.square(gradient.values) * (1 - beta_2), gradient.indices))
      m_hat = (
          u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient /
          (1 - u_product_t))
      v_hat = v / (1 - beta_2_power)

      variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon))
    else:
      # Dense gradients.
      m.assign_add((gradient - m) * (1 - beta_1))
      v.assign_add((tf.square(gradient) - v) * (1 - beta_2))
      m_hat = (
          u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient /
          (1 - u_product_t))
      v_hat = v / (1 - beta_2_power)

      variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon))
示例#5
0
def _do_scale(image, size):
  """Rescale the image by scaling the smaller spatial dimension to `size`."""
  shape = tf.cast(tf.shape(image), tf.float32)
  w_greater = tf.greater(shape[0], shape[1])
  shape = tf.cond(w_greater,
                  lambda: tf.cast([shape[0] / shape[1] * size, size], tf.int32),
                  lambda: tf.cast([size, shape[1] / shape[0] * size], tf.int32))

  return tf.image.resize([image], shape, method='bicubic')[0]
示例#6
0
def process_source_id(source_id):
    """Processes source_id to the right format."""
    if source_id.dtype == tf.string:
        source_id = tf.cast(tf.strings.to_number(source_id), tf.int64)
    with tf.control_dependencies([source_id]):
        source_id = tf.cond(pred=tf.equal(tf.size(input=source_id), 0),
                            true_fn=lambda: tf.cast(tf.constant(-1), tf.int64),
                            false_fn=lambda: tf.identity(source_id))
    return source_id
  def decode(self, serialized_example):
    """Decode the serialized example.

    Args:
      serialized_example: a single serialized tf.Example string.

    Returns:
      decoded_tensors: a dictionary of tensors with the following fields:
        - image: a uint8 tensor of shape [None, None, 3].
        - source_id: a string scalar tensor.
        - height: an integer scalar tensor.
        - width: an integer scalar tensor.
        - groundtruth_classes: a int64 tensor of shape [None].
        - groundtruth_is_crowd: a bool tensor of shape [None].
        - groundtruth_area: a float32 tensor of shape [None].
        - groundtruth_boxes: a float32 tensor of shape [None, 4].
        - groundtruth_instance_masks: a float32 tensor of shape
            [None, None, None].
        - groundtruth_instance_masks_png: a string tensor of shape [None].
    """
    parsed_tensors = tf.io.parse_single_example(
        serialized=serialized_example, features=self._keys_to_features)
    for k in parsed_tensors:
      if isinstance(parsed_tensors[k], tf.SparseTensor):
        if parsed_tensors[k].dtype == tf.string:
          parsed_tensors[k] = tf.sparse.to_dense(
              parsed_tensors[k], default_value='')
        else:
          parsed_tensors[k] = tf.sparse.to_dense(
              parsed_tensors[k], default_value=0)

    image = self._decode_image(parsed_tensors)
    boxes = self._decode_boxes(parsed_tensors)
    areas = self._decode_areas(parsed_tensors)
    is_crowds = tf.cond(
        tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
        lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
        lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool))  # pylint: disable=line-too-long
    if self._include_mask:
      masks = self._decode_masks(parsed_tensors)

    decoded_tensors = {
        'image': image,
        'source_id': parsed_tensors['image/source_id'],
        'height': parsed_tensors['image/height'],
        'width': parsed_tensors['image/width'],
        'groundtruth_classes': parsed_tensors['image/object/class/label'],
        'groundtruth_is_crowd': is_crowds,
        'groundtruth_area': areas,
        'groundtruth_boxes': boxes,
    }
    if self._include_mask:
      decoded_tensors.update({
          'groundtruth_instance_masks': masks,
          'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
      })
    return decoded_tensors
        def body_fn(i, written_count, current_vol, current_log_spot, vol_paths,
                    log_spot_paths):
            """Simulate Heston process to the next time point."""
            time_step = dt[i]

            if normal_draws is None:
                normals = random.mv_normal_sample(
                    (num_samples, ),
                    mean=tf.zeros([2], dtype=mean_reversion.dtype),
                    seed=seed)
            else:
                normals = normal_draws[i]

            def _next_vol_fn():
                return _update_variance(mean_reversion[i], theta[i], volvol[i],
                                        rho[i], current_vol, time_step,
                                        normals[..., 0])

            # Do not update variance if `time_step > tolerance`
            next_vol = tf.cond(time_step > tolerance, _next_vol_fn,
                               lambda: current_vol)

            def _next_log_spot_fn():
                return _update_log_spot(mean_reversion[i], theta[i], volvol[i],
                                        rho[i], current_vol, next_vol,
                                        current_log_spot, time_step,
                                        normals[..., 1])

            # Do not update state if `time_step > tolerance`
            next_log_spot = tf.cond(time_step > tolerance, _next_log_spot_fn,
                                    lambda: current_log_spot)

            if record_samples:
                # Update volatility paths
                vol_paths = vol_paths.write(written_count, next_vol)
                # Update log-spot paths
                log_spot_paths = log_spot_paths.write(written_count,
                                                      next_log_spot)
            else:
                vol_paths = next_vol
                log_spot_paths = next_log_spot
            written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
            return (i + 1, written_count, next_vol, next_log_spot, vol_paths,
                    log_spot_paths)
示例#9
0
 def body_fn(i, written_count, current_vol, current_log_spot, vol_paths,
             log_spot_paths):
   """Simulate Heston process to the next time point."""
   time_step = dt[i]
   if normal_draws is None:
     normals = random.mv_normal_sample(
         (num_samples,),
         mean=tf.zeros([3], dtype=kappa.dtype), seed=seed)
   else:
     normals = normal_draws[i]
   def _next_vol_fn():
     return _update_variance(
         kappa[i], theta[i], epsilon[i], rho[i],
         current_vol, time_step, normals[..., :2])
   # Do not update variance if `time_step > tolerance`
   next_vol = tf.cond(time_step > tolerance,
                      _next_vol_fn,
                      lambda: current_vol)
   def _next_log_spot_fn():
     return _update_log_spot(
         kappa[i], theta[i], epsilon[i], rho[i],
         current_vol, next_vol, current_log_spot, time_step,
         normals[..., -1])
   # Do not update state if `time_step > tolerance`
   next_log_spot = tf.cond(time_step > tolerance,
                           _next_log_spot_fn,
                           lambda: current_log_spot)
   # Update volatility paths
   vol_paths = utils.maybe_update_along_axis(
       tensor=vol_paths,
       do_update=keep_mask[i + 1],
       ind=written_count,
       axis=1,
       new_tensor=tf.expand_dims(next_vol, axis=1))
   # Update log-spot paths
   log_spot_paths = utils.maybe_update_along_axis(
       tensor=log_spot_paths,
       do_update=keep_mask[i + 1],
       ind=written_count,
       axis=1,
       new_tensor=tf.expand_dims(next_log_spot, axis=1))
   written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
   return (i + 1, written_count,
           next_vol, next_log_spot, vol_paths, log_spot_paths)
示例#10
0
    def _get_reset_state(self, observation, done, default_state):
        """Resets the state wherever marked in `done` tensor.

    Consider the following example with num_timesteps=2, batch_size=3,
    state_size=1:
      default_state (batch_size, state_size) = [[5.], [5.], [5.]]
      done (num_timesteps, batch_size) = [[True, True, False],
                                          [False, True, False]]
      observation (num_timesteps, batch_size, 1) = [[[1.], [2.], [3.]],
                                                    [[4.], [5.], [6.]]]
      self.get_initial_state implements `observation + 10`.
    then returned tensor will be of shape (num_timesteps, batch_size,
    state_size) and its value will be:
      [[[11.], [12.], [0.]],
       [[0.],  [15.], [0.]]]
    where state values are replaced by call to `self.get_initial_state` wherever
    done=True. Note that the state values where done=False are set to zeros and
    are expected not to be used by the caller.

    Args:
      observation: A nested structure with individual tensors that have first
        two dimensions equal to [num_timesteps, batch_size].
      done: A boolean tensor of shape  [num_timesteps, batch_size].
      default_state: A tensor or nested structure with individual tensors that
        have first dimension equal to batch_size and no time dimension.

    Returns:
      A structure similar to `default_state` except that all tensors in the
      returned structure have an additional leading dimension equal to
      num_timesteps.
    """
        reset_indices = tf.compat.v1.where(tf.equal(done, True))

        def _get_reset_state_indices():
            reset_indices_obs = tf.nest.map_structure(
                lambda t: tf.gather_nd(t, reset_indices), observation)
            # shape: [num_indices_to_reset, ...]
            reset_indices_state = self.get_initial_state(
                reset_indices_obs, batch_size=tf.shape(reset_indices)[0])
            # Scatter tensors in `reset_indices_state` to shape: [num_timesteps,
            # batch_size, ...]
            return tf.nest.map_structure(
                lambda reset_tensor: tf.scatter_nd(indices=reset_indices,
                                                   updates=reset_tensor,
                                                   shape=done.shape.as_list() +
                                                   reset_tensor.shape.as_list(
                                                   )[1:]), reset_indices_state)

        # A minor optimization wherein if all elements in `done` are False, we
        # simply return a structure with zeros tensors of correct shape.
        return tf.cond(
            tf.greater(tf.size(reset_indices), 0), _get_reset_state_indices,
            lambda: tf.nest.map_structure(
                lambda t: tf.zeros(shape=done.shape.as_list() + t.shape.
                                   as_list()[1:],
                                   dtype=t.dtype), default_state))
示例#11
0
    def state_y(self, t):
        """Computes the state variable `y(t)` for tha Gaussian HJM Model.

    For Gaussian HJM model, the state parameter y(t), can be analytically
    computed as follows:

    y_ij(t) = exp(-k_i * t) * exp(-k_j * t) * (
              int_0^t rho_ij * sigma_i(u) * sigma_j(u) * du)

    Args:
      t: A rank 1 real `Tensor` of shape `[num_times]` specifying the time `t`.

    Returns:
      A real `Tensor` of shape [self._factors, self._factors, num_times]
      containing the computed y_ij(t).
    """
        t = tf.convert_to_tensor(t, dtype=self._dtype)
        t_shape = tf.shape(t)
        t = tf.broadcast_to(t, tf.concat([[self._dim], t_shape], axis=0))
        time_index = tf.searchsorted(self._jump_locations, t)
        # create a matrix k2(i,j) = k(i) + k(j)
        mr2 = tf.expand_dims(self._mean_reversion, axis=-1)
        # Add a dimension corresponding to `num_times`
        mr2 = tf.expand_dims(mr2 + tf.transpose(mr2), axis=-1)

        def _integrate_volatility_squared(vol, l_limit, u_limit):
            # create sigma2_ij = sigma_i * sigma_j
            vol = tf.expand_dims(vol, axis=-2)
            vol_squared = tf.expand_dims(
                self._rho, axis=-1) * (vol * tf.transpose(vol, perm=[1, 0, 2]))
            return vol_squared / mr2 * (tf.math.exp(mr2 * u_limit) -
                                        tf.math.exp(mr2 * l_limit))

        is_constant_vol = tf.math.equal(tf.shape(self._jump_values_vol)[-1], 0)
        v_squared_between_vol_knots = tf.cond(
            is_constant_vol,
            lambda: tf.zeros(shape=(self._dim, self._dim, 0),
                             dtype=self._dtype),
            lambda: _integrate_volatility_squared(  # pylint: disable=g-long-lambda
                self._jump_values_vol, self._padded_knots, self._jump_locations
            ))
        v_squared_at_vol_knots = tf.concat([
            tf.zeros((self._dim, self._dim, 1), dtype=self._dtype),
            utils.cumsum_using_matvec(v_squared_between_vol_knots)
        ],
                                           axis=-1)

        vn = tf.concat([self._zero_padding, self._jump_locations], axis=1)

        v_squared_t = _integrate_volatility_squared(
            self._volatility(t), tf.gather(vn, time_index, batch_dims=1), t)
        v_squared_t += tf.gather(v_squared_at_vol_knots,
                                 time_index,
                                 batch_dims=-1)

        return tf.math.exp(-mr2 * t) * v_squared_t
示例#12
0
def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float):
    """Apply `func` to image w/ `args` as input with probability `prob`."""
    assert isinstance(args, tuple)

    # Apply the function with probability `prob`.
    should_apply_op = tf.cast(
        tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
    augmented_image = tf.cond(should_apply_op, lambda: func(image, *args),
                              lambda: image)
    return augmented_image
示例#13
0
    def alpha(self, value):
        value = tf.convert_to_tensor(value, self.dtype)

        def get_logit_alpha():
            a = tf.clip_by_value(value / 4., 0., 1.)
            logit_alpha = tf.math.log(a / (1. - a))
            return logit_alpha

        self._logit_alpha.assign(
            tf.cond(value < 0, lambda: self._logit_alpha, get_logit_alpha))
示例#14
0
def random_apply(func, p, x):
    """Randomly apply function func to x with probability p."""
    return tf.cond(
        tf.less(
            tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
            tf.cast(p, tf.float32),
        ),
        lambda: func(x),
        lambda: x,
    )
示例#15
0
  def update_state(self, labels, probabilities, **kwargs):
    """Updates this metric.

    This will flatten the labels and probabilities, and then compute the ECE
    over all predictions.

    Args:
      labels: Tensor of shape [..., ] of class labels in [0, k-1].
      probabilities: Tensor of shape [..., ], [..., 1] or [..., k] of normalized
        probabilities associated with the True class in the binary case, or with
        each of k classes in the multiclass case.
      **kwargs: Other potential keywords, which will be ignored by this method.
    """
    del kwargs  # unused
    labels = tf.convert_to_tensor(labels)
    probabilities = tf.cast(probabilities, self.dtype)

    # Flatten labels to [N, ] and probabilities to [N, 1] or [N, k].
    if tf.rank(labels) != 1:
      labels = tf.reshape(labels, [-1])
    if tf.rank(probabilities) != 2 or (tf.shape(probabilities)[0] !=
                                       tf.shape(labels)[0]):
      probabilities = tf.reshape(probabilities, [tf.shape(labels)[0], -1])
    # Extend any probabilities of shape [N, 1] to shape [N, 2].
    # NOTE: XLA does not allow for different shapes in the branches of a
    # conditional statement. Therefore, explicit indexing is used.
    given_k = tf.shape(probabilities)[-1]
    k = tf.math.maximum(2, given_k)
    probabilities = tf.cond(
        given_k < 2,
        lambda: tf.concat([1. - probabilities, probabilities], axis=-1)[:, -k:],
        lambda: probabilities)

    pred_labels = tf.math.argmax(probabilities, axis=-1)
    pred_probs = tf.math.reduce_max(probabilities, axis=-1)
    correct_preds = tf.math.equal(pred_labels,
                                  tf.cast(labels, pred_labels.dtype))
    correct_preds = tf.cast(correct_preds, self.dtype)

    bin_indices = tf.histogram_fixed_width_bins(
        pred_probs, tf.constant([0., 1.], self.dtype), nbins=self.num_bins)
    batch_correct_sums = tf.math.unsorted_segment_sum(
        data=tf.cast(correct_preds, self.dtype),
        segment_ids=bin_indices,
        num_segments=self.num_bins)
    batch_prob_sums = tf.math.unsorted_segment_sum(data=pred_probs,
                                                   segment_ids=bin_indices,
                                                   num_segments=self.num_bins)
    batch_counts = tf.math.unsorted_segment_sum(data=tf.ones_like(bin_indices),
                                                segment_ids=bin_indices,
                                                num_segments=self.num_bins)
    batch_counts = tf.cast(batch_counts, self.dtype)
    self.correct_sums.assign_add(batch_correct_sums)
    self.prob_sums.assign_add(batch_prob_sums)
    self.counts.assign_add(batch_counts)
示例#16
0
        def update_if_finite_grads():
            """Update assuming the gradients are finite."""
            def incr_loss_scale():
                new_loss_scale = self.current_loss_scale * self.multiplier
                return tf.group(
                    _assign_if_finite(self.current_loss_scale, new_loss_scale),
                    self.counter.assign(0))

            return tf.cond(
                self.counter + 1 >= self.growth_steps, incr_loss_scale,
                lambda: _op_in_graph_mode(self.counter.assign_add(1)))
示例#17
0
def apply_randomization(features, label, randomize_prob):
  """Randomize each categorical feature with some probability."""
  rnd_tok = lambda: tf.as_string(tf.random.uniform([], 0, 99999999, tf.int32))

  for idx in CAT_FEATURE_INDICES:
    key = feature_name(idx)
    # Ignore lint since tf.cond should evaluate lambda immediately.
    features[key] = tf.cond(tf.random.uniform([]) < randomize_prob,
                            rnd_tok,
                            lambda: features[key])  # pylint: disable=cell-var-from-loop
  return features, label
        def body_fn(i, written_count, current_vol, current_log_spot, vol_paths,
                    log_spot_paths):
            """Simulate Heston process to the next time point."""
            time_step = dt[i]
            if normal_draws is None:
                normals = random.mv_normal_sample(
                    (num_samples, ),
                    mean=tf.zeros([3], dtype=kappa.dtype),
                    seed=seed)
            else:
                normals = normal_draws[i]

            def _next_vol_fn():
                return _update_variance(kappa[i], theta[i], epsilon[i], rho[i],
                                        current_vol, time_step,
                                        normals[..., :2])

            # Do not update variance if `time_step > tolerance`
            next_vol = tf.cond(time_step > tolerance, _next_vol_fn,
                               lambda: current_vol)

            def _next_log_spot_fn():
                return _update_log_spot(kappa[i], theta[i], epsilon[i], rho[i],
                                        current_vol, next_vol,
                                        current_log_spot, time_step,
                                        normals[..., -1])

            # Do not update state if `time_step > tolerance`
            next_log_spot = tf.cond(time_step > tolerance, _next_log_spot_fn,
                                    lambda: current_log_spot)
            vol_paths = tf.cond(
                keep_mask[i + 1],
                lambda: vol_paths.write(written_count, next_vol),
                lambda: vol_paths)
            log_spot_paths = tf.cond(
                keep_mask[i + 1],
                lambda: log_spot_paths.write(written_count, next_log_spot),
                lambda: log_spot_paths)
            written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
            return (i + 1, written_count, next_vol, next_log_spot, vol_paths,
                    log_spot_paths)
def resize_and_extract(image, target_size, random_centering):
    """Upscale image to target_size (>image.size), extract original size crop."""
    original_shape = image.shape
    size = tf.reshape(target_size, [1])
    size = tf.concat([size, size], axis=0)
    image = tf.image.resize(image, size=size)
    pad_size = target_size - original_shape[1]
    pad_size_left, pad_size_right = _make_padding_sizes(
        pad_size, random_centering)
    if len(original_shape) == 3:
        image = tf.expand_dims(image, 0)
    image = tf.cond(pad_size_right > 0,
                    lambda: image[:, pad_size_left:-pad_size_right, :, :],
                    lambda: image[:, pad_size_left:, :, :])
    image = tf.cond(pad_size_right > 0,
                    lambda: image[:, :, pad_size_left:-pad_size_right, :],
                    lambda: image[:, :, pad_size_left:, :])
    if len(original_shape) == 3:
        image = tf.squeeze(image, 0)
    image.set_shape(original_shape)
    return image
示例#20
0
 def __call__(self, step):
     with tf.name_scope(self.name or 'WarmUp') as name:
         # Implements linear warmup. i.e., if global_step < warmup_steps, the
         # learning rate will be `global_step/num_warmup_steps * init_lr`.
         global_step_float = tf.cast(step, tf.float32)
         warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
         warmup_percent_done = global_step_float / warmup_steps_float
         warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
         return tf.cond(global_step_float < warmup_steps_float,
                        lambda: warmup_learning_rate,
                        lambda: self.decay_schedule_fn(step),
                        name=name)
示例#21
0
    def __call__(self, step: int):
        """Compute learning rate at given step."""
        def warmup_lr():
            return self._rescaled_lr * (
                step / tf.cast(self._warmup_steps, tf.float32))

        def piecewise_lr():
            return tf.compat.v1.train.piecewise_constant(
                tf.cast(step, tf.float32), self._step_boundaries,
                self._lr_values)

        return tf.cond(step < self._warmup_steps, warmup_lr, piecewise_lr)
示例#22
0
    def __call__(self, step):
        with tf.name_scope(self.name or "SGDRDecay") as name:
            initial_learning_rate = tf.convert_to_tensor(
                self.initial_learning_rate, name="initial_learning_rate"
            )
            dtype = initial_learning_rate.dtype
            first_decay_steps = tf.cast(self.first_decay_steps, dtype)
            alpha = tf.cast(self.alpha, dtype)
            t_mul = tf.cast(self._t_mul, dtype)
            m_mul = tf.cast(self._m_mul, dtype)

            global_step_recomp = tf.cast(step, dtype)
            completed_fraction = global_step_recomp / first_decay_steps

            def compute_step(completed_fraction, geometric=False):
                """Helper for `cond` operation."""
                if geometric:
                    i_restart = tf.floor(
                        tf.math.log(1.0 - completed_fraction * (1.0 - t_mul))
                        / tf.math.log(t_mul)
                    )

                    sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
                    completed_fraction = (
                        completed_fraction - sum_r
                    ) / t_mul**i_restart

                else:
                    i_restart = tf.floor(completed_fraction)
                    completed_fraction -= i_restart

                return i_restart, completed_fraction

            i_restart, completed_fraction = tf.cond(
                tf.equal(t_mul, 1.0),
                lambda: compute_step(completed_fraction, geometric=False),
                lambda: compute_step(completed_fraction, geometric=True),
            )

            m_fac = m_mul**i_restart
            cosine_decayed = (
                0.5
                * m_fac
                * (
                    1.0
                    + tf.cos(
                        tf.constant(math.pi, dtype=dtype) * completed_fraction
                    )
                )
            )
            decayed = (1 - alpha) * cosine_decayed + alpha

            return tf.multiply(initial_learning_rate, decayed, name=name)
示例#23
0
def select_and_apply_random_policy(policies: Any, image: tf.Tensor):
    """Select a random policy from `policies` and apply it to `image`."""
    policy_to_select = tf.random.uniform([],
                                         maxval=len(policies),
                                         dtype=tf.int32)
    # Note that using tf.case instead of tf.conds would result in significantly
    # larger graphs and would even break export for some larger policies.
    for (i, policy) in enumerate(policies):
        image = tf.cond(tf.equal(i, policy_to_select),
                        lambda selected_policy=policy: selected_policy(image),
                        lambda: image)
    return image
示例#24
0
    def update(self, grads):
        """Updates the value of the loss scale.

    Args:
      grads: A nested structure of unscaled gradients, each which is an
        all-reduced gradient of the loss with respect to a weight.

    Returns:
      update_op: In eager mode, None. In graph mode, an op to update the loss
        scale.
      should_apply_gradients: Either a bool or a scalar boolean tensor. If
        False, the caller should skip applying `grads` to the variables this
        step.
    """
        grads = tf.nest.flatten(grads)
        if tf.distribute.has_strategy(
        ) and tf.distribute.in_cross_replica_context():
            distribution = tf.distribute.get_strategy()
            is_finite_per_replica = distribution.extended.call_for_each_replica(
                _is_all_finite, args=(grads, ))
            # Each replica computed the same `is_finite` value, since `grads` is
            # all-reduced across replicas. Arbitrarily take `is_finite` from the first
            # replica.
            is_finite = (distribution.experimental_local_results(
                is_finite_per_replica)[0])
        else:
            is_finite = _is_all_finite(grads)

        def update_if_finite_grads():
            """Update assuming the gradients are finite."""
            def incr_loss_scale():
                new_loss_scale = self.current_loss_scale * self.multiplier
                return tf.group(
                    _assign_if_finite(self.current_loss_scale, new_loss_scale),
                    self.counter.assign(0))

            return tf.cond(
                self.counter + 1 >= self.growth_steps, incr_loss_scale,
                lambda: _op_in_graph_mode(self.counter.assign_add(1)))

        def update_if_not_finite_grads():
            """Update assuming the gradients are nonfinite."""

            new_loss_scale = tf.maximum(
                self.current_loss_scale / self.multiplier, 1)
            return tf.group(self.counter.assign(0),
                            self.current_loss_scale.assign(new_loss_scale))

        update_op = tf.cond(is_finite, update_if_finite_grads,
                            update_if_not_finite_grads)
        should_apply_gradients = is_finite
        return update_op, should_apply_gradients
示例#25
0
文件: lstm.py 项目: huaxz1986/keras
  def gpu_lstm_with_fallback(inputs, init_h, init_c, kernel, recurrent_kernel,
                             bias, mask, time_major, go_backwards,
                             sequence_lengths, zero_output_for_mask,
                             return_sequences):
    """Use cuDNN kernel when mask is none or strictly right padded."""
    if mask is None:
      return gpu_lstm(
          inputs=inputs,
          init_h=init_h,
          init_c=init_c,
          kernel=kernel,
          recurrent_kernel=recurrent_kernel,
          bias=bias,
          mask=mask,
          time_major=time_major,
          go_backwards=go_backwards,
          sequence_lengths=sequence_lengths,
          return_sequences=return_sequences)

    def cudnn_lstm_fn():
      return gpu_lstm(
          inputs=inputs,
          init_h=init_h,
          init_c=init_c,
          kernel=kernel,
          recurrent_kernel=recurrent_kernel,
          bias=bias,
          mask=mask,
          time_major=time_major,
          go_backwards=go_backwards,
          sequence_lengths=sequence_lengths,
          return_sequences=return_sequences)

    def stardard_lstm_fn():
      return standard_lstm(
          inputs=inputs,
          init_h=init_h,
          init_c=init_c,
          kernel=kernel,
          recurrent_kernel=recurrent_kernel,
          bias=bias,
          mask=mask,
          time_major=time_major,
          go_backwards=go_backwards,
          sequence_lengths=sequence_lengths,
          zero_output_for_mask=zero_output_for_mask,
          return_sequences=return_sequences)

    return tf.cond(
        gru_lstm_utils.is_cudnn_supported_inputs(mask, time_major),
        true_fn=cudnn_lstm_fn,
        false_fn=stardard_lstm_fn)
示例#26
0
        def body_fn(i, written_count, current_var, current_log_spot, vol_paths,
                    log_spot_paths):
            """Simulate Heston process to the next time point."""
            time_step = dt[i]

            def _next_vol_fn():
                return _update_variance(i, kappa[i], theta[i], epsilon[i],
                                        rho[i], current_var, time_step,
                                        num_samples, random_type, seed)

            # Do not update variance if `time_step > tolerance`
            next_vol = tf.cond(
                time_step > tolerance,
                lambda: _next_vol_fn(),  # pylint: disable=unnecessary-lambda
                lambda: current_var)

            def _next_log_spot_fn():
                return _update_log_spot(i, kappa[i], theta[i], epsilon[i],
                                        rho[i], current_var, next_vol,
                                        current_log_spot, time_step,
                                        num_samples, random_type, seed)

            # Do not update state if `time_step > tolerance`
            next_log_spot = tf.cond(
                time_step > tolerance,
                lambda: _next_log_spot_fn(),  # pylint: disable=unnecessary-lambda
                lambda: current_log_spot)

            vol_paths = tf.cond(
                keep_mask[i + 1],
                lambda: vol_paths.write(written_count, next_vol),
                lambda: vol_paths)
            log_spot_paths = tf.cond(
                keep_mask[i + 1],
                lambda: log_spot_paths.write(written_count, next_log_spot),
                lambda: log_spot_paths)
            written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
            return (i + 1, written_count, next_vol, next_log_spot, vol_paths,
                    log_spot_paths)
示例#27
0
 def __call__(self, step: int):
     lr = self._lr_schedule(step)
     if self._warmup_steps:
         initial_learning_rate = tf.convert_to_tensor(
             self._lr_schedule.initial_learning_rate,
             name="initial_learning_rate")
         dtype = initial_learning_rate.dtype
         global_step_recomp = tf.cast(step, dtype)
         warmup_steps = tf.cast(self._warmup_steps, dtype)
         warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps
         lr = tf.cond(global_step_recomp < warmup_steps, lambda: warmup_lr,
                      lambda: lr)
     return lr
  def maybe_run_update_step(self):
    """Creates TensorFlow update op for compression."""

    def maybe_update_alpha():
      """Maybe update the alpha param.

      Checks if global_step is between begin_compression_step and
      end_compression_step, and if the current training step is a
      compression step.

      Returns:
        Boolean tensor whether the training step is a compression step.
      """
      is_step_within_compression_range = tf.logical_and(
          tf.greater_equal(
              tf.cast(self._global_step, tf.int32),
              self._spec.begin_compression_step),
          tf.logical_or(
              tf.less_equal(
                  tf.cast(self._global_step, tf.int32),
                  self._spec.end_compression_step),
              tf.less(self._spec.end_compression_step, 0)))
      is_compression_step = tf.less_equal(
          tf.add(self.last_alpha_update_step, self._spec.compression_frequency),
          tf.cast(self._global_step, tf.int32))
      return tf.logical_and(is_step_within_compression_range,
                            is_compression_step)

    def no_update_op():
      pass

    def compressor_and_alpha_update_op_fn():
      return self._compressor_and_alpha_update_op()

    tf.cond(
        pred=maybe_update_alpha(),
        true_fn=compressor_and_alpha_update_op_fn,
        false_fn=no_update_op)
    return
示例#29
0
        def apply_transform(i, x):
            """Apply the i-th transformation."""
            def brightness_foo():
                if brightness == 0:
                    return x
                else:
                    return random_brightness(x,
                                             max_delta=brightness,
                                             impl=impl)

            def contrast_foo():
                if contrast == 0:
                    return x
                else:
                    return tf.image.random_contrast(x,
                                                    lower=1 - contrast,
                                                    upper=1 + contrast)

            def saturation_foo():
                if saturation == 0:
                    return x
                else:
                    return tf.image.random_saturation(x,
                                                      lower=1 - saturation,
                                                      upper=1 + saturation)

            def hue_foo():
                if hue == 0:
                    return x
                else:
                    return tf.image.random_hue(x, max_delta=hue)

            x = tf.cond(
                tf.less(i, 2),
                lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
                lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo),
            )
            return x
示例#30
0
    def decode_batch_example(self, tfexample_data):
        """Decode multiple features batched in a single tf.Tensor.

    This function is used to decode features wrapped in
    `tfds.features.Sequence()`.
    By default, this function apply `decode_example` on each individual
    elements using `tf.map_fn`. However, for optimization, features can
    overwrite this method to apply a custom batch decoding.

    Args:
      tfexample_data: Same `tf.Tensor` inputs as `decode_example`, but with
        and additional first dimension for the sequence length.

    Returns:
      tensor_data: Tensor or dictionary of tensor, output of the tf.data.Dataset
        object
    """
        ex = tfexample_data

        # Note: This all works fine in Eager mode (without tf.function) because
        # tf.data pipelines are always executed in Graph mode.

        # Apply the decoding to each of the individual distributed features.
        decode_map_fn = functools.partial(
            tf.map_fn,
            self.decode_example,
            fn_output_signature=self.dtype,
            parallel_iterations=10,
            name='sequence_decode',
        )

        if (
                # input/output could potentially be a `dict` for custom feature
                # connectors. Empty length not supported for those for now.
                isinstance(ex, dict) or isinstance(self.shape, dict)
                or not _has_shape_ambiguity(in_shape=ex.shape,
                                            out_shape=self.shape)):
            return decode_map_fn(ex)
        else:
            # `tf.map_fn` cannot resolve ambiguity when decoding an empty sequence
            # with unknown output shape (e.g. decode images `tf.string`):
            # `(0,)` -> `(0, None, None, 3)`.
            # Instead, we arbitrarily set unknown shape to `0`:
            # `(0,)` -> `(0, 0, 0, 3)`
            return tf.cond(
                tf.equal(tf.shape(ex)[0], 0),  # Empty sequence
                lambda: _make_empty_seq_output(shape=self.shape,
                                               dtype=self.dtype),
                lambda: decode_map_fn(ex),
            )