示例#1
0
 def test_none_throws_error(self, none_field):
   """Tests that an error is thrown if a GraphsTuple field is None."""
   graphs_tuple = utils_np.data_dicts_to_graphs_tuple([self.graphs_dicts[1]])
   graphs_tuple = graphs_tuple.replace(**{none_field: None})
   with self.assertRaisesRegex(
       ValueError, "`{}` was `None`. All fields of the `G".format(none_field)):
     utils_tf.specs_from_graphs_tuple(graphs_tuple)
示例#2
0
def _compile_with_tf_function(fn, graphs_tuple):

    input_signature = utils_tf.specs_from_graphs_tuple(
        graphs_tuple,
        dynamic_num_graphs=True,
        dynamic_num_nodes=True,
        dynamic_num_edges=True,
    )

    @functools.partial(tf.function, input_signature=[input_signature])
    def compiled_fn(graphs_tuple):
        assert _leading_static_shape(graphs_tuple.n_node) is None
        assert _leading_static_shape(graphs_tuple.senders) is None
        assert _leading_static_shape(graphs_tuple.nodes) is None
        return fn(graphs_tuple)

    return compiled_fn
示例#3
0
def get_graph_subsampling_dataset(prefix, arrays, shuffle_indices,
                                  ratio_unlabeled_data_to_labeled_data,
                                  max_nodes, max_edges, **subsampler_kwargs):
    """Returns tf_dataset for online sampling."""
    def generator():
        labeled_indices = arrays[f"{prefix}_indices"]
        if ratio_unlabeled_data_to_labeled_data > 0:
            num_unlabeled_data_to_add = int(
                ratio_unlabeled_data_to_labeled_data *
                labeled_indices.shape[0])
            unlabeled_indices = np.random.choice(
                NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False)
            root_node_indices = np.concatenate(
                [labeled_indices, unlabeled_indices])
        else:
            root_node_indices = labeled_indices
        if shuffle_indices:
            root_node_indices = root_node_indices.copy()
            np.random.shuffle(root_node_indices)

        for index in root_node_indices:
            graph = sub_sampler.subsample_graph(
                index,
                arrays["author_institution_index"],
                arrays["institution_author_index"],
                arrays["author_paper_index"],
                arrays["paper_author_index"],
                arrays["paper_paper_index"],
                arrays["paper_paper_index_t"],
                paper_years=arrays["paper_year"],
                max_nodes=max_nodes,
                max_edges=max_edges,
                **subsampler_kwargs)

            graph = add_nodes_label(graph, arrays["paper_label"])
            graph = add_nodes_year(graph, arrays["paper_year"])
            graph = tf_graphs.GraphsTuple(*graph)
            yield graph

    sample_graph = next(generator())

    return tf.data.Dataset.from_generator(
        generator,
        output_signature=utils_tf.specs_from_graphs_tuple(sample_graph))
示例#4
0
  def test_correct_signature(
      self,
      dynamic_num_nodes,
      dynamic_num_edges,
      dynamic_num_graphs,
      batched,
      replace_globals_with_constant):
    """Tests that the correct spec is created when using different options."""

    if batched:
      input_data_dicts = [self.graphs_dicts[1], self.graphs_dicts[2]]
    else:
      input_data_dicts = [self.graphs_dicts[1]]

    graph = utils_np.data_dicts_to_graphs_tuple(input_data_dicts)
    num_graphs = len(input_data_dicts)
    num_edges = sum(graph.n_edge).item()
    num_nodes = sum(graph.n_node).item()

    # Manually setting edges and globals fields to give some variety in
    # testing situations.
    # Making edges have rank 1 to .
    graph = graph.replace(edges=np.zeros(num_edges))

    # Make a constant field.
    if replace_globals_with_constant:
      graph = graph.replace(globals=np.array(0.0, dtype=np.float32))

    spec_signature = utils_tf.specs_from_graphs_tuple(
        graph, dynamic_num_graphs, dynamic_num_nodes, dynamic_num_edges)

    # Captures if nodes/edges will be dynamic either due to dynamic nodes/edges
    # or dynamic graphs.
    dynamic_nodes_or_graphs = dynamic_num_nodes or dynamic_num_graphs
    dynamic_edges_or_graphs = dynamic_num_edges or dynamic_num_graphs

    num_edges = None if dynamic_edges_or_graphs else num_edges
    num_nodes = None if dynamic_nodes_or_graphs else num_nodes
    num_graphs = None if dynamic_num_graphs else num_graphs

    if replace_globals_with_constant:
      expected_globals_shape = []
    else:
      expected_globals_shape = [num_graphs,] + test_utils.GLOBALS_DIMS

    expected_answer = graphs.GraphsTuple(
        nodes=tf.TensorSpec(
            shape=[num_nodes,] + test_utils.NODES_DIMS,
            dtype=tf.float32),
        edges=tf.TensorSpec(
            shape=[num_edges],  # Edges were manually replaced to have dim 1.
            dtype=tf.float64),
        n_node=tf.TensorSpec(
            shape=[num_graphs],
            dtype=tf.int32),
        n_edge=tf.TensorSpec(
            shape=[num_graphs],
            dtype=tf.int32),
        globals=tf.TensorSpec(
            shape=expected_globals_shape,
            dtype=tf.float32),
        receivers=tf.TensorSpec(
            shape=[num_edges],
            dtype=tf.int32),
        senders=tf.TensorSpec(
            shape=[num_edges],
            dtype=tf.int32),
        )

    with self.subTest(name="Correct Type."):
      self.assertIsInstance(spec_signature, graphs.GraphsTuple)

    with self.subTest(name="Correct Signature."):
      self.assertAllEqual(spec_signature, expected_answer)
示例#5
0
	
	shape = list(tensor_sample.shape)
	dtype = tensor_sample.dtype

	return description_fn(shape=shape, dtype=dtype) 

# Get some example data that resembles the tensors that will be fed
# into update_step():

Input_data, example_target_data = code.next_batch(0)
graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data)	
example_input_data = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)

# Get the input signature for that function by obtaining the specs
input_signature = [
  utils_tf.specs_from_graphs_tuple(example_input_data),
  specs_from_tensor(example_target_data)
]

# Compile the update function using the input signature for speedy code.
compiled_update_step = tf.function(update_step, input_signature=input_signature)
############# tf_function ###################

for echo in range(max_iteration):
	code.mess_up_order()
	
	for i in range(code.total_number):
		Input_data, Output_data = code.next_batch(i)
		graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data)
		graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
		
示例#6
0
def get_signatures(in_graphs, gt_graphs):
    in_signature = utils_tf.specs_from_graphs_tuple(in_graphs, True)
    gt_signature = utils_tf.specs_from_graphs_tuple(gt_graphs, True)
    return in_signature, gt_signature
示例#7
0
    def __init__(
        self,
        rgt,
        num_epochs,
        optimizer,
        init_lr,
        decay_steps,
        end_lr,
        power,
        cycle,
        tr_size,
        tr_batch_size,
        val_batch_size,
        tr_path_data,
        val_path_data,
        file_ext,
        seed,
        msg_ratio=0.45,
        input_fields=None,
        class_weights=[1.0, 1.0],
        scaler=True,
        delta_time_validation=60,
        log_path="logs",
        restore_path=None,
        compile=False,
        debug=False,
    ):
        np.random.seed(seed)
        tf.random.set_seed(seed)
        self._rs = np.random.RandomState(seed)
        super(EstimatorRGT, self).__init__(name="EstimatorRGT")

        self._best_acc = tf.Variable(0, trainable=False)
        self._best_delta = tf.Variable(np.infty, trainable=False)
        self._delta_time_validation = delta_time_validation

        self._model = rgt
        self._tr_size = tr_size
        self._num_epochs = num_epochs
        self._tr_batch_size = tr_batch_size
        self._val_batch_size = val_batch_size
        self._loss_fn = partial(
            binary_crossentropy,
            entity="edges",
            class_weights=class_weights,
            msg_ratio=msg_ratio,
        )

        self._lr = tf.Variable(init_lr, trainable=False, dtype=tf.float32, name="lr")
        self._step = tf.Variable(0, trainable=False, dtype=tf.float32, name="tr_step")
        self._opt = snt.optimizers.__getattribute__(optimizer)(learning_rate=self._lr)
        self._schedule_lr_fn = tf.keras.optimizers.schedules.PolynomialDecay(
            init_lr, decay_steps, end_lr, power=power, cycle=cycle
        )

        self._tr_path_data = tr_path_data
        self._val_path_data = val_path_data
        self._input_fields = input_fields
        self._batch_generator = partial(
            init_generator,
            scaler=scaler,
            random_state=self._rs,
            file_ext=file_ext,
            input_fields=input_fields,
        )

        if restore_path is not None:
            self._log_dir = os.path.join(log_path, restore_path)
        else:
            self._log_dir = os.path.join(
                log_path, datetime.now().strftime("%Y%m%d-%H%M%S")
            )
            os.makedirs(self._log_dir)
        self.__set_managers(seed, restore_path is not None)

        if debug:
            tf.debugging.experimental.enable_dump_debug_info(
                dump_root=os.path.join(self._log_dir, "debug"),
                tensor_debug_mode="FULL_HEALTH",
                circular_buffer_size=-1,
            )

        if compile:
            val_generator = self._batch_generator(
                self._val_path_data, self._val_batch_size
            )
            in_val_graphs, gt_val_graphs, _ = next(val_generator)
            in_signature = specs_from_graphs_tuple(in_val_graphs, True)
            gt_signature = specs_from_graphs_tuple(gt_val_graphs, True)
            self._update_model_weights = tf.function(
                self.__update_model_weights,
                input_signature=[in_signature, gt_signature],
            )
            self._eval = tf.function(
                self.__eval,
                input_signature=[in_signature],
            )
        else:
            self._update_model_weights = self.__update_model_weights
            self._eval = self.__eval