def testUndirectedGraphUnlimitedNbrs(self):
     pack_nbrs_lib.pack_nbrs(self._training_examples_path,
                             self._neighbor_examples_path,
                             self._graph_path,
                             self._output_nsl_training_data_path,
                             add_undirected_edges=True)
     expected_nsl_train_data = {
         'A': _augmented_a_undirected_two_nbrs(),
         'C': _augmented_c_undirected_two_nbrs()
     }
     actual_nsl_train_data = _read_tfrecord_examples(
         self._output_nsl_training_data_path)
     self.assertDictEqual(actual_nsl_train_data, expected_nsl_train_data)
 def testDirectedGraphLimitedNbrs(self):
     pack_nbrs_lib.pack_nbrs(self._training_examples_path,
                             self._neighbor_examples_path,
                             self._graph_path,
                             self._output_nsl_training_data_path,
                             add_undirected_edges=False,
                             max_nbrs=1)
     expected_nsl_train_data = {
         'A': _augmented_a_directed_one_nbr(),
         'C': _augmented_c_directed()
     }
     actual_nsl_train_data = _read_tfrecord_examples(
         self._output_nsl_training_data_path)
     self.assertDictEqual(actual_nsl_train_data, expected_nsl_train_data)
    def testUndirectedGraphUnlimitedNbrsNoNeighborExamples(self):
        """Tests pack_nbrs() with an empty second argument (neighbor examples).

    In this case, the edge A-->B is dangling because there will be no Example
    named "B" in the input.
    """
        pack_nbrs_lib.pack_nbrs(self._training_examples_path,
                                '',
                                self._graph_path,
                                self._output_nsl_training_data_path,
                                add_undirected_edges=True)
        expected_nsl_train_data = {
            # Node A has only one neighbor, namely C.
            'A': _augmented_a_directed_one_nbr(),
            # C's only neighbor in the undirected case is A.
            'C': _augmented_c_undirected_one_nbr_a()
        }
        actual_nsl_train_data = _read_tfrecord_examples(
            self._output_nsl_training_data_path)
        self.assertDictEqual(actual_nsl_train_data, expected_nsl_train_data)