Exemplo n.º 1
0
def create_query_freq_based_trie(fdir, f_compressed, fname):

    i2r, i2c, smat = parallel_get_qp2q_sparse_data(fdir=fdir, compressed=f_compressed, n_jobs=16)
    sdf = SparseDataFrame(data_matrix=smat, columns=i2c, rows=i2r)

    LOGGER.info("Created sparsedataframe")

    trie = CharTrie()

    LOGGER.info("Computed sum of each col")
    flat_mat = np.sum(sdf.data_matrix, axis=0)  # Compute sum for each col
    assert flat_mat.shape == (1, sdf.data_matrix.shape[1])

    LOGGER.info("Creating label_freq dictionary")
    label_freq = {c: flat_mat[0, sdf.c2i[c]] for c in sdf.c2i}
    LOGGER.info("Created label_freq dictionary")

    trie.update(label_freq)
    LOGGER.info("Created trie")

    res_dir = os.path.dirname(fname)
    Path(res_dir).mkdir(exist_ok=True, parents=True)
    with open("{}".format(fname), "wb") as trie_file:
        pickle.dump(trie, trie_file, protocol=pickle.HIGHEST_PROTOCOL)

    LOGGER.info("Saved trie in {}/{}".format(res_dir, fname))
Exemplo n.º 2
0
def create_pref_to_topk_dict(fdir, f_compressed, fname, k):

    i2r, i2c, smat = parallel_get_qp2q_sparse_data(fdir=fdir, compressed=f_compressed, n_jobs=16)
    sdf = SparseDataFrame(data_matrix=smat, columns=i2c, rows=i2r)

    LOGGER.info("Created sparsedataframe")

    flat_mat = np.sum(sdf.data_matrix, axis=0)  # Compute sum for each col
    LOGGER.info("Computed sum of each col")
    assert flat_mat.shape == (1, sdf.data_matrix.shape[1])

    LOGGER.info("Creating label_freq dictionary")
    label_freq = {c: int(flat_mat[0, sdf.c2i[c]]) for c in sdf.c2i}
    LOGGER.info("Created label_freq dictionary")

    all_prefs = {label[:i]: [] for label in label_freq for i in range(len(label) + 1)}
    LOGGER.info("Number of prefixes = {}".format(len(all_prefs)))

    label_trie = CharTrie()
    label_trie.update(label_freq)
    LOGGER.info("Created label trie")

    all_prefs_to_topk = {
        pref: get_top_k_w_label_trie(prefix=pref, label_trie=label_trie, k=k)
        for pref in tqdm(all_prefs)
    }

    res_dir = os.path.dirname(fname)
    Path(res_dir).mkdir(exist_ok=True, parents=True)

    with open(fname, "w") as f:
        json.dump(all_prefs_to_topk, f)

    LOGGER.info("Saved top-k dict in {}".format(fname))
Exemplo n.º 3
0
 def _populate_trie(self, values: List[str]) -> CharTrie:
     """Takes a list and inserts its elements into a new trie and returns it"""
     if self._default_tokenizer:
         return reduce(self._populate_trie_reducer, iter(values),
                       CharTrie())
     return reduce(self._populate_trie_reducer_regex, iter(values),
                   CharTrie())
Exemplo n.º 4
0
Arquivo: run.py Projeto: ropc/meme-bot
def run_meme():
    parser = argparse.ArgumentParser()
    parser.add_argument('command',
                        nargs='+',
                        help="meme command. this does not need '!meme'")
    parser.add_argument('-w', '--wait', action='store_true', default=False)
    args = parser.parse_args()

    log.debug(f'debug args: {args}')

    ptvsd.enable_attach()
    if args.wait:
        ptvsd.wait_for_attach()

    # create dict of memes
    # TODO: have this work closer to how it works in memebot.py
    memes = CharTrie()
    for meme in ALL_MEMES:
        for alias in meme.aliases:
            memes[alias] = meme

    # run given meme
    command = ' '.join(args.command).strip()
    alias, meme = memes.longest_prefix(command)
    if not alias or not meme:
        log.error('no such meme found')
        exit(1)

    text = command[len(alias):].strip()
    asyncio.run(_run_generator(meme, text))
Exemplo n.º 5
0
 def __init__(self, season='recent'):
     self.athlete_web = nx.DiGraph()
     self.athletes_by_name = CharTrie()
     self.athletes_by_id = {}
     self.athletes_by_index = []
     self.race_history = set()
     self.athletes_considered = set()
     self.rankings = []
     self.search_queue = []
     self.season = season
Exemplo n.º 6
0
 def cs_filter(data_frame: pd.DataFrame) -> pd.DataFrame:
     """Filter out all commits that are not in the case study if one was
     selected."""
     if case_study is None or data_frame.empty:
         return data_frame
     # use a trie for fast prefix lookup
     revisions = CharTrie()
     for revision in case_study.revisions:
         revisions[revision.hash] = True
     return data_frame[data_frame["revision"].apply(
         lambda x: revisions.has_node(x.hash) != 0)]
Exemplo n.º 7
0
 def on_start(self):
     global user_data
     global stock_data
     global symbol_data
     global tag_trie
     global save_portfolio
     user_data = self.load_storage_data('data.json')
     if user_data is None: # first time opening the app
         user_data = {'PORTFOLIOS': [Portfolio('My First Portfolio', 10000).get_save_dict()]}
         stock_data = {}
         symbol_json = open('symbols.json')
         symbol_data = json.load(symbol_json)
         symbol_json.close()
         self.save_storage_data(user_data, 'data.json')
         self.save_storage_data(stock_data, 'stocks.json')
         self.save_storage_data(symbol_data, 'symbols.json')
     else:
         stock_data = self.load_storage_data('stocks.json')
         symbol_data = self.load_storage_data('symbols.json')
     
     tag_trie = CharTrie()
     for tag in symbol_data:
         tag_trie[tag] = True
     stock_scrape.stock_data_cache = stock_data
     stock_scrape.stock_data_save_func = lambda data: self.save_storage_data(data, 'stocks.json')
     load_portfolio(0)
     def _save_portfolio_func():
         global portfolio_changed
         portfolio_changed = True
         user_data['PORTFOLIOS'][current_portfolio_index] = current_portfolio.get_save_dict()
         self.save_storage_data(user_data, 'data.json')
     save_portfolio = _save_portfolio_func
Exemplo n.º 8
0
    def __init__(self,
                 outer_graph=gall,
                 ner_type_resolver=NERTypeResolver(),
                 metric_threshold=0.8,
                 strict_type_match=True):
        self.ntr = ner_type_resolver

        # Init storage
        self._trie = CharTrie()
        self._metric_threshold = metric_threshold
        self._strict_type_match = strict_type_match
        self._allowed_types = ENT_CLASSES

        self.predicate_namespace = dbo  # todo: move to constructor args
        self.outer_graph = outer_graph
        self.cache = dict()
Exemplo n.º 9
0
class TrieRule:
    prefix: CharTrie = CharTrie()

    @classmethod
    def add_prefix(cls, prefix: str, value: Any):
        if prefix in cls.prefix:
            logger.warning(f'Duplicated prefix rule "{prefix}"')
            return
        cls.prefix[prefix] = value

    @classmethod
    def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
        prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None)
        state[PREFIX_KEY] = prefix
        if event.get_type() != "message":
            return prefix

        message = event.get_message()
        message_seg: MessageSegment = message[0]
        if message_seg.is_text():
            segment_text = str(message_seg).lstrip()
            pf = cls.prefix.longest_prefix(segment_text)
            prefix[RAW_CMD_KEY] = pf.key
            prefix[CMD_KEY] = pf.value
            if pf.key:
                msg = message.copy()
                msg.pop(0)
                new_message = msg.__class__(
                    segment_text[len(pf.key):].lstrip())
                for new_segment in reversed(new_message):
                    msg.insert(0, new_segment)
                prefix[CMD_ARG_KEY] = msg

        return prefix
Exemplo n.º 10
0
class Index:
    def __init__(self):
        self.zoznamkrajin = set()
        self.uzly = {}
        self.strom = CharTrie()
        self.strom.enable_sorting(enable=True)

    def pridajzaznam(self, zaznam):
        name = zaznam.attrib['name'].split("_")
        typ = zaznam.attrib['type']
        nazov = ''

        if typ == 'hillshade':
            name.remove(name[0])
            nazov = name[0]

        elif typ == 'voice':
            jazyk = name[0].split('-')[0]
            try:
                lokale = Locale.parse(babel.core.LOCALE_ALIASES[jazyk])
                nazov = lokale.get_territory_name(Locale('en'))

            except KeyError:
                print('Nerozoznal som skratku jazyka: ' + jazyk)
                nazov = 'Voice ' + jazyk

            print(nazov)

        else:
            nazov = name[0]

        if len(name) > 3:
            nazov += '-'
            nazov += name[1]
        if nazov not in self.zoznamkrajin:
            self.zoznamkrajin.add(nazov)
            self.uzly[nazov] = Uzol(nazov)
        self.uzly[nazov].pridajzaznam(Zaznam(zaznam))

    def hladaj(self, text):
        return self.strom.itervalues(prefix=text.lower())

    def spravstrom(self):
        for i in self.uzly:
            self.strom[i.lower()] = self.uzly[i]
            print("Pridal som: " + i.lower() + ", " + str(self.uzly[i]))
Exemplo n.º 11
0
 def __populate_trie_reducer(self, trie_accumulator=CharTrie(), value="") -> CharTrie:
     """Adds value to trie accumulator"""
     if self.case_sensitive:
         key = self.joiner.join([x.orth_ for x in self.default_tokenizer.tokenize(value)])
     else:
         key = self.joiner.join([x.lower_ for x in self.default_tokenizer.tokenize(value)])
     trie_accumulator[key] = value
     return trie_accumulator
Exemplo n.º 12
0
 def _populate_trie_reducer_regex(self, trie_accumulator=CharTrie(), value="") -> CharTrie:
     """Adds value to trie accumulator"""
     regex = re.compile(r"[A-Za-z0-9]+|[^\w\s]|_")
     if self._case_sensitive:
         key = self._joiner.join([x for x in re.findall(regex, value)])
     else:
         key = self._joiner.join([x.lower() for x in re.findall(regex, value)])
     trie_accumulator[key] = value
     return trie_accumulator
Exemplo n.º 13
0
class TrieRule:
    prefix: CharTrie = CharTrie()
    suffix: CharTrie = CharTrie()

    @classmethod
    def add_prefix(cls, prefix: str, value: Any):
        if prefix in cls.prefix:
            logger.warning(f'Duplicated prefix rule "{prefix}"')
            return
        cls.prefix[prefix] = value

    @classmethod
    def add_suffix(cls, suffix: str, value: Any):
        if suffix[::-1] in cls.suffix:
            logger.warning(f'Duplicated suffix rule "{suffix}"')
            return
        cls.suffix[suffix[::-1]] = value

    @classmethod
    def get_value(cls, bot: Bot, event: Event,
                  state: dict) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        if event.type != "message":
            state["_prefix"] = {}
            state["_suffix"] = {}
            return {}, {}

        prefix = None
        suffix = None
        message = event.message[0]
        if message.type == "text":
            prefix = cls.prefix.longest_prefix(message.data["text"].lstrip())
        message_r = event.message[-1]
        if message_r.type == "text":
            suffix = cls.suffix.longest_prefix(
                message_r.data["text"].rstrip()[::-1])

        state["_prefix"] = {prefix.key: prefix.value} if prefix else {}
        state["_suffix"] = {suffix.key: suffix.value} if suffix else {}

        return ({
            prefix.key: prefix.value
        } if prefix else {}, {
            suffix.key: suffix.value
        } if suffix else {})
Exemplo n.º 14
0
 def _populate_trie_reducer(self, trie_accumulator=CharTrie(), value="") -> CharTrie:
     """Adds value to trie accumulator"""
     if self._case_sensitive:
         key = self._joiner.join([x.orth_ if isinstance(x, Token) else x for x in
                                  self._default_tokenizer.tokenize(value, disable=disable_spacy)])
     else:
         key = self._joiner.join([x.lower_ if isinstance(x, Token) else x.lower() for x in
                                  self._default_tokenizer.tokenize(value, disable=disable_spacy)])
     trie_accumulator[key] = value
     return trie_accumulator
Exemplo n.º 15
0
 def __init__(self, cells, dictionary=None):
     self.cells = [
         Cell(
             c['letter'],
             c.get('points', 1),
             c.get('P', 1),
             c.get('S', 1),
         ) for c in cells
     ]
     self.letters = [c['letter'] for c in cells]
     self.candidates = {}  # word => (positions, value)
     self.dictionary = dictionary
     if not dictionary:
         self.dictionary = CharTrie()
         print('Loading custom words ... ', end="")
         with open('SK-custom.txt') as f:
             self.dictionary.update(
                 (w, True) for w in f.read().splitlines())
             print(len(self.dictionary))
Exemplo n.º 16
0
    def __init__(self,
                 command_prefix,
                 help_command=EmbedHelpCommand(),
                 description=None,
                 **options):
        super().__init__(command_prefix,
                         help_command=help_command,
                         description=description,
                         **options)
        # make sure this is a Trie
        self.all_commands = CharTrie(self.all_commands)

        config = get_guild_config_dict(os.getenv('MEME_BOT_GUILD_CONFIG',
                                                 '{}'))

        # cogs setup
        # TODO: this needs auto dependency management
        ooc_cog = OutOfContext(self,
                               int(os.getenv('MEME_BOT_OOC_CHANNEL_ID', 0)))
        self.add_cog(Quote())
        self.add_cog(ooc_cog)
        self.add_cog(ChatStats())
        self.add_cog(
            RollDice(
                int(x)
                for x in os.getenv('MEME_BOT_UNLUCKY_ROLL_IDS', '').split(',')
                if x))
        self.add_cog(Player(self, config))
        self.add_cog(Meme(self))
        self.add_cog(Beans())
        self.add_cog(Meta(self, config, LOG_FILE))
        self.add_cog(WolframAlpha(os.getenv('MEME_BOT_WOLFRAM_ALPHA_KEY', '')))
        self.add_cog(TarotCard())
        self.add_cog(
            Suggest(
                os.getenv('MEME_BOT_GITHUB_TOKEN', ''),
                os.getenv('MEME_BOT_SUGGESTION_GITHUB_PROJECT_COLUMN_ID', '')))
        self.add_cog(
            Reminder(
                self,
                os.getenv('REMINDERS_SAVE_FILE_PATH', './reminders.pickle')))
        self.add_cog(Hey(ooc_cog=ooc_cog))
Exemplo n.º 17
0
 def __init__(self, keep_naming=False, paths=None, options=None):
     if options == None:
         # if there are paths, keep empty, if there are not, set all
         options = []
         if paths is None:
             options = ["HIERARCHIES"]
     if paths is None:
         paths = CharTrie()
     super().__init__(options, paths)
     self.keep_naming = keep_naming
     self.transform_list_elements = True
Exemplo n.º 18
0
 def _build_tree(self, words):
     """
     构建trie树
     :param words:
     :return:
     """
     prefix_tree = CharTrie()
     for seq in words:
         py_str = ''.join([item[0] for item in pinyin(seq, style=pypinyin.NORMAL)])
         if py_str in prefix_tree:
             prefix_tree[py_str].append(seq)
         else:
             prefix_tree[py_str] = [seq]
     self.tree = prefix_tree
Exemplo n.º 19
0
class Autocompleter():
    """Autocomplete System.

    Maintains a trie with keys from a given corpus of words.
    Gives autocompletion suggestions by retrieving all keys for a give prefix.
    """
    def __init__(self, words):
        """Initialize a autocompleter with a given set of words."""
        self.trie = CharTrie((word, True) for word in words)

    def suggest(self, prefix):
        """Return all words in the corpus starting with a given prefix."""
        try:
            return self.trie.keys(prefix=prefix)
        except KeyError:
            return []
Exemplo n.º 20
0
 def __init__ (self, name: Text, layers: List[Layer], keyboard: PhysicalKeyboard):
     # XXX: add sanity checks (i.e. modifier are not used elsewhere, no duplicates, …)
     self.name = name
     self.layers = layers
     self.keyboard = keyboard
     self._modifierToLayer : Dict[FrozenSet[Button], Tuple[int, Layer]] = dict ()
     self.bufferLen = 0
     t = self.t = CharTrie ()
     for i, l in enumerate (layers):
         for m in l.modifier:
             self._modifierToLayer[m] = (i, l)
         for button, v in l.layout.items ():
             if isinstance (v, str):
                 t.setdefault (v, [])
                 for m in l.modifier:
                     comb = ButtonCombination (m, frozenset ([button]))
                     t[v].append (comb)
                 self.bufferLen = max (len (v), self.bufferLen)
Exemplo n.º 21
0
    def _build_tree(self, words, level):
        """
        构建trie树
        :param words:
        :param level:
        :return:
        """
        prefix_tree = CharTrie()
        for seq in words:
            py_str = ''.join([
                item[0] for item in pinyin(seq, style=pypinyin.NORMAL)
            ])  # 默认没有额外编码
            # confused音编码
            if level == PinyinLevel.CONFUSED:
                py_str = py_maker.translate(word=seq, level=level.value)

            if py_str in prefix_tree:
                prefix_tree[py_str].append(seq)
            else:
                prefix_tree[py_str] = [seq]
        self.tree = prefix_tree
def get_available_collections():
    aliases_by_index = es_source.indices.get_alias(name='ori_*')
    aliases_by_collection = CharTrie({
        alias[4:]: alias
        for props in aliases_by_index.values() for alias in props['aliases']
    })

    def get_alias(collection):
        if aliases_by_collection.has_subtrie(collection + '_'):
            return u'ori_{}_*'.format(collection)
        else:
            return aliases_by_collection.get(collection)

    ori_base_url = 'http://api.openraadsinformatie.nl/v0/'
    resp = requests.post(ori_base_url + 'search/organizations',
                         json={
                             'filters': {
                                 'classification': {
                                     'terms': ['Municipality']
                                 }
                             },
                             'size': 500
                         })
    data = resp.json()
    if data['meta']['total'] > 500:
        print('WARNING: only loading 500/{} municipalities'.format(
            data['meta']['total']))

    available_collections = {
        next(ref['identifier'] for ref in org['identifiers']
             if ref['scheme'] == 'CBS'): {
                 'ori_name': org['name'],
                 'ori_alias': get_alias(org['meta']['collection'])
             }
        for org in data['organizations']
    }
    return available_collections
Exemplo n.º 23
0
    def __init__(self):
        super().__init__()

        if not os.path.isdir(str(os.getcwd()) + '/saves'):
            os.mkdir(str(os.getcwd()) + '/saves')

        with open(str(os.getcwd()) + '/namesIDs.csv', 'r') as backup_csv:
            global ID_ARCHIVE
            reader = csv.reader(backup_csv)
            ID_ARCHIVE = CharTrie(
                map(lambda x: (x[0] + ': ' + x[1], x[1]), reader))

        self.pages = {}
        for page in (StartPage, PageOne, GatherPage, ViewPage, LoadingFrame):
            new_page = page(self)
            self.pages[page] = new_page
            new_page.grid(row=0, column=0, sticky='nsew')

        self.set_page(StartPage)

        self.wm_title("Runner Rank")
        self.save = None

        self.configure()
Exemplo n.º 24
0
class Solver:
    def __init__(self, cells, dictionary=None):
        self.cells = [
            Cell(
                c['letter'],
                c.get('points', 1),
                c.get('P', 1),
                c.get('S', 1),
            ) for c in cells
        ]
        self.letters = [c['letter'] for c in cells]
        self.candidates = {}  # word => (positions, value)
        self.dictionary = dictionary
        if not dictionary:
            self.dictionary = CharTrie()
            print('Loading custom words ... ', end="")
            with open('SK-custom.txt') as f:
                self.dictionary.update(
                    (w, True) for w in f.read().splitlines())
                print(len(self.dictionary))

    def word_value(self, visited):
        val = 0
        S = 1
        for pos in visited:
            cell = self.cells[pos]
            val += cell.points * cell.P
            S *= cell.S
        return val * S

    def next_char(self, visited, pos):
        visited = visited + (pos, )
        word = ''.join((self.letters[p] for p in visited))
        has = self.dictionary.has_node(word)
        if has & CharTrie.HAS_VALUE:
            newval = self.word_value(visited)
            if self.candidates.get(word, (None, 0))[1] < newval:
                self.candidates[word] = (visited, newval)
            # print(word)

        # Don't continue if thera are no words with this prefix
        if not has & CharTrie.HAS_SUBTRIE:
            return

        row = pos // 4
        col = pos % 4
        prev_row = row - 1
        next_row = row + 1
        prev_col = col - 1
        next_col = col + 1

        if next_row < 4:
            # Adds the charcter S the current pos
            pos = next_row * 4 + col
            if pos not in visited:
                self.next_char(visited, pos)
            #Adds the character SE the current pos
            pos = next_row * 4 + next_col
            if next_col < 4 and pos not in visited:
                self.next_char(visited, pos)

        if next_col < 4:
            # Adds the charcter E of the current pos
            pos = row * 4 + next_col
            if pos not in visited:
                self.next_char(visited, pos)
            #Adds the character NE the current pos
            pos = prev_row * 4 + next_col
            if prev_row >= 0 and pos not in visited:
                self.next_char(visited, pos)

        if prev_row >= 0:
            # Adds the charcter N the current pos
            pos = prev_row * 4 + col
            if pos not in visited:
                self.next_char(visited, pos)
            # Adds the charcter NW of the current pos
            pos = prev_row * 4 + prev_col
            if prev_col >= 0 and pos not in visited:
                self.next_char(visited, pos)

        if prev_col >= 0:
            # Adds the charcter W of the current pos
            pos = row * 4 + prev_col
            if pos not in visited:
                self.next_char(visited, pos)
            # Adds the charcter SW of the current pos
            pos = next_row * 4 + prev_col
            if next_row < 4 and pos not in visited:
                self.next_char(visited, pos)

    def solve_and_swipe(self):
        for pos in range(0, 16):
            self.next_char((), pos)

        # sort by length
        pyautogui.moveTo(x0, y0, duration=0.1)
        pyautogui.click(duration=0.1)
        for word, visited_val in self.candidates.items():
            print(word, visited_val[1])
            self.swipe(visited_val[0])

    def swipe(self, positions):
        grid = 94
        pos = positions[0]
        pyautogui.moveTo(x0 + (pos % 4) * grid,
                         y0 + (pos // 4) * grid,
                         duration=0.1)
        pyautogui.mouseDown(button='left',
                            logScreenshot=False,
                            _pause=False,
                            duration=0.1)
        for pos in positions[1:]:
            pyautogui.moveTo(x0 + (pos % 4) * grid,
                             y0 + (pos // 4) * grid,
                             duration=0.1)
        pyautogui.mouseUp(button='left',
                          logScreenshot=False,
                          _pause=False,
                          duration=0.1)
Exemplo n.º 25
0
from pygtrie import CharTrie
import os

trie = CharTrie()

trie['hi'] = 'yes'
trie['bye'] = 'no'

print 'entries'
for entry in trie:
    print entry

print list(trie['h':])
Exemplo n.º 26
0
 def populate_trie(self, values: List[str]) -> CharTrie:
     """Takes a list and inserts its elements into a new trie and returns it"""
     return reduce(self.__populate_trie_reducer, iter(values), CharTrie())
Exemplo n.º 27
0
"""Implement the item searchbar for filtering items by various keywords.
"""
from tkinter import ttk
import tkinter as tk
from typing import Optional, Set, Callable, Tuple

import srctools.logger
from pygtrie import CharTrie

from app import UI, TK_ROOT
from localisation import gettext

LOGGER = srctools.logger.get_logger(__name__)
word_to_ids: 'CharTrie[Set[Tuple[str, int]]]' = CharTrie()
_type_cback: Optional[Callable[[], None]] = None


def init(
        frm: tk.Frame,
        refresh_cback: Callable[[Optional[Set[Tuple[str, int]]]],
                                None]) -> None:
    """Initialise the UI objects.

    The callback is triggered whenever the UI changes, passing along
    the visible items or None if no filter is specified.
    """
    global _type_cback
    refresh_tim: Optional[str] = None
    result: Optional[set[tuple[str, int]]] = None

    def on_type(*args) -> None:
Exemplo n.º 28
0
 def __init__(self, stream: tp.Iterable[str]) -> None:
     self.__hash_to_id: CharTrie = CharTrie()
     for line in stream:
         slices = line.strip().split(', ')
         self.__hash_to_id[slices[1]] = int(slices[0])
Exemplo n.º 29
0
class AutoCompleteEngine:
    """
    AutoCompleteEngine generates suggestions given a users input.
    """
    def __init__(self, magic_commands: Optional[Iterable[str]]):
        self._magics_args_suggesters: Dict[str, Callable] = {}
        self._commands_trie = CharTrie()
        if magic_commands is not None:
            for magic in magic_commands:
                self.add_magic_command(magic)

    def suggest(self, code: str, cursor_pos: int) -> Dict:
        """
        @param code: string contains the current state of the user's input. It could be a CWL file
        or magic commands.
        @param cursor_pos: current position of cursor
        @return: {'matches': ['MATCH1', 'MATCH1'],
                'cursor_end': ,
                'cursor_start': , }
        """
        matches = []
        cursor_end = cursor_pos
        cursor_start = cursor_pos
        line_classifier = re.compile(
            r'(?P<command>^%[ ]+[\w]*)(?P<args>( [\S]*)*)', re.MULTILINE)
        for match in line_classifier.finditer(code):  # type: re.Match
            if match.start('command') <= cursor_pos <= match.end('command'):
                new_cursor_pos = cursor_pos - match.span()[0]
                code = match.group()
                matches, cursor_start, cursor_end = self._suggest_magic_command(
                    code, new_cursor_pos)
                cursor_start += match.span()[0]
                cursor_end += match.span()[0]
            elif match.span()[0] <= cursor_pos <= match.span()[1]:
                new_cursor_pos = cursor_pos - match.start('args')
                code = match.group('args')
                command = match.group('command')[1:].strip()
                matches, cursor_start, cursor_end = self._suggest_magics_arguments(
                    command, code, new_cursor_pos)
                cursor_start += match.start('args')
                cursor_end += match.start('args')
        return {
            'matches': matches,
            'cursor_end': cursor_end,
            'cursor_start': cursor_start
        }

    def _suggest_magic_command(self, code: str,
                               cursor_pos: int) -> Tuple[List[str], int, int]:
        cursor_end, cursor_start, token = self._parse_tokens(code, cursor_pos)
        if token == '%':
            token = ''
        try:
            matches = [
                m for m in set(self._commands_trie.values(prefix=token))
            ]
            matches.sort(key=len)
        except KeyError:
            matches = []
            cursor_end = cursor_pos
            cursor_start = cursor_pos
        return matches, cursor_start, cursor_end

    def _suggest_magics_arguments(
            self, command: str, code: str,
            cursor_pos: int) -> Tuple[List[str], int, int]:
        """Stateless command's arguments suggester"""
        cursor_end, cursor_start, query_token = self._parse_tokens(
            code, cursor_pos)
        options: List[str] = self._magics_args_suggesters[command](query_token)
        return options, cursor_start, cursor_end

    def add_magic_commands_suggester(self, magic_name: str,
                                     suggester: Callable) -> None:
        self._magics_args_suggesters[magic_name] = suggester

    @classmethod
    def _parse_tokens(cls, code, cursor_pos):
        code_length = len(code)
        token_ends_at = code.find(" ", cursor_pos)
        cursor_end = min(token_ends_at + 1, code_length - 1)
        if token_ends_at == -1:
            token_ends_at = code_length - 1
            cursor_end = code_length
        token_starts_at = code.rfind(" ", 0, cursor_pos)
        cursor_start = token_starts_at + 1
        if token_starts_at == -1:
            token_starts_at = 0
            cursor_start = cursor_pos
        token = code[token_starts_at:token_ends_at + 1].strip().upper()
        return cursor_end, cursor_start, token

    def add_magic_command(self, magic_command_name: str):
        for i in range(1, len(magic_command_name) + 1):
            self._commands_trie[
                magic_command_name[-i:].upper()] = magic_command_name
Exemplo n.º 30
0
 def __init__(self, magic_commands: Optional[Iterable[str]]):
     self._magics_args_suggesters: Dict[str, Callable] = {}
     self._commands_trie = CharTrie()
     if magic_commands is not None:
         for magic in magic_commands:
             self.add_magic_command(magic)