Пример #1
0
    def test_check_valid_graph_pooling_exception_raised_types(
            self, err_msg, data_type, pool_map_type, sizes_type):
        """Check the type errors for invalid input types."""
        data = tf.convert_to_tensor(value=np.ones((2, 3, 3), dtype=data_type))
        pool_map = _dense_to_sparse(np.ones((2, 3, 3), dtype=pool_map_type))
        sizes = tf.convert_to_tensor(
            value=np.array(((1, 2), (2, 3)), dtype=sizes_type))

        with self.assertRaisesRegexp(TypeError, err_msg):
            utils.check_valid_graph_pooling_input(data, pool_map, sizes)
Пример #2
0
def pool(data, pool_map, sizes, algorithm='max', name=None):
  #  pyformat: disable
  """Implements graph pooling.

  The features at each output vertex are computed by pooling over a subset of
  vertices in the input graph. This pooling window is specified by the input
  `pool_map`.

  The shorthands used below are
    `V1`: The number of vertices in the input data.
    `V2`: The number of vertices in the pooled output data.
    `C`: The number of channels in the data.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    data: A `float` tensor with shape `[A1, ..., An, V1, C]`.
    pool_map: A `SparseTensor` with the same type as `data` and with shape
      `[A1, ..., An, V2, V1]`. The features for an output vertex `v2` will be
      computed by pooling over the corresponding input vertices specified by
      the entries in `pool_map[A1, ..., An, v2, :]`.
    sizes: An `int` tensor of shape `[A1, ..., An, 2]` indicating the true
      input sizes in case of padding (`sizes=None` indicates no padding).
      `sizes[A1, ..., An, 0] <= V2` specifies the padding in the (pooled)
      output, and `sizes[A1, ..., An, 1] <= V1` specifies the padding in the
      input.
    algorithm: The pooling function, must be either 'max' or 'weighted'. Default
      is 'max'. For 'max' pooling, the output features are the maximum over the
      input vertices (in this case only the indices of the `SparseTensor`
      `pool_map` are used, the values are ignored). For 'weighted', the output
      features are a weighted sum of the input vertices, the weights specified
      by the values of `pool_map`.
    name: A name for this op. Defaults to 'graph_pooling_pool'.

  Returns:
    Tensor with shape `[A1, ..., An, V2, C]`.

  Raises:
    TypeError: if the input types are invalid.
    ValueError: if the input dimensions are invalid.
    ValueError: if `algorithm` is invalid.
  """
  #  pyformat: enable
  with tf.compat.v1.name_scope(
      name, 'graph_pooling_pool', [data, pool_map, sizes]):
    data = tf.convert_to_tensor(value=data)
    pool_map = tf.compat.v1.convert_to_tensor_or_sparse_tensor(value=pool_map)
    if sizes is not None:
      sizes = tf.convert_to_tensor(value=sizes)
    utils.check_valid_graph_pooling_input(data, pool_map, sizes)

    if sizes is not None:
      sizes_output, sizes_input = tf.split(sizes, 2, axis=-1)
      sizes_output = tf.squeeze(sizes_output, axis=-1)
      sizes_input = tf.squeeze(sizes_input, axis=-1)
    else:
      sizes_output = None
      sizes_input = None

    batched = data.shape.ndims > 2
    if batched:
      x_flat, _ = utils.flatten_batch_to_2d(data, sizes_input)
      pool_map_block_diagonal = utils.convert_to_block_diag_2d(pool_map, sizes)
    else:
      x_flat = data
      pool_map_block_diagonal = pool_map

    if algorithm == 'weighted':
      pooled = tf.sparse.sparse_dense_matmul(pool_map_block_diagonal, x_flat)
    elif algorithm == 'max':
      pool_groups = tf.gather(x_flat, pool_map_block_diagonal.indices[:, 1])
      pooled = tf.math.segment_max(
          data=pool_groups, segment_ids=pool_map_block_diagonal.indices[:, 0])
    else:
      raise ValueError('The pooling method must be "weighted" or "max"')

    if batched:
      if sizes_output is not None:
        pooled = utils.unflatten_2d_to_batch(pooled, sizes_output)
      else:
        output_shape = tf.concat((tf.shape(input=pool_map)[:-1], (-1,)), axis=0)
        pooled = tf.reshape(pooled, output_shape)

    return pooled