コード例 #1
0
    def test_unpool_random(self, num_vertices, num_features):
        """Tests pooling with random data inputs."""
        output_vertices = num_vertices // 2
        pool_map = np.zeros(shape=(output_vertices, num_vertices),
                            dtype=np.float32)
        for i in range(output_vertices):
            pool_map[i, (i * 2, i * 2 + 1)] = (0.5, 0.5)
        data = np.random.uniform(size=(output_vertices,
                                       num_features)).astype(np.float32)

        unpooled = gp.unpool(data,
                             _dense_to_sparse(pool_map),
                             sizes=None,
                             name=None)

        with self.subTest(name='direct_unpool'):
            true = np.zeros(shape=(num_vertices,
                                   num_features)).astype(np.float32)
            true[0::2, :] = data
            true[1::2, :] = data
            self.assertAllClose(unpooled, true)

        with self.subTest(name='permute_pool_map'):
            permutation = np.random.permutation(num_vertices)
            pool_map_permute = pool_map[:, permutation]
            unpooled_permute = gp.unpool(data,
                                         _dense_to_sparse(pool_map_permute),
                                         None)
            true_permute = true[permutation, :]
            self.assertAllClose(unpooled_permute, true_permute)
コード例 #2
0
    def test_unpool_exception_raised_types(self, err_msg, data_type,
                                           pool_map_type, sizes_type):
        """Tests the correct exceptions are raised for invalid types."""
        data = np.ones((2, 3, 3), dtype=data_type)
        pool_map = _dense_to_sparse(np.ones((2, 3, 3), dtype=pool_map_type))
        sizes = np.array(((1, 2), (2, 3)), dtype=sizes_type)

        with self.assertRaisesRegexp(TypeError, err_msg):
            gp.unpool(data, pool_map, sizes)
コード例 #3
0
    def test_unpool_exception_raised_shapes(self, err_msg, data_shape,
                                            pool_map_shape, sizes_shape):
        """Tests the correct exceptions are raised for invalid shapes."""
        data = np.ones(data_shape, dtype=np.float32)
        pool_map = _dense_to_sparse(np.ones(pool_map_shape, dtype=np.float32))
        if sizes_shape is not None:
            sizes = np.ones(sizes_shape, dtype=np.int32)
        else:
            sizes = None

        with self.assertRaisesRegexp(ValueError, err_msg):
            gp.unpool(data, pool_map, sizes)
コード例 #4
0
  def test_unpool_identity(self, batch_shape, num_vertices, num_features,
                           data_type):
    """Tests graph unpooling with identity maps."""
    data_shape = np.concatenate((batch_shape, (num_vertices, num_features)))
    data = np.random.uniform(size=data_shape).astype(data_type)
    pool_map = _batch_sparse_eye(batch_shape, num_vertices, data_type)

    unpooled = gp.unpool(data, pool_map, sizes=None)
    self.assertAllClose(unpooled, data)
コード例 #5
0
    def test_unpool_jacobian_random(self):
        """Tests the jacobian is correct."""
        sizes = ((2, 4), (3, 5))
        data_init = np.random.uniform(size=(2, 3, 6))
        pool_map = np.random.uniform(size=(2, 3, 5))
        data_init[0, -1, :] = 0.
        pool_map[0, -1, :] = 0.
        pool_map = _dense_to_sparse(pool_map)
        data = tf.convert_to_tensor(value=data_init)

        unpooled = gp.unpool(data, pool_map, sizes)

        self.assert_jacobian_is_correct(data, data_init, unpooled)
コード例 #6
0
  def test_unpool_preset_padded(self):
    """Tests pooling with preset data and padding."""
    data = np.reshape(np.arange(12).astype(np.float32), (2, 3, 2))
    data[0, -1, :] = 0.
    sizes = ((2, 3), (3, 3))
    pool_map = _dense_to_sparse(
        np.array((((0.5, 0.5, 0.), (0., 0., 1.), (0., 0., 0.)),
                  ((1., 0., 0.), (0., 1., 0.), (0., 0., 1.))),
                 dtype=np.float32))

    unpooled = gp.unpool(data, pool_map, sizes)

    true = (((0., 1.), (0., 1.), (2., 3.)), ((6., 7.), (8., 9.), (10., 11.)))
    self.assertAllClose(unpooled, true)
コード例 #7
0
 def gp_unpool(data):
   return gp.unpool(data, pool_map, sizes)