Example #1
def kgcnn_ops_tensor_scatter_nd_by_name(segment_name,
    """Scatter operation chosen by name that can replace segment-operations.

        segment_name (str): Operation to update scattered updates. Either 'sum' or 'min' etc.
        tensor (tf.tensor): Tensor to scatter updates into.
        indices (tf.tensor): Indices to for updates.
        updates (tf.tensor): Updates of new entries for tensor.
        name (str): Name of the tensor.

        tf.tensor: Updates scattered into tensor with different update rules.
    pool = None
    if segment_name in ["segment_mean", "mean", "reduce_mean"]:
        pool = tensor_scatter_nd_mean(tensor, indices, updates, name=name)
    elif segment_name in ["segment_sum", "sum", "reduce_sum"]:
        pool = tf.tensor_scatter_nd_add(tensor, indices, updates, name=name)
    elif segment_name in ["segment_max", "max", "reduce_max"]:
        pool = tf.tensor_scatter_nd_max(tensor, indices, updates, name=name)
    elif segment_name in ["segment_min", "sum", "reduce_min"]:
        pool = tf.tensor_scatter_nd_min(tensor, indices, updates, name=name)
        raise TypeError("Unknown pooling, choose: 'mean', 'sum', ...")
    return pool
Example #2
def build_grid(indexes, truths, preds, ind_mask, update=False, grid=None):
    """This function is used to broadcast elements into the output shape.

  This function is used to broadcasts a list of truths into the correct index
  in the output shape. This is used for the ground truth map construction in
  the scaled loss and the classification map in the darknet loss.

    indexes: A `Tensor` for the indexes
    truths: A `Tensor` for the ground truth.
    preds: A `Tensor` for the predictions.
    ind_mask: A `Tensor` for the index masks.
    update: A `bool` for updating the grid.
    grid: A `Tensor` for the grid.

    grid: A `Tensor` representing the augmented grid.
    # this function is used to broadcast all the indexes to the correct
    # into the correct ground truth mask, used for iou detection map
    # in the scaled loss and the classification mask in the darknet loss
    num_flatten = tf.shape(preds)[-1]

    # is there a way to verify that we are not on the CPU?
    ind_mask = tf.cast(ind_mask, indexes.dtype)

    # find all the batch indexes using the cumulated sum of a ones tensor
    # cumsum(ones) - 1 yeild the zero indexed batches
    bhep = tf.reduce_max(tf.ones_like(indexes), axis=-1, keepdims=True)
    bhep = tf.math.cumsum(bhep, axis=0) - 1

    # concatnate the batch sizes to the indexes
    indexes = tf.concat([bhep, indexes], axis=-1)
    indexes = apply_mask(tf.cast(ind_mask, indexes.dtype), indexes)
    indexes = (indexes + (ind_mask - 1))

    # reshape the indexes into the correct shape for the loss,
    # just flatten all indexes but the last
    indexes = tf.reshape(indexes, [-1, 4])

    # also flatten the ground truth value on all axis but the last
    truths = tf.reshape(truths, [-1, num_flatten])

    # build a zero grid in the samve shape as the predicitons
    if grid is None:
        grid = tf.zeros_like(preds)
    # remove invalid values from the truths that may have
    # come up from computation, invalid = nan and inf
    truths = math_ops.rm_nan_inf(truths)

    # scatter update the zero grid
    if update:
        grid = tf.tensor_scatter_nd_update(grid, indexes, truths)
        grid = tf.tensor_scatter_nd_max(grid, indexes, truths)

    # stop gradient and return to avoid TPU errors and save compute
    # resources
    return grid
Example #3
    def _loss_semantic_segmentation(self, pred_seg, mask_gt, classes):
        # Note num_classes here is without the background class so cfg.num_classes-1
        batch_size = tf.shape(pred_seg)[0]
        mask_h = tf.shape(pred_seg)[1]
        mask_w = tf.shape(pred_seg)[2]
        num_classes = tf.shape(pred_seg)[3]
        loss_s = 0.0

        for i in range(batch_size):
            cur_segment = pred_seg[i]
            cur_class_gt = classes[i]
            masks = mask_gt[i]

            masks = tf.expand_dims(masks, axis=-1)
            masks = tf.image.resize(masks, [mask_h, mask_w],
            masks = tf.cast(masks + 0.5, tf.int64)
            masks = tf.squeeze(tf.cast(masks, tf.float32))

            segment_gt = tf.zeros(
                (mask_h, mask_w, num_classes +
                 1))  # [height, width, num_cls]; num_cls including background
            segment_gt = tf.transpose(segment_gt, perm=(2, 0, 1))

            obj_cls = tf.expand_dims(cur_class_gt, axis=-1)
            segment_gt = tf.tensor_scatter_nd_max(segment_gt,
            segment_gt = tf.transpose(segment_gt, perm=(1, 2, 0))
            loss_s += tf.reduce_sum(
                    segment_gt[:, :, 1:],
                    cur_segment))  #exclude background from segment_gt

        loss_s /= (tf.cast(mask_h, tf.float32) * tf.cast(mask_w, tf.float32))
        return loss_s / tf.cast(batch_size, tf.float32)
    def _encode_centers_and_offets(self, instance_mask):
        """Generates center heatmaps and offets from instance id mask.

      instance_mask: `tf.Tensor` of shape [height, width] representing
        groundtruth instance id mask.
      instance_centers_heatmap: `tf.Tensor` of shape [height, width, 1]
      instance_centers_offset: `tf.Tensor` of shape [height, width, 2]
        shape = tf.shape(instance_mask)
        height, width = shape[0], shape[1]

        padding_start = int(3 * self._sigma + 1)
        padding_end = int(3 * self._sigma + 2)

        # padding should be equal to self._gaussian_size which is calculated
        # as size = int(6 * sigma + 3)
        padding = padding_start + padding_end

        instance_centers_heatmap = tf.zeros(
            shape=[height + padding, width + padding], dtype=tf.float32)
        centers_offset_y = tf.zeros(shape=[height, width], dtype=tf.float32)
        centers_offset_x = tf.zeros(shape=[height, width], dtype=tf.float32)
        semantic_weights = tf.ones(shape=[height, width], dtype=tf.float32)

        unique_instance_ids, _ = tf.unique(tf.reshape(instance_mask, [-1]))

        # The following method for encoding center heatmaps and offets is inspired
        # by the reference implementation available at
        # https://github.com/google-research/deeplab2/blob/main/data/sample_generator.py  # pylint: disable=line-too-long
        for instance_id in unique_instance_ids:
            if instance_id == self._ignore_label:

            mask = tf.equal(instance_mask, instance_id)
            mask_area = tf.reduce_sum(tf.cast(mask, dtype=tf.float32))
            mask_indices = tf.cast(tf.where(mask), dtype=tf.float32)
            mask_center = tf.reduce_mean(mask_indices, axis=0)
            mask_center_y = tf.cast(tf.round(mask_center[0]), dtype=tf.int32)
            mask_center_x = tf.cast(tf.round(mask_center[1]), dtype=tf.int32)

            if mask_area < self._small_instance_area_threshold:
                semantic_weights = tf.where(mask, self._small_instance_weight,

            gaussian_size = self._gaussian_size
            indices_y = tf.range(mask_center_y, mask_center_y + gaussian_size)
            indices_x = tf.range(mask_center_x, mask_center_x + gaussian_size)

            indices = tf.stack(tf.meshgrid(indices_y, indices_x))
            indices = tf.reshape(indices,
                                 shape=[2, gaussian_size * gaussian_size])
            indices = tf.transpose(indices)

            instance_centers_heatmap = tf.tensor_scatter_nd_max(

            centers_offset_y = tf.tensor_scatter_nd_update(
                indices=tf.cast(mask_indices, dtype=tf.int32),
                updates=tf.cast(mask_center_y, dtype=tf.float32) -
                mask_indices[:, 0])

            centers_offset_x = tf.tensor_scatter_nd_update(
                indices=tf.cast(mask_indices, dtype=tf.int32),
                updates=tf.cast(mask_center_x, dtype=tf.float32) -
                mask_indices[:, 1])

        instance_centers_heatmap = instance_centers_heatmap[
            padding_start:padding_start + height,
            padding_start:padding_start + width]
        instance_centers_heatmap = tf.expand_dims(instance_centers_heatmap,

        instance_centers_offset = tf.stack(
            [centers_offset_y, centers_offset_x], axis=-1)

        return (instance_centers_heatmap, instance_centers_offset,