コード例 #1
0
 def test_batch_length(self):
     s = AntVisionSequence(self.data_paths, self.batch_size,
                           self.bot_to_emulate)
     s.build_indexes(True)
     self.assertEqual(
         ceil((sum(self.expected_example_counts) * .6) / self.batch_size),
         len(s))
コード例 #2
0
    def test_create_index(self):
        s = AntVisionSequence(self.data_paths, self.batch_size,
                              self.bot_to_emulate)
        s.build_indexes(True)
        gi_0 = seq(
            s.game_indexes).find(lambda gi: gi.game_path == self.data_paths[0])
        gi_1 = seq(
            s.game_indexes).find(lambda gi: gi.game_path == self.data_paths[1])

        self.assertEqual(gi_0.length, self.expected_example_counts[0])
        self.assertEqual(gi_1.length, self.expected_example_counts[1])
コード例 #3
0
ファイル: __main__.py プロジェクト: Jason-Turan0/ants
def build_index(task: Tuple[str, str]):
    game_path, seq_type = task
    bot_to_emulate = 'memetix_1'
    if seq_type == 'MapViewSequence':
        s = MapViewSequence([game_path], 50, bot_to_emulate)
    elif seq_type == 'AntVisionSequence':
        s = AntVisionSequence([game_path], 50, bot_to_emulate)
    elif seq_type == 'CombinedSequence':
        s = CombinedSequence([game_path], 50, bot_to_emulate)
    else:
        raise NotImplementedError(seq_type)
    s.build_indexes(False)
    return True
コード例 #4
0
 def test_range_intersects(self):
     s = AntVisionSequence(self.data_paths, self.batch_size,
                           self.bot_to_emulate)
     self.assertTrue(s.ranges_intersect(0, 6, 5, 10))
     self.assertTrue(s.ranges_intersect(0, 6, 4, 10))
     self.assertFalse(s.ranges_intersect(0, 5, 5, 10))
     self.assertFalse(s.ranges_intersect(10, 20, 5, 10))
     self.assertFalse(s.ranges_intersect(0, 0, 0, 0))
コード例 #5
0
 def create_sequence(self, game_paths: List[str],
                     batch_size: int) -> FileSystemSequence:
     return AntVisionSequence(game_paths, batch_size, self.bot_name, 7)
コード例 #6
0
    def test_set_sizes(self):
        s = AntVisionSequence(self.data_paths, self.batch_size,
                              self.bot_to_emulate)
        s.build_indexes(True)

        for gi in s.game_indexes:
            print(f'{gi.position_start} {gi.position_end} {gi.length}')

        print('Training')
        print(s.get_training_range())
        print(s.get_training_batch_count())
        training_count = seq(range(s.get_training_batch_count())) \
            .map(lambda batch_index: s.get_training_batch(batch_index)[0].shape[0]) \
            .sum()
        self.assertEqual(1429, training_count)

        print('Test')
        pprint(s.get_test_range())
        print(s.get_test_batch_count())
        test_count = seq(range(s.get_test_batch_count())) \
            .map(lambda batch_index: s.get_test_batch(batch_index)[0].shape[0]) \
            .sum()
        self.assertEqual(477, test_count)

        print('Cross_val')
        pprint(s.get_cross_val_range())
        print(s.get_cross_val_batch_count())
        crossval_count = seq(range(s.get_cross_val_batch_count())) \
            .map(lambda batch_index: s.get_cross_val_batch(batch_index)[0].shape[0]) \
            .sum()
        self.assertEqual(477, crossval_count)
コード例 #7
0
 def test_get_batch_across_index(self):
     s = AntVisionSequence(self.data_paths, self.batch_size,
                           self.bot_to_emulate)
     s.build_indexes(True)
     last_batch_index = s[floor(s.game_indexes[0].length % self.batch_size)]
     self.assertEqual((50, 12, 12, 7), last_batch_index[0].shape)
コード例 #8
0
 def test_get_last_batch(self):
     s = AntVisionSequence(self.data_paths, self.batch_size,
                           self.bot_to_emulate)
     s.build_indexes(True)
     last_batch = s[len(s) - 1]
     self.assertEqual((29, 12, 12, 7), last_batch[0].shape)
コード例 #9
0
 def test_get_batch(self):
     s = AntVisionSequence(self.data_paths, self.batch_size,
                           self.bot_to_emulate)
     s.build_indexes(True)
     blah = s[0]
     self.assertEqual((50, 12, 12, 7), blah[0].shape)