def euclidean_dist_transform( images: TensorLike, dtype: Type[tf.dtypes.DType] = tf.float32, name: Optional[str] = None, ) -> tf.Tensor: """Applies euclidean distance transform(s) to the image(s). Args: images: A tensor of shape `(num_images, num_rows, num_columns, 1)` (NHWC), or `(num_rows, num_columns, 1)` (HWC) or `(num_rows, num_columns)` (HW). dtype: `tf.dtypes.DType` of the output tensor. name: The name of the op. Returns: Image(s) with the type `dtype` and same shape as `images`, with the transform applied. If a tensor of all ones is given as input, the output tensor will be filled with the max value of the `dtype`. Raises: TypeError: If `image` is not tf.uint8, or `dtype` is not floating point. ValueError: If `image` more than one channel, or `image` is not of rank between 2 and 4. """ with tf.name_scope(name or "euclidean_distance_transform"): image_or_images = tf.convert_to_tensor(images, name="images") if image_or_images.dtype.base_dtype != tf.uint8: raise TypeError("Invalid dtype %s. Expected uint8." % image_or_images.dtype) images = img_utils.to_4D_image(image_or_images) original_ndims = img_utils.get_ndims(image_or_images) if images.get_shape()[3] != 1 and images.get_shape()[3] is not None: raise ValueError("`images` must have only one channel") if dtype not in [tf.float16, tf.float32, tf.float64]: raise TypeError("`dtype` must be float16, float32 or float64") images = tf.cast(images, dtype) output = _image_so.ops.addons_euclidean_distance_transform(images) return img_utils.from_4D_image(output, original_ndims)
def sparsemax(logits: types.TensorLike, axis: int = -1) -> tf.Tensor: r"""Sparsemax activation function. For each batch $i$, and class $j$, compute sparsemax activation function: $$ \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0). $$ See [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068). Usage: >>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]]) >>> tfa.activations.sparsemax(x) <tf.Tensor: shape=(2, 3), dtype=float32, numpy= array([[0., 0., 1.], [0., 0., 1.]], dtype=float32)> Args: logits: A `Tensor`. axis: `int`, axis along which the sparsemax operation is applied. Returns: A `Tensor`, output of sparsemax transformation. Has the same type and shape as `logits`. Raises: ValueError: In case `dim(logits) == 1`. """ logits = tf.convert_to_tensor(logits, name="logits") # We need its original shape for shape inference. shape = logits.get_shape() rank = shape.rank is_last_axis = (axis == -1) or (axis == rank - 1) if is_last_axis: output = _compute_2d_sparsemax(logits) output.set_shape(shape) return output # If dim is not the last dimension, we have to do a transpose so that we can # still perform softmax on its last dimension. # Swap logits' dimension of dim and its last dimension. rank_op = tf.rank(logits) axis_norm = axis % rank logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1)) # Do the actual softmax on its last dimension. output = _compute_2d_sparsemax(logits) output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1)) # Make shape inference work since transpose may erase its static shape. output.set_shape(shape) return output
def hamming_loss_fn( y_true: TensorLike, y_pred: TensorLike, threshold: Union[FloatTensorLike, None], mode: str, ) -> tf.Tensor: """Computes hamming loss. Hamming loss is the fraction of wrong labels to the total number of labels. In multi-class classification, hamming loss is calculated as the hamming distance between `y_true` and `y_pred`. In multi-label classification, hamming loss penalizes only the individual labels. Args: y_true: actual target value. y_pred: predicted target value. threshold: Elements of `y_pred` greater than threshold are converted to be 1, and the rest 0. If threshold is None, the argmax is converted to 1, and the rest 0. mode: multi-class or multi-label. Returns: hamming loss: float. """ if mode not in ["multiclass", "multilabel"]: raise TypeError("mode must be either multiclass or multilabel]") if threshold is None: threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True) # make sure [0, 0, 0] doesn't become [1, 1, 1] # Use abs(x) > eps, instead of x != 0 to check for zero y_pred = tf.logical_and(y_pred >= threshold, tf.abs(y_pred) > 1e-12) else: y_pred = y_pred > threshold y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred, tf.int32) if mode == "multiclass": nonzero = tf.cast(tf.math.count_nonzero(y_true * y_pred, axis=-1), tf.float32) return 1.0 - nonzero else: nonzero = tf.cast(tf.math.count_nonzero(y_true - y_pred, axis=-1), tf.float32) return nonzero / y_true.get_shape()[-1]
def sparsemax(logits: types.TensorLike, axis: int = -1) -> tf.Tensor: """Sparsemax activation function [1]. For each batch `i` and class `j` we have $$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$ [1]: https://arxiv.org/abs/1602.02068 Args: logits: Input tensor. axis: Integer, axis along which the sparsemax operation is applied. Returns: Tensor, output of sparsemax transformation. Has the same type and shape as `logits`. Raises: ValueError: In case `dim(logits) == 1`. """ logits = tf.convert_to_tensor(logits, name="logits") # We need its original shape for shape inference. shape = logits.get_shape() rank = shape.rank is_last_axis = (axis == -1) or (axis == rank - 1) if is_last_axis: output = _compute_2d_sparsemax(logits) output.set_shape(shape) return output # If dim is not the last dimension, we have to do a transpose so that we can # still perform softmax on its last dimension. # Swap logits' dimension of dim and its last dimension. rank_op = tf.rank(logits) axis_norm = axis % rank logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1)) # Do the actual softmax on its last dimension. output = _compute_2d_sparsemax(logits) output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1)) # Make shape inference work since transpose may erase its static shape. output.set_shape(shape) return output
def transform( images: TensorLike, transforms: TensorLike, interpolation: str = "NEAREST", output_shape: Optional[list] = None, name: Optional[str] = None, ) -> tf.Tensor: """Applies the given transform(s) to the image(s). Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or (num_rows, num_columns) (HW). transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". output_shape: Output dimesion after the transform, [height, width]. If None, output is the same size as input image. name: The name of the op. Returns: Image(s) with the same type and shape as `images`, with the given transform(s) applied. Transformed coordinates outside of the input image will be filled with zeros. Raises: TypeError: If `image` is an invalid type. ValueError: If output shape is not 1-D int32 Tensor. """ with tf.name_scope(name or "transform"): image_or_images = tf.convert_to_tensor(images, name="images") transform_or_transforms = tf.convert_to_tensor(transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) images = img_utils.to_4D_image(image_or_images) original_ndims = img_utils.get_ndims(image_or_images) if output_shape is None: output_shape = tf.shape(images)[1:3] output_shape = tf.convert_to_tensor(output_shape, tf.dtypes.int32, name="output_shape") if not output_shape.get_shape().is_compatible_with([2]): raise ValueError( "output_shape must be a 1-D Tensor of 2 elements: " "new_height, new_width") if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: raise ValueError("transforms rank must be statically known") elif len(transform_or_transforms.get_shape()) == 2: transforms = transform_or_transforms else: transforms = transform_or_transforms raise ValueError( "transforms should have rank 1 or 2, but got rank %d" % len(transforms.get_shape())) output = tf.raw_ops.ImageProjectiveTransformV2( images=images, transforms=transforms, output_shape=output_shape, interpolation=interpolation.upper(), ) return img_utils.from_4D_image(output, original_ndims)
def transform( images: TensorLike, transforms: TensorLike, interpolation: str = "nearest", fill_mode: str = "constant", output_shape: Optional[list] = None, name: Optional[str] = None, fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Applies the given transform(s) to the image(s). Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or (num_rows, num_columns) (HW). transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "nearest", "bilinear". fill_mode: Points outside the boundaries of the input are filled according to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - *reflect*: `(d c b a | a b c d | d c b a)` The input is extended by reflecting about the edge of the last pixel. - *constant*: `(k k k k | a b c d | k k k k)` The input is extended by filling all values beyond the edge with the same constant value k = 0. - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by wrapping around to the opposite edge. - *nearest*: `(a a a a | a b c d | d d d d)` The input is extended by the nearest pixel. fill_value: a float represents the value to be filled outside the boundaries when `fill_mode` is "constant". output_shape: Output dimesion after the transform, [height, width]. If None, output is the same size as input image. name: The name of the op. Returns: Image(s) with the same type and shape as `images`, with the given transform(s) applied. Transformed coordinates outside of the input image will be filled with zeros. Raises: TypeError: If `image` is an invalid type. ValueError: If output shape is not 1-D int32 Tensor. """ with tf.name_scope(name or "transform"): image_or_images = tf.convert_to_tensor(images, name="images") transform_or_transforms = tf.convert_to_tensor(transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) images = img_utils.to_4D_image(image_or_images) original_ndims = img_utils.get_ndims(image_or_images) if output_shape is None: output_shape = tf.shape(images)[1:3] output_shape = tf.convert_to_tensor(output_shape, tf.dtypes.int32, name="output_shape") if not output_shape.get_shape().is_compatible_with([2]): raise ValueError( "output_shape must be a 1-D Tensor of 2 elements: " "new_height, new_width") if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: raise ValueError("transforms rank must be statically known") elif len(transform_or_transforms.get_shape()) == 2: transforms = transform_or_transforms else: transforms = transform_or_transforms raise ValueError( "transforms should have rank 1 or 2, but got rank %d" % len(transforms.get_shape())) if LooseVersion(tf.__version__) >= LooseVersion("2.4.0"): fill_value = tf.convert_to_tensor(fill_value, dtype=tf.float32, name="fill_value") output = tf.raw_ops.ImageProjectiveTransformV3( images=images, transforms=transforms, output_shape=output_shape, interpolation=interpolation.upper(), fill_mode=fill_mode.upper(), fill_value=fill_value, ) else: fill_mode = fill_mode.upper() # TODO(WindQAQ): Get rid of the check once we drop TensorFlow < 2.4 support. if fill_mode == "CONSTANT": warnings.warn( "fill_value is not supported and is always 0 for TensorFlow < 2.4.0." ) if fill_mode == "NEAREST": raise ValueError( "NEAREST fill_mode is not supported for TensorFlow < 2.4.0." ) output = tf.raw_ops.ImageProjectiveTransformV2( images=images, transforms=transforms, output_shape=output_shape, interpolation=interpolation.upper(), fill_mode=fill_mode, ) return img_utils.from_4D_image(output, original_ndims)
def sparse_image_warp( image: TensorLike, source_control_point_locations: TensorLike, dest_control_point_locations: TensorLike, interpolation_order: int = 2, regularization_weight: FloatTensorLike = 0.0, num_boundary_points: int = 0, name: str = "sparse_image_warp", ) -> tf.Tensor: """Image warping using correspondences between sparse control points. Apply a non-linear warp to the image, where the warp is specified by the source and destination locations of a (potentially small) number of control points. First, we use a polyharmonic spline (`tfa.image.interpolate_spline`) to interpolate the displacements between the corresponding control points to a dense flow field. Then, we warp the image using this dense flow field (`tfa.image.dense_image_warp`). Let t index our control points. For `regularization_weight = 0`, we have: warped_image[b, dest_control_point_locations[b, t, 0], dest_control_point_locations[b, t, 1], :] = image[b, source_control_point_locations[b, t, 0], source_control_point_locations[b, t, 1], :]. For `regularization_weight > 0`, this condition is met approximately, since regularized interpolation trades off smoothness of the interpolant vs. reconstruction of the interpolant at the control points. See `tfa.image.interpolate_spline` for further documentation of the `interpolation_order` and `regularization_weight` arguments. Args: image: `[batch, height, width, channels]` float `Tensor` source_control_point_locations: `[batch, num_control_points, 2]` float `Tensor` dest_control_point_locations: `[batch, num_control_points, 2]` float `Tensor` interpolation_order: polynomial order used by the spline interpolation regularization_weight: weight on smoothness regularizer in interpolation num_boundary_points: How many zero-flow boundary points to include at each image edge. Usage: `num_boundary_points=0`: don't add zero-flow points `num_boundary_points=1`: 4 corners of the image `num_boundary_points=2`: 4 corners and one in the middle of each edge (8 points total) `num_boundary_points=n`: 4 corners and n-1 along each edge name: A name for the operation (optional). Note that image and offsets can be of type tf.half, tf.float32, or tf.float64, and do not necessarily have to be the same type. Returns: warped_image: `[batch, height, width, channels]` float `Tensor` with same type as input image. flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense flow field produced by the interpolation. """ image = tf.convert_to_tensor(image) source_control_point_locations = tf.convert_to_tensor( source_control_point_locations) dest_control_point_locations = tf.convert_to_tensor( dest_control_point_locations) control_point_flows = dest_control_point_locations - source_control_point_locations clamp_boundaries = num_boundary_points > 0 boundary_points_per_edge = num_boundary_points - 1 with tf.name_scope(name or "sparse_image_warp"): batch_size, image_height, image_width, _ = image.get_shape().as_list() # This generates the dense locations where the interpolant # will be evaluated. grid_locations = _get_grid_locations(image_height, image_width) flattened_grid_locations = np.reshape(grid_locations, [image_height * image_width, 2]) flattened_grid_locations = tf.constant( _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) if clamp_boundaries: ( dest_control_point_locations, control_point_flows, ) = _add_zero_flow_controls_at_boundary( dest_control_point_locations, control_point_flows, image_height, image_width, boundary_points_per_edge, ) flattened_flows = interpolate_spline( dest_control_point_locations, control_point_flows, flattened_grid_locations, interpolation_order, regularization_weight, ) dense_flows = tf.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) warped_image = dense_image_warp(image, dense_flows) return warped_image, dense_flows
def hamming_loss_fn( y_true: TensorLike, y_pred: TensorLike, threshold: Union[FloatTensorLike, None], mode: str, ) -> tf.Tensor: """Computes hamming loss. Hamming loss is the fraction of wrong labels to the total number of labels. In multi-class classification, hamming loss is calculated as the hamming distance between `actual` and `predictions`. In multi-label classification, hamming loss penalizes only the individual labels. Args: y_true: actual target value y_pred: predicted target value threshold: Elements of `y_pred` greater than threshold are converted to be 1, and the rest 0. If threshold is None, the argmax is converted to 1, and the rest 0. mode: multi-class or multi-label Returns: hamming loss: float Usage: >>> # multi-class hamming loss >>> hl = HammingLoss(mode='multiclass', threshold=0.6) >>> actuals = tf.constant([[1, 0, 0, 0],[0, 0, 1, 0], ... [0, 0, 0, 1],[0, 1, 0, 0]], dtype=tf.float32) >>> predictions = tf.constant([[0.8, 0.1, 0.1, 0], ... [0.2, 0, 0.8, 0],[0.05, 0.05, 0.1, 0.8],[1, 0, 0, 0]], ... dtype=tf.float32) >>> hl.update_state(actuals, predictions) <tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=4.0> >>> hl.result().numpy() 0.25 >>> # multi-label hamming loss >>> hl = HammingLoss(mode='multilabel', threshold=0.8) >>> actuals = tf.constant([[1, 0, 1, 0],[0, 1, 0, 1], ... [0, 0, 0,1]], dtype=tf.int32) >>> predictions = tf.constant([[0.82, 0.5, 0.90, 0], ... [0, 1, 0.4, 0.98],[0.89, 0.79, 0, 0.3]],dtype=tf.float32) >>> hl.update_state(actuals, predictions) <tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=3.0> >>> hl.result().numpy() 0.16666667 """ if mode not in ["multiclass", "multilabel"]: raise TypeError("mode must be either multiclass or multilabel]") if threshold is None: threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True) # make sure [0, 0, 0] doesn't become [1, 1, 1] # Use abs(x) > eps, instead of x != 0 to check for zero y_pred = tf.logical_and(y_pred >= threshold, tf.abs(y_pred) > 1e-12) else: y_pred = y_pred > threshold y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred, tf.int32) if mode == "multiclass": nonzero = tf.cast(tf.math.count_nonzero(y_true * y_pred, axis=-1), tf.float32) return 1.0 - nonzero else: nonzero = tf.cast(tf.math.count_nonzero(y_true - y_pred, axis=-1), tf.float32) return nonzero / y_true.get_shape()[-1]
def sequence_loss( logits: TensorLike, targets: TensorLike, weights: TensorLike, average_across_timesteps: bool = True, average_across_batch: bool = True, sum_over_timesteps: bool = False, sum_over_batch: bool = False, softmax_loss_function: Optional[Callable] = None, name: Optional[str] = None, ) -> tf.Tensor: """Weighted cross-entropy loss for a sequence of logits. Depending on the values of `average_across_timesteps` / `sum_over_timesteps` and `average_across_batch` / `sum_over_batch`, the return Tensor will have rank 0, 1, or 2 as these arguments reduce the cross-entropy at each target, which has shape `[batch_size, sequence_length]`, over their respective dimensions. For example, if `average_across_timesteps` is `True` and `average_across_batch` is `False`, then the return Tensor will have shape `[batch_size]`. Note that `average_across_timesteps` and `sum_over_timesteps` cannot be True at same time. Same for `average_across_batch` and `sum_over_batch`. The recommended loss reduction in tf 2.0 has been changed to sum_over, instead of weighted average. User are recommend to use `sum_over_timesteps` and `sum_over_batch` for reduction. Args: logits: A Tensor of shape `[batch_size, sequence_length, num_decoder_symbols]` and dtype float. The logits correspond to the prediction across all classes at each timestep. targets: A Tensor of shape `[batch_size, sequence_length]` and dtype int. The target represents the true class at each timestep. weights: A Tensor of shape `[batch_size, sequence_length]` and dtype float. `weights` constitutes the weighting of each prediction in the sequence. When using `weights` as masking, set all valid timesteps to 1 and all padded timesteps to 0, e.g. a mask returned by `tf.sequence_mask`. average_across_timesteps: If set, sum the cost across the sequence dimension and divide the cost by the total label weight across timesteps. average_across_batch: If set, sum the cost across the batch dimension and divide the returned cost by the batch size. sum_over_timesteps: If set, sum the cost across the sequence dimension and divide the size of the sequence. Note that any element with 0 weights will be excluded from size calculation. sum_over_batch: if set, sum the cost across the batch dimension and divide the total cost by the batch size. Not that any element with 0 weights will be excluded from size calculation. softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). **Note that to avoid confusion, it is required for the function to accept named arguments.** name: Optional name for this operation, defaults to "sequence_loss". Returns: A float Tensor of rank 0, 1, or 2 depending on the `average_across_timesteps` and `average_across_batch` arguments. By default, it has rank 0 (scalar) and is the weighted average cross-entropy (log-perplexity) per symbol. Raises: ValueError: logits does not have 3 dimensions or targets does not have 2 dimensions or weights does not have 2 dimensions. """ if len(logits.get_shape()) != 3: raise ValueError( "Logits must be a " "[batch_size x sequence_length x logits] tensor" ) targets_rank = len(targets.get_shape()) if targets_rank != 2 and targets_rank != 3: raise ValueError( "Targets must be either a [batch_size x sequence_length] tensor " + "where each element contains the labels' index" + "or a [batch_size x sequence_length x num_classes] tensor " + "where the third axis is a one-hot representation of the labels" ) if len(weights.get_shape()) != 2: raise ValueError("Weights must be a [batch_size x sequence_length] tensor") if average_across_timesteps and sum_over_timesteps: raise ValueError( "average_across_timesteps and sum_over_timesteps cannot " "be set to True at same time." ) if average_across_batch and sum_over_batch: raise ValueError( "average_across_batch and sum_over_batch cannot be set " "to True at same time." ) if average_across_batch and sum_over_timesteps: raise ValueError( "average_across_batch and sum_over_timesteps cannot be set " "to True at same time because of ambiguous order." ) if sum_over_batch and average_across_timesteps: raise ValueError( "sum_over_batch and average_across_timesteps cannot be set " "to True at same time because of ambiguous order." ) with tf.name_scope(name or "sequence_loss"): num_classes = tf.shape(input=logits)[2] logits_flat = tf.reshape(logits, [-1, num_classes]) if softmax_loss_function is None: if targets_rank == 2: targets = tf.reshape(targets, [-1]) crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=targets, logits=logits_flat ) else: targets = tf.reshape(targets, [-1, num_classes]) crossent = tf.nn.softmax_cross_entropy_with_logits( labels=targets, logits=logits_flat ) else: targets = tf.reshape(targets, [-1]) crossent = softmax_loss_function(labels=targets, logits=logits_flat) crossent *= tf.reshape(weights, [-1]) if average_across_timesteps and average_across_batch: crossent = tf.reduce_sum(input_tensor=crossent) total_size = tf.reduce_sum(input_tensor=weights) crossent = tf.math.divide_no_nan(crossent, total_size) elif sum_over_timesteps and sum_over_batch: crossent = tf.reduce_sum(input_tensor=crossent) total_count = tf.cast(tf.math.count_nonzero(weights), crossent.dtype) crossent = tf.math.divide_no_nan(crossent, total_count) else: crossent = tf.reshape(crossent, tf.shape(input=logits)[0:2]) if average_across_timesteps or average_across_batch: reduce_axis = [0] if average_across_batch else [1] crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis) total_size = tf.reduce_sum(input_tensor=weights, axis=reduce_axis) crossent = tf.math.divide_no_nan(crossent, total_size) elif sum_over_timesteps or sum_over_batch: reduce_axis = [0] if sum_over_batch else [1] crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis) total_count = tf.cast( tf.math.count_nonzero(weights, axis=reduce_axis), dtype=crossent.dtype, ) crossent = tf.math.divide_no_nan(crossent, total_count) return crossent