def __init__(self, tree: GitObject, rev: str): self.tree = tree self.rev = rev self.trie = Trie() self.trie[()] = tree self._build(tree, ())
def _dfs_visit( cube: Tuple[int, int], grid: List[List[str]], word_list: pygtrie.Trie, word: str, cubes_visited: Set[Tuple[int, int]], ) -> Set[str]: """ Visit the next state in a depth-first search for words in the letter grid. :param Tuple[int, int] cube: grid indices of a letter cube :param List[List[str]] grid: a grid of letter cubes to search :param pygtrie.Trie word_list: a Trie containing all valid words :param str word: a string of letters representing the current search path :param Set[Tuple[int, int]] cubes_visited: set of cube indices already visited :return: set of valid words found from the current search state :rtype: Set[str] """ i, j = cube # Capitalize all words word += grid[i][j].upper() if not word_list.has_node(word): return set() words_found = {word} if word_list.has_key(word) and len(word) > 2 else set() neighbors = _get_neighboring_cubes(cube, grid, cubes_visited) neighboring_words = [ _dfs_visit(n, grid, word_list, word, cubes_visited.union({cube})) for n in neighbors ] return reduce(lambda x, y: x.union(y), neighboring_words, words_found)
def __init__(self, fs, root_dir): from dvc.repo import Repo default_ignore_patterns = [ ".hg/", ".git/", ".git", f"{Repo.DVC_DIR}/", ] self.fs = fs self.root_dir = root_dir self.ignores_trie_fs = Trie() self._ignores_trie_subrepos = Trie() key = self._get_key(root_dir) self.ignores_trie_fs[key] = DvcIgnorePatterns( default_ignore_patterns, root_dir, fs.sep, ) self._ignores_trie_subrepos[key] = self.ignores_trie_fs[key] self._update( self.root_dir, self._ignores_trie_subrepos, dnames=None, ignore_subrepos=False, ) self._update( self.root_dir, self.ignores_trie_fs, dnames=None, ignore_subrepos=True, )
def build_outs_trie(stages): outs = Trie() for stage in filter(bool, stages): # bug? not using it later for out in stage.outs: out_key = out.path_info.parts # Check for dup outs if out_key in outs: dup_stages = [stage, outs[out_key].stage] raise OutputDuplicationError(str(out), dup_stages) # Check for overlapping outs if outs.has_subtrie(out_key): parent = out overlapping = first(outs.values(prefix=out_key)) else: parent = outs.shortest_prefix(out_key).value overlapping = out if parent and overlapping: msg = ( "Paths for outs:\n'{}'('{}')\n'{}'('{}')\n" "overlap. To avoid unpredictable behaviour, " "rerun command with non overlapping outs paths." ).format( str(parent), parent.stage.addressing, str(overlapping), overlapping.stage.addressing, ) raise OverlappingOutputPathsError(parent, overlapping, msg) outs[out_key] = out return outs
def build_outs_trie(stages): outs = Trie() for stage in stages: for out in stage.outs: out_key = out.fs.path.parts(out.fs_path) # Check for dup outs if out_key in outs: dup_stages = [stage, outs[out_key].stage] raise OutputDuplicationError(str(out), dup_stages) # Check for overlapping outs if outs.has_subtrie(out_key): parent = out overlapping = first(outs.values(prefix=out_key)) else: parent = outs.shortest_prefix(out_key).value overlapping = out if parent and overlapping: msg = ( "The output paths:\n'{}'('{}')\n'{}'('{}')\n" "overlap and are thus in the same tracked directory.\n" "To keep reproducibility, outputs should be in separate " "tracked directories or tracked individually.").format( str(parent), parent.stage.addressing, str(overlapping), overlapping.stage.addressing, ) raise OverlappingOutputPathsError(parent, overlapping, msg) outs[out_key] = out return outs
def keys_without_children(trie: Trie): result = [] def childless_collector(key_transformer, path, children, _=None): if not list(children): result.append(key_transformer(path)) trie.traverse(childless_collector) return result
class DirInfo: PARAM_RELPATH = "relpath" def __init__(self): self.trie = Trie() @property def size(self): try: return sum(hash_info.size for _, hash_info in self.trie.iteritems() # noqa: B301 ) except TypeError: return None @property def nfiles(self): return len(self.trie) def items(self, path_info=None): for key, hash_info in self.trie.iteritems(): # noqa: B301 path = posixpath.sep.join(key) if path_info is not None: path = path_info / path yield path, hash_info @classmethod def from_list(cls, lst): ret = DirInfo() for _entry in lst: entry = _entry.copy() relpath = entry.pop(cls.PARAM_RELPATH) parts = tuple(relpath.split(posixpath.sep)) ret.trie[parts] = HashInfo.from_dict(entry) return ret def to_list(self): # Sorting the list by path to ensure reproducibility return sorted( ( { # NOTE: not using hash_info.to_dict() because we don't want # size/nfiles fields at this point. hash_info.name: hash_info.value, self.PARAM_RELPATH: posixpath.sep.join(parts), } for parts, hash_info in self.trie.iteritems() # noqa: B301 ), key=itemgetter(self.PARAM_RELPATH), ) def merge(self, ancestor, their): merged = DirInfo() merged.trie = _merge(ancestor.trie, self.trie, their.trie) return merged
def _update_sub_repo(self, path, ignore_trie: Trie): from dvc.repo import Repo if path == self.root_dir: return dvc_dir = self.fs.path.join(path, Repo.DVC_DIR) if not self.fs.exists(dvc_dir): return root, dname = self.fs.path.split(path) key = self._get_key(root) pattern_info = PatternInfo(f"/{dname}/", f"in sub_repo:{dname}") new_pattern = DvcIgnorePatterns([pattern_info], root, self.fs.sep) old_pattern = ignore_trie.longest_prefix(key).value if old_pattern: plist, prefix = merge_patterns( self.fs.path.flavour, old_pattern.pattern_list, old_pattern.dirname, new_pattern.pattern_list, new_pattern.dirname, ) ignore_trie[key] = DvcIgnorePatterns(plist, prefix, self.fs.sep) else: ignore_trie[key] = new_pattern
def walk(self, top, topdown=True, **kwargs): from pygtrie import Trie assert topdown if not self.exists(top): raise FileNotFoundError if not self.isdir(top): raise NotADirectoryError root = PathInfo(os.path.abspath(top)) outs = self._find_outs(top, recursive=True, strict=False) trie = Trie() for out in outs: trie[out.path_info.parts] = out if out.is_dir_checksum and (self.fetch or self.stream): # pull dir cache if needed dir_cache = out.get_dir_cache(**kwargs) # pull dir contents if needed if self.fetch: if out.changed_cache(filter_info=top): used_cache = out.get_used_cache(filter_info=top) self.repo.cloud.pull(used_cache, **kwargs) for entry in dir_cache: entry_relpath = entry[out.remote.PARAM_RELPATH] path_info = out.path_info / entry_relpath trie[path_info.parts] = None yield from self._walk(root, trie, topdown=topdown)
def walk(self, top, topdown=True, onerror=None, **kwargs): from pygtrie import Trie assert topdown root = os.path.abspath(top) try: info = self.info(root) except FileNotFoundError: if onerror is not None: onerror(FileNotFoundError(top)) return if info["type"] != "directory": if onerror is not None: onerror(NotADirectoryError(top)) return trie = Trie() for out in info["outs"]: trie[out.fs.path.parts(out.fs_path)] = out if out.is_dir_checksum and self.path.isin_or_eq(root, out.fs_path): self._add_dir(trie, out, **kwargs) yield from self._walk(root, trie, topdown=topdown, **kwargs)
def walk(self, top, topdown=True): from pygtrie import Trie assert topdown if not self.exists(top): raise FileNotFoundError if not self.isdir(top): raise NotADirectoryError root = PathInfo(os.path.abspath(top)) outs = self._find_outs(top, recursive=True, strict=False) trie = Trie() for out in outs: trie[out.path_info.parts] = out if out.is_dir_checksum and (self.fetch or self.stream): # will pull dir cache if needed with self.repo.state: cache = out.collect_used_dir_cache() for _, names in cache.scheme_names(out.scheme): for name in names: path_info = out.path_info.parent / name trie[path_info.parts] = None yield from self._walk(root, trie, topdown=topdown)
def walk(self, top, topdown=True, onerror=None, **kwargs): from pygtrie import Trie assert topdown if not self.exists(top): if onerror is not None: onerror(FileNotFoundError(top)) return if not self.isdir(top): if onerror is not None: onerror(NotADirectoryError(top)) return root = PathInfo(os.path.abspath(top)) outs = self._find_outs(top, recursive=True, strict=False) trie = Trie() for out in outs: trie[out.path_info.parts] = out if out.is_dir_checksum and root.isin_or_eq(out.path_info): self._add_dir(top, trie, out, **kwargs) yield from self._walk(root, trie, topdown=topdown, **kwargs)
def walk(self, top, topdown=True, onerror=None, **kwargs): from pygtrie import Trie assert topdown root = PathInfo(os.path.abspath(top)) try: meta = self.metadata(root) except OutputNotFoundError: if onerror is not None: onerror(FileNotFoundError(top)) return if not meta.isdir: if onerror is not None: onerror(NotADirectoryError(top)) return trie = Trie() for out in meta.outs: trie[out.path_info.parts] = out if out.is_dir_checksum and root.isin_or_eq(out.path_info): self._add_dir(top, trie, out, **kwargs) yield from self._walk(root, trie, topdown=topdown, **kwargs)
def filter_legal_actions_for_vw_node(legal_actions: Tuple[int], is_recursive: bool, current_trie: pygtrie.Trie): if is_recursive: # We need to check that only pieces which can be extended are valid. legal_actions = [ action for action in legal_actions if current_trie.has_subtrie((action, )) ] else: # We need to check that only pieces which yield a valid word are # valid. legal_actions = [ action for action in legal_actions if current_trie.has_key((action, )) ] return legal_actions
def __init__( self, repo: Optional["Repo"] = None, subrepos=False, repo_factory: RepoFactory = None, **kwargs, ): super().__init__() from pygtrie import Trie if repo is None: repo, repo_factory = self._repo_from_fs_config(subrepos=subrepos, **kwargs) if not repo_factory: from dvc.repo import Repo self.repo_factory: RepoFactory = Repo else: self.repo_factory = repo_factory def _getcwd(): relparts = () if repo.fs.path.isin(repo.fs.path.getcwd(), repo.root_dir): relparts = repo.fs.path.relparts(repo.fs.path.getcwd(), repo.root_dir) return self.root_marker + self.sep.join(relparts) self.path = Path(self.sep, getcwd=_getcwd) self.repo = repo self.hash_jobs = repo.fs.hash_jobs self._traverse_subrepos = subrepos self._subrepos_trie = Trie() """Keeps track of each and every path with the corresponding repo.""" key = self._get_key(self.repo.root_dir) self._subrepos_trie[key] = repo self._datafss = {} """Keep a datafs instance of each repo.""" if hasattr(repo, "dvc_dir"): self._datafss[key] = DataFileSystem(repo=repo)
def remove_children_of(trie: Trie, keys): def delete_key(key): del trie[key] seq(keys) \ .filter(lambda key: key in trie) \ .flat_map(lambda key: trie.keys(prefix=key)) \ .filter(lambda key: key not in keys) \ .for_each(delete_key)
def build_trie(alphabet, vocab): from pygtrie import CharTrie as Trie trie = Trie() start_time = datetime.now() info('start building trie at {}'.format( start_time.strftime("%H:%M:%S"))) for v in vocab: trie[v] = 1 end_time = datetime.now() info('finish building trie at {} (delta {})'.format( end_time.strftime("%H:%M:%S"), end_time - start_time)) return trie
def cal_left_right_entropy(self): left_right_entropy = {} for n in range(self.min_n, self.max_n + 1): ngrams_entropy = {} target_ngrams = self.ngrams_words[n] parent_words = self.ngrams_words[n + 1] left_neighbors = Trie() right_neighbors = Trie() for parent_word in tqdm(parent_words, desc='build neighbors'): right_neighbors[parent_word] = self.ngrams_freq[parent_word] left_neighbors[parent_word[1:] + parent_word[0]] = self.ngrams_freq[parent_word] for target_ngram in tqdm(target_ngrams, desc='target ngram'): try: right_neighbors_counts = ( right_neighbors.values(target_ngram)) right_entropy = self.cal_ngram_entropy( right_neighbors_counts) except KeyError: right_entropy = 0 try: left_neighbors_counts = ( left_neighbors.values(target_ngram)) left_entropy = self.cal_ngram_entropy( left_neighbors_counts) except KeyError: left_entropy = 0 ngrams_entropy[target_ngram] = (left_entropy, right_entropy) left_right_entropy.update(ngrams_entropy) return left_right_entropy
def build_trie(alphabet, vocab): from pygtrie import CharTrie as Trie start_time = datetime.now() info('start building trie at {}'.format( start_time.strftime("%H:%M:%S"))) trie = Trie() for i, v in enumerate(vocab, start=1): trie[v] = 1 if i % 10000 == 0: info('inserted {} ...'.format(i)) end_time = datetime.now() info('finish building trie at {} (delta {})'.format( end_time.strftime("%H:%M:%S"), end_time - start_time)) return trie
def walk( self, top, topdown=True, onerror=None, download_callback=None, **kwargs ): from pygtrie import Trie assert topdown if not self.exists(top): if onerror is not None: onerror(FileNotFoundError(top)) return if not self.isdir(top): if onerror is not None: onerror(NotADirectoryError(top)) return root = PathInfo(os.path.abspath(top)) outs = self._find_outs(top, recursive=True, strict=False) trie = Trie() for out in outs: trie[out.path_info.parts] = out if out.is_dir_checksum and (self.fetch or self.stream): # pull dir cache if needed dir_cache = out.get_dir_cache(**kwargs) # pull dir contents if needed if self.fetch: if out.changed_cache(filter_info=top): used_cache = out.get_used_cache(filter_info=top) downloaded = self.repo.cloud.pull(used_cache, **kwargs) if download_callback: download_callback(downloaded) for entry in dir_cache: entry_relpath = entry[out.remote.tree.PARAM_RELPATH] if os.name == "nt": entry_relpath = entry_relpath.replace("/", os.sep) path_info = out.path_info / entry_relpath trie[path_info.parts] = None yield from self._walk(root, trie, topdown=topdown)
class JiebaTokenizer(object): def __init__(self, user_dict: Union[str, Iterable] = None): self.t = posseg.POSTokenizer() self.t.initialize() self.trie = Trie() if user_dict: self.load_user_dict(user_dict) def load_user_dict(self, user_dict: Union[str, Iterable] = None): if isinstance(user_dict, str): with open(user_dict) as fin: user_dict = [line.strip('\r\n') for line in fin] seg_dict = {} for line in user_dict: line = line.strip() if not line: continue arr = line.split('=', 1) key = arr[0].strip() if len(arr) > 1 and len(arr[1].strip()) > 0: value = [x for x in arr[1].strip().split()] assert len(key) == sum(len(x) for x in value) else: value = [key] seg_dict[key] = value self.t.tokenizer.add_word(key, 100) # jieba一定会切分英文串,对于英文还需要trie树来合并一下 for key, value in seg_dict.items(): self.trie[self.t.tokenizer.lcut(key)] = value def tokenize(self, s: str) -> List[Token]: res = [] term_list = list(self.t.cut(s)) text_list = [x.word for x in term_list] idx = 0 while idx < len(text_list): m = self.trie.longest_prefix(text_list[idx:]) if m: res.extend(Token(x, 'nz') for x in m.value) idx += len(m.key) else: res.append(Token(term_list[idx].word, term_list[idx].flag)) idx += 1 return res
def _get_restricted_trie(rule: str, trie: pygtrie.Trie) -> pygtrie.Trie: if trie: if '[pos]' in rule or '[or]' in rule: # For [pos] or [or] operators, remove words from input trie that are # not valid intersect_word_actions. valid_word_actions = self.valid_word_actions.intersect_word_actions elif '[neg]' in rule: # For [neg] rules, remove words from the input trie that are not # valid diff_word_actions. valid_word_actions = self.valid_word_actions.diff_word_actions else: valid_word_actions = None if valid_word_actions: result = copy.deepcopy(trie) for word in trie.iterkeys(()): if not _can_construct(word, valid_word_actions): del result[word] return result return trie
def walk(self, top, topdown=True): from pygtrie import Trie assert topdown if not self.exists(top): raise FileNotFoundError if not self.isdir(top): raise NotADirectoryError root = PathInfo(os.path.abspath(top)) outs = self._find_outs(top, recursive=True, strict=False) trie = Trie() for out in outs: trie[out.path_info.parts] = out yield from self._walk(root, trie, topdown=topdown)
def _calc_ngram_entropy(ngram_freq, ngram_keys, n): """ 基于ngram频率信息计算熵信息 :param ngram_freq: :param ngram_keys: :param n: :return: """ if isinstance(n, Iterable): ## 一次性计算 len(N)>1 的 ngram entropy = {} for ni in n: entropy = { **entropy, **_calc_ngram_entropy(ngram_freq, ngram_keys, ni) } return entropy ngram_entropy = {} target_ngrams = ngram_keys[n] parent_candidates = ngram_keys[n + 1] if CPU_COUNT == 1: ## 对 n+1 gram 进行建Trie处理 left_neighbors = Trie() right_neighbors = Trie() for parent_candidate in parent_candidates: right_neighbors[parent_candidate] = ngram_freq[parent_candidate] left_neighbors[parent_candidate[1:] + parent_candidate[0]] = ngram_freq[parent_candidate] ## 计算 for target_ngram in target_ngrams: try: ## 一定情况下, 一个candidate ngram 没有左右neighbor right_neighbor_counts = (right_neighbors.values(target_ngram)) right_entropy = _ngram_entropy_scorer(right_neighbor_counts) except KeyError: right_entropy = 0 try: left_neighbor_counts = (left_neighbors.values(target_ngram)) left_entropy = _ngram_entropy_scorer(left_neighbor_counts) except KeyError: left_entropy = 0 ngram_entropy[target_ngram] = (left_entropy, right_entropy) return ngram_entropy else: ## TODO 多进程计算 pass
def _update_trie(self, dirname: str, trie: Trie) -> None: key = self._get_key(dirname) old_pattern = trie.longest_prefix(key).value matches = old_pattern.matches(dirname, DvcIgnore.DVCIGNORE_FILE, False) path = self.fs.path.join(dirname, DvcIgnore.DVCIGNORE_FILE) if not matches and self.fs.exists(path): name = self.fs.path.relpath(path, self.root_dir) new_pattern = DvcIgnorePatterns.from_file(path, self.fs, name) if old_pattern: plist, prefix = merge_patterns( self.fs.path.flavour, old_pattern.pattern_list, old_pattern.dirname, new_pattern.pattern_list, new_pattern.dirname, ) trie[key] = DvcIgnorePatterns(plist, prefix, self.fs.sep) else: trie[key] = new_pattern elif old_pattern: trie[key] = old_pattern
def calcul_ngram_entropy(ngram_freq, ngram_keys, n): """ Calcul entropy by ngram frequences """ # Calcul ngram entropy if isinstance(n,collections.abc.Iterable): entropy = {} for ni in n: entropy = {**entropy,**calcul_ngram_entropy(ngram_freq,ngram_keys,ni)} return entropy ngram_entropy = {} parent_candidates = ngram_keys[n+1] if n!=1: target_ngrams = ngram_keys[n] else: target_ngrams = [l for l in ngram_keys[n] if ToolWord().is_english_word(l[0])] if hp.CPU_COUNT == 1: # Build trie for n+1 gram left_neighbors = Trie() right_neighbors = Trie() for parent_candidate in parent_candidates: right_neighbors[parent_candidate] = ngram_freq[parent_candidate] left_neighbors[parent_candidate[1:]+(parent_candidate[0],)] = ngram_freq[parent_candidate] # Calcul entropy for target_ngram in target_ngrams: try: right_neighbor_counts = (right_neighbors.values(target_ngram)) right_entropy = entropy_of_list(right_neighbor_counts) except KeyError: right_entropy = 0 try: left_neighbor_counts = (left_neighbors.values(target_ngram)) left_entropy = entropy_of_list(left_neighbor_counts) except KeyError: left_entropy = 0 ngram_entropy[target_ngram] = (left_entropy,right_entropy) return ngram_entropy else: # Multi process pass
def _collect_graph(self, stages): """Generate a graph by using the given stages on the given directory The nodes of the graph are the stage's path relative to the root. Edges are created when the output of one stage is used as a dependency in other stage. The direction of the edges goes from the stage to its dependency: For example, running the following: $ dvc run -o A "echo A > A" $ dvc run -d A -o B "echo B > B" $ dvc run -d B -o C "echo C > C" Will create the following graph: ancestors <-- | C.dvc -> B.dvc -> A.dvc | | | --> descendants | ------- pipeline ------> | v (weakly connected components) Args: stages (list): used to build a graph, if None given, collect stages in the repository. Raises: OutputDuplicationError: two outputs with the same path StagePathAsOutputError: stage inside an output directory OverlappingOutputPathsError: output inside output directory CyclicGraphError: resulting graph has cycles """ import networkx as nx from pygtrie import Trie from dvc.exceptions import ( OutputDuplicationError, OverlappingOutputPathsError, StagePathAsOutputError, ) G = nx.DiGraph() stages = stages or self.stages outs = Trie() # Use trie to efficiently find overlapping outs and deps for stage in filter(bool, stages): # bug? not using it later for out in stage.outs: out_key = out.path_info.parts # Check for dup outs if out_key in outs: dup_stages = [stage, outs[out_key].stage] raise OutputDuplicationError(str(out), dup_stages) # Check for overlapping outs if outs.has_subtrie(out_key): parent = out overlapping = first(outs.values(prefix=out_key)) else: parent = outs.shortest_prefix(out_key).value overlapping = out if parent and overlapping: msg = ("Paths for outs:\n'{}'('{}')\n'{}'('{}')\n" "overlap. To avoid unpredictable behaviour, " "rerun command with non overlapping outs paths." ).format( str(parent), parent.stage.addressing, str(overlapping), overlapping.stage.addressing, ) raise OverlappingOutputPathsError(parent, overlapping, msg) outs[out_key] = out for stage in stages: out = outs.shortest_prefix(PathInfo(stage.path).parts).value if out: raise StagePathAsOutputError(stage, str(out)) # Building graph G.add_nodes_from(stages) for stage in stages: for dep in stage.deps: if dep.path_info is None: continue dep_key = dep.path_info.parts overlapping = [n.value for n in outs.prefixes(dep_key)] if outs.has_subtrie(dep_key): overlapping.extend(outs.values(prefix=dep_key)) G.add_edges_from((stage, out.stage) for out in overlapping) check_acyclic(G) return G
def trie(self): return Trie(self._dict)
def trie(self): from pygtrie import Trie return Trie(self._dict)
class _DvcFileSystem(AbstractFileSystem): # pylint:disable=abstract-method """DVC + git-tracked files fs. Args: repo: DVC or git repo. subrepos: traverse to subrepos (by default, it ignores subrepos) repo_factory: A function to initialize subrepo with, default is Repo. kwargs: Additional keyword arguments passed to the `DataFileSystem()`. """ root_marker = "/" PARAM_REPO_URL = "repo_url" PARAM_REPO_ROOT = "repo_root" PARAM_REV = "rev" PARAM_CACHE_DIR = "cache_dir" PARAM_CACHE_TYPES = "cache_types" PARAM_SUBREPOS = "subrepos" def __init__( self, repo: Optional["Repo"] = None, subrepos=False, repo_factory: RepoFactory = None, **kwargs, ): super().__init__() from pygtrie import Trie if repo is None: repo, repo_factory = self._repo_from_fs_config(subrepos=subrepos, **kwargs) if not repo_factory: from dvc.repo import Repo self.repo_factory: RepoFactory = Repo else: self.repo_factory = repo_factory def _getcwd(): relparts = () if repo.fs.path.isin(repo.fs.path.getcwd(), repo.root_dir): relparts = repo.fs.path.relparts(repo.fs.path.getcwd(), repo.root_dir) return self.root_marker + self.sep.join(relparts) self.path = Path(self.sep, getcwd=_getcwd) self.repo = repo self.hash_jobs = repo.fs.hash_jobs self._traverse_subrepos = subrepos self._subrepos_trie = Trie() """Keeps track of each and every path with the corresponding repo.""" key = self._get_key(self.repo.root_dir) self._subrepos_trie[key] = repo self._datafss = {} """Keep a datafs instance of each repo.""" if hasattr(repo, "dvc_dir"): self._datafss[key] = DataFileSystem(repo=repo) def _get_key(self, path): parts = self.repo.fs.path.relparts(path, self.repo.root_dir) if parts == (".", ): parts = () return parts @property def repo_url(self): if self.repo is None: return None return self.repo.url @property def config(self): return { self.PARAM_REPO_URL: self.repo_url, self.PARAM_REPO_ROOT: self.repo.root_dir, self.PARAM_REV: getattr(self.repo.fs, "rev", None), self.PARAM_CACHE_DIR: os.path.abspath(self.repo.odb.local.cache_dir), self.PARAM_CACHE_TYPES: self.repo.odb.local.cache_types, self.PARAM_SUBREPOS: self._traverse_subrepos, } @classmethod def _repo_from_fs_config( cls, **config) -> Tuple["Repo", Optional["RepoFactory"]]: from dvc.external_repo import erepo_factory, external_repo from dvc.repo import Repo url = config.get(cls.PARAM_REPO_URL) root = config.get(cls.PARAM_REPO_ROOT) assert url or root def _open(*args, **kwargs): # NOTE: if original repo was an erepo (and has a URL), # we cannot use Repo.open() since it will skip erepo # cache/remote setup for local URLs if url is None: return Repo.open(*args, **kwargs) return external_repo(*args, **kwargs) cache_dir = config.get(cls.PARAM_CACHE_DIR) cache_config = ({} if not cache_dir else { "cache": { "dir": cache_dir, "type": config.get(cls.PARAM_CACHE_TYPES), } }) repo_kwargs: dict = { "rev": config.get(cls.PARAM_REV), "subrepos": config.get(cls.PARAM_SUBREPOS, False), "uninitialized": True, } factory: Optional["RepoFactory"] = None if url is None: repo_kwargs["config"] = cache_config else: repo_kwargs["cache_dir"] = cache_dir factory = erepo_factory(url, root, cache_config) with _open( url if url else root, **repo_kwargs, ) as repo: return repo, factory def _get_repo(self, path: str) -> "Repo": """Returns repo that the path falls in, using prefix. If the path is already tracked/collected, it just returns the repo. Otherwise, it collects the repos that might be in the path's parents and then returns the appropriate one. """ if not self.repo.fs.path.isin_or_eq(path, self.repo.root_dir): # outside of repo return self.repo key = self._get_key(path) repo = self._subrepos_trie.get(key) if repo: return repo prefix_key, repo = self._subrepos_trie.longest_prefix(key) prefix = self.repo.fs.path.join( self.repo.root_dir, *prefix_key, # pylint: disable=not-an-iterable ) parents = (parent for parent in self.repo.fs.path.parents(path)) dirs = [path] + list(takewhile(lambda p: p != prefix, parents)) dirs.reverse() self._update(dirs, starting_repo=repo) return self._subrepos_trie.get(key) or self.repo @wrap_with(threading.Lock()) def _update(self, dirs, starting_repo): """Checks for subrepo in directories and updates them.""" repo = starting_repo for d in dirs: key = self._get_key(d) if self._is_dvc_repo(d): repo = self.repo_factory( d, fs=self.repo.fs, scm=self.repo.scm, repo_factory=self.repo_factory, ) self._datafss[key] = DataFileSystem(repo=repo) self._subrepos_trie[key] = repo def _is_dvc_repo(self, dir_path): """Check if the directory is a dvc repo.""" if not self._traverse_subrepos: return False from dvc.repo import Repo repo_path = self.repo.fs.path.join(dir_path, Repo.DVC_DIR) return self.repo.fs.isdir(repo_path) def _get_fs_pair( self, path ) -> Tuple[Optional[FileSystem], Optional[str], Optional[DataFileSystem], Optional[str], ]: """ Returns a pair of fss based on repo the path falls in, using prefix. """ parts = self.path.relparts(path, self.root_marker) if parts and parts[0] == os.curdir: parts = parts[1:] fs_path = self.repo.fs.path.join(self.repo.root_dir, *parts) repo = self._get_repo(fs_path) fs = repo.fs repo_parts = fs.path.relparts(repo.root_dir, self.repo.root_dir) if repo_parts[0] == os.curdir: repo_parts = repo_parts[1:] dvc_parts = parts[len(repo_parts):] if dvc_parts and dvc_parts[0] == os.curdir: dvc_parts = dvc_parts[1:] key = self._get_key(repo.root_dir) dvc_fs = self._datafss.get(key) if dvc_fs: dvc_path = dvc_fs.path.join(*dvc_parts) if dvc_parts else "" else: dvc_path = None return fs, fs_path, dvc_fs, dvc_path def open(self, path, mode="r", encoding="utf-8", **kwargs): # pylint: disable=arguments-renamed, arguments-differ if "b" in mode: encoding = None fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path) try: return fs.open(fs_path, mode=mode, encoding=encoding) except FileNotFoundError: if not dvc_fs: raise return dvc_fs.open(dvc_path, mode=mode, encoding=encoding, **kwargs) def isdvc(self, path, **kwargs): _, _, dvc_fs, dvc_path = self._get_fs_pair(path) return dvc_fs is not None and dvc_fs.isdvc(dvc_path, **kwargs) def ls( # pylint: disable=arguments-differ self, path, detail=True, dvc_only=False, **kwargs): fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path) repo = dvc_fs.repo if dvc_fs else self.repo dvcignore = repo.dvcignore ignore_subrepos = kwargs.get("ignore_subrepos", True) names = set() if dvc_fs: with suppress(FileNotFoundError): for entry in dvc_fs.ls(dvc_path, detail=False): names.add(dvc_fs.path.name(entry)) if not dvc_only and fs: try: for entry in dvcignore.ls(fs, fs_path, detail=False, ignore_subrepos=ignore_subrepos): names.add(fs.path.name(entry)) except (FileNotFoundError, NotADirectoryError): pass dvcfiles = kwargs.get("dvcfiles", False) def _func(fname): from dvc.dvcfile import is_valid_filename from dvc.ignore import DvcIgnore if dvcfiles: return True return not (is_valid_filename(fname) or fname == DvcIgnore.DVCIGNORE_FILE) names = filter(_func, names) infos = [] paths = [] for name in names: entry_path = self.path.join(path, name) try: info = self.info(entry_path, ignore_subrepos=ignore_subrepos) except FileNotFoundError: continue infos.append(info) paths.append(entry_path) if not detail: return paths return infos def get_file( # pylint: disable=arguments-differ self, rpath, lpath, callback=DEFAULT_CALLBACK, **kwargs): fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(rpath) if fs: try: fs.get_file(fs_path, lpath, callback=callback, **kwargs) return except FileNotFoundError: if not dvc_fs: raise dvc_fs.get_file(dvc_path, lpath, callback=callback, **kwargs) def info(self, path, **kwargs): fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path) repo = dvc_fs.repo if dvc_fs else self.repo dvcignore = repo.dvcignore ignore_subrepos = kwargs.get("ignore_subrepos", True) dvc_info = None if dvc_fs: try: dvc_info = dvc_fs.info(dvc_path) except FileNotFoundError: pass fs_info = None if fs: try: fs_info = fs.info(fs_path) if dvcignore.is_ignored(fs, fs_path, ignore_subrepos=ignore_subrepos): fs_info = None except (FileNotFoundError, NotADirectoryError): if not dvc_info: raise # NOTE: if some parent in fs_path turns out to be a file, it means # that the whole repofs branch doesn't exist. if fs and not fs_info and dvc_info: for parent in fs.path.parents(fs_path): try: if fs.info(parent)["type"] != "directory": dvc_info = None break except FileNotFoundError: continue if not dvc_info and not fs_info: raise FileNotFoundError info = _merge_info(repo, fs_info, dvc_info) info["name"] = path return info def checksum(self, path): fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path) try: return fs.checksum(fs_path) except FileNotFoundError: return dvc_fs.checksum(dvc_path)