def get_data(): inputs_tr, targets_tr, sort_indices_tr, _ = create_data( batch_size_tr, num_elements_min_max_tr) inputs_tr = utils_tf.set_zero_edge_features(inputs_tr, 1) inputs_tr = utils_tf.set_zero_global_features(inputs_tr, 1) # Test/generalization. inputs_ge, targets_ge, sort_indices_ge, _ = create_data( batch_size_ge, num_elements_min_max_ge) inputs_ge = utils_tf.set_zero_edge_features(inputs_ge, 1) inputs_ge = utils_tf.set_zero_global_features(inputs_ge, 1) targets_tr = utils_tf.set_zero_global_features(targets_tr, 1) targets_ge = utils_tf.set_zero_global_features(targets_ge, 1) return inputs_tr, targets_tr, sort_indices_tr, inputs_ge, targets_ge, sort_indices_ge
def test_fill_global_state(self, global_size): """Tests for filling the global state with a constant content.""" for g in self.graphs_dicts_in: g.pop("globals") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) n_graphs = self.reference_graph.n_edge.shape[0] graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size) self.assertAllEqual((n_graphs, global_size), graphs_tuple.globals.get_shape().as_list())
def test_fill_state_user_specified_types(self, dtype): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, 1, dtype) self.assertEqual(dtype, graphs_tuple.edges.dtype) self.assertEqual(dtype, graphs_tuple.nodes.dtype) self.assertEqual(dtype, graphs_tuple.globals.dtype)
def test_fill_state_default_types(self): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1) graphs_tuple = utils_tf.set_zero_global_features( graphs_tuple, global_size=1) self.assertEqual(tf.float32, graphs_tuple.edges.dtype) self.assertEqual(tf.float32, graphs_tuple.nodes.dtype) self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
def test_fill_global_state_dynamic(self, global_size): """Tests for filling the global state with a constant content.""" for g in self.graphs_dicts_in: g.pop("globals") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) # Hide global shape information graphs_tuple = graphs_tuple._replace( n_node=tf.placeholder_with_default(graphs_tuple.n_node, shape=[None])) n_graphs = self.reference_graph.n_edge.shape[0] graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size) with self.test_session() as sess: actual_globals = sess.run(graphs_tuple.globals) self.assertNDArrayNear( np.zeros((n_graphs, global_size)), actual_globals, err=1e-4)
def zeros_graph(sample_graph, edge_size, node_size, global_size): zeros_graphs = sample_graph.replace(nodes=None, edges=None, globals=None) zeros_graphs = utils_tf.set_zero_edge_features(zeros_graphs, edge_size) zeros_graphs = utils_tf.set_zero_node_features(zeros_graphs, node_size) zeros_graphs = utils_tf.set_zero_global_features(zeros_graphs, global_size) return zeros_graphs
def main(): all_graphs = test_graph() inputs = [all_graphs[0][0]] in_graphs = utils_tf.data_dicts_to_graphs_tuple(inputs) in_graphs = utils_tf.set_zero_global_features(in_graphs, 1) dtype_shape_from_graphs_tuple(in_graphs)