def __init__(self, local_cache): from dvc.tree.local import LocalRemoteTree repo = local_cache.repo self.repo = repo self.root_dir = repo.root_dir self.tree = LocalRemoteTree(None, {"url": self.root_dir}) state_config = repo.config.get("state", {}) self.row_limit = state_config.get("row_limit", self.STATE_ROW_LIMIT) self.row_cleanup_quota = state_config.get("row_cleanup_quota", self.STATE_ROW_CLEANUP_QUOTA) if not repo.tmp_dir: self.state_file = None return self.state_file = os.path.join(repo.tmp_dir, self.STATE_FILE) # https://www.sqlite.org/tempfiles.html self.temp_files = [ self.state_file + "-journal", self.state_file + "-wal", ] self.database = None self.cursor = None self.inserts = 0
def __init__(self, root_dir=None, scm=None, rev=None): from dvc.state import State, StateNoop from dvc.lock import make_lock from dvc.scm import SCM from dvc.cache import Cache from dvc.data_cloud import DataCloud from dvc.repo.metrics import Metrics from dvc.repo.plots import Plots from dvc.repo.params import Params from dvc.tree.local import LocalRemoteTree from dvc.utils.fs import makedirs from dvc.stage.cache import StageCache if scm: tree = scm.get_tree(rev) self.root_dir = self.find_root(root_dir, tree) self.scm = scm self.tree = tree self.state = StateNoop() else: root_dir = self.find_root(root_dir) self.root_dir = os.path.abspath(os.path.realpath(root_dir)) self.tree = LocalRemoteTree(None, {"url": self.root_dir}) self.dvc_dir = os.path.join(self.root_dir, self.DVC_DIR) self.config = Config(self.dvc_dir, tree=self.tree) if not scm: no_scm = self.config["core"].get("no_scm", False) self.scm = SCM(self.root_dir, no_scm=no_scm) self.tmp_dir = os.path.join(self.dvc_dir, "tmp") self.index_dir = os.path.join(self.tmp_dir, "index") makedirs(self.index_dir, exist_ok=True) hardlink_lock = self.config["core"].get("hardlink_lock", False) self.lock = make_lock( os.path.join(self.tmp_dir, "lock"), tmp_dir=self.tmp_dir, hardlink_lock=hardlink_lock, friendly=True, ) self.cache = Cache(self) self.cloud = DataCloud(self) if not scm: # NOTE: storing state and link_state in the repository itself to # avoid any possible state corruption in 'shared cache dir' # scenario. self.state = State(self.cache.local) self.stage_cache = StageCache(self) self.metrics = Metrics(self) self.plots = Plots(self) self.params = Params(self) self._ignore()
def test_protect_ignore_erofs(tmp_dir, mocker): tmp_dir.gen("foo", "foo") foo = PathInfo("foo") tree = LocalRemoteTree(None, {}) mock_chmod = mocker.patch("os.chmod", side_effect=OSError(errno.EROFS, "read-only fs")) tree.protect(foo) assert mock_chmod.called
def test_nobranch(self): tree = LocalRemoteTree(None, {"url": self._root_dir}, use_dvcignore=True) self.assertWalkEqual( tree.walk("."), [ (".", ["data_dir"], ["bar", "тест", "code.py", "foo"]), (join("data_dir"), ["data_sub_dir"], ["data"]), (join("data_dir", "data_sub_dir"), [], ["data_sub"]), ], ) self.assertWalkEqual( tree.walk(join("data_dir", "data_sub_dir")), [(join("data_dir", "data_sub_dir"), [], ["data_sub"])], )
def test_ignore_collecting_dvcignores(tmp_dir, dvc, dname): tmp_dir.gen({"dir": {"subdir": {}}}) top_ignore_file = (tmp_dir / dname).with_name(DvcIgnore.DVCIGNORE_FILE) top_ignore_file.write_text(os.path.basename(dname)) dvc.tree.__dict__.pop("dvcignore", None) ignore_file = tmp_dir / dname / DvcIgnore.DVCIGNORE_FILE ignore_file.write_text("foo") assert len(dvc.tree.dvcignore.ignores) == 3 assert DvcIgnoreDirs([".git", ".hg", ".dvc"]) in dvc.tree.dvcignore.ignores ignore_pattern_trie = None for ignore in dvc.tree.dvcignore.ignores: if isinstance(ignore, DvcIgnorePatternsTrie): ignore_pattern_trie = ignore assert ignore_pattern_trie is not None assert (DvcIgnorePatterns.from_files( os.fspath(top_ignore_file), LocalRemoteTree(None, {"url": dvc.root_dir}), ) == ignore_pattern_trie[os.fspath(ignore_file)]) assert any(i for i in dvc.tree.dvcignore.ignores if isinstance(i, DvcIgnoreRepo))
def test(self): tree = LocalRemoteTree(None, {"url": self._root_dir}) self.assertWalkEqual( tree.walk(self._root_dir), [ ( self._root_dir, ["data_dir"], ["code.py", "bar", "тест", "foo"], ), (join(self._root_dir, "data_dir"), ["data_sub_dir"], ["data"]), ( join(self._root_dir, "data_dir", "data_sub_dir"), [], ["data_sub"], ), ], )
def test(self): tree = CleanTree(LocalRemoteTree(None, {"url": self.root_dir})) file_time, file_size = get_mtime_and_size(self.DATA, tree) dir_time, dir_size = get_mtime_and_size(self.DATA_DIR, tree) actual_file_size = os.path.getsize(self.DATA) actual_dir_size = os.path.getsize(self.DATA) + os.path.getsize( self.DATA_SUB) self.assertIs(type(file_time), str) self.assertIs(type(file_size), str) self.assertEqual(file_size, str(actual_file_size)) self.assertIs(type(dir_time), str) self.assertIs(type(dir_size), str) self.assertEqual(dir_size, str(actual_dir_size))
def test_protect_ignore_errors(tmp_dir, mocker, err): tmp_dir.gen("foo", "foo") foo = PathInfo("foo") tree = LocalRemoteTree(None, {}) tree.protect(foo) mock_chmod = mocker.patch("os.chmod", side_effect=OSError(err, "something")) tree.protect(foo) assert mock_chmod.called
def test_path_object_and_str_are_valid_types_get_mtime_and_size(tmp_dir): tmp_dir.gen({ "dir": { "dir_file": "dir file content" }, "file": "file_content" }) tree = CleanTree(LocalRemoteTree(None, {"url": os.fspath(tmp_dir)})) time, size = get_mtime_and_size("dir", tree) object_time, object_size = get_mtime_and_size(PathInfo("dir"), tree) assert time == object_time assert size == object_size time, size = get_mtime_and_size("file", tree) object_time, object_size = get_mtime_and_size(PathInfo("file"), tree) assert time == object_time assert size == object_size
def __init__( self, dvc_dir=None, validate=True, tree=None, ): # pylint: disable=super-init-not-called from dvc.tree.local import LocalRemoteTree self.dvc_dir = dvc_dir if not dvc_dir: try: from dvc.repo import Repo self.dvc_dir = os.path.join(Repo.find_dvc_dir()) except NotDvcRepoError: self.dvc_dir = None else: self.dvc_dir = os.path.abspath(os.path.realpath(dvc_dir)) self.wtree = LocalRemoteTree(None, {"url": self.dvc_dir}) self.tree = tree.tree if tree else self.wtree self.load(validate=validate)
def test_status_download_optimization(mocker, dvc): """When comparing the status to pull a remote cache, And the desired files to fetch are already on the local cache, Don't check the existence of the desired files on the remote cache """ cache = LocalCache(LocalRemoteTree(dvc, {})) infos = NamedCache() infos.add("local", "acbd18db4cc2f85cedef654fccc4a4d8", "foo") infos.add("local", "37b51d194a7513e45b56f6524f2d51f2", "bar") local_exists = list(infos["local"]) mocker.patch.object(cache, "hashes_exist", return_value=local_exists) other_remote = mocker.Mock() other_remote.url = "other_remote" other_remote.hashes_exist.return_value = [] other_remote.index = RemoteIndexNoop() cache.status(infos, other_remote, download=True) assert other_remote.hashes_exist.call_count == 0
def test_subdir(self): tree = LocalRemoteTree(None, {"url": self._root_dir}) self.assertWalkEqual( tree.walk(join("data_dir", "data_sub_dir")), [(join("data_dir", "data_sub_dir"), [], ["data_sub"])], )
class Repo: DVC_DIR = ".dvc" from dvc.repo.destroy import destroy from dvc.repo.install import install from dvc.repo.add import add from dvc.repo.remove import remove from dvc.repo.ls import ls from dvc.repo.freeze import freeze, unfreeze from dvc.repo.move import move from dvc.repo.run import run from dvc.repo.imp import imp from dvc.repo.imp_url import imp_url from dvc.repo.reproduce import reproduce from dvc.repo.checkout import _checkout from dvc.repo.push import push from dvc.repo.fetch import _fetch from dvc.repo.pull import pull from dvc.repo.status import status from dvc.repo.gc import gc from dvc.repo.commit import commit from dvc.repo.diff import diff from dvc.repo.brancher import brancher from dvc.repo.get import get from dvc.repo.get_url import get_url from dvc.repo.update import update def __init__(self, root_dir=None, scm=None, rev=None): from dvc.state import State, StateNoop from dvc.lock import make_lock from dvc.scm import SCM from dvc.cache import Cache from dvc.data_cloud import DataCloud from dvc.repo.experiments import Experiments from dvc.repo.metrics import Metrics from dvc.repo.plots import Plots from dvc.repo.params import Params from dvc.tree.local import LocalRemoteTree from dvc.utils.fs import makedirs from dvc.stage.cache import StageCache if scm: tree = scm.get_tree(rev) self.root_dir = self.find_root(root_dir, tree) self.scm = scm self.tree = scm.get_tree(rev, use_dvcignore=True, dvcignore_root=self.root_dir) self.state = StateNoop() else: root_dir = self.find_root(root_dir) self.root_dir = os.path.abspath(os.path.realpath(root_dir)) self.tree = LocalRemoteTree( self, {"url": self.root_dir}, use_dvcignore=True, dvcignore_root=self.root_dir, ) self.dvc_dir = os.path.join(self.root_dir, self.DVC_DIR) self.config = Config(self.dvc_dir, tree=self.tree) if not scm: no_scm = self.config["core"].get("no_scm", False) self.scm = SCM(self.root_dir, no_scm=no_scm) self.tmp_dir = os.path.join(self.dvc_dir, "tmp") self.index_dir = os.path.join(self.tmp_dir, "index") makedirs(self.index_dir, exist_ok=True) hardlink_lock = self.config["core"].get("hardlink_lock", False) self.lock = make_lock( os.path.join(self.tmp_dir, "lock"), tmp_dir=self.tmp_dir, hardlink_lock=hardlink_lock, friendly=True, ) self.cache = Cache(self) self.cloud = DataCloud(self) if not scm: # NOTE: storing state and link_state in the repository itself to # avoid any possible state corruption in 'shared cache dir' # scenario. self.state = State(self.cache.local) self.stage_cache = StageCache(self) self.metrics = Metrics(self) self.plots = Plots(self) self.params = Params(self) try: self.experiments = Experiments(self) except NotImplementedError: self.experiments = None self._ignore() @property def tree(self): return self._tree @tree.setter def tree(self, tree): self._tree = tree # Our graph cache is no longer valid, as it was based on the previous # tree. self._reset() def __repr__(self): return f"{self.__class__.__name__}: '{self.root_dir}'" @classmethod def find_root(cls, root=None, tree=None): root_dir = os.path.realpath(root or os.curdir) if tree: if tree.isdir(os.path.join(root_dir, cls.DVC_DIR)): return root_dir raise NotDvcRepoError(f"'{root}' does not contain DVC directory") if not os.path.isdir(root_dir): raise NotDvcRepoError(f"directory '{root}' does not exist") while True: dvc_dir = os.path.join(root_dir, cls.DVC_DIR) if os.path.isdir(dvc_dir): return root_dir if os.path.ismount(root_dir): break root_dir = os.path.dirname(root_dir) message = ("you are not inside of a DVC repository " "(checked up to mount point '{}')").format(root_dir) raise NotDvcRepoError(message) @classmethod def find_dvc_dir(cls, root=None): root_dir = cls.find_root(root) return os.path.join(root_dir, cls.DVC_DIR) @staticmethod def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): from dvc.repo.init import init init(root_dir=root_dir, no_scm=no_scm, force=force, subdir=subdir) return Repo(root_dir) def unprotect(self, target): return self.cache.local.tree.unprotect(PathInfo(target)) def _ignore(self): flist = [ self.config.files["local"], self.tmp_dir, ] if self.experiments: flist.append(self.experiments.exp_dir) if path_isin(self.cache.local.cache_dir, self.root_dir): flist += [self.cache.local.cache_dir] self.scm.ignore_list(flist) def get_stage(self, path=None, name=None): if not path: path = PIPELINE_FILE logger.debug("Assuming '%s' to be a stage inside '%s'", name, path) dvcfile = Dvcfile(self, path) return dvcfile.stages[name] def get_stages(self, path=None, name=None): if not path: path = PIPELINE_FILE logger.debug("Assuming '%s' to be a stage inside '%s'", name, path) if name: return [self.get_stage(path, name)] dvcfile = Dvcfile(self, path) return list(dvcfile.stages.values()) def check_modified_graph(self, new_stages): """Generate graph including the new stage to check for errors""" # Building graph might be costly for the ones with many DVC-files, # so we provide this undocumented hack to skip it. See [1] for # more details. The hack can be used as: # # repo = Repo(...) # repo._skip_graph_checks = True # repo.add(...) # # A user should care about not duplicating outs and not adding cycles, # otherwise DVC might have an undefined behaviour. # # [1] https://github.com/iterative/dvc/issues/2671 if not getattr(self, "_skip_graph_checks", False): self._collect_graph(self.stages + new_stages) def _collect_inside(self, path, graph): import networkx as nx stages = nx.dfs_postorder_nodes(graph) return [stage for stage in stages if path_isin(stage.path, path)] def collect(self, target=None, with_deps=False, recursive=False, graph=None): if not target: return list(graph) if graph else self.stages if recursive and os.path.isdir(target): return self._collect_inside(os.path.abspath(target), graph or self.graph) path, name = parse_target(target) stages = self.get_stages(path, name) if not with_deps: return stages res = set() for stage in stages: res.update(self._collect_pipeline(stage, graph=graph)) return res def _collect_pipeline(self, stage, graph=None): import networkx as nx pipeline = get_pipeline(get_pipelines(graph or self.graph), stage) return nx.dfs_postorder_nodes(pipeline, stage) def _collect_from_default_dvcfile(self, target): dvcfile = Dvcfile(self, PIPELINE_FILE) if dvcfile.exists(): return dvcfile.stages.get(target) def collect_granular(self, target=None, with_deps=False, recursive=False, graph=None): """ Priority is in the order of following in case of ambiguity: - .dvc file or .yaml file - dir if recursive and directory exists - stage_name - output file """ if not target: return [(stage, None) for stage in self.stages] file, name = parse_target(target) stages = [] # Optimization: do not collect the graph for a specific target if not file: # parsing is ambiguous when it does not have a colon # or if it's not a dvcfile, as it can be a stage name # in `dvc.yaml` or, an output in a stage. logger.debug("Checking if stage '%s' is in '%s'", target, PIPELINE_FILE) if not (recursive and os.path.isdir(target)): stage = self._collect_from_default_dvcfile(target) if stage: stages = (self._collect_pipeline(stage) if with_deps else [stage]) elif not with_deps and is_valid_filename(file): stages = self.get_stages(file, name) if not stages: if not (recursive and os.path.isdir(target)): try: (out, ) = self.find_outs_by_path(target, strict=False) filter_info = PathInfo(os.path.abspath(target)) return [(out.stage, filter_info)] except OutputNotFoundError: pass try: stages = self.collect(target, with_deps, recursive, graph) except StageFileDoesNotExistError as exc: # collect() might try to use `target` as a stage name # and throw error that dvc.yaml does not exist, whereas it # should say that both stage name and file does not exist. if file and is_valid_filename(file): raise raise NoOutputOrStageError(target, exc.file) from exc except StageNotFound as exc: raise NoOutputOrStageError(target, exc.file) from exc return [(stage, None) for stage in stages] def used_cache( self, targets=None, all_branches=False, with_deps=False, all_tags=False, all_commits=False, remote=None, force=False, jobs=None, recursive=False, used_run_cache=None, ): """Get the stages related to the given target and collect the `info` of its outputs. This is useful to know what files from the cache are _in use_ (namely, a file described as an output on a stage). The scope is, by default, the working directory, but you can use `all_branches`/`all_tags`/`all_commits` to expand the scope. Returns: A dictionary with Schemes (representing output's location) mapped to items containing the output's `dumpd` names and the output's children (if the given output is a directory). """ from dvc.cache import NamedCache cache = NamedCache() for branch in self.brancher( all_branches=all_branches, all_tags=all_tags, all_commits=all_commits, ): targets = targets or [None] pairs = cat( self.collect_granular( target, recursive=recursive, with_deps=with_deps) for target in targets) suffix = f"({branch})" if branch else "" for stage, filter_info in pairs: used_cache = stage.get_used_cache( remote=remote, force=force, jobs=jobs, filter_info=filter_info, ) cache.update(used_cache, suffix=suffix) if used_run_cache: used_cache = self.stage_cache.get_used_cache( used_run_cache, remote=remote, force=force, jobs=jobs, ) cache.update(used_cache) return cache def _collect_graph(self, stages): """Generate a graph by using the given stages on the given directory The nodes of the graph are the stage's path relative to the root. Edges are created when the output of one stage is used as a dependency in other stage. The direction of the edges goes from the stage to its dependency: For example, running the following: $ dvc run -o A "echo A > A" $ dvc run -d A -o B "echo B > B" $ dvc run -d B -o C "echo C > C" Will create the following graph: ancestors <-- | C.dvc -> B.dvc -> A.dvc | | | --> descendants | ------- pipeline ------> | v (weakly connected components) Args: stages (list): used to build a graph, if None given, collect stages in the repository. Raises: OutputDuplicationError: two outputs with the same path StagePathAsOutputError: stage inside an output directory OverlappingOutputPathsError: output inside output directory CyclicGraphError: resulting graph has cycles """ import networkx as nx from pygtrie import Trie from dvc.exceptions import ( OutputDuplicationError, StagePathAsOutputError, OverlappingOutputPathsError, ) G = nx.DiGraph() stages = stages or self.stages outs = Trie() # Use trie to efficiently find overlapping outs and deps for stage in filter(bool, stages): # bug? not using it later for out in stage.outs: out_key = out.path_info.parts # Check for dup outs if out_key in outs: dup_stages = [stage, outs[out_key].stage] raise OutputDuplicationError(str(out), dup_stages) # Check for overlapping outs if outs.has_subtrie(out_key): parent = out overlapping = first(outs.values(prefix=out_key)) else: parent = outs.shortest_prefix(out_key).value overlapping = out if parent and overlapping: msg = ("Paths for outs:\n'{}'('{}')\n'{}'('{}')\n" "overlap. To avoid unpredictable behaviour, " "rerun command with non overlapping outs paths." ).format( str(parent), parent.stage.addressing, str(overlapping), overlapping.stage.addressing, ) raise OverlappingOutputPathsError(parent, overlapping, msg) outs[out_key] = out for stage in stages: out = outs.shortest_prefix(PathInfo(stage.path).parts).value if out: raise StagePathAsOutputError(stage, str(out)) # Building graph G.add_nodes_from(stages) for stage in stages: for dep in stage.deps: if dep.path_info is None: continue dep_key = dep.path_info.parts overlapping = [n.value for n in outs.prefixes(dep_key)] if outs.has_subtrie(dep_key): overlapping.extend(outs.values(prefix=dep_key)) G.add_edges_from((stage, out.stage) for out in overlapping) check_acyclic(G) return G @cached_property def graph(self): return self._collect_graph(self.stages) @cached_property def pipelines(self): return get_pipelines(self.graph) @cached_property def stages(self): """ Walks down the root directory looking for Dvcfiles, skipping the directories that are related with any SCM (e.g. `.git`), DVC itself (`.dvc`), or directories tracked by DVC (e.g. `dvc add data` would skip `data/`) NOTE: For large repos, this could be an expensive operation. Consider using some memoization. """ return self._collect_stages() @cached_property def plot_templates(self): from .plots.template import PlotTemplates return PlotTemplates(self.dvc_dir) def _collect_stages(self): stages = [] outs = set() for root, dirs, files in self.tree.walk(self.root_dir): for file_name in filter(is_valid_filename, files): new_stages = self.get_stages(os.path.join(root, file_name)) stages.extend(new_stages) outs.update(out.fspath for stage in new_stages for out in stage.outs if out.scheme == "local") dirs[:] = [d for d in dirs if os.path.join(root, d) not in outs] return stages def find_outs_by_path(self, path, outs=None, recursive=False, strict=True): if not outs: outs = [out for stage in self.stages for out in stage.outs] abs_path = os.path.abspath(path) path_info = PathInfo(abs_path) match = path_info.__eq__ if strict else path_info.isin_or_eq def func(out): if out.scheme == "local" and match(out.path_info): return True if recursive and out.path_info.isin(path_info): return True return False matched = list(filter(func, outs)) if not matched: raise OutputNotFoundError(path, self) return matched def find_out_by_relpath(self, relpath): path = os.path.join(self.root_dir, relpath) (out, ) = self.find_outs_by_path(path) return out def is_dvc_internal(self, path): path_parts = os.path.normpath(path).split(os.path.sep) return self.DVC_DIR in path_parts @contextmanager def open_by_relpath(self, path, remote=None, mode="r", encoding=None): """Opens a specified resource as a file descriptor""" tree = RepoTree(self, stream=True) path = os.path.join(self.root_dir, path) try: with self.state: with tree.open( os.path.join(self.root_dir, path), mode=mode, encoding=encoding, remote=remote, ) as fobj: yield fobj except FileNotFoundError as exc: raise FileMissingError(path) from exc except IsADirectoryError as exc: raise DvcIsADirectoryError(f"'{path}' is a directory") from exc def close(self): self.scm.close() @locked def checkout(self, *args, **kwargs): return self._checkout(*args, **kwargs) @locked def fetch(self, *args, **kwargs): return self._fetch(*args, **kwargs) def _reset(self): self.__dict__.pop("graph", None) self.__dict__.pop("stages", None) self.__dict__.pop("pipelines", None) self.tree.__dict__.pop("dvcignore", None)
def test_is_protected(tmp_dir, dvc, link_name): tree = LocalRemoteTree(dvc, {}) link_method = getattr(tree, link_name) (tmp_dir / "foo").write_text("foo") foo = PathInfo(tmp_dir / "foo") link = PathInfo(tmp_dir / "link") link_method(foo, link) assert not tree.is_protected(foo) assert not tree.is_protected(link) tree.protect(foo) assert tree.is_protected(foo) assert tree.is_protected(link) tree.unprotect(link) assert not tree.is_protected(link) if os.name == "nt" and link_name == "hardlink": # NOTE: NTFS doesn't allow deleting read-only files, which forces us to # set write perms on the link, which propagates to the source. assert not tree.is_protected(foo) else: assert tree.is_protected(foo)
class State: # pylint: disable=too-many-instance-attributes """Class for the state database. Args: repo (dvc.repo.Repo): repo instance that this state belongs to. config (configobj.ConfigObj): config for the state. Raises: StateVersionTooNewError: thrown when dvc version is older than the state database version. """ VERSION = 3 STATE_FILE = "state" STATE_TABLE = "state" STATE_TABLE_LAYOUT = ("inode INTEGER PRIMARY KEY, " "mtime TEXT NOT NULL, " "size TEXT NOT NULL, " "md5 TEXT NOT NULL, " "timestamp TEXT NOT NULL") STATE_INFO_TABLE = "state_info" STATE_INFO_TABLE_LAYOUT = "count INTEGER" STATE_INFO_ROW = 1 LINK_STATE_TABLE = "link_state" LINK_STATE_TABLE_LAYOUT = ("path TEXT PRIMARY KEY, " "inode INTEGER NOT NULL, " "mtime TEXT NOT NULL") STATE_ROW_LIMIT = 100000000 STATE_ROW_CLEANUP_QUOTA = 50 MAX_INT = 2**63 - 1 MAX_UINT = 2**64 - 2 def __init__(self, local_cache): from dvc.tree.local import LocalRemoteTree repo = local_cache.repo self.repo = repo self.root_dir = repo.root_dir self.tree = LocalRemoteTree(None, {"url": self.root_dir}) state_config = repo.config.get("state", {}) self.row_limit = state_config.get("row_limit", self.STATE_ROW_LIMIT) self.row_cleanup_quota = state_config.get("row_cleanup_quota", self.STATE_ROW_CLEANUP_QUOTA) if not repo.tmp_dir: self.state_file = None return self.state_file = os.path.join(repo.tmp_dir, self.STATE_FILE) # https://www.sqlite.org/tempfiles.html self.temp_files = [ self.state_file + "-journal", self.state_file + "-wal", ] self.database = None self.cursor = None self.inserts = 0 @property def files(self): return self.temp_files + [self.state_file] def __enter__(self): self.load() def __exit__(self, typ, value, tbck): self.dump() def _execute(self, cmd, parameters=()): logger.trace(cmd) return self.cursor.execute(cmd, parameters) def _fetchall(self): ret = self.cursor.fetchall() logger.debug("fetched: %s", ret) return ret def _to_sqlite(self, num): assert num >= 0 assert num < self.MAX_UINT # NOTE: sqlite stores unit as signed ints, so maximum uint is 2^63-1 # see http://jakegoulding.com/blog/2011/02/06/sqlite-64-bit-integers/ if num > self.MAX_INT: ret = -(num - self.MAX_INT) else: ret = num assert self._from_sqlite(ret) == num return ret def _from_sqlite(self, num): assert abs(num) <= self.MAX_INT if num < 0: return abs(num) + self.MAX_INT assert num < self.MAX_UINT assert num >= 0 return num def _prepare_db(self, empty=False): from dvc import __version__ if not empty: cmd = "PRAGMA user_version;" self._execute(cmd) ret = self._fetchall() assert len(ret) == 1 assert len(ret[0]) == 1 assert isinstance(ret[0][0], int) version = ret[0][0] if version > self.VERSION: raise StateVersionTooNewError(__version__, self.VERSION, version) elif version < self.VERSION: logger.warning( "State file version '%d' is too old. " "Reformatting to the current version '%d'.", version, self.VERSION, ) cmd = "DROP TABLE IF EXISTS {};" self._execute(cmd.format(self.STATE_TABLE)) self._execute(cmd.format(self.STATE_INFO_TABLE)) self._execute(cmd.format(self.LINK_STATE_TABLE)) # Check that the state file is indeed a database cmd = "CREATE TABLE IF NOT EXISTS {} ({})" self._execute(cmd.format(self.STATE_TABLE, self.STATE_TABLE_LAYOUT)) self._execute( cmd.format(self.STATE_INFO_TABLE, self.STATE_INFO_TABLE_LAYOUT)) self._execute( cmd.format(self.LINK_STATE_TABLE, self.LINK_STATE_TABLE_LAYOUT)) cmd = ("INSERT OR IGNORE INTO {} (count) SELECT 0 " "WHERE NOT EXISTS (SELECT * FROM {})") self._execute(cmd.format(self.STATE_INFO_TABLE, self.STATE_INFO_TABLE)) cmd = "PRAGMA user_version = {};" self._execute(cmd.format(self.VERSION)) def load(self): """Loads state database.""" retries = 1 while True: assert self.database is None assert self.cursor is None assert self.inserts == 0 empty = not os.path.exists(self.state_file) # NOTE: we use nolock option because fcntl() lock sqlite uses # doesn't work on some older NFS/CIFS filesystems. # This opens a possibility of data corruption by concurrent writes, # which is prevented by repo lock. self.database = _connect_sqlite(self.state_file, {"nolock": 1}) self.cursor = self.database.cursor() # Try loading once to check that the file is indeed a database # and reformat it if it is not. try: self._prepare_db(empty=empty) return except sqlite3.DatabaseError: self.cursor.close() self.database.close() self.database = None self.cursor = None self.inserts = 0 if retries > 0: os.unlink(self.state_file) retries -= 1 else: raise def _vacuum(self): # NOTE: see https://bugs.python.org/issue28518 self.database.isolation_level = None self._execute("VACUUM") self.database.isolation_level = "" def dump(self): """Saves state database.""" assert self.database is not None cmd = "SELECT count from {} WHERE rowid=?".format( self.STATE_INFO_TABLE) self._execute(cmd, (self.STATE_INFO_ROW, )) ret = self._fetchall() assert len(ret) == 1 assert len(ret[0]) == 1 count = self._from_sqlite(ret[0][0]) + self.inserts if count > self.row_limit: msg = "cleaning up state, this might take a while." logger.warning(msg) delete = count - self.row_limit delete += int(self.row_limit * (self.row_cleanup_quota / 100.0)) cmd = ( "DELETE FROM {} WHERE timestamp IN (" "SELECT timestamp FROM {} ORDER BY timestamp ASC LIMIT {});") self._execute( cmd.format(self.STATE_TABLE, self.STATE_TABLE, delete)) self._vacuum() cmd = "SELECT COUNT(*) FROM {}" self._execute(cmd.format(self.STATE_TABLE)) ret = self._fetchall() assert len(ret) == 1 assert len(ret[0]) == 1 count = ret[0][0] cmd = "UPDATE {} SET count = ? WHERE rowid = ?".format( self.STATE_INFO_TABLE) self._execute(cmd, (self._to_sqlite(count), self.STATE_INFO_ROW)) self.database.commit() self.cursor.close() self.database.close() self.database = None self.cursor = None self.inserts = 0 @staticmethod def _file_metadata_changed(actual_mtime, mtime, actual_size, size): return actual_mtime != mtime or actual_size != size def _update_state_record_timestamp_for_inode(self, actual_inode): cmd = "UPDATE {} SET timestamp = ? WHERE inode = ?".format( self.STATE_TABLE) self._execute(cmd, (current_timestamp(), self._to_sqlite(actual_inode))) def _update_state_for_path_changed(self, actual_inode, actual_mtime, actual_size, checksum): cmd = ("UPDATE {} SET " "mtime = ?, size = ?, " "md5 = ?, timestamp = ? " "WHERE inode = ?").format(self.STATE_TABLE) self._execute( cmd, ( actual_mtime, actual_size, checksum, current_timestamp(), self._to_sqlite(actual_inode), ), ) def _insert_new_state_record(self, actual_inode, actual_mtime, actual_size, checksum): assert checksum is not None cmd = ("INSERT INTO {}(inode, mtime, size, md5, timestamp) " "VALUES (?, ?, ?, ?, ?)").format(self.STATE_TABLE) self._execute( cmd, ( self._to_sqlite(actual_inode), actual_mtime, actual_size, checksum, current_timestamp(), ), ) self.inserts += 1 def get_state_record_for_inode(self, inode): cmd = ("SELECT mtime, size, md5, timestamp from {} WHERE " "inode=?".format(self.STATE_TABLE)) self._execute(cmd, (self._to_sqlite(inode), )) results = self._fetchall() if results: # uniqueness constrain on inode assert len(results) == 1 return results[0] return None def save(self, path_info, checksum): """Save checksum for the specified path info. Args: path_info (dict): path_info to save checksum for. checksum (str): checksum to save. """ assert isinstance(path_info, str) or path_info.scheme == "local" assert checksum is not None assert os.path.exists(path_info) actual_mtime, actual_size = get_mtime_and_size(path_info, self.tree) actual_inode = get_inode(path_info) existing_record = self.get_state_record_for_inode(actual_inode) if not existing_record: self._insert_new_state_record(actual_inode, actual_mtime, actual_size, checksum) return self._update_state_for_path_changed(actual_inode, actual_mtime, actual_size, checksum) def get(self, path_info): """Gets the checksum for the specified path info. Checksum will be retrieved from the state database if available. Args: path_info (dict): path info to get the checksum for. Returns: str or None: checksum for the specified path info or None if it doesn't exist in the state database. """ assert isinstance(path_info, str) or path_info.scheme == "local" path = os.fspath(path_info) # NOTE: use os.path.exists instead of LocalRemoteTree.exists # because it uses lexists() and will return True for broken # symlinks that we cannot stat() in get_mtime_and_size if not os.path.exists(path): return None actual_mtime, actual_size = get_mtime_and_size(path, self.tree) actual_inode = get_inode(path) existing_record = self.get_state_record_for_inode(actual_inode) if not existing_record: return None mtime, size, checksum, _ = existing_record if self._file_metadata_changed(actual_mtime, mtime, actual_size, size): return None self._update_state_record_timestamp_for_inode(actual_inode) return checksum def save_link(self, path_info): """Adds the specified path to the list of links created by dvc. This list is later used on `dvc checkout` to cleanup old links. Args: path_info (dict): path info to add to the list of links. """ assert isinstance(path_info, str) or path_info.scheme == "local" if not self.tree.exists(path_info): return mtime, _ = get_mtime_and_size(path_info, self.tree) inode = get_inode(path_info) relative_path = relpath(path_info, self.root_dir) cmd = "REPLACE INTO {}(path, inode, mtime) " "VALUES (?, ?, ?)".format( self.LINK_STATE_TABLE) self._execute(cmd, (relative_path, self._to_sqlite(inode), mtime)) def get_unused_links(self, used): """Removes all saved links except the ones that are used. Args: used (list): list of used links that should not be removed. """ unused = [] self._execute(f"SELECT * FROM {self.LINK_STATE_TABLE}") for row in self.cursor: relative_path, inode, mtime = row inode = self._from_sqlite(inode) path = os.path.join(self.root_dir, relative_path) if path in used or not self.tree.exists(path): continue actual_inode = get_inode(path) actual_mtime, _ = get_mtime_and_size(path, self.tree) if (inode, mtime) == (actual_inode, actual_mtime): logger.debug("Removing '%s' as unused link.", path) unused.append(relative_path) return unused def remove_links(self, unused): for path in unused: remove(path) for chunk_unused in to_chunks(unused, chunk_size=SQLITE_MAX_VARIABLES_NUMBER): cmd = "DELETE FROM {} WHERE path IN ({})".format( self.LINK_STATE_TABLE, ",".join(["?"] * len(chunk_unused))) self._execute(cmd, tuple(chunk_unused))
def setUp(self): super().setUp() self.tree = LocalRemoteTree(None, {})
class TestLocalRemoteTree(TestDir): def setUp(self): super().setUp() self.tree = LocalRemoteTree(None, {}) def test_open(self): with self.tree.open(self.FOO) as fd: self.assertEqual(fd.read(), self.FOO_CONTENTS) with self.tree.open(self.UNICODE, encoding="utf-8") as fd: self.assertEqual(fd.read(), self.UNICODE_CONTENTS) def test_exists(self): self.assertTrue(self.tree.exists(self.FOO)) self.assertTrue(self.tree.exists(self.UNICODE)) self.assertFalse(self.tree.exists("not-existing-file")) def test_isdir(self): self.assertTrue(self.tree.isdir(self.DATA_DIR)) self.assertFalse(self.tree.isdir(self.FOO)) self.assertFalse(self.tree.isdir("not-existing-file")) def test_isfile(self): self.assertTrue(self.tree.isfile(self.FOO)) self.assertFalse(self.tree.isfile(self.DATA_DIR)) self.assertFalse(self.tree.isfile("not-existing-file"))