def testError(self,
                descr,
                mode,
                data,
                repeats,
                axis,
                exception=ValueError,
                error=None):
    # Make sure that this is also an error case for numpy.
    with self.assertRaises(exception):
      np.repeat(data, repeats, axis)

    if mode == 'constant':
      data = constant_op.constant(data)
      repeats = constant_op.constant(repeats)
    elif mode == 'dynamic':
      data = constant_op.constant(data)
      repeats = constant_op.constant(repeats)
      data = array_ops.placeholder_with_default(data, data.shape)
      repeats = array_ops.placeholder_with_default(repeats, repeats.shape)
    elif mode == 'unknown_shape':
      data = array_ops.placeholder_with_default(data, None)
      repeats = array_ops.placeholder_with_default(repeats, None)

    with self.assertRaisesRegexp(exception, error):
      ragged_util.repeat(data, repeats, axis)
    def testError(self,
                  descr,
                  mode,
                  data,
                  repeats,
                  axis,
                  exception=ValueError,
                  error=None):
        # Make sure that this is also an error case for numpy.
        with self.assertRaises(exception):
            np.repeat(data, repeats, axis)

        if mode == 'constant':
            data = constant_op.constant(data)
            repeats = constant_op.constant(repeats)
        elif mode == 'dynamic':
            data = constant_op.constant(data)
            repeats = constant_op.constant(repeats)
            data = array_ops.placeholder_with_default(data, data.shape)
            repeats = array_ops.placeholder_with_default(
                repeats, repeats.shape)
        elif mode == 'unknown_shape':
            data = array_ops.placeholder_with_default(data, None)
            repeats = array_ops.placeholder_with_default(repeats, None)

        with self.assertRaisesRegex(exception, error):
            ragged_util.repeat(data, repeats, axis)
Beispiel #3
0
def combine_graphs(batch):
    """Combine multiple graphs into a single network"""

    # Compute the mappings from bond index to graph index
    batch_size = tf.size(batch['n_atom'], name='batch_size')
    mol_id = tf.range(batch_size, name='mol_inds')
    batch['node_graph_indices'] = repeat(mol_id, batch['n_atom'], axis=0)
    batch['bond_graph_indices'] = repeat(mol_id, batch['n_bond'], axis=0)

    # Reshape the bond, connectivity, and node lists
    for c in ['atom', 'bond', 'connectivity']:
        batch[c] = batch[c].flat_values

    # Reshape the connectivity matrix to (None, 2)
    batch['connectivity'] = tf.reshape(batch['connectivity'], (-1, 2))

    # Denote the shapes for the atom and bond matrices
    #  Only an issue for 1.14, which cannot infer them it seems
    for c in ['atom', 'bond']:
        batch[c].set_shape((None,))

    # Compute offsets for the connectivity matrix
    offset_values = tf.cumsum(batch['n_atom'], exclusive=True)
    offsets = repeat(offset_values, batch['n_bond'], name='offsets', axis=0)
    batch['connectivity'] += tf.expand_dims(offsets, 1)

    return batch
Beispiel #4
0
def repeat(values, repeats, axis):
    """See https://github.com/tensorflow/tensorflow/issues/8246."""
    try:
        repeat = tf.repeat
    except AttributeError:
        from tensorflow.python.ops.ragged.ragged_util import repeat  # pylint: disable=import-error
    if values.shape.ndims == 1:
        return repeat(values, repeats, axis=axis)
    else:
        indices = repeat(tf.range(tf.shape(values)[axis]), repeats, axis=0)
        return tf.gather(values, indices, axis=axis)
def row_splits_to_segment_ids(splits, name=None):
    """Generates the segmentation corresponding to a RaggedTensor `splits` vector.

  Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if
  `splits[j] <= i < splits[j+1]`.  Example:

  ```python
  >>> ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9]).eval()
  [ 0 0 0 2 2 3 4 4 4 ]
  ```

  Args:
    splits: A sorted 1-D int64 Tensor.  `splits[0]` must be zero.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A sorted 1-D int64 Tensor, with `shape=[splits[-1]]`

  Raises:
    ValueError: If `splits` is invalid.
  """
    with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name:
        splits = ops.convert_to_tensor(splits,
                                       dtype=dtypes.int64,
                                       name="splits")
        splits.shape.assert_has_rank(1)
        if tensor_shape.dimension_value(splits.shape[0]) == 0:
            raise ValueError("Invalid row_splits: []")
        row_lengths = splits[1:] - splits[:-1]
        nrows = array_ops.shape(splits, out_type=dtypes.int64)[-1] - 1
        indices = math_ops.range(nrows)
        return ragged_util.repeat(indices, repeats=row_lengths, axis=0)
Beispiel #6
0
def row_splits_to_segment_ids(splits, name=None):
  """Generates the segmentation corresponding to a RaggedTensor `row_splits`.

  Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if
  `splits[j] <= i < splits[j+1]`.  Example:

  ```python
  >>> ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9]).eval()
  [ 0 0 0 2 2 3 4 4 4 ]
  ```

  Args:
    splits: A sorted 1-D int64 Tensor.  `splits[0]` must be zero.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A sorted 1-D int64 Tensor, with `shape=[splits[-1]]`

  Raises:
    ValueError: If `splits` is invalid.
  """
  with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name:
    splits = ops.convert_to_tensor(splits, dtype=dtypes.int64, name="splits")
    splits.shape.assert_has_rank(1)
    if tensor_shape.dimension_value(splits.shape[0]) == 0:
      raise ValueError("Invalid row_splits: []")
    row_lengths = splits[1:] - splits[:-1]
    nrows = array_ops.shape(splits, out_type=dtypes.int64)[-1] - 1
    indices = math_ops.range(nrows)
    return ragged_util.repeat(indices, repeats=row_lengths, axis=0)
Beispiel #7
0
def _expand_and_tile(global_features, row_splits_or_k):
    if not isinstance(row_splits_or_k, tf.Tensor) or (
            row_splits_or_k.shape.ndims == 0):
        # knn
        raise NotImplementedError('TODO')
    else:
        from tensorflow.python.ops.ragged.ragged_util import repeat
        return repeat(global_features, row_splits_or_k, axis=0)
Beispiel #8
0
def tf_repeat(values, repeats):
    values = tf.convert_to_tensor(values)
    repeats = tf.convert_to_tensor(repeats)

    if values.shape.ndims != 1:
        raise ValueError("values must be rank 1, got shape %s" % values.shape)
    if repeats.shape.ndims != 1:
        raise ValueError("repeats must be rank 1, got shape %s" % repeats.shape)
    if not repeats.dtype.is_integer:
        raise ValueError("repeats must be an integer, got %s" % repeats.dtype)

    try:
        from tensorflow.python.ops.ragged.ragged_util import repeat

        return repeat(values, repeats, axis=0)
    except ImportError:
        return _foldl_repeat(values, repeats)
Beispiel #9
0
  def testValuesMatchesNumpy(self, mode, data, repeats, axis):
    # Exception: we can't handle negative axis if data.ndims is unknown.
    if axis < 0 and mode == 'unknown_shape':
      return

    expected = np.repeat(data, repeats, axis)

    if mode == 'constant':
      data = constant_op.constant(data)
      repeats = constant_op.constant(repeats)
    elif mode == 'dynamic':
      data = constant_op.constant(data)
      repeats = constant_op.constant(repeats)
      data = array_ops.placeholder_with_default(data, data.shape)
      repeats = array_ops.placeholder_with_default(repeats, repeats.shape)
    elif mode == 'unknown_shape':
      data = array_ops.placeholder_with_default(data, None)
      repeats = array_ops.placeholder_with_default(repeats, None)

    result = ragged_util.repeat(data, repeats, axis)
    self.assertAllEqual(result, expected)
Beispiel #10
0
  def testValuesMatchesNumpy(self, mode, data, repeats, axis):
    # Exception: we can't handle negative axis if data.ndims is unknown.
    if axis < 0 and mode == 'unknown_shape':
      return

    expected = np.repeat(data, repeats, axis)

    if mode == 'constant':
      data = constant_op.constant(data)
      repeats = constant_op.constant(repeats)
    elif mode == 'dynamic':
      data = constant_op.constant(data)
      repeats = constant_op.constant(repeats)
      data = array_ops.placeholder_with_default(data, data.shape)
      repeats = array_ops.placeholder_with_default(repeats, repeats.shape)
    elif mode == 'unknown_shape':
      data = array_ops.placeholder_with_default(data, None)
      repeats = array_ops.placeholder_with_default(repeats, None)

    result = ragged_util.repeat(data, repeats, axis)
    self.assertAllEqual(result, expected)
Beispiel #11
0
def flat_expanding_global_deconv(
        global_features, coord_features, row_splits_or_k):
    """
    Global deconvolution operation.

    Args:
        global_features: [pi, fi]
        coord_features: [po, fk]
        row_splits_or_k: [pi+1]

    Returns:
        convolved features: [po, fi*fk]
    """
    from tensorflow.python.ops.ragged.ragged_util import repeat
    if row_splits_or_k.shape.ndims == 0:
        raise NotImplementedError

    global_features = repeat(
        global_features, utils.diff(row_splits_or_k), axis=0)
    merged = utils.merge_features(global_features, coord_features)
    merged = utils.flatten_final_dims(merged, 2)
    return merged
Beispiel #12
0
def row_splits_to_segment_ids(splits, name=None, out_type=None):
    """Generates the segmentation corresponding to a RaggedTensor `row_splits`.

  Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if
  `splits[j] <= i < splits[j+1]`.  Example:

  >>> print(tf.ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9]))
   tf.Tensor([0 0 0 2 2 3 4 4 4], shape=(9,), dtype=int64)

  Args:
    splits: A sorted 1-D integer Tensor.  `splits[0]` must be zero.
    name: A name prefix for the returned tensor (optional).
    out_type: The dtype for the return value.  Defaults to `splits.dtype`,
      or `tf.int64` if `splits` does not have a dtype.

  Returns:
    A sorted 1-D integer Tensor, with `shape=[splits[-1]]`

  Raises:
    ValueError: If `splits` is invalid.
  """
    with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name:
        splits = ops.convert_to_tensor(splits,
                                       name="splits",
                                       preferred_dtype=dtypes.int64)
        if splits.dtype not in (dtypes.int32, dtypes.int64):
            raise ValueError("splits must have dtype int32 or int64")
        splits.shape.assert_has_rank(1)
        if tensor_shape.dimension_value(splits.shape[0]) == 0:
            raise ValueError("Invalid row_splits: []")
        if out_type is None:
            out_type = splits.dtype
        else:
            out_type = dtypes.as_dtype(out_type)
        row_lengths = splits[1:] - splits[:-1]
        nrows = array_ops.shape(splits, out_type=out_type)[-1] - 1
        indices = math_ops.range(nrows)
        return ragged_util.repeat(indices, repeats=row_lengths, axis=0)
 def testRepeat(self, data, repeats, expected, axis=None):
     result = ragged_util.repeat(data, repeats, axis)
     self.assertAllEqual(result, expected)
Beispiel #14
0
 def testRepeat(self, data, repeats, expected, axis=None):
   result = ragged_util.repeat(data, repeats, axis)
   self.assertAllEqual(result, expected)
Beispiel #15
0
 def _class_maps_to_colored_maps(self, class_map):
     gray_gts = tf.argmax(class_map, axis=3,
                          output_type=tf.int32) * (255 // self.n_classes)
     temp = tf.expand_dims(gray_gts, 3)
     rgb_gt = repeat(temp, 3, axis=3)
     return tf.cast(rgb_gt, dtype=tf.uint8)
 def testRepeat(self, data, repeats, expected, axis=None):
   result = ragged_util.repeat(data, repeats, axis)
   with self.test_session():
     self.assertEqual(result.eval().tolist(), expected)
Beispiel #17
0
def _repeat(args, axis=0):
    from tensorflow.python.ops.ragged.ragged_util import repeat
    data, repeats = args
    return repeat(data, repeats, axis=axis)
def batch_gather(params, indices, name=None):
  """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
    >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
    >>> tf.compat.v1.batch_gather(params, indices)
    [['b', 'c', 'a'], [], [], ['e', 'e']]
    ```
  """
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.batch_gather(params, indices, name)

  with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
    indices_ndims = indices.shape.ndims
    if indices_ndims is None:
      raise ValueError(
          'batch_gather does not allow indices with unknown shape.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')

    if ragged_tensor.is_ragged(indices):
      # If the outermost ragged dimension is a batch dimension, recurse.
      if indices_ndims > 2:
        if not ragged_tensor.is_ragged(params):
          raise ValueError('batch shape from indices does '
                           'not match params shape')
        checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)]
        with ops.control_dependencies(checks):
          return ragged_tensor.RaggedTensor.from_row_splits(
              batch_gather(params.values, indices.values), indices.row_splits,
              validate=False)

      # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
      else:
        # Ensure that `params` is ragged and has at least 2 dimensions.
        if not ragged_tensor.is_ragged(params):
          if params.shape.ndims is not None and params.shape.ndims < 2:
            raise ValueError('batch shape from indices does '
                             'not match params shape')
          params = ragged_tensor.RaggedTensor.from_tensor(
              params, ragged_rank=1,
              row_splits_dtype=indices.row_splits.dtype)

        # Adjust indices from within-batch to global (in params.values), and
        # then use ragged.gather to gather them.
        num_indices = indices.row_lengths()
        params_starts = params.row_starts()
        adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
        adjusted_index_values = (
            math_ops.cast(indices.values, adjustments.dtype) + adjustments)
        return ragged_tensor.RaggedTensor.from_row_splits(
            ragged_gather_ops.gather(params.values, adjusted_index_values),
            indices.row_splits, validate=False)

    else:  # params is a RaggedTensor and indices is a Tensor.
      if indices_ndims == 1:
        return ragged_gather_ops.gather(params, indices)
      elif indices_ndims == 2:
        # Adjust indices from batch-local to global (in params.values)
        adjustments = array_ops.expand_dims(params.row_starts(), 1)
        adjusted_indices = (
            math_ops.cast(indices, adjustments.dtype) + adjustments)
        return ragged_gather_ops.gather(params.values, adjusted_indices)
      else:
        raise ValueError('batch shape from indices does not match params shape')
Beispiel #19
0
def batch_gather(params, indices, name=None):
  """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
    >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
    >>> ragged.batch_gather(params, indices)
    [['b', 'c', 'a'], [], [], ['e', 'e']]
    ```
  """
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.batch_gather(params, indices, name)

  with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    indices_ndims = indices.shape.ndims
    if indices_ndims is None:
      raise ValueError(
          'batch_gather does not allow indices with unknown shape.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')

    if ragged_tensor.is_ragged(indices):
      # If the outermost ragged dimension is a batch dimension, recurse.
      if indices_ndims > 2:
        if not ragged_tensor.is_ragged(params):
          raise ValueError('batch shape from indices does '
                           'not match params shape')
        checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)]
        with ops.control_dependencies(checks):
          return ragged_tensor.RaggedTensor.from_row_splits(
              batch_gather(params.values, indices.values), indices.row_splits)

      # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
      else:
        # Ensure that `params` is ragged and has at least 2 dimensions.
        if not ragged_tensor.is_ragged(params):
          if params.shape.ndims is not None and params.shape.ndims < 2:
            raise ValueError('batch shape from indices does '
                             'not match params shape')
          params = ragged_conversion_ops.from_tensor(params, ragged_rank=1)

        # Adjust indices from within-batch to global (in params.values), and
        # then use ragged.gather to gather them.
        num_indices = indices.row_lengths()
        params_starts = params.row_starts()
        adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
        adjusted_index_values = math_ops.to_int64(indices.values) + adjustments
        return ragged_tensor.RaggedTensor.from_row_splits(
            gather(params.values, adjusted_index_values), indices.row_splits)

    else:  # params is a RaggedTensor and indices is a Tensor.
      if indices_ndims == 1:
        return gather(params, indices)
      elif indices_ndims == 2:
        # Adjust indices from batch-local to global (in params.values)
        adjustments = array_ops.expand_dims(params.row_starts(), 1)
        adjusted_indices = math_ops.to_int64(indices) + adjustments
        return gather(params.values, adjusted_indices)
      else:
        raise ValueError('batch shape from indices does not match params shape')
Beispiel #20
0
def segmentation_logits(inputs,
                        output_spec,
                        num_obj_classes=16,
                        r0=0.1,
                        initial_filters=(16, ),
                        initial_activation=seg_activation,
                        filters=(32, 64, 128, 256),
                        global_units='combined',
                        query_fn=core.query_pairs,
                        radii_fn=core.constant_radii,
                        global_deconv_all=False,
                        coords_transform=None,
                        weights_transform=None,
                        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=seg_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:

        def weights_transform(*args, **kwargs):
            return None

    coords = inputs['positions']
    normals = inputs.get('normals')

    if normals is None:
        raise NotImplementedError()
    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        features = tf.ragged.map_flat_values(core_layers.Dense(f), features)
        features = tf.ragged.map_flat_values(initial_activation, features)
    assert (isinstance(features, tf.RaggedTensor)
            and features.ragged_rank == 1)

    class_embeddings = _get_class_embeddings(
        b.as_batched_model_input(inputs['obj_label']), num_obj_classes,
        [initial_filters[-1], filters[0]])

    features = core.add_local_global(features, class_embeddings[0])

    input_row_splits = features.row_splits
    features = utils.flatten_leading_dims(features, 2)

    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert (unscaled_radii2.shape == (n_res, ))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(tf.unstack,
                                        arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar('r%d' % i,
                                        tf.sqrt(radius2),
                                        family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        is_tensor_or_var = isinstance(r2, (tf.Tensor, tf.Variable))
        if is_tensor_or_var:
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    pp_radii2 = [maybe_feed(r2) for r2 in radii2]

    all_features = []
    in_place_neighborhoods = []
    sampled_neighborhoods = []
    global_features = []
    # encoder
    for i, (radius2, pp_radius2) in enumerate(zip(radii2, pp_radii2)):
        neighbors, sample_rate = query_fn(coords,
                                          pp_radius2,
                                          name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        in_place_neighborhoods.append(neighborhood)
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.in_place_conv)

        all_features.append(features)

        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(
                convolver.global_conv(features, coord_features,
                                      nested_row_splits[-2], filters[i]))
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)

        # resample
        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            sampled_neighborhoods.append(neighborhood)

            features, nested_row_splits = core.convolve(
                features, radius2, filters[i + 1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units is not None:
        row_splits = nested_row_splits[-2]
        if global_units == 'combined':
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)
        else:
            coord_features = coords_transform(coords, None)
            global_features = convolver.global_conv(features, coord_features,
                                                    row_splits, global_units)

        coord_features = coords_transform(coords, None)
        features = convolver.global_deconv(global_features, coord_features,
                                           row_splits, filters[-1])

    # decoder
    for i in range(n_res - 1, -1, -1):
        if i < n_res - 1:
            # up-sample
            neighborhood = sampled_neighborhoods.pop().transpose
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i], neighborhood, coords_transform,
                weights_transform, convolver.resample_conv)
            if global_deconv_all:
                coords = neighborhood.out_coords
                row_splits = \
                    neighborhood.offset_batched_neighbors.nested_row_splits[-2]
                coord_features = coords_transform(coords)
                deconv_features = convolver.global_deconv(
                    global_features, coord_features, row_splits, filters[i])
                features = tf.keras.layers.Add()([features, deconv_features])

        forward_features = all_features.pop()
        if not (i == n_res - 1 and global_units is None):
            features = tf.keras.layers.Lambda(tf.concat,
                                              arguments=dict(axis=-1))(
                                                  [features, forward_features])
        neighborhood = in_place_neighborhoods.pop().transpose
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.resample_conv)

    features = tf.RaggedTensor.from_row_splits(features, input_row_splits)
    features = core.add_local_global(features, class_embeddings[-1])
    logits = tf.ragged.map_flat_values(
        core_layers.Dense(output_spec.shape[-1]), features)

    valid_classes_mask = inputs.get('valid_classes_mask')
    if valid_classes_mask is not None:
        row_lengths = utils.diff(logits.row_splits)
        valid_classes_mask = b.as_batched_model_input(valid_classes_mask)
        valid_classes_mask = repeat(valid_classes_mask, row_lengths, axis=0)

        def flat_fn(flat_logits):
            neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits)
            return utils.lambda_call(tf.where, valid_classes_mask, flat_logits,
                                     neg_inf)

        logits = tf.ragged.map_flat_values(flat_fn, logits)
    return logits
Beispiel #21
0
 def testRepeat(self, data, repeats, expected, axis=None):
     result = ragged_util.repeat(data, repeats, axis)
     with self.test_session():
         self.assertEqual(result.eval().tolist(), expected)