def convolve(image, pixel_filter, channels=3, name=None): """Perform a 2D pixel convolution on the given image. Arguments: image: A 3D `float32` `Tensor` of shape `[height, width, channels]`, where `channels` is the third argument to this function and the first two dimensions are arbitrary. pixel_filter: A 2D `Tensor`, representing pixel weightings for the kernel. This will be used to create a 4D kernel---the extra two dimensions are for channels (see `tf.nn.conv2d` documentation), and the kernel will be constructed so that the channels are independent: each channel only observes the data from neighboring pixels of the same channel. channels: An integer representing the number of channels in the image (e.g., 3 for RGB). Returns: A 3D `float32` `Tensor` of the same shape as the input. """ with tf.name_scope(name, "convolve"): tf.assert_type(image, tf.float32) channel_filter = tf.eye(channels) filter_ = tf.expand_dims(tf.expand_dims(pixel_filter, -1), -1) * tf.expand_dims( tf.expand_dims(channel_filter, 0), 0) result_batch = tf.nn.conv2d( tf.stack([image]), # batch filter=filter_, strides=[1, 1, 1, 1], padding="SAME", ) return result_batch[0] # unbatch
def _get_grid_locations(image_height, image_width): """Wrapper for np.meshgrid.""" tfv1.assert_type(image_height, tf.int32) tfv1.assert_type(image_width, tf.int32) y_range = tf.range(image_height) x_range = tf.range(image_width) y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij') return tf.stack((y_grid, x_grid), -1)
def _buckets(data, bucket_count=None): """Create a TensorFlow op to group data into histogram buckets. Arguments: data: A `Tensor` of any shape. Must be castable to `float64`. bucket_count: Optional positive `int` or scalar `int32` `Tensor`. Returns: A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is a triple `[left_edge, right_edge, count]` for a single bucket. The value of `k` is either `bucket_count` or `1` or `0`. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if bucket_count is None: bucket_count = summary_v2.DEFAULT_BUCKET_COUNT with tf.name_scope('buckets', values=[data, bucket_count]), \ tf.control_dependencies([tf.assert_scalar(bucket_count), tf.assert_type(bucket_count, tf.int32)]): data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) is_empty = tf.equal(tf.size(input=data), 0) def when_empty(): return tf.constant([], shape=(0, 3), dtype=tf.float64) def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ is_singular = tf.equal(range_, 0) def when_nonsingular(): bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast(tf.floor(offsets / bucket_width), dtype=tf.int32) clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) one_hots = tf.one_hot(clamped_indices, depth=bucket_count) bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots, axis=0), dtype=tf.float64) edges = tf.linspace(min_, max_, bucket_count + 1) left_edges = edges[:-1] right_edges = edges[1:] return tf.transpose(a=tf.stack( [left_edges, right_edges, bucket_counts])) def when_singular(): center = min_ bucket_starts = tf.stack([center - 0.5]) bucket_ends = tf.stack([center + 0.5]) bucket_counts = tf.stack([tf.cast(tf.size(input=data), tf.float64)]) return tf.transpose( a=tf.stack([bucket_starts, bucket_ends, bucket_counts])) return tf.cond(is_singular, when_singular, when_nonsingular) return tf.cond(is_empty, when_empty, when_nonempty)
def op(name, images, max_outputs=3, display_name=None, description=None, collections=None): """Create a legacy image summary op for use in a TensorFlow graph. Arguments: name: A unique name for the generated summary node. images: A `Tensor` representing pixel data with shape `[k, h, w, c]`, where `k` is the number of images, `h` and `w` are the height and width of the images, and `c` is the number of channels, which should be 1, 3, or 4. Any of the dimensions may be statically unknown (i.e., `None`). max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this many images will be emitted at each step. When more than `max_outputs` many images are provided, the first `max_outputs` many images will be used and the rest silently discarded. display_name: Optional name for this summary in TensorBoard, as a constant `str`. Defaults to `name`. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[Graph Keys.SUMMARIES]`. Returns: A TensorFlow summary op. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if display_name is None: display_name = name summary_metadata = metadata.create_summary_metadata( display_name=display_name, description=description) with tf.name_scope(name), \ tf.control_dependencies([tf.assert_rank(images, 4), tf.assert_type(images, tf.uint8), tf.assert_non_negative(max_outputs)]): limited_images = images[:max_outputs] encoded_images = tf.map_fn(tf.image.encode_png, limited_images, dtype=tf.string, name='encode_each_image') image_shape = tf.shape(images) dimensions = tf.stack([ tf.as_string(image_shape[2], name='width'), tf.as_string(image_shape[1], name='height') ], name='dimensions') tensor = tf.concat([dimensions, encoded_images], axis=0) return tf.summary.tensor_summary(name='image_summary', tensor=tensor, collections=collections, summary_metadata=summary_metadata)
def nllfun(x, alpha, scale): r"""Implements the negative log-likelihood (NLL). Specifically, we implement -log(p(x | 0, \alpha, c) of Equation 16 in the paper as nllfun(x, alpha, shape). Args: x: The residual for which the NLL is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. Must be a tensorflow tensor or numpy array of floats. alpha: The shape parameter of the NLL (\alpha in the paper), where more negative values cause outliers to "cost" more and inliers to "cost" less. Alpha can be any non-negative value, but the gradient of the NLL with respect to alpha has singularities at 0 and 2 so you may want to limit usage to (0, 2) during gradient descent. Must be a tensorflow tensor or numpy array of floats. Varying alpha in that range allows for smooth interpolation between a Cauchy distribution (alpha = 0) and a Normal distribution (alpha = 2) similar to a Student's T distribution. scale: The scale parameter of the loss. When |x| < scale, the NLL is like that of a (possibly unnormalized) normal distribution, and when |x| > scale the NLL takes on a different shape according to alpha. Must be a tensorflow tensor or numpy array of floats. Returns: The NLLs for each element of x, in the same shape as x. This is returned as a TensorFlow graph node of floats with the same precision as x. """ # `scale` and `alpha` must have the same type as `x`. tf.assert_type(scale, x.dtype) tf.assert_type(alpha, x.dtype) assert_ops = [ # `scale` must be > 0. tf.Assert(tf.reduce_all(scale > 0.), [scale]), # `alpha` must be >= 0. tf.Assert(tf.reduce_all(alpha >= 0.), [alpha]), ] with tf.control_dependencies(assert_ops): loss = general.lossfun(x, alpha, scale, approximate=False) log_partition = tf.math.log(scale) + log_base_partition_function(alpha) nll = loss + log_partition return nll
def _buckets(data, bucket_count=None): """Create a TensorFlow op to group data into histogram buckets. Arguments: data: A `Tensor` of any shape. Must be castable to `float64`. bucket_count: Optional positive `int` or scalar `int32` `Tensor`. Returns: A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is a triple `[left_edge, right_edge, count]` for a single bucket. The value of `k` is either `bucket_count` or `1` or `0`. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if bucket_count is None: bucket_count = summary_v2.DEFAULT_BUCKET_COUNT with tf.name_scope('buckets', values=[data, bucket_count]), \ tf.control_dependencies([tf.assert_scalar(bucket_count), tf.assert_type(bucket_count, tf.int32)]): data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) is_empty = tf.equal(tf.size(input=data), 0) def when_empty(): return tf.constant([], shape=(0, 3), dtype=tf.float64) def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ is_singular = tf.equal(range_, 0) def when_nonsingular(): #这个函数将输入的值封装到bucket里面。以后需要该这个函数 bucket_num = tf.minimum(bucket_count, data.size()[0]) left_edges = tf.linspace(0, bucket_num - 1, bucket_num) right_edges = tf.linspace(1, bucket_num, bucket_num) return tf.transpose(a=tf.stack( [left_edges, right_edges, data[0:bucket_num]])) #stack是拼装矩阵。这里将这三个矩阵顺次拼装 def when_singular(): bucket_num = tf.minimum(bucket_count, data.size()[0]) left_edges = tf.linspace(0, bucket_num - 1, bucket_num) right_edges = tf.linspace(1, bucket_num, bucket_num) return tf.transpose( a=tf.stack([left_edges, right_edges, data[0:bucket_num]])) return tf.cond(is_singular, when_singular, when_nonsingular) return tf.cond(is_empty, when_empty, when_nonempty)
def op(name, data, display_name=None, description=None, collections=None): """Create a legacy text summary op. Text data summarized via this plugin will be visible in the Text Dashboard in TensorBoard. The standard TensorBoard Text Dashboard will render markdown in the strings, and will automatically organize 1D and 2D tensors into tables. If a tensor with more than 2 dimensions is provided, a 2D subarray will be displayed along with a warning message. (Note that this behavior is not intrinsic to the text summary API, but rather to the default TensorBoard text plugin.) Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. data: A string-type Tensor to summarize. The text must be encoded in UTF-8. display_name: Optional name for this summary in TensorBoard, as a constant `str`. Defaults to `name`. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. collections: Optional list of ops.GraphKeys. The collections to which to add the summary. Defaults to [Graph Keys.SUMMARIES]. Returns: A TensorSummary op that is configured so that TensorBoard will recognize that it contains textual data. The TensorSummary is a scalar `Tensor` of type `string` which contains `Summary` protobufs. Raises: ValueError: If tensor has the wrong type. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if display_name is None: display_name = name summary_metadata = metadata.create_summary_metadata( display_name=display_name, description=description ) with tf.name_scope(name): with tf.control_dependencies([tf.assert_type(data, tf.string)]): return tf.summary.tensor_summary( name="text_summary", tensor=data, collections=collections, summary_metadata=summary_metadata, )
def interpolate1d(x, values, tangents): r"""Perform cubic hermite spline interpolation on a 1D spline. The x coordinates of the spline knots are at [0 : 1 : len(values)-1]. Queries outside of the range of the spline are computed using linear extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline for details, where "x" corresponds to `x`, "p" corresponds to `values`, and "m" corresponds to `tangents`. Args: x: A tensor of any size of single or double precision floats containing the set of values to be used for interpolation into the spline. values: A vector of single or double precision floats containing the value of each knot of the spline being interpolated into. Must be the same length as `tangents` and the same type as `x`. tangents: A vector of single or double precision floats containing the tangent (derivative) of each knot of the spline being interpolated into. Must be the same length as `values` and the same type as `x`. Returns: The result of interpolating along the spline defined by `values`, and `tangents`, using `x` as the query values. Will be the same length and type as `x`. """ # `values` and `tangents` must have the same type as `x`. tf.assert_type(values, x.dtype) tf.assert_type(tangents, x.dtype) float_dtype = x.dtype assert_ops = [ # `values` must be a vector. tf.Assert(tf.equal(tf.rank(values), 1), [tf.shape(values)]), # `tangents` must be a vector. tf.Assert(tf.equal(tf.rank(tangents), 1), [tf.shape(values)]), # `values` and `tangents` must have the same length. tf.Assert( tf.equal(tf.shape(values)[0], tf.shape(tangents)[0]), [tf.shape(values)[0], tf.shape(tangents)[0]]), ] with tf.control_dependencies(assert_ops): # Find the indices of the knots below and above each x. x_lo = tf.cast( tf.floor( tf.clip_by_value(x, 0., tf.cast( tf.shape(values)[0] - 2, float_dtype))), tf.int32) x_hi = x_lo + 1 # Compute the relative distance between each `x` and the knot below it. t = x - tf.cast(x_lo, float_dtype) # Compute the cubic hermite expansion of `t`. t_sq = tf.square(t) t_cu = t * t_sq h01 = -2. * t_cu + 3. * t_sq h00 = 1. - h01 h11 = t_cu - t_sq h10 = h11 - t_sq + t # Linearly extrapolate above and below the extents of the spline for all # values. value_before = tangents[0] * t + values[0] value_after = tangents[-1] * (t - 1.) + values[-1] # Cubically interpolate between the knots below and above each query point. neighbor_values_lo = tf.gather(values, x_lo) neighbor_values_hi = tf.gather(values, x_hi) neighbor_tangents_lo = tf.gather(tangents, x_lo) neighbor_tangents_hi = tf.gather(tangents, x_hi) value_mid = ( neighbor_values_lo * h00 + neighbor_values_hi * h01 + neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11) # Return the interpolated or extrapolated values for each query point, # depending on whether or not the query lies within the span of the spline. return tf.where(t < 0., value_before, tf.where(t > 1., value_after, value_mid))
def op( name, labels, predictions, num_thresholds=None, weights=None, display_name=None, description=None, collections=None, ): """Create a PR curve summary op for a single binary classifier. Computes true/false positive/negative values for the given `predictions` against the ground truth `labels`, against a list of evenly distributed threshold values in `[0, 1]` of length `num_thresholds`. Each number in `predictions`, a float in `[0, 1]`, is compared with its corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn value at each threshold. This is then multiplied with `weights` which can be used to reweight certain values, or more commonly used for masking values. Args: name: A tag attached to the summary. Used by TensorBoard for organization. labels: The ground truth values. A Tensor of `bool` values with arbitrary shape. predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. Dimensions must match those of `labels`. num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for. Should be `>= 2`. This value should be a constant integer value, not a Tensor that stores an integer. weights: Optional float32 `Tensor`. Individual counts are multiplied by this value. This tensor must be either the same shape as or broadcastable to the `labels` tensor. display_name: Optional name for this summary in TensorBoard, as a constant `str`. Defaults to `name`. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[Graph Keys.SUMMARIES]`. Returns: A summary operation for use in a TensorFlow graph. The float32 tensor produced by the summary operation is of dimension (6, num_thresholds). The first dimension (of length 6) is of the order: true positives, false positives, true negatives, false negatives, precision, recall. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if num_thresholds is None: num_thresholds = _DEFAULT_NUM_THRESHOLDS if weights is None: weights = 1.0 dtype = predictions.dtype with tf.name_scope(name, values=[labels, predictions, weights]): tf.assert_type(labels, tf.bool) # We cast to float to ensure we have 0.0 or 1.0. f_labels = tf.cast(labels, dtype) # Ensure predictions are all in range [0.0, 1.0]. predictions = tf.minimum(1.0, tf.maximum(0.0, predictions)) # Get weighted true/false labels. true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Before we begin, flatten predictions. predictions = tf.reshape(predictions, [-1]) # Shape the labels so they are broadcast-able for later multiplication. true_labels = tf.reshape(true_labels, [-1, 1]) false_labels = tf.reshape(false_labels, [-1, 1]) # To compute TP/FP/TN/FN, we are measuring a binary classifier # C(t) = (predictions >= t) # at each threshold 't'. So we have # TP(t) = sum( C(t) * true_labels ) # FP(t) = sum( C(t) * false_labels ) # # But, computing C(t) requires computation for each t. To make it fast, # observe that C(t) is a cumulative integral, and so if we have # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} # where n = num_thresholds, and if we can compute the bucket function # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) # then we get # C(t_i) = sum( B(j), j >= i ) # which is the reversed cumulative sum in tf.cumsum(). # # We can compute B(i) efficiently by taking advantage of the fact that # our thresholds are evenly distributed, in that # width = 1.0 / (num_thresholds - 1) # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] # Given a prediction value p, we can map it to its bucket by # bucket_index(p) = floor( p * (num_thresholds - 1) ) # so we can use tf.scatter_add() to update the buckets in one pass. # Compute the bucket indices for each prediction value. bucket_indices = tf.cast( tf.floor(predictions * (num_thresholds - 1)), tf.int32 ) # Bucket predictions. tp_buckets = tf.reduce_sum( input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) * true_labels, axis=0, ) fp_buckets = tf.reduce_sum( input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) * false_labels, axis=0, ) # Set up the cumulative sums to compute the actual metrics. tp = tf.cumsum(tp_buckets, reverse=True, name="tp") fp = tf.cumsum(fp_buckets, reverse=True, name="fp") # fn = sum(true_labels) - tp # = sum(tp_buckets) - tp # = tp[0] - tp # Similarly, # tn = fp[0] - fp tn = fp[0] - fp fn = tp[0] - tp precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) return _create_tensor_summary( name, tp, fp, tn, fn, precision, recall, num_thresholds, display_name, description, collections, )
def draw_samples(alpha, scale): r"""Draw samples from the robust distribution. This function implements Algorithm 1 the paper. This code is written to allow for sampling from a set of different distributions, each parametrized by its own alpha and scale values, as opposed to the more standard approach of drawing N samples from the same distribution. This is done by repeatedly performing N instances of rejection sampling for each of the N distributions until at least one proposal for each of the N distributions has been accepted. All samples are drawn with a zero mean, to use a non-zero mean just add each mean to each sample. Args: alpha: A TF tensor/scalar or numpy array/scalar of floats where each element is the shape parameter of that element's distribution. scale: A TF tensor/scalar or numpy array/scalar of floats where each element is the scale parameter of that element's distribution. Must be the same shape as `alpha`. Returns: A TF tensor with the same shape and precision as `alpha` and `scale` where each element is a sample drawn from the distribution specified for that element by `alpha` and `scale`. """ # `scale` must have the same type as `alpha`. float_dtype = alpha.dtype tf.assert_type(scale, float_dtype) assert_ops = [ # `scale` must be > 0. tf.Assert(tf.reduce_all(scale > 0.), [scale]), # `alpha` must be >= 0. tf.Assert(tf.reduce_all(alpha >= 0.), [alpha]), # `alpha` and `scale` must have the same shape. tf.Assert(tf.reduce_all(tf.equal(tf.shape(alpha), tf.shape(scale))), [tf.shape(alpha), tf.shape(scale)]), ] with tf.control_dependencies(assert_ops): shape = tf.shape(alpha) # The distributions we will need for rejection sampling. The sqrt(2) scaling # of the Cauchy distribution corrects for our differing conventions for # standardization. cauchy = tfp.distributions.Cauchy(loc=0., scale=tf.sqrt(2.)) uniform = tfp.distributions.Uniform(low=0., high=1.) def while_cond(_, accepted): """Terminate the loop only when all samples have been accepted.""" return ~tf.reduce_all(accepted) def while_body(samples, accepted): """Generate N proposal samples, and then perform rejection sampling.""" # Draw N samples from a Cauchy, our proposal distribution. cauchy_sample = tf.cast(cauchy.sample(shape), float_dtype) # Compute the likelihood of each sample under its target distribution. nll = nllfun(cauchy_sample, alpha, tf.cast(1, float_dtype)) # Bound the NLL. We don't use the approximate loss as it may cause # unpredictable behavior in the context of sampling. nll_bound = general.lossfun( cauchy_sample, tf.cast(0, float_dtype), tf.cast(1, float_dtype), approximate=False) + log_base_partition_function(alpha) # Draw N samples from a uniform distribution, and use each uniform sample # to decide whether or not to accept each proposal sample. uniform_sample = tf.cast(uniform.sample(shape), float_dtype) accept = uniform_sample <= tf.math.exp(nll_bound - nll) # If a sample is accepted, replace its element in `samples` with the # proposal sample, and set its bit in `accepted` to True. samples = tf.where(accept, cauchy_sample, samples) accepted = accept | accepted return (samples, accepted) # Initialize the loop. The first item does not matter as it will get # overwritten, the second item must be all False. while_loop_vars = (tf.zeros(shape, float_dtype), tf.zeros(shape, dtype=bool)) # Perform rejection sampling until all N samples have been accepted. terminal_state = tf.while_loop(cond=while_cond, body=while_body, loop_vars=while_loop_vars) # Because our distribution is a location-scale family, we sample from # p(x | 0, \alpha, 1) and then scale each sample by `scale`. samples = tf.multiply(terminal_state[0], scale) return samples
def lossfun(x, alpha, scale, approximate=False, epsilon=1e-6): r"""Implements the general form of the loss. This implements the rho(x, \alpha, c) function described in "A General and Adaptive Robust Loss Function", Jonathan T. Barron, https://arxiv.org/abs/1701.03077. Args: x: The residual for which the loss is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. Must be a tensorflow tensor or numpy array of floats. alpha: The shape parameter of the loss (\alpha in the paper), where more negative values produce a loss with more robust behavior (outliers "cost" less), and more positive values produce a loss with less robust behavior (outliers are penalized more heavily). Alpha can be any value in [-infinity, infinity], but the gradient of the loss with respect to alpha is 0 at -infinity, infinity, 0, and 2. Must be a tensorflow tensor or numpy array of floats with the same precision as `x`. Varying alpha allows for smooth interpolation between a number of discrete robust losses: alpha=-Infinity: Welsch/Leclerc Loss. alpha=-2: Geman-McClure loss. alpha=0: Cauchy/Lortentzian loss. alpha=1: Charbonnier/pseudo-Huber loss. alpha=2: L2 loss. scale: The scale parameter of the loss. When |x| < scale, the loss is an L2-like quadratic bowl, and when |x| > scale the loss function takes on a different shape according to alpha. Must be a tensorflow tensor or numpy array of single-precision floats. approximate: a bool, where if True, this function returns an approximate and faster form of the loss, as described in the appendix of the paper. This approximation holds well everywhere except as x and alpha approach zero. epsilon: A float that determines how inaccurate the "approximate" version of the loss will be. Larger values are less accurate but more numerically stable. Must be great than single-precision machine epsilon. Returns: The losses for each element of x, in the same shape as x. This is returned as a TensorFlow graph node of single precision floats. """ # `scale` and `alpha` must have the same type as `x`. tf.assert_type(scale, x.dtype) tf.assert_type(alpha, x.dtype) float_dtype = x.dtype # `scale` must be > 0. assert_ops = [tf.Assert(tf.reduce_all(tf.greater(scale, 0.)), [scale])] with tf.control_dependencies(assert_ops): # Broadcast `alpha` and `scale` to have the same shape as `x`. alpha = tf.broadcast_to(alpha, tf.shape(x)) scale = tf.broadcast_to(scale, tf.shape(x)) if approximate: # `epsilon` must be greater than single-precision machine epsilon. assert epsilon > np.finfo(np.float32).eps # Compute an approximate form of the loss which is faster, but innacurate # when x and alpha are near zero. b = tf.abs(alpha - tf.cast(2., float_dtype)) + epsilon d = tf.where(tf.greater_equal(alpha, 0.), alpha + epsilon, alpha - epsilon) loss = (b / d) * (tf.pow(tf.square(x / scale) / b + 1., 0.5 * d) - 1.) else: # Compute the exact loss. # This will be used repeatedly. squared_scaled_x = tf.square(x / scale) # The loss when alpha == 2. loss_two = 0.5 * squared_scaled_x # The loss when alpha == 0. loss_zero = util.log1p_safe(0.5 * squared_scaled_x) # The loss when alpha == -infinity. loss_neginf = -tf.math.expm1(-0.5 * squared_scaled_x) # The loss when alpha == +infinity. loss_posinf = util.expm1_safe(0.5 * squared_scaled_x) # The loss when not in one of the above special cases. machine_epsilon = tf.cast(np.finfo(np.float32).eps, float_dtype) # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. beta_safe = tf.maximum(machine_epsilon, tf.abs(alpha - 2.)) # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. alpha_safe = tf.where(tf.greater_equal(alpha, 0.), tf.ones_like(alpha), -tf.ones_like(alpha)) * tf.maximum( machine_epsilon, tf.abs(alpha)) loss_otherwise = (beta_safe / alpha_safe) * ( tf.pow(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.) # Select which of the cases of the loss to return. loss = tf.where( tf.equal(alpha, -tf.cast(float('inf'), float_dtype)), loss_neginf, tf.where( tf.equal(alpha, 0.), loss_zero, tf.where( tf.equal(alpha, 2.), loss_two, tf.where( tf.equal(alpha, tf.cast(float('inf'), float_dtype)), loss_posinf, loss_otherwise)))) return loss