示例#1
0
 def test_sequentially_correlated_all_previous_generator_raise_error(self):
   rs = np.random.RandomState(1)
   with self.assertRaises(ValueError):
     _ = set_generator.SequentiallyCorrelatedSetGenerator(
         order='not_implemented', correlated_sets='all', universe_size=30,
         num_sets=3, set_size=10, shared_prop=0.2, random_state=rs)
   with self.assertRaises(ValueError):
     _ = set_generator.SequentiallyCorrelatedSetGenerator(
         order='random', correlated_sets='not_implemented', universe_size=30,
         num_sets=3, set_size=10, shared_prop=0.2, random_state=rs)
示例#2
0
 def test_sequentially_correlated_all_previous_generator_reversed(self):
   rs = np.random.RandomState(1)
   sc_gen = set_generator.SequentiallyCorrelatedSetGenerator(
       order='reversed', correlated_sets='all', universe_size=30, num_sets=3,
       set_size=10, shared_prop=0.2, random_state=rs)
   set_ids_list = [set_ids for set_ids in sc_gen][::-1]
   previous_set_ids = set(set_ids_list[0])
   for set_ids in set_ids_list[1:]:
     shared_ids = previous_set_ids.intersection(set_ids)
     self.assertLen(shared_ids, 2)
     previous_set_ids.update(set_ids)
示例#3
0
 def test_sequentially_correlated_all_previous_generator_different_sizes(self):
   rs = np.random.RandomState(1)
   sc_gen = set_generator.SequentiallyCorrelatedSetGenerator(
       order='original', correlated_sets='all', universe_size=100,
       shared_prop=0.2, set_sizes=[10, 15, 20, 20], random_state=rs)
   expected_overlap_size = iter([2, 3, 4])
   set_ids_list = [set_ids for set_ids in sc_gen]
   previous_set_ids = set(set_ids_list[0])
   for set_ids in set_ids_list[1:]:
     shared_ids = previous_set_ids.intersection(set_ids)
     self.assertLen(shared_ids, next(expected_overlap_size))
     previous_set_ids.update(set_ids)
示例#4
0
 def test_sequentially_correlated_one_previous_generator_original(self):
   rs = np.random.RandomState(1)
   sc_gen = set_generator.SequentiallyCorrelatedSetGenerator(
       order='original', correlated_sets='one', universe_size=30, num_sets=3,
       set_size=10, shared_prop=0.2, random_state=rs)
   set_ids_list = [set_ids for set_ids in sc_gen]
   previous_set_ids = set(set_ids_list[0])
   union_set_ids = set(set_ids_list[0])
   for set_ids in set_ids_list[1:]:
     self.assertLen(previous_set_ids.intersection(set_ids), 2)
     self.assertLen(union_set_ids.intersection(set_ids), 2)
     previous_set_ids = set(set_ids)
     union_set_ids.update(set_ids)