def testSampleFeatureOnlyExtractionWithNoNeighbors(self): """Test sample feature extraction without neighbor features.""" # Simulate batch size of 1. features = { 'F0': tf.constant([[1.0, 2.0]]), 'F1': tf.constant([[3.0, 4.0, 5.0]]), } expected_sample_features = { 'F0': tf.constant([[1.0, 2.0]]), 'F1': tf.constant([[3.0, 4.0, 5.0]]), } neighbor_config = configs.GraphNeighborConfig(max_neighbors=0) sample_features, nbr_features, nbr_weights = utils.unpack_neighbor_features( features, neighbor_config) self.assertIsNone(nbr_weights) with self.cached_session() as sess: sess.run([sample_features, nbr_features]) self.assertAllEqual(sample_features['F0'], expected_sample_features['F0']) self.assertAllEqual(sample_features['F1'], expected_sample_features['F1']) self.assertEmpty(nbr_features)
def testExtraNeighborFeaturesIgnored(self): """Test that extra neighbor features are ignored.""" # Simulate a batch size of 1 for simplicity. features = { 'F0': tf.constant([[1.0, 2.0]]), 'NL_nbr_0_F0': tf.constant([[1.1, 2.1]]), 'NL_nbr_0_weight': tf.constant([[0.25]]), 'NL_nbr_1_F0': tf.constant([[1.2, 2.2]]), 'NL_nbr_1_weight': tf.constant([[0.75]]), } expected_sample_features = { 'F0': tf.constant([[1.0, 2.0]]), } expected_neighbor_features = { 'F0': tf.constant([[1.1, 2.1]]), } expected_neighbor_weights = tf.constant([[0.25]]) neighbor_config = configs.GraphNeighborConfig(max_neighbors=1) sample_features, nbr_features, nbr_weights = self.evaluate( utils.unpack_neighbor_features(features, neighbor_config)) self.assertAllEqual(sample_features['F0'], expected_sample_features['F0']) self.assertAllEqual(nbr_features['F0'], expected_neighbor_features['F0']) self.assertAllEqual(nbr_weights, expected_neighbor_weights)
def __init__(self, neighbor_config=None, feature_names=None, weight_dtype=None, **kwargs): """Initializes an instance of `NeighborFeatures`. Args: neighbor_config: A `configs.GraphNeighborConfig` instance describing neighbor attributes. feature_names: Optional[List[Text]], names denoting the keys of features for which to create neighbor inputs. If `None`, all features are assumed to have corresponding neighbor features. weight_dtype: `tf.DType` for `neighbor_weights`. Defaults to `tf.float32`. **kwargs: Additional arguments to be passed `tf.keras.layers.Layer`. """ super(NeighborFeatures, self).__init__( autocast=False, dtype=kwargs.pop('dtype') if 'dtype' in kwargs else weight_dtype, **kwargs) self._neighbor_config = ( configs.GraphNeighborConfig() if neighbor_config is None else attr.evolve(neighbor_config)) self._feature_names = ( feature_names if feature_names is None else set(feature_names))
def testSampleFeatureOnlyExtractionWithNeighbors(self): """Test sample feature extraction with neighbor features.""" # Simulate batch size of 1. features = { 'F0': tf.constant([[1.0, 2.0]]), 'F1': tf.constant([[3.0, 4.0, 5.0]]), 'NL_nbr_0_F0': tf.constant([[1.1, 2.1]]), 'NL_nbr_0_F1': tf.constant([[3.1, 4.1, 5.1]]), 'NL_nbr_0_weight': tf.constant([[0.25]]), 'NL_nbr_1_F0': tf.constant([[1.2, 2.2]]), 'NL_nbr_1_F1': tf.constant([[3.2, 4.2, 5.2]]), 'NL_nbr_1_weight': tf.constant([[0.75]]), } expected_sample_features = { 'F0': tf.constant([[1.0, 2.0]]), 'F1': tf.constant([[3.0, 4.0, 5.0]]), } neighbor_config = configs.GraphNeighborConfig(max_neighbors=0) sample_features, nbr_features, nbr_weights = utils.unpack_neighbor_features( features, neighbor_config) self.assertIsNone(nbr_weights) sample_features, nbr_features = self.evaluate( [sample_features, nbr_features]) self.assertAllEqual(sample_features['F0'], expected_sample_features['F0']) self.assertAllEqual(sample_features['F1'], expected_sample_features['F1']) self.assertEmpty(nbr_features)
def testInvalidNeighborWeightRank(self): """Input containing a rank 3 neighbor weight tensor raises ValueError.""" features = { 'F0': tf.constant([1.0, 2.0]), 'NL_nbr_0_F0': tf.constant([1.1, 2.1]), 'NL_nbr_0_weight': tf.constant([[[0.25]]]), } with self.assertRaises(ValueError): neighbor_config = configs.GraphNeighborConfig(max_neighbors=1) utils.unpack_neighbor_features(features, neighbor_config)
def testEmptyFeatures(self): """Tests strip_neighbor_features with empty input.""" features = dict() neighbor_config = configs.GraphNeighborConfig() sample_features = utils.strip_neighbor_features(features, neighbor_config) # We create a dummy tensor so that the computation graph is not empty. dummy_tensor = tf.constant(1.0) sample_features, dummy_tensor = self.evaluate( [sample_features, dummy_tensor]) self.assertEmpty(sample_features)
def testInvalidRank(self): """Input containing rank 1 tensors raises ValueError.""" # Simulate a batch size of 1 for simplicity. features = { 'F0': tf.constant([1.0, 2.0]), 'NL_nbr_0_F0': tf.constant([1.1, 2.1]), 'NL_nbr_0_weight': tf.constant([0.25]), } with self.assertRaises(ValueError): neighbor_config = configs.GraphNeighborConfig(max_neighbors=1) utils.unpack_neighbor_features(features, neighbor_config)
def testMissingNeighborWeight(self): """Missing neighbor weight raises KeyError.""" # Simulate a batch size of 1 for simplicity. features = { 'F0': tf.constant([[1.0, 2.0]]), 'NL_nbr_0_F0': tf.constant([[1.1, 2.1]]), 'NL_nbr_0_weight': tf.constant([[0.25]]), 'NL_nbr_1_F0': tf.constant([[1.2, 2.2]]), } with self.assertRaises(KeyError): neighbor_config = configs.GraphNeighborConfig(max_neighbors=2) utils.unpack_neighbor_features(features, neighbor_config)
def testSampleAndNeighborFeatureShapeIncompatibility(self): """Sample feature and neighbor feature have incompatible shapes.""" # Simulate a batch size of 1 for simplicity. # The shape of the sample feature is 1x2 while the shape of the # corresponding neighbor feature 1x3. features = { 'F0': tf.constant([[1.0, 2.0]]), 'NL_nbr_0_F0': tf.constant([[1.1, 2.1, 3.1]]), 'NL_nbr_0_weight': tf.constant([[0.25]]), } with self.assertRaises(ValueError): neighbor_config = configs.GraphNeighborConfig(max_neighbors=1) utils.unpack_neighbor_features(features, neighbor_config)
def testEmptyFeatures(self): """Test unpack_neighbor_features with empty input.""" features = {} neighbor_config = configs.GraphNeighborConfig(max_neighbors=0) sample_features, nbr_features, nbr_weights = utils.unpack_neighbor_features( features, neighbor_config) self.assertIsNone(nbr_weights) # We create a dummy tensor so that the computation graph is not empty. dummy_tensor = tf.constant(1.0) sample_features, nbr_features, dummy_tensor = self.evaluate( [sample_features, nbr_features, dummy_tensor]) self.assertEmpty(sample_features) self.assertEmpty(nbr_features)
def testSparse(self, keep_rank): """Tests the layer with a variable number of neighbors.""" batch_size = 4 input_size = 2 features = { 'input': tf.sparse.from_dense( np.random.normal(size=(batch_size, input_size))), # Every sample but the last has 1 neighbor. 'NL_nbr_0_input': tf.RaggedTensor.from_row_starts( values=np.random.normal(size=(batch_size - 1) * input_size), row_starts=[0, 2, 4, 6]).to_sparse(), 'NL_nbr_0_weight': np.expand_dims(np.array([0.9, 0.3, 0.6, 0.]), -1), # Only the 1st and 3rd sample have a second neighbor. 'NL_nbr_1_input': tf.RaggedTensor.from_row_starts( values=np.random.normal(size=(batch_size - 2) * input_size), row_starts=[0, 2, 2, 4]).to_sparse(), 'NL_nbr_1_weight': np.expand_dims(np.array([0.25, 0., 0.75, 0.]), -1), } model = _make_model( configs.GraphNeighborConfig(max_neighbors=2), {'input': tf.keras.Input(input_size, dtype=tf.float64)}, keep_rank, tf.float64) samples, neighbors, weights = self.evaluate(model(features)) # Check that samples are unchanged. self.assertAllClose(samples['input'].values, self.evaluate(features['input'].values)) # Check that weights are grouped together and have the right shape. self.assertAllClose( weights, np.array([0.9, 0.25, 0.3, 0., 0.6, 0.75, 0., 0.]).reshape((batch_size, 2, 1) if keep_rank else (batch_size * 2, 1))) # Check that neighbors are grouped together. dense_neighbors = self.evaluate( tf.sparse.to_dense(neighbors['input'], -1.)) neighbor0 = self.evaluate( tf.sparse.to_dense(features['NL_nbr_0_input'], -1)) neighbor1 = self.evaluate( tf.sparse.to_dense(features['NL_nbr_1_input'], -1)) for i in range(batch_size): actual = (dense_neighbors[i] if keep_rank else np.split( dense_neighbors, batch_size)[i]) self.assertAllEqual(actual, np.stack([neighbor0[i], neighbor1[i]]))
def testNeighborWeightShapeIncompatibility(self): """One neighbor weight has an incompatibile shape.""" # Simulate a batch size of 1 for simplicity. # The shape of one neighbor weight is 1x2 instead of 1x1. features = { 'F0': tf.constant([[1.0, 2.0]]), 'NL_nbr_0_F0': tf.constant([[1.1, 2.1]]), 'NL_nbr_0_weight': tf.constant([[0.25]]), 'NL_nbr_1_F0': tf.constant([[1.2, 2.2]]), 'NL_nbr_1_weight': tf.constant([[0.5, 0.75]]), } with self.assertRaises(ValueError): neighbor_config = configs.GraphNeighborConfig(max_neighbors=2) utils.unpack_neighbor_features(features, neighbor_config)
def testFeaturesWithDynamicBatchSizeAndFeatureShape(self): """Tests the case when the batch size and feature shape are both dynamic.""" # Use a dynamic batch size and a dynamic feature shape. The former # corresponds to the first dimension of the tensors defined below, and the # latter corresonponds to the second dimension of 'sample_features' and # 'neighbor_i_features'. feature_specs = { 'F0': tf.TensorSpec((None, None, 3), tf.float32), 'NL_nbr_0_F0': tf.TensorSpec((None, None, 3), tf.float32), 'NL_nbr_0_weight': tf.TensorSpec((None, 1), tf.float32), } # Specify a batch size of 3 and a pre-batching feature shape of 2x3 at run # time. sample1 = [[1, 2, 3], [3, 2, 1]] sample2 = [[4, 5, 6], [6, 5, 4]] sample3 = [[7, 8, 9], [9, 8, 7]] sample_features = [sample1, sample2, sample3] # 3x2x3 neighbor_0_features = [[[1, 3, 5], [5, 3, 1]], [[7, 9, 11], [11, 9, 7]], [[13, 15, 17], [17, 15, 13]]] # 3x2x3 neighbor_0_weights = [[0.25], [0.5], [0.75]] # 3x1 expected_sample_features = {'F0': sample_features} features = { 'F0': sample_features, 'NL_nbr_0_F0': neighbor_0_features, 'NL_nbr_0_weight': neighbor_0_weights, } neighbor_config = configs.GraphNeighborConfig() @tf.function(input_signature=[feature_specs]) def _strip_neighbor_features(features): return utils.strip_neighbor_features(features, neighbor_config) sample_features = self.evaluate(_strip_neighbor_features(features)) # Check that only the sample features are retained. feature_keys = sorted(sample_features.keys()) self.assertListEqual(feature_keys, ['F0']) # Check that the value of the sample feature remains unchanged. self.assertAllEqual(sample_features['F0'], expected_sample_features['F0'])
def testNeighborFeatureShapeIncompatibility(self): """One neighbor feature has an incompatible shape.""" # Simulate a batch size of 1 for simplicity. # The shape of the sample feature and one neighbor feature is 1x2, while the # shape of another neighbor feature 1x3. features = { 'F0': tf.constant([[1.0, 2.0]]), 'NL_nbr_0_F0': tf.constant([[1.1, 2.1]]), 'NL_nbr_0_weight': tf.constant([[0.25]]), 'NL_nbr_1_F0': tf.constant([[1.2, 2.2, 3.2]]), 'NL_nbr_1_weight': tf.constant([[0.5]]), } with self.assertRaises(ValueError): neighbor_config = configs.GraphNeighborConfig(max_neighbors=2) utils.unpack_neighbor_features(features, neighbor_config)
def testNoNeighborFeatures(self): """Tests strip_neighbor_features when there are no neighbor features.""" features = {'F0': tf.constant(11.0, shape=[2, 2])} neighbor_config = configs.GraphNeighborConfig() sample_features = utils.strip_neighbor_features(features, neighbor_config) expected_sample_features = {'F0': tf.constant(11.0, shape=[2, 2])} sample_features = self.evaluate(sample_features) # Check that only the sample features are retained. feature_keys = sorted(sample_features.keys()) self.assertListEqual(feature_keys, ['F0']) # Check that the values of the sample feature remains unchanged. self.assertAllEqual(sample_features['F0'], expected_sample_features['F0'])
def testBatchedFeatures(self): """Tests strip_neighbor_features with batched input features.""" features = { 'F0': tf.constant(11.0, shape=[2, 2]), 'F1': tf.SparseTensor(indices=[[0, 0], [0, 1]], values=[1.0, 2.0], dense_shape=[2, 4]), 'NL_nbr_0_F0': tf.constant(22.0, shape=[2, 2]), 'NL_nbr_0_F1': tf.SparseTensor(indices=[[1, 0], [1, 1]], values=[3.0, 4.0], dense_shape=[2, 4]), 'NL_nbr_0_weight': tf.constant(0.25, shape=[2, 1]), } neighbor_config = configs.GraphNeighborConfig() sample_features = utils.strip_neighbor_features( features, neighbor_config) expected_sample_features = { 'F0': tf.constant(11.0, shape=[2, 2]), 'F1': tf.SparseTensor(indices=[[0, 0], [0, 1]], values=[1.0, 2.0], dense_shape=[2, 4]), } sample_features = self.evaluate(sample_features) # Check that only the sample features are retained. feature_keys = sorted(sample_features.keys()) self.assertListEqual(feature_keys, ['F0', 'F1']) # Check that the values of the sample features remain unchanged. self.assertAllEqual(sample_features['F0'], expected_sample_features['F0']) self.assertAllEqual(sample_features['F1'].values, expected_sample_features['F1'].values) self.assertAllEqual(sample_features['F1'].indices, expected_sample_features['F1'].indices) self.assertAllEqual(sample_features['F1'].dense_shape, expected_sample_features['F1'].dense_shape)
def make_cora_dataset( file_path, batch_size=128, shuffle=False, neighbor_config: Optional[configs.GraphNeighborConfig] = None, max_seq_length=1433, num_parallel_calls=tf.data.experimental.AUTOTUNE): """Returns a `tf.data.Dataset` instance based on data in `file_path`.""" if neighbor_config is None: neighbor_config = configs.GraphNeighborConfig() features = { 'words': tf.io.FixedLenFeature([max_seq_length], tf.int64, default_value=tf.constant(0, dtype=tf.int64, shape=[max_seq_length ])), 'label': tf.io.FixedLenFeature((), tf.int64, default_value=-1), } for i in range(neighbor_config.max_neighbors): nbr_feature_key = '{}{}_{}'.format(neighbor_config.prefix, i, 'words') nbr_weight_key = '{}{}{}'.format(neighbor_config.prefix, i, neighbor_config.weight_suffix) features[nbr_feature_key] = tf.io.FixedLenFeature( [max_seq_length], tf.int64, default_value=tf.constant(0, dtype=tf.int64, shape=[max_seq_length])) features[nbr_weight_key] = tf.io.FixedLenFeature( [1], tf.float32, default_value=tf.constant([0.0])) dataset = tf.data.experimental.make_batched_features_dataset( [file_path], batch_size, features, label_key='label', num_epochs=1, shuffle=shuffle, drop_final_batch=True) dataset = dataset.map(functools.partial(pack_nodes_and_edges, batch_size, neighbor_config), num_parallel_calls=num_parallel_calls) return dataset.prefetch(num_parallel_calls)
def testBatchedSampleAndNeighborFeatureExtraction(self): """Test input contains two samples with one feature and three neighbors.""" # Simulate a batch size of 2. features = { 'F0': tf.constant(11.0, shape=[2, 2]), 'NL_nbr_0_F0': tf.constant(22.0, shape=[2, 2]), 'NL_nbr_0_weight': tf.constant(0.25, shape=[2, 1]), 'NL_nbr_1_F0': tf.constant(33.0, shape=[2, 2]), 'NL_nbr_1_weight': tf.constant(0.75, shape=[2, 1]), 'NL_nbr_2_F0': tf.constant(44.0, shape=[2, 2]), 'NL_nbr_2_weight': tf.constant(1.0, shape=[2, 1]), } expected_sample_features = { 'F0': tf.constant(11.0, shape=[2, 2]), } # The key in this dictionary will contain the original sample's feature # name. The shape of the corresponding tensor will be 6x2, which is the # result of doing an interleaved merge of three 2x2 tensors along axis 0. expected_neighbor_features = { 'F0': tf.constant([[22.0, 22.0], [33.0, 33.0], [44.0, 44.0], [22.0, 22.0], [33.0, 33.0], [44.0, 44.0]]), } # The shape of this tensor is 6x1, which is the result of doing an # interleaved merge of three 2x1 tensors along axis 0. expected_neighbor_weights = tf.constant([[0.25], [0.75], [1.0], [0.25], [0.75], [1.0]]) neighbor_config = configs.GraphNeighborConfig(max_neighbors=3) sample_features, nbr_features, nbr_weights = utils.unpack_neighbor_features( features, neighbor_config) with self.cached_session() as sess: sess.run([sample_features, nbr_features, nbr_weights]) self.assertAllEqual(sample_features['F0'], expected_sample_features['F0']) self.assertAllEqual(nbr_features['F0'], expected_neighbor_features['F0']) self.assertAllEqual(nbr_weights, expected_neighbor_weights)
def _create_and_compile_graph_reg_model(model_fn, weight, max_neighbors): """Creates and compiles a graph regularized model. Args: model_fn: A function that builds a linear regression model. weight: Initial value for the weights variable in the linear regressor. max_neighbors: The maximum number of neighbors for graph regularization. Returns: A pair containing the unregularized model and the graph regularized model as `tf.keras.Model` instances. """ model = model_fn((2, ), weight) graph_reg_config = configs.GraphRegConfig( configs.GraphNeighborConfig(max_neighbors=max_neighbors), multiplier=1) graph_reg_model = graph_regularization.GraphRegularization( model, graph_reg_config) graph_reg_model.compile( optimizer=keras.optimizers.SGD(LEARNING_RATE), loss='MSE') return model, graph_reg_model
def testDense(self, keep_rank): """Tests creating image neighbors.""" # Make fake 8x8 images. batch_size = 4 image_height = 8 image_width = 8 features = { 'image': np.random.randint(0, 256, size=(batch_size, image_height, image_width, 1)).astype(np.uint8), 'NL_nbr_0_image': np.random.randint(0, 256, size=(batch_size, image_height, image_width, 1)).astype(np.uint8), 'NL_nbr_1_image': np.random.randint(0, 256, size=(batch_size, image_height, image_width, 1)).astype(np.uint8), 'NL_nbr_2_image': np.random.randint(0, 256, size=(batch_size, image_height, image_width, 1)).astype(np.uint8), 'NL_nbr_0_weight': np.random.uniform(size=(batch_size, 1)).astype(np.float32), 'NL_nbr_1_weight': np.random.uniform(size=(batch_size, 1)).astype(np.float32), 'NL_nbr_2_weight': np.random.uniform(size=(batch_size, 1)).astype(np.float32), } num_neighbors = 3 model = _make_model( configs.GraphNeighborConfig(max_neighbors=num_neighbors), { 'image': tf.keras.Input((image_height, image_width, 1), dtype=tf.uint8, name='image'), }, keep_rank) samples, neighbors, weights = self.evaluate(model(features)) samples, neighbors = (samples['image'], neighbors['image']) # Check that samples are unchanged. self.assertAllEqual(samples, features['image']) # Check that neighbors and weights are grouped together for each sample. for i in range(batch_size): self.assertAllEqual( neighbors[i] if keep_rank else neighbors[(i * num_neighbors):((i + 1) * num_neighbors)], np.stack([ features['NL_nbr_0_image'][i], features['NL_nbr_1_image'][i], features['NL_nbr_2_image'][i], ])) self.assertAllEqual( weights[i] if keep_rank else np.split(weights, batch_size)[i], np.stack([ features['NL_nbr_0_weight'][i], features['NL_nbr_1_weight'][i], features['NL_nbr_2_weight'][i], ]))
def from_config(cls, config): return cls( configs.GraphNeighborConfig(**config.pop('neighbor_config')), **config)
def testSparseFeature(self): """Test the case when the sample has a sparse feature.""" # Simulate batch size of 2. features = { 'F0': tf.constant(11.0, shape=[2, 2]), 'F1': tf.SparseTensor(indices=[[0, 0], [0, 1]], values=[1.0, 2.0], dense_shape=[2, 4]), 'NL_nbr_0_F0': tf.constant(22.0, shape=[2, 2]), 'NL_nbr_0_F1': tf.SparseTensor(indices=[[1, 0], [1, 1]], values=[3.0, 4.0], dense_shape=[2, 4]), 'NL_nbr_0_weight': tf.constant(0.25, shape=[2, 1]), 'NL_nbr_1_F0': tf.constant(33.0, shape=[2, 2]), 'NL_nbr_1_F1': tf.SparseTensor(indices=[[0, 2], [1, 3]], values=[5.0, 6.0], dense_shape=[2, 4]), 'NL_nbr_1_weight': tf.constant(0.75, shape=[2, 1]), } expected_sample_features = { 'F0': tf.constant(11.0, shape=[2, 2]), 'F1': tf.SparseTensor(indices=[[0, 0], [0, 1]], values=[1.0, 2.0], dense_shape=[2, 4]), } # The keys in this dictionary will contain the original sample's feature # names. expected_neighbor_features = { # The shape of the corresponding tensor for 'F0' will be 4x2, which is # the result of doing an interleaved merge of two 2x2 tensors along # axis 0. 'F0': tf.constant([[22, 22], [33, 33], [22, 22], [33, 33]]), # The shape of the corresponding tensor for 'F1' will be 4x4, which is # the result of doing an interleaved merge of two 2x4 tensors along # axis 0. 'F1': tf.SparseTensor(indices=[[1, 2], [2, 0], [2, 1], [3, 3]], values=[5.0, 3.0, 4.0, 6.0], dense_shape=[4, 4]), } # The shape of this tensor is 4x1, which is the result of doing an # interleaved merge of two 2x1 tensors along axis 0. expected_neighbor_weights = tf.constant([[0.25], [0.75], [0.25], [0.75]]) neighbor_config = configs.GraphNeighborConfig(max_neighbors=2) sample_features, nbr_features, nbr_weights = self.evaluate( utils.unpack_neighbor_features(features, neighbor_config)) self.assertAllEqual(sample_features['F0'], expected_sample_features['F0']) self.assertAllEqual(sample_features['F1'].values, expected_sample_features['F1'].values) self.assertAllEqual(sample_features['F1'].indices, expected_sample_features['F1'].indices) self.assertAllEqual(sample_features['F1'].dense_shape, expected_sample_features['F1'].dense_shape) self.assertAllEqual(nbr_features['F0'], expected_neighbor_features['F0']) self.assertAllEqual(nbr_features['F1'].values, expected_neighbor_features['F1'].values) self.assertAllEqual(nbr_features['F1'].indices, expected_neighbor_features['F1'].indices) self.assertAllEqual(nbr_features['F1'].dense_shape, expected_neighbor_features['F1'].dense_shape) self.assertAllEqual(nbr_weights, expected_neighbor_weights)
def testDynamicBatchSizeAndFeatureShape(self): """Test the case when the batch size and feature shape are both dynamic.""" # Use a dynamic batch size and a dynamic feature shape. The former # corresponds to the first dimension of the tensors defined below, and the # latter corresonponds to the second dimension of 'sample_features' and # 'neighbor_i_features'. feature_specs = { 'F0': tf.TensorSpec((None, None, 3), tf.float32), 'NL_nbr_0_F0': tf.TensorSpec((None, None, 3), tf.float32), 'NL_nbr_0_weight': tf.TensorSpec((None, 1), tf.float32), 'NL_nbr_1_F0': tf.TensorSpec((None, None, 3), tf.float32), 'NL_nbr_1_weight': tf.TensorSpec((None, 1), tf.float32) } # Specify a batch size of 3 and a pre-batching feature shape of 2x3 at run # time. sample1 = [[1, 2, 3], [3, 2, 1]] sample2 = [[4, 5, 6], [6, 5, 4]] sample3 = [[7, 8, 9], [9, 8, 7]] sample_features = [sample1, sample2, sample3] # 3x2x3 neighbor_0_features = [[[1, 3, 5], [5, 3, 1]], [[7, 9, 11], [11, 9, 7]], [[13, 15, 17], [17, 15, 13]]] # 3x2x3 neighbor_0_weights = [[0.25], [0.5], [0.75]] # 3x1 neighbor_1_features = [[[2, 4, 6], [6, 4, 2]], [[8, 10, 12], [12, 10, 8]], [[14, 16, 18], [18, 16, 14]]] # 3x2x3 neighbor_1_weights = [[0.75], [0.5], [0.25]] # 3x1 expected_sample_features = {'F0': sample_features} features = { 'F0': sample_features, 'NL_nbr_0_F0': neighbor_0_features, 'NL_nbr_0_weight': neighbor_0_weights, 'NL_nbr_1_F0': neighbor_1_features, 'NL_nbr_1_weight': neighbor_1_weights } # The key in this dictionary will contain the original sample's feature # name. The shape of the corresponding tensor will be 6x2x3, which is the # result of doing an interleaved merge of 2 3x2x3 tensors along axis 0. expected_neighbor_features = { 'F0': [[[1, 3, 5], [5, 3, 1]], [[2, 4, 6], [6, 4, 2]], [[7, 9, 11], [11, 9, 7]], [[8, 10, 12], [12, 10, 8]], [[13, 15, 17], [17, 15, 13]], [[14, 16, 18], [18, 16, 14]]], } # The shape of this tensor is 6x1, which is the result of doing an # interleaved merge of two 3x1 tensors along axis 0. expected_neighbor_weights = [[0.25], [0.75], [0.5], [0.5], [0.75], [0.25]] neighbor_config = configs.GraphNeighborConfig(max_neighbors=2) @tf.function(input_signature=[feature_specs]) def _unpack_neighbor_features(features): return utils.unpack_neighbor_features(features, neighbor_config) sample_feats, nbr_feats, nbr_weights = self.evaluate( _unpack_neighbor_features(features)) self.assertAllEqual(sample_feats['F0'], expected_sample_features['F0']) self.assertAllEqual(nbr_feats['F0'], expected_neighbor_features['F0']) self.assertAllEqual(nbr_weights, expected_neighbor_weights)