コード例 #1
0
ファイル: model_factory.py プロジェクト: Jason-Turan0/ants
def map_to_input(
        task: Tuple[str, EncodingType, str]) -> Tuple[np.ndarray, np.ndarray]:
    bot_name, type, game_path = task
    channel_count = 7
    gst = GameStateTranslator()
    if type == EncodingType.ANT_VISION_2D:
        feature_cache_path = game_path.replace(
            '.json', f'_ANT_VISION_2D_FEATURES_{bot_name}_{channel_count}.npy')
        label_cache_path = game_path.replace(
            '.json', f'_ANT_VISION_2D_LABELS_{bot_name}_{channel_count}.npy')
        if os.path.exists(feature_cache_path):
            return np.load(feature_cache_path), np.load(label_cache_path)
        gs = load_game_state(game_path)
        try:
            ant_vision = gst.convert_to_2d_ant_vision(bot_name, [gs])
            features, labels = enc.encode_2d_examples(ant_vision,
                                                      channel_count)
            print(f'Saving {feature_cache_path}')
            np.save(feature_cache_path, features)
            np.save(label_cache_path, labels)
            return features, labels
        except:
            print(f'Failed to load ${game_path}')
            return np.empty([0, 12, 12, 7]), np.empty([0, 5])

    elif type == EncodingType.MAP_2D:
        feature_cache_path = game_path.replace(
            '.json',
            f'_ANT_VISION_2DMAP_FEATURES_{bot_name}_{channel_count}.npy')
        label_cache_path = game_path.replace(
            '.json',
            f'_ANT_VISION_2DMAP_LABELS_{bot_name}_{channel_count}.npy')
        if os.path.exists(feature_cache_path):
            return np.load(feature_cache_path), np.load(label_cache_path)
        gs = load_game_state(game_path)
        try:
            ant_vision = gst.convert_to_antmap(bot_name, [gs])
            features, labels = enc.encode_map_examples(ant_vision,
                                                       channel_count)
            print(f'Saving {feature_cache_path}')
            np.save(feature_cache_path, features)
            np.save(label_cache_path, labels)
            return features, labels
        except:
            print(f'Failed to load ${game_path}')
            return np.empty([0, 43, 39, 7]), np.empty([0, 5])
    else:
        raise NotImplementedError()
コード例 #2
0
    def test_encode_2d_ant_vision(self):
        bot_to_emulate = 'pkmiec_1'

        gsg = GameStateGenerator()
        gst = GameStateTranslator()
        test_game_state = gsg.generate_from_file(self.data_path)
        expected_ant_vision = gst.convert_to_2d_ant_vision(
            bot_to_emulate, [test_game_state])

        actual_encoded_ant_vision = map_to_input(
            (bot_to_emulate, EncodingType.ANT_VISION_2D, self.data_path))
        actual_decoded_ant_vision = decode_ant_vision_2d_examples(
            actual_encoded_ant_vision)

        for index, expected in enumerate(expected_ant_vision):
            self.assertEqual(expected.label,
                             actual_decoded_ant_vision[index].label)
            for expected_pos in expected.features.keys():
                self.assertEqual(
                    expected.features[expected_pos],
                    actual_decoded_ant_vision[index].features[expected_pos])
コード例 #3
0
 def test_create_2d_ant_vision(self):
     game_state = create_test_game_state()
     translator = GameStateTranslator()
     translated = translator.convert_to_2d_ant_vision(
         'pkmiec_1', [game_state])
     self.assertIsNotNone(translated)