def _decode_areas(parsed_tensors): xmin = parsed_tensors['image/object/bbox/xmin'] xmax = parsed_tensors['image/object/bbox/xmax'] ymin = parsed_tensors['image/object/bbox/ymin'] ymax = parsed_tensors['image/object/bbox/ymax'] return tf.cond( tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0), lambda: parsed_tensors['image/object/area'], lambda: (xmax - xmin) * (ymax - ymin))
def _stability_limit_tensor(total_count, dtype): limit = tf.cast(BATES_TOTAL_COUNT_STABILITY_LIMITS[dtype], dtype) return tf.cond( tf.math.reduce_any(total_count > limit), # pylint: disable=g-long-lambda lambda: tf.print( 'WARNING: Bates PDF/CDF is unstable for `total_count` >', limit, output_stream=sys.stderr), tf.no_op)
def cond(pred, true_fn, false_fn): """A version of tf.cond that tries to evaluate the condition.""" v = get_static_value(pred) if v is None: return tf.cond(pred, true_fn, false_fn) if v: return true_fn() else: return false_fn()
def update_step(self, gradient, variable): """Update step given gradient and the associated model variable.""" var_dtype = variable.dtype lr = tf.cast(self.learning_rate, var_dtype) local_step = tf.cast(self.iterations + 1, var_dtype) next_step = tf.cast(self.iterations + 2, var_dtype) decay = tf.cast(0.96, var_dtype) beta_1 = tf.cast(self.beta_1, var_dtype) beta_2 = tf.cast(self.beta_2, var_dtype) u_t = beta_1 * (1. - 0.5 * (tf.pow(decay, local_step))) u_t_1 = beta_1 * (1. - 0.5 * (tf.pow(decay, next_step))) def get_cached_u_product(): return self._u_product def compute_new_u_product(): u_product_t = self._u_product * u_t self._u_product.assign(u_product_t) self._u_product_counter += 1 return u_product_t u_product_t = tf.cond( self._u_product_counter == (self.iterations + 2), true_fn=get_cached_u_product, false_fn=compute_new_u_product) u_product_t_1 = u_product_t * u_t_1 beta_2_power = tf.pow(beta_2, local_step) var_key = self._var_key(variable) m = self._momentums[self._index_dict[var_key]] v = self._velocities[self._index_dict[var_key]] if isinstance(gradient, tf.IndexedSlices): # Sparse gradients. m.assign_add(-m * (1 - beta_1)) m.scatter_add( tf.IndexedSlices(gradient.values * (1 - beta_1), gradient.indices)) v.assign_add(-v * (1 - beta_2)) v.scatter_add( tf.IndexedSlices( tf.square(gradient.values) * (1 - beta_2), gradient.indices)) m_hat = ( u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / (1 - u_product_t)) v_hat = v / (1 - beta_2_power) variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon)) else: # Dense gradients. m.assign_add((gradient - m) * (1 - beta_1)) v.assign_add((tf.square(gradient) - v) * (1 - beta_2)) m_hat = ( u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / (1 - u_product_t)) v_hat = v / (1 - beta_2_power) variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon))
def _do_scale(image, size): """Rescale the image by scaling the smaller spatial dimension to `size`.""" shape = tf.cast(tf.shape(image), tf.float32) w_greater = tf.greater(shape[0], shape[1]) shape = tf.cond(w_greater, lambda: tf.cast([shape[0] / shape[1] * size, size], tf.int32), lambda: tf.cast([size, shape[1] / shape[0] * size], tf.int32)) return tf.image.resize([image], shape, method='bicubic')[0]
def process_source_id(source_id): """Processes source_id to the right format.""" if source_id.dtype == tf.string: source_id = tf.cast(tf.strings.to_number(source_id), tf.int64) with tf.control_dependencies([source_id]): source_id = tf.cond(pred=tf.equal(tf.size(input=source_id), 0), true_fn=lambda: tf.cast(tf.constant(-1), tf.int64), false_fn=lambda: tf.identity(source_id)) return source_id
def decode(self, serialized_example): """Decode the serialized example. Args: serialized_example: a single serialized tf.Example string. Returns: decoded_tensors: a dictionary of tensors with the following fields: - image: a uint8 tensor of shape [None, None, 3]. - source_id: a string scalar tensor. - height: an integer scalar tensor. - width: an integer scalar tensor. - groundtruth_classes: a int64 tensor of shape [None]. - groundtruth_is_crowd: a bool tensor of shape [None]. - groundtruth_area: a float32 tensor of shape [None]. - groundtruth_boxes: a float32 tensor of shape [None, 4]. - groundtruth_instance_masks: a float32 tensor of shape [None, None, None]. - groundtruth_instance_masks_png: a string tensor of shape [None]. """ parsed_tensors = tf.io.parse_single_example( serialized=serialized_example, features=self._keys_to_features) for k in parsed_tensors: if isinstance(parsed_tensors[k], tf.SparseTensor): if parsed_tensors[k].dtype == tf.string: parsed_tensors[k] = tf.sparse.to_dense( parsed_tensors[k], default_value='') else: parsed_tensors[k] = tf.sparse.to_dense( parsed_tensors[k], default_value=0) image = self._decode_image(parsed_tensors) boxes = self._decode_boxes(parsed_tensors) areas = self._decode_areas(parsed_tensors) is_crowds = tf.cond( tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0), lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool), lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool)) # pylint: disable=line-too-long if self._include_mask: masks = self._decode_masks(parsed_tensors) decoded_tensors = { 'image': image, 'source_id': parsed_tensors['image/source_id'], 'height': parsed_tensors['image/height'], 'width': parsed_tensors['image/width'], 'groundtruth_classes': parsed_tensors['image/object/class/label'], 'groundtruth_is_crowd': is_crowds, 'groundtruth_area': areas, 'groundtruth_boxes': boxes, } if self._include_mask: decoded_tensors.update({ 'groundtruth_instance_masks': masks, 'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'], }) return decoded_tensors
def body_fn(i, written_count, current_vol, current_log_spot, vol_paths, log_spot_paths): """Simulate Heston process to the next time point.""" time_step = dt[i] if normal_draws is None: normals = random.mv_normal_sample( (num_samples, ), mean=tf.zeros([2], dtype=mean_reversion.dtype), seed=seed) else: normals = normal_draws[i] def _next_vol_fn(): return _update_variance(mean_reversion[i], theta[i], volvol[i], rho[i], current_vol, time_step, normals[..., 0]) # Do not update variance if `time_step > tolerance` next_vol = tf.cond(time_step > tolerance, _next_vol_fn, lambda: current_vol) def _next_log_spot_fn(): return _update_log_spot(mean_reversion[i], theta[i], volvol[i], rho[i], current_vol, next_vol, current_log_spot, time_step, normals[..., 1]) # Do not update state if `time_step > tolerance` next_log_spot = tf.cond(time_step > tolerance, _next_log_spot_fn, lambda: current_log_spot) if record_samples: # Update volatility paths vol_paths = vol_paths.write(written_count, next_vol) # Update log-spot paths log_spot_paths = log_spot_paths.write(written_count, next_log_spot) else: vol_paths = next_vol log_spot_paths = next_log_spot written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_vol, next_log_spot, vol_paths, log_spot_paths)
def body_fn(i, written_count, current_vol, current_log_spot, vol_paths, log_spot_paths): """Simulate Heston process to the next time point.""" time_step = dt[i] if normal_draws is None: normals = random.mv_normal_sample( (num_samples,), mean=tf.zeros([3], dtype=kappa.dtype), seed=seed) else: normals = normal_draws[i] def _next_vol_fn(): return _update_variance( kappa[i], theta[i], epsilon[i], rho[i], current_vol, time_step, normals[..., :2]) # Do not update variance if `time_step > tolerance` next_vol = tf.cond(time_step > tolerance, _next_vol_fn, lambda: current_vol) def _next_log_spot_fn(): return _update_log_spot( kappa[i], theta[i], epsilon[i], rho[i], current_vol, next_vol, current_log_spot, time_step, normals[..., -1]) # Do not update state if `time_step > tolerance` next_log_spot = tf.cond(time_step > tolerance, _next_log_spot_fn, lambda: current_log_spot) # Update volatility paths vol_paths = utils.maybe_update_along_axis( tensor=vol_paths, do_update=keep_mask[i + 1], ind=written_count, axis=1, new_tensor=tf.expand_dims(next_vol, axis=1)) # Update log-spot paths log_spot_paths = utils.maybe_update_along_axis( tensor=log_spot_paths, do_update=keep_mask[i + 1], ind=written_count, axis=1, new_tensor=tf.expand_dims(next_log_spot, axis=1)) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_vol, next_log_spot, vol_paths, log_spot_paths)
def _get_reset_state(self, observation, done, default_state): """Resets the state wherever marked in `done` tensor. Consider the following example with num_timesteps=2, batch_size=3, state_size=1: default_state (batch_size, state_size) = [[5.], [5.], [5.]] done (num_timesteps, batch_size) = [[True, True, False], [False, True, False]] observation (num_timesteps, batch_size, 1) = [[[1.], [2.], [3.]], [[4.], [5.], [6.]]] self.get_initial_state implements `observation + 10`. then returned tensor will be of shape (num_timesteps, batch_size, state_size) and its value will be: [[[11.], [12.], [0.]], [[0.], [15.], [0.]]] where state values are replaced by call to `self.get_initial_state` wherever done=True. Note that the state values where done=False are set to zeros and are expected not to be used by the caller. Args: observation: A nested structure with individual tensors that have first two dimensions equal to [num_timesteps, batch_size]. done: A boolean tensor of shape [num_timesteps, batch_size]. default_state: A tensor or nested structure with individual tensors that have first dimension equal to batch_size and no time dimension. Returns: A structure similar to `default_state` except that all tensors in the returned structure have an additional leading dimension equal to num_timesteps. """ reset_indices = tf.compat.v1.where(tf.equal(done, True)) def _get_reset_state_indices(): reset_indices_obs = tf.nest.map_structure( lambda t: tf.gather_nd(t, reset_indices), observation) # shape: [num_indices_to_reset, ...] reset_indices_state = self.get_initial_state( reset_indices_obs, batch_size=tf.shape(reset_indices)[0]) # Scatter tensors in `reset_indices_state` to shape: [num_timesteps, # batch_size, ...] return tf.nest.map_structure( lambda reset_tensor: tf.scatter_nd(indices=reset_indices, updates=reset_tensor, shape=done.shape.as_list() + reset_tensor.shape.as_list( )[1:]), reset_indices_state) # A minor optimization wherein if all elements in `done` are False, we # simply return a structure with zeros tensors of correct shape. return tf.cond( tf.greater(tf.size(reset_indices), 0), _get_reset_state_indices, lambda: tf.nest.map_structure( lambda t: tf.zeros(shape=done.shape.as_list() + t.shape. as_list()[1:], dtype=t.dtype), default_state))
def state_y(self, t): """Computes the state variable `y(t)` for tha Gaussian HJM Model. For Gaussian HJM model, the state parameter y(t), can be analytically computed as follows: y_ij(t) = exp(-k_i * t) * exp(-k_j * t) * ( int_0^t rho_ij * sigma_i(u) * sigma_j(u) * du) Args: t: A rank 1 real `Tensor` of shape `[num_times]` specifying the time `t`. Returns: A real `Tensor` of shape [self._factors, self._factors, num_times] containing the computed y_ij(t). """ t = tf.convert_to_tensor(t, dtype=self._dtype) t_shape = tf.shape(t) t = tf.broadcast_to(t, tf.concat([[self._dim], t_shape], axis=0)) time_index = tf.searchsorted(self._jump_locations, t) # create a matrix k2(i,j) = k(i) + k(j) mr2 = tf.expand_dims(self._mean_reversion, axis=-1) # Add a dimension corresponding to `num_times` mr2 = tf.expand_dims(mr2 + tf.transpose(mr2), axis=-1) def _integrate_volatility_squared(vol, l_limit, u_limit): # create sigma2_ij = sigma_i * sigma_j vol = tf.expand_dims(vol, axis=-2) vol_squared = tf.expand_dims( self._rho, axis=-1) * (vol * tf.transpose(vol, perm=[1, 0, 2])) return vol_squared / mr2 * (tf.math.exp(mr2 * u_limit) - tf.math.exp(mr2 * l_limit)) is_constant_vol = tf.math.equal(tf.shape(self._jump_values_vol)[-1], 0) v_squared_between_vol_knots = tf.cond( is_constant_vol, lambda: tf.zeros(shape=(self._dim, self._dim, 0), dtype=self._dtype), lambda: _integrate_volatility_squared( # pylint: disable=g-long-lambda self._jump_values_vol, self._padded_knots, self._jump_locations )) v_squared_at_vol_knots = tf.concat([ tf.zeros((self._dim, self._dim, 1), dtype=self._dtype), utils.cumsum_using_matvec(v_squared_between_vol_knots) ], axis=-1) vn = tf.concat([self._zero_padding, self._jump_locations], axis=1) v_squared_t = _integrate_volatility_squared( self._volatility(t), tf.gather(vn, time_index, batch_dims=1), t) v_squared_t += tf.gather(v_squared_at_vol_knots, time_index, batch_dims=-1) return tf.math.exp(-mr2 * t) * v_squared_t
def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float): """Apply `func` to image w/ `args` as input with probability `prob`.""" assert isinstance(args, tuple) # Apply the function with probability `prob`. should_apply_op = tf.cast( tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool) augmented_image = tf.cond(should_apply_op, lambda: func(image, *args), lambda: image) return augmented_image
def alpha(self, value): value = tf.convert_to_tensor(value, self.dtype) def get_logit_alpha(): a = tf.clip_by_value(value / 4., 0., 1.) logit_alpha = tf.math.log(a / (1. - a)) return logit_alpha self._logit_alpha.assign( tf.cond(value < 0, lambda: self._logit_alpha, get_logit_alpha))
def random_apply(func, p, x): """Randomly apply function func to x with probability p.""" return tf.cond( tf.less( tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32), tf.cast(p, tf.float32), ), lambda: func(x), lambda: x, )
def update_state(self, labels, probabilities, **kwargs): """Updates this metric. This will flatten the labels and probabilities, and then compute the ECE over all predictions. Args: labels: Tensor of shape [..., ] of class labels in [0, k-1]. probabilities: Tensor of shape [..., ], [..., 1] or [..., k] of normalized probabilities associated with the True class in the binary case, or with each of k classes in the multiclass case. **kwargs: Other potential keywords, which will be ignored by this method. """ del kwargs # unused labels = tf.convert_to_tensor(labels) probabilities = tf.cast(probabilities, self.dtype) # Flatten labels to [N, ] and probabilities to [N, 1] or [N, k]. if tf.rank(labels) != 1: labels = tf.reshape(labels, [-1]) if tf.rank(probabilities) != 2 or (tf.shape(probabilities)[0] != tf.shape(labels)[0]): probabilities = tf.reshape(probabilities, [tf.shape(labels)[0], -1]) # Extend any probabilities of shape [N, 1] to shape [N, 2]. # NOTE: XLA does not allow for different shapes in the branches of a # conditional statement. Therefore, explicit indexing is used. given_k = tf.shape(probabilities)[-1] k = tf.math.maximum(2, given_k) probabilities = tf.cond( given_k < 2, lambda: tf.concat([1. - probabilities, probabilities], axis=-1)[:, -k:], lambda: probabilities) pred_labels = tf.math.argmax(probabilities, axis=-1) pred_probs = tf.math.reduce_max(probabilities, axis=-1) correct_preds = tf.math.equal(pred_labels, tf.cast(labels, pred_labels.dtype)) correct_preds = tf.cast(correct_preds, self.dtype) bin_indices = tf.histogram_fixed_width_bins( pred_probs, tf.constant([0., 1.], self.dtype), nbins=self.num_bins) batch_correct_sums = tf.math.unsorted_segment_sum( data=tf.cast(correct_preds, self.dtype), segment_ids=bin_indices, num_segments=self.num_bins) batch_prob_sums = tf.math.unsorted_segment_sum(data=pred_probs, segment_ids=bin_indices, num_segments=self.num_bins) batch_counts = tf.math.unsorted_segment_sum(data=tf.ones_like(bin_indices), segment_ids=bin_indices, num_segments=self.num_bins) batch_counts = tf.cast(batch_counts, self.dtype) self.correct_sums.assign_add(batch_correct_sums) self.prob_sums.assign_add(batch_prob_sums) self.counts.assign_add(batch_counts)
def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = self.current_loss_scale * self.multiplier return tf.group( _assign_if_finite(self.current_loss_scale, new_loss_scale), self.counter.assign(0)) return tf.cond( self.counter + 1 >= self.growth_steps, incr_loss_scale, lambda: _op_in_graph_mode(self.counter.assign_add(1)))
def apply_randomization(features, label, randomize_prob): """Randomize each categorical feature with some probability.""" rnd_tok = lambda: tf.as_string(tf.random.uniform([], 0, 99999999, tf.int32)) for idx in CAT_FEATURE_INDICES: key = feature_name(idx) # Ignore lint since tf.cond should evaluate lambda immediately. features[key] = tf.cond(tf.random.uniform([]) < randomize_prob, rnd_tok, lambda: features[key]) # pylint: disable=cell-var-from-loop return features, label
def body_fn(i, written_count, current_vol, current_log_spot, vol_paths, log_spot_paths): """Simulate Heston process to the next time point.""" time_step = dt[i] if normal_draws is None: normals = random.mv_normal_sample( (num_samples, ), mean=tf.zeros([3], dtype=kappa.dtype), seed=seed) else: normals = normal_draws[i] def _next_vol_fn(): return _update_variance(kappa[i], theta[i], epsilon[i], rho[i], current_vol, time_step, normals[..., :2]) # Do not update variance if `time_step > tolerance` next_vol = tf.cond(time_step > tolerance, _next_vol_fn, lambda: current_vol) def _next_log_spot_fn(): return _update_log_spot(kappa[i], theta[i], epsilon[i], rho[i], current_vol, next_vol, current_log_spot, time_step, normals[..., -1]) # Do not update state if `time_step > tolerance` next_log_spot = tf.cond(time_step > tolerance, _next_log_spot_fn, lambda: current_log_spot) vol_paths = tf.cond( keep_mask[i + 1], lambda: vol_paths.write(written_count, next_vol), lambda: vol_paths) log_spot_paths = tf.cond( keep_mask[i + 1], lambda: log_spot_paths.write(written_count, next_log_spot), lambda: log_spot_paths) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_vol, next_log_spot, vol_paths, log_spot_paths)
def resize_and_extract(image, target_size, random_centering): """Upscale image to target_size (>image.size), extract original size crop.""" original_shape = image.shape size = tf.reshape(target_size, [1]) size = tf.concat([size, size], axis=0) image = tf.image.resize(image, size=size) pad_size = target_size - original_shape[1] pad_size_left, pad_size_right = _make_padding_sizes( pad_size, random_centering) if len(original_shape) == 3: image = tf.expand_dims(image, 0) image = tf.cond(pad_size_right > 0, lambda: image[:, pad_size_left:-pad_size_right, :, :], lambda: image[:, pad_size_left:, :, :]) image = tf.cond(pad_size_right > 0, lambda: image[:, :, pad_size_left:-pad_size_right, :], lambda: image[:, :, pad_size_left:, :]) if len(original_shape) == 3: image = tf.squeeze(image, 0) image.set_shape(original_shape) return image
def __call__(self, step): with tf.name_scope(self.name or 'WarmUp') as name: # Implements linear warmup. i.e., if global_step < warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. global_step_float = tf.cast(step, tf.float32) warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) warmup_percent_done = global_step_float / warmup_steps_float warmup_learning_rate = self.initial_learning_rate * warmup_percent_done return tf.cond(global_step_float < warmup_steps_float, lambda: warmup_learning_rate, lambda: self.decay_schedule_fn(step), name=name)
def __call__(self, step: int): """Compute learning rate at given step.""" def warmup_lr(): return self._rescaled_lr * ( step / tf.cast(self._warmup_steps, tf.float32)) def piecewise_lr(): return tf.compat.v1.train.piecewise_constant( tf.cast(step, tf.float32), self._step_boundaries, self._lr_values) return tf.cond(step < self._warmup_steps, warmup_lr, piecewise_lr)
def __call__(self, step): with tf.name_scope(self.name or "SGDRDecay") as name: initial_learning_rate = tf.convert_to_tensor( self.initial_learning_rate, name="initial_learning_rate" ) dtype = initial_learning_rate.dtype first_decay_steps = tf.cast(self.first_decay_steps, dtype) alpha = tf.cast(self.alpha, dtype) t_mul = tf.cast(self._t_mul, dtype) m_mul = tf.cast(self._m_mul, dtype) global_step_recomp = tf.cast(step, dtype) completed_fraction = global_step_recomp / first_decay_steps def compute_step(completed_fraction, geometric=False): """Helper for `cond` operation.""" if geometric: i_restart = tf.floor( tf.math.log(1.0 - completed_fraction * (1.0 - t_mul)) / tf.math.log(t_mul) ) sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) completed_fraction = ( completed_fraction - sum_r ) / t_mul**i_restart else: i_restart = tf.floor(completed_fraction) completed_fraction -= i_restart return i_restart, completed_fraction i_restart, completed_fraction = tf.cond( tf.equal(t_mul, 1.0), lambda: compute_step(completed_fraction, geometric=False), lambda: compute_step(completed_fraction, geometric=True), ) m_fac = m_mul**i_restart cosine_decayed = ( 0.5 * m_fac * ( 1.0 + tf.cos( tf.constant(math.pi, dtype=dtype) * completed_fraction ) ) ) decayed = (1 - alpha) * cosine_decayed + alpha return tf.multiply(initial_learning_rate, decayed, name=name)
def select_and_apply_random_policy(policies: Any, image: tf.Tensor): """Select a random policy from `policies` and apply it to `image`.""" policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32) # Note that using tf.case instead of tf.conds would result in significantly # larger graphs and would even break export for some larger policies. for (i, policy) in enumerate(policies): image = tf.cond(tf.equal(i, policy_to_select), lambda selected_policy=policy: selected_policy(image), lambda: image) return image
def update(self, grads): """Updates the value of the loss scale. Args: grads: A nested structure of unscaled gradients, each which is an all-reduced gradient of the loss with respect to a weight. Returns: update_op: In eager mode, None. In graph mode, an op to update the loss scale. should_apply_gradients: Either a bool or a scalar boolean tensor. If False, the caller should skip applying `grads` to the variables this step. """ grads = tf.nest.flatten(grads) if tf.distribute.has_strategy( ) and tf.distribute.in_cross_replica_context(): distribution = tf.distribute.get_strategy() is_finite_per_replica = distribution.extended.call_for_each_replica( _is_all_finite, args=(grads, )) # Each replica computed the same `is_finite` value, since `grads` is # all-reduced across replicas. Arbitrarily take `is_finite` from the first # replica. is_finite = (distribution.experimental_local_results( is_finite_per_replica)[0]) else: is_finite = _is_all_finite(grads) def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = self.current_loss_scale * self.multiplier return tf.group( _assign_if_finite(self.current_loss_scale, new_loss_scale), self.counter.assign(0)) return tf.cond( self.counter + 1 >= self.growth_steps, incr_loss_scale, lambda: _op_in_graph_mode(self.counter.assign_add(1))) def update_if_not_finite_grads(): """Update assuming the gradients are nonfinite.""" new_loss_scale = tf.maximum( self.current_loss_scale / self.multiplier, 1) return tf.group(self.counter.assign(0), self.current_loss_scale.assign(new_loss_scale)) update_op = tf.cond(is_finite, update_if_finite_grads, update_if_not_finite_grads) should_apply_gradients = is_finite return update_op, should_apply_gradients
def gpu_lstm_with_fallback(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, sequence_lengths, zero_output_for_mask, return_sequences): """Use cuDNN kernel when mask is none or strictly right padded.""" if mask is None: return gpu_lstm( inputs=inputs, init_h=init_h, init_c=init_c, kernel=kernel, recurrent_kernel=recurrent_kernel, bias=bias, mask=mask, time_major=time_major, go_backwards=go_backwards, sequence_lengths=sequence_lengths, return_sequences=return_sequences) def cudnn_lstm_fn(): return gpu_lstm( inputs=inputs, init_h=init_h, init_c=init_c, kernel=kernel, recurrent_kernel=recurrent_kernel, bias=bias, mask=mask, time_major=time_major, go_backwards=go_backwards, sequence_lengths=sequence_lengths, return_sequences=return_sequences) def stardard_lstm_fn(): return standard_lstm( inputs=inputs, init_h=init_h, init_c=init_c, kernel=kernel, recurrent_kernel=recurrent_kernel, bias=bias, mask=mask, time_major=time_major, go_backwards=go_backwards, sequence_lengths=sequence_lengths, zero_output_for_mask=zero_output_for_mask, return_sequences=return_sequences) return tf.cond( gru_lstm_utils.is_cudnn_supported_inputs(mask, time_major), true_fn=cudnn_lstm_fn, false_fn=stardard_lstm_fn)
def body_fn(i, written_count, current_var, current_log_spot, vol_paths, log_spot_paths): """Simulate Heston process to the next time point.""" time_step = dt[i] def _next_vol_fn(): return _update_variance(i, kappa[i], theta[i], epsilon[i], rho[i], current_var, time_step, num_samples, random_type, seed) # Do not update variance if `time_step > tolerance` next_vol = tf.cond( time_step > tolerance, lambda: _next_vol_fn(), # pylint: disable=unnecessary-lambda lambda: current_var) def _next_log_spot_fn(): return _update_log_spot(i, kappa[i], theta[i], epsilon[i], rho[i], current_var, next_vol, current_log_spot, time_step, num_samples, random_type, seed) # Do not update state if `time_step > tolerance` next_log_spot = tf.cond( time_step > tolerance, lambda: _next_log_spot_fn(), # pylint: disable=unnecessary-lambda lambda: current_log_spot) vol_paths = tf.cond( keep_mask[i + 1], lambda: vol_paths.write(written_count, next_vol), lambda: vol_paths) log_spot_paths = tf.cond( keep_mask[i + 1], lambda: log_spot_paths.write(written_count, next_log_spot), lambda: log_spot_paths) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_vol, next_log_spot, vol_paths, log_spot_paths)
def __call__(self, step: int): lr = self._lr_schedule(step) if self._warmup_steps: initial_learning_rate = tf.convert_to_tensor( self._lr_schedule.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype global_step_recomp = tf.cast(step, dtype) warmup_steps = tf.cast(self._warmup_steps, dtype) warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps lr = tf.cond(global_step_recomp < warmup_steps, lambda: warmup_lr, lambda: lr) return lr
def maybe_run_update_step(self): """Creates TensorFlow update op for compression.""" def maybe_update_alpha(): """Maybe update the alpha param. Checks if global_step is between begin_compression_step and end_compression_step, and if the current training step is a compression step. Returns: Boolean tensor whether the training step is a compression step. """ is_step_within_compression_range = tf.logical_and( tf.greater_equal( tf.cast(self._global_step, tf.int32), self._spec.begin_compression_step), tf.logical_or( tf.less_equal( tf.cast(self._global_step, tf.int32), self._spec.end_compression_step), tf.less(self._spec.end_compression_step, 0))) is_compression_step = tf.less_equal( tf.add(self.last_alpha_update_step, self._spec.compression_frequency), tf.cast(self._global_step, tf.int32)) return tf.logical_and(is_step_within_compression_range, is_compression_step) def no_update_op(): pass def compressor_and_alpha_update_op_fn(): return self._compressor_and_alpha_update_op() tf.cond( pred=maybe_update_alpha(), true_fn=compressor_and_alpha_update_op_fn, false_fn=no_update_op) return
def apply_transform(i, x): """Apply the i-th transformation.""" def brightness_foo(): if brightness == 0: return x else: return random_brightness(x, max_delta=brightness, impl=impl) def contrast_foo(): if contrast == 0: return x else: return tf.image.random_contrast(x, lower=1 - contrast, upper=1 + contrast) def saturation_foo(): if saturation == 0: return x else: return tf.image.random_saturation(x, lower=1 - saturation, upper=1 + saturation) def hue_foo(): if hue == 0: return x else: return tf.image.random_hue(x, max_delta=hue) x = tf.cond( tf.less(i, 2), lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo), ) return x
def decode_batch_example(self, tfexample_data): """Decode multiple features batched in a single tf.Tensor. This function is used to decode features wrapped in `tfds.features.Sequence()`. By default, this function apply `decode_example` on each individual elements using `tf.map_fn`. However, for optimization, features can overwrite this method to apply a custom batch decoding. Args: tfexample_data: Same `tf.Tensor` inputs as `decode_example`, but with and additional first dimension for the sequence length. Returns: tensor_data: Tensor or dictionary of tensor, output of the tf.data.Dataset object """ ex = tfexample_data # Note: This all works fine in Eager mode (without tf.function) because # tf.data pipelines are always executed in Graph mode. # Apply the decoding to each of the individual distributed features. decode_map_fn = functools.partial( tf.map_fn, self.decode_example, fn_output_signature=self.dtype, parallel_iterations=10, name='sequence_decode', ) if ( # input/output could potentially be a `dict` for custom feature # connectors. Empty length not supported for those for now. isinstance(ex, dict) or isinstance(self.shape, dict) or not _has_shape_ambiguity(in_shape=ex.shape, out_shape=self.shape)): return decode_map_fn(ex) else: # `tf.map_fn` cannot resolve ambiguity when decoding an empty sequence # with unknown output shape (e.g. decode images `tf.string`): # `(0,)` -> `(0, None, None, 3)`. # Instead, we arbitrarily set unknown shape to `0`: # `(0,)` -> `(0, 0, 0, 3)` return tf.cond( tf.equal(tf.shape(ex)[0], 0), # Empty sequence lambda: _make_empty_seq_output(shape=self.shape, dtype=self.dtype), lambda: decode_map_fn(ex), )