def test_none_throws_error(self, none_field): """Tests that an error is thrown if a GraphsTuple field is None.""" graphs_tuple = utils_np.data_dicts_to_graphs_tuple([self.graphs_dicts[1]]) graphs_tuple = graphs_tuple.replace(**{none_field: None}) with self.assertRaisesRegex( ValueError, "`{}` was `None`. All fields of the `G".format(none_field)): utils_tf.specs_from_graphs_tuple(graphs_tuple)
def _compile_with_tf_function(fn, graphs_tuple): input_signature = utils_tf.specs_from_graphs_tuple( graphs_tuple, dynamic_num_graphs=True, dynamic_num_nodes=True, dynamic_num_edges=True, ) @functools.partial(tf.function, input_signature=[input_signature]) def compiled_fn(graphs_tuple): assert _leading_static_shape(graphs_tuple.n_node) is None assert _leading_static_shape(graphs_tuple.senders) is None assert _leading_static_shape(graphs_tuple.nodes) is None return fn(graphs_tuple) return compiled_fn
def get_graph_subsampling_dataset(prefix, arrays, shuffle_indices, ratio_unlabeled_data_to_labeled_data, max_nodes, max_edges, **subsampler_kwargs): """Returns tf_dataset for online sampling.""" def generator(): labeled_indices = arrays[f"{prefix}_indices"] if ratio_unlabeled_data_to_labeled_data > 0: num_unlabeled_data_to_add = int( ratio_unlabeled_data_to_labeled_data * labeled_indices.shape[0]) unlabeled_indices = np.random.choice( NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False) root_node_indices = np.concatenate( [labeled_indices, unlabeled_indices]) else: root_node_indices = labeled_indices if shuffle_indices: root_node_indices = root_node_indices.copy() np.random.shuffle(root_node_indices) for index in root_node_indices: graph = sub_sampler.subsample_graph( index, arrays["author_institution_index"], arrays["institution_author_index"], arrays["author_paper_index"], arrays["paper_author_index"], arrays["paper_paper_index"], arrays["paper_paper_index_t"], paper_years=arrays["paper_year"], max_nodes=max_nodes, max_edges=max_edges, **subsampler_kwargs) graph = add_nodes_label(graph, arrays["paper_label"]) graph = add_nodes_year(graph, arrays["paper_year"]) graph = tf_graphs.GraphsTuple(*graph) yield graph sample_graph = next(generator()) return tf.data.Dataset.from_generator( generator, output_signature=utils_tf.specs_from_graphs_tuple(sample_graph))
def test_correct_signature( self, dynamic_num_nodes, dynamic_num_edges, dynamic_num_graphs, batched, replace_globals_with_constant): """Tests that the correct spec is created when using different options.""" if batched: input_data_dicts = [self.graphs_dicts[1], self.graphs_dicts[2]] else: input_data_dicts = [self.graphs_dicts[1]] graph = utils_np.data_dicts_to_graphs_tuple(input_data_dicts) num_graphs = len(input_data_dicts) num_edges = sum(graph.n_edge).item() num_nodes = sum(graph.n_node).item() # Manually setting edges and globals fields to give some variety in # testing situations. # Making edges have rank 1 to . graph = graph.replace(edges=np.zeros(num_edges)) # Make a constant field. if replace_globals_with_constant: graph = graph.replace(globals=np.array(0.0, dtype=np.float32)) spec_signature = utils_tf.specs_from_graphs_tuple( graph, dynamic_num_graphs, dynamic_num_nodes, dynamic_num_edges) # Captures if nodes/edges will be dynamic either due to dynamic nodes/edges # or dynamic graphs. dynamic_nodes_or_graphs = dynamic_num_nodes or dynamic_num_graphs dynamic_edges_or_graphs = dynamic_num_edges or dynamic_num_graphs num_edges = None if dynamic_edges_or_graphs else num_edges num_nodes = None if dynamic_nodes_or_graphs else num_nodes num_graphs = None if dynamic_num_graphs else num_graphs if replace_globals_with_constant: expected_globals_shape = [] else: expected_globals_shape = [num_graphs,] + test_utils.GLOBALS_DIMS expected_answer = graphs.GraphsTuple( nodes=tf.TensorSpec( shape=[num_nodes,] + test_utils.NODES_DIMS, dtype=tf.float32), edges=tf.TensorSpec( shape=[num_edges], # Edges were manually replaced to have dim 1. dtype=tf.float64), n_node=tf.TensorSpec( shape=[num_graphs], dtype=tf.int32), n_edge=tf.TensorSpec( shape=[num_graphs], dtype=tf.int32), globals=tf.TensorSpec( shape=expected_globals_shape, dtype=tf.float32), receivers=tf.TensorSpec( shape=[num_edges], dtype=tf.int32), senders=tf.TensorSpec( shape=[num_edges], dtype=tf.int32), ) with self.subTest(name="Correct Type."): self.assertIsInstance(spec_signature, graphs.GraphsTuple) with self.subTest(name="Correct Signature."): self.assertAllEqual(spec_signature, expected_answer)
shape = list(tensor_sample.shape) dtype = tensor_sample.dtype return description_fn(shape=shape, dtype=dtype) # Get some example data that resembles the tensors that will be fed # into update_step(): Input_data, example_target_data = code.next_batch(0) graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data) example_input_data = utils_tf.data_dicts_to_graphs_tuple(graph_dicts) # Get the input signature for that function by obtaining the specs input_signature = [ utils_tf.specs_from_graphs_tuple(example_input_data), specs_from_tensor(example_target_data) ] # Compile the update function using the input signature for speedy code. compiled_update_step = tf.function(update_step, input_signature=input_signature) ############# tf_function ################### for echo in range(max_iteration): code.mess_up_order() for i in range(code.total_number): Input_data, Output_data = code.next_batch(i) graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data) graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
def get_signatures(in_graphs, gt_graphs): in_signature = utils_tf.specs_from_graphs_tuple(in_graphs, True) gt_signature = utils_tf.specs_from_graphs_tuple(gt_graphs, True) return in_signature, gt_signature
def __init__( self, rgt, num_epochs, optimizer, init_lr, decay_steps, end_lr, power, cycle, tr_size, tr_batch_size, val_batch_size, tr_path_data, val_path_data, file_ext, seed, msg_ratio=0.45, input_fields=None, class_weights=[1.0, 1.0], scaler=True, delta_time_validation=60, log_path="logs", restore_path=None, compile=False, debug=False, ): np.random.seed(seed) tf.random.set_seed(seed) self._rs = np.random.RandomState(seed) super(EstimatorRGT, self).__init__(name="EstimatorRGT") self._best_acc = tf.Variable(0, trainable=False) self._best_delta = tf.Variable(np.infty, trainable=False) self._delta_time_validation = delta_time_validation self._model = rgt self._tr_size = tr_size self._num_epochs = num_epochs self._tr_batch_size = tr_batch_size self._val_batch_size = val_batch_size self._loss_fn = partial( binary_crossentropy, entity="edges", class_weights=class_weights, msg_ratio=msg_ratio, ) self._lr = tf.Variable(init_lr, trainable=False, dtype=tf.float32, name="lr") self._step = tf.Variable(0, trainable=False, dtype=tf.float32, name="tr_step") self._opt = snt.optimizers.__getattribute__(optimizer)(learning_rate=self._lr) self._schedule_lr_fn = tf.keras.optimizers.schedules.PolynomialDecay( init_lr, decay_steps, end_lr, power=power, cycle=cycle ) self._tr_path_data = tr_path_data self._val_path_data = val_path_data self._input_fields = input_fields self._batch_generator = partial( init_generator, scaler=scaler, random_state=self._rs, file_ext=file_ext, input_fields=input_fields, ) if restore_path is not None: self._log_dir = os.path.join(log_path, restore_path) else: self._log_dir = os.path.join( log_path, datetime.now().strftime("%Y%m%d-%H%M%S") ) os.makedirs(self._log_dir) self.__set_managers(seed, restore_path is not None) if debug: tf.debugging.experimental.enable_dump_debug_info( dump_root=os.path.join(self._log_dir, "debug"), tensor_debug_mode="FULL_HEALTH", circular_buffer_size=-1, ) if compile: val_generator = self._batch_generator( self._val_path_data, self._val_batch_size ) in_val_graphs, gt_val_graphs, _ = next(val_generator) in_signature = specs_from_graphs_tuple(in_val_graphs, True) gt_signature = specs_from_graphs_tuple(gt_val_graphs, True) self._update_model_weights = tf.function( self.__update_model_weights, input_signature=[in_signature, gt_signature], ) self._eval = tf.function( self.__eval, input_signature=[in_signature], ) else: self._update_model_weights = self.__update_model_weights self._eval = self.__eval