コード例 #1
0
ファイル: model_factory.py プロジェクト: Jason-Turan0/ants
def map_to_input(
        task: Tuple[str, EncodingType, str]) -> Tuple[np.ndarray, np.ndarray]:
    bot_name, type, game_path = task
    channel_count = 7
    gst = GameStateTranslator()
    if type == EncodingType.ANT_VISION_2D:
        feature_cache_path = game_path.replace(
            '.json', f'_ANT_VISION_2D_FEATURES_{bot_name}_{channel_count}.npy')
        label_cache_path = game_path.replace(
            '.json', f'_ANT_VISION_2D_LABELS_{bot_name}_{channel_count}.npy')
        if os.path.exists(feature_cache_path):
            return np.load(feature_cache_path), np.load(label_cache_path)
        gs = load_game_state(game_path)
        try:
            ant_vision = gst.convert_to_2d_ant_vision(bot_name, [gs])
            features, labels = enc.encode_2d_examples(ant_vision,
                                                      channel_count)
            print(f'Saving {feature_cache_path}')
            np.save(feature_cache_path, features)
            np.save(label_cache_path, labels)
            return features, labels
        except:
            print(f'Failed to load ${game_path}')
            return np.empty([0, 12, 12, 7]), np.empty([0, 5])

    elif type == EncodingType.MAP_2D:
        feature_cache_path = game_path.replace(
            '.json',
            f'_ANT_VISION_2DMAP_FEATURES_{bot_name}_{channel_count}.npy')
        label_cache_path = game_path.replace(
            '.json',
            f'_ANT_VISION_2DMAP_LABELS_{bot_name}_{channel_count}.npy')
        if os.path.exists(feature_cache_path):
            return np.load(feature_cache_path), np.load(label_cache_path)
        gs = load_game_state(game_path)
        try:
            ant_vision = gst.convert_to_antmap(bot_name, [gs])
            features, labels = enc.encode_map_examples(ant_vision,
                                                       channel_count)
            print(f'Saving {feature_cache_path}')
            np.save(feature_cache_path, features)
            np.save(label_cache_path, labels)
            return features, labels
        except:
            print(f'Failed to load ${game_path}')
            return np.empty([0, 43, 39, 7]), np.empty([0, 5])
    else:
        raise NotImplementedError()
コード例 #2
0
    def test_encode_2d_map(self):
        bot_to_emulate = 'pkmiec_1'

        gsg = GameStateGenerator()
        gst = GameStateTranslator()
        test_game_state = gsg.generate_from_file(self.data_path)
        expected_map = gst.convert_to_antmap(bot_to_emulate, [test_game_state])

        actual_encoded_map = map_to_input(
            (bot_to_emulate, EncodingType.MAP_2D, self.data_path))
        actual_decoded_map = decode_map_examples(actual_encoded_map)

        for index, expected in enumerate(expected_map):
            self.assertEqual(expected.label, actual_decoded_map[index].label)
            for expected_pos in expected.features.keys():
                self.assertEqual(
                    expected.features[expected_pos],
                    actual_decoded_map[index].features[expected_pos])