class DependencyManager(StateTransitioner, BaseDependencyManager):
    """
    This dependency manager downloads dependency bundles from Codalab server
    to the local filesystem. It caches all downloaded dependencies but cleans up the
    old ones if the disk use hits the given threshold. It's also NFS-safe.
    In this class, dependencies are uniquely identified by DependencyKey.
    """

    DEPENDENCIES_DIR_NAME = 'dependencies'
    DEPENDENCY_FAILURE_COOLDOWN = 10
    # TODO(bkgoksel): The server writes these to the worker_dependencies table, which stores the dependencies
    # json as a SqlAlchemy LargeBinary, which defaults to MySQL BLOB, which has a size limit of
    # 65K. For now we limit this value to about 58K to avoid any issues but we probably want to do
    # something better (either specify MEDIUMBLOB in the SqlAlchemy definition of the table or change
    # the data format of how we store this)
    MAX_SERIALIZED_LEN = 60000

    # If it has been this long since a worker has downloaded anything, another worker will take over downloading.
    DEPENDENCY_DOWNLOAD_TIMEOUT_SECONDS = 5 * 60

    def __init__(
        self,
        commit_file: str,
        bundle_service: BundleServiceClient,
        worker_dir: str,
        max_cache_size_bytes: int,
        download_dependencies_max_retries: int,
    ):
        super(DependencyManager, self).__init__()
        self.add_transition(DependencyStage.DOWNLOADING, self._transition_from_DOWNLOADING)
        self.add_terminal(DependencyStage.READY)
        self.add_terminal(DependencyStage.FAILED)

        self._id: str = "worker-dependency-manager-{}".format(uuid.uuid4().hex[:8])
        self._state_committer = JsonStateCommitter(commit_file)
        self._bundle_service = bundle_service
        self._max_cache_size_bytes = max_cache_size_bytes
        self.dependencies_dir = os.path.join(worker_dir, DependencyManager.DEPENDENCIES_DIR_NAME)
        self._download_dependencies_max_retries = download_dependencies_max_retries
        if not os.path.exists(self.dependencies_dir):
            logger.info('{} doesn\'t exist, creating.'.format(self.dependencies_dir))
            os.makedirs(self.dependencies_dir, 0o770)

        # Create a lock for concurrency over NFS
        # Create a separate locks directory to hold the lock files.
        # Each lock file is created when a process tries to claim the main lock.
        locks_claims_dir: str = os.path.join(worker_dir, 'locks_claims')
        try:
            os.makedirs(locks_claims_dir)
        except FileExistsError:
            logger.info(f"A locks directory at {locks_claims_dir} already exists.")
        self._state_lock = NFSLock(os.path.join(locks_claims_dir, 'state.lock'))

        # File paths that are currently being used to store dependencies. Used to prevent conflicts
        self._paths: Set[str] = set()
        # DependencyKey -> WorkerThread(thread, success, failure_message)
        self._downloading = ThreadDict(fields={'success': False, 'failure_message': None})
        # Sync states between dependency-state.json and dependency directories on the local file system.
        self._sync_state()

        self._stop = False
        self._main_thread = None
        logger.info(f"Initialized Dependency Manager with ID: {self._id}")

    def _sync_state(self):
        """
        Synchronize dependency states between dependencies-state.json and the local file system as follows:
        1. dependencies and paths: populated from dependencies-state.json
        2. directories on the local file system: the bundle contents
        This function forces the 1 and 2 to be in sync by taking the intersection (e.g., deleting bundles from the
        local file system that don't appear in the dependencies-state.json and vice-versa)
        """
        with self._state_lock:
            # Load states from dependencies-state.json, which contains information about bundles (e.g., state,
            # dependencies, last used, etc.).
            if self._state_committer.state_file_exists:
                # If the state file exists, do not pass in a default. It's critical that we read the contents
                # of the state file, as this method prunes dependencies. If we can't read the contents of the
                # state file, fail immediately.
                dependencies, paths = self._fetch_state()
                logger.info(
                    'Found {} dependencies, {} paths from cache.'.format(
                        len(dependencies), len(paths)
                    )
                )
            else:
                dependencies: Dict[DependencyKey, DependencyState] = dict()
                paths: Set[str] = set()
                logger.info(
                    f'State file did not exist. Will create one at path {self._state_committer.path}.'
                )

            # Get the paths that exist in dependency state, loaded path and
            # the local file system (the dependency directories under self.dependencies_dir)
            local_directories = set(os.listdir(self.dependencies_dir))
            paths_in_loaded_state = [dep_state.path for dep_state in dependencies.values()]
            paths = paths.intersection(paths_in_loaded_state).intersection(local_directories)

            # Remove the orphaned dependencies if they don't exist in paths
            # (intersection of paths in dependency state, loaded paths and the paths on the local file system)
            dependencies_to_remove = [
                dep for dep, dep_state in dependencies.items() if dep_state.path not in paths
            ]
            for dep in dependencies_to_remove:
                logger.info(
                    "Dependency {} in dependency state but its path {} doesn't exist on the local file system. "
                    "Removing it from dependency state.".format(
                        dep, os.path.join(self.dependencies_dir, dependencies[dep].path)
                    )
                )
                del dependencies[dep]

            # Remove the orphaned directories from the local file system
            directories_to_remove = local_directories - paths
            for directory in directories_to_remove:
                full_path = os.path.join(self.dependencies_dir, directory)
                if os.path.exists(full_path):
                    logger.info(
                        "Remove orphaned directory {} from the local file system.".format(full_path)
                    )
                    remove_path(full_path)

            # Save the current synced state back to the state file: dependency-state.json as
            # the current state might have been changed during the state syncing phase
            self._commit_state(dependencies, paths)

    def _fetch_state(self, default=None):
        """
        Fetch state from JSON file stored on disk.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        WARNING: If a value for `default` is specified, errors will be silently handled.
        """
        assert self._state_lock.is_locked
        state: DependencyManagerState = self._state_committer.load(default)
        dependencies: Dict[DependencyKey, DependencyState] = state['dependencies']
        paths: Set[str] = state['paths']
        return dependencies, paths

    def _fetch_dependencies(self, default=None) -> Dict[DependencyKey, DependencyState]:
        """
        Fetch dependencies from JSON file stored on disk.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        WARNING: If a value for `default` is specified, errors will be silently handled.
        """
        assert self._state_lock.is_locked
        dependencies, _ = self._fetch_state(default)
        return dependencies

    def _commit_state(self, dependencies: Dict[DependencyKey, DependencyState], paths: Set[str]):
        """
        Update state in dependencies JSON file stored on disk.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        """
        assert self._state_lock.is_locked
        state: DependencyManagerState = {'dependencies': dependencies, 'paths': paths}
        self._state_committer.commit(state)

    def start(self):
        logger.info('Starting local dependency manager...')

        def loop(self):
            while not self._stop:
                try:
                    self._transition_dependencies()
                    self._cleanup()
                except Exception:
                    traceback.print_exc()
                time.sleep(1)

        self._main_thread = threading.Thread(target=loop, args=[self])
        self._main_thread.start()

    def stop(self):
        logger.info('Stopping local dependency manager...')
        self._stop = True
        self._downloading.stop()
        self._main_thread.join()
        self._state_lock.release()
        logger.info('Stopped local dependency manager.')

    def _transition_dependencies(self):
        with self._state_lock:
            try:
                dependencies, paths = self._fetch_state()

                # Update the class variable _paths as the transition function may update it
                self._paths = paths
                for dep_key, dep_state in dependencies.items():
                    dependencies[dep_key] = self.transition(dep_state)
                self._commit_state(dependencies, self._paths)
            except (ValueError, EnvironmentError):
                # Do nothing if an error is thrown while reading from the state file
                logging.exception("Error reading from state file while transitioning dependencies")
                pass

    def _prune_failed_dependencies(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds so that further runs
        get to retry the download. Without pruning, any future run depending on a
        failed dependency would automatically fail indefinitely.
        """
        with self._state_lock:
            try:
                dependencies, paths = self._fetch_state()
                failed_deps: Dict[DependencyKey, DependencyState] = {
                    dep_key: dep_state
                    for dep_key, dep_state in dependencies.items()
                    if dep_state.stage == DependencyStage.FAILED
                    and time.time() - dep_state.last_used
                    > DependencyManager.DEPENDENCY_FAILURE_COOLDOWN
                }
                if len(failed_deps) == 0:
                    return

                for dep_key, dep_state in failed_deps.items():
                    self._delete_dependency(dep_key, dependencies, paths)
                self._commit_state(dependencies, paths)
            except (ValueError, EnvironmentError):
                # Do nothing if an error is thrown while reading from the state file
                logging.exception(
                    "Error reading from state file while pruning failed dependencies."
                )
                pass

    def _cleanup(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds.
        Limit the disk usage of the dependencies (both the bundle files and the serialized state file size)
        Deletes oldest failed dependencies first and then oldest finished dependencies.
        Doesn't touch downloading dependencies.
        """
        self._prune_failed_dependencies()

        while True:
            with self._state_lock:
                try:
                    dependencies, paths = self._fetch_state()
                except (ValueError, EnvironmentError):
                    # Do nothing if an error is thrown while reading from the state file
                    logging.exception(
                        "Error reading from state file when cleaning up dependencies. Trying again..."
                    )
                    continue

                bytes_used = sum(dep_state.size_bytes for dep_state in dependencies.values())
                serialized_length = len(codalab.worker.pyjson.dumps(dependencies))
                if (
                    bytes_used > self._max_cache_size_bytes
                    or serialized_length > DependencyManager.MAX_SERIALIZED_LEN
                ):
                    logger.debug(
                        '%d dependencies, disk usage: %s (max %s), serialized size: %s (max %s)',
                        len(dependencies),
                        size_str(bytes_used),
                        size_str(self._max_cache_size_bytes),
                        size_str(serialized_length),
                        DependencyManager.MAX_SERIALIZED_LEN,
                    )
                    ready_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in dependencies.items()
                        if dep_state.stage == DependencyStage.READY and not dep_state.dependents
                    }
                    failed_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in dependencies.items()
                        if dep_state.stage == DependencyStage.FAILED
                    }

                    if failed_deps:
                        dep_key_to_remove = min(
                            failed_deps.items(), key=lambda dep: dep[1].last_used
                        )[0]
                    elif ready_deps:
                        dep_key_to_remove = min(
                            ready_deps.items(), key=lambda dep: dep[1].last_used
                        )[0]
                    else:
                        logger.info(
                            'Dependency quota full but there are only downloading dependencies, not cleaning up '
                            'until downloads are over.'
                        )
                        break
                    if dep_key_to_remove:
                        self._delete_dependency(dep_key_to_remove, dependencies, paths)
                        self._commit_state(dependencies, paths)
                else:
                    break

    def _delete_dependency(self, dep_key, dependencies, paths):
        """
        Remove the given dependency from the manager's state
        Modifies `dependencies` and `paths` that are passed in.
        Also deletes any known files on the filesystem if any exist.

        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        """
        assert self._state_lock.is_locked

        if dep_key in dependencies:
            try:
                path_to_remove = dependencies[dep_key].path
                paths.remove(path_to_remove)
                # Deletes dependency content from disk
                remove_path(path_to_remove)
            except Exception:
                pass
            finally:
                del dependencies[dep_key]
                logger.info(f"Deleted dependency {dep_key}.")

    def has(self, dependency_key):
        """
        Takes a DependencyKey and returns true if the manager has processed this dependency
        """
        with self._state_lock:
            dependencies: Dict[DependencyKey, DependencyState] = self._fetch_dependencies()
            return dependency_key in dependencies

    def get(self, uuid: str, dependency_key: DependencyKey) -> DependencyState:
        """
        Request the dependency for the run with uuid, registering uuid as a dependent of this dependency
        """
        with self._state_lock:
            dependencies, paths = self._fetch_state()

            now = time.time()
            # Add dependency state if it does not exist
            if dependency_key not in dependencies:
                dependencies[dependency_key] = DependencyState(
                    stage=DependencyStage.DOWNLOADING,
                    downloading_by=None,
                    dependency_key=dependency_key,
                    path=self._assign_path(paths, dependency_key),
                    size_bytes=0,
                    dependents={uuid},
                    last_used=now,
                    last_downloading=now,
                    message="Starting download",
                    killed=False,
                )

            # Update last_used as long as it isn't in a FAILED stage
            if dependencies[dependency_key].stage != DependencyStage.FAILED:
                dependencies[dependency_key].dependents.add(uuid)
                dependencies[dependency_key] = dependencies[dependency_key]._replace(last_used=now)

            self._commit_state(dependencies, paths)
            return dependencies[dependency_key]

    def release(self, uuid, dependency_key):
        """
        Register that the run with uuid is no longer dependent on this dependency
        If no more runs are dependent on this dependency, kill it.
        """
        with self._state_lock:
            dependencies, paths = self._fetch_state()

            if dependency_key in dependencies:
                dep_state = dependencies[dependency_key]
                if uuid in dep_state.dependents:
                    dep_state.dependents.remove(uuid)
                if not dep_state.dependents:
                    dep_state = dep_state._replace(killed=True)
                    dependencies[dependency_key] = dep_state
                self._commit_state(dependencies, paths)

    def _assign_path(self, paths: Set[str], dependency_key: DependencyKey) -> str:
        """
        Checks the current path against `paths`.
        Normalize the path for the dependency by replacing / with _, avoiding conflicts.
        Adds the new path to `paths`.
        """
        path: str = (
            os.path.join(dependency_key.parent_uuid, dependency_key.parent_path)
            if dependency_key.parent_path
            else dependency_key.parent_uuid
        )
        path = path.replace(os.path.sep, '_')

        # You could have a conflict between, for example a/b_c and a_b/c
        while path in paths:
            path = path + '_'

        paths.add(path)
        return path

    def _store_dependency(self, dependency_path, fileobj, target_type):
        """
        Copy the dependency fileobj to its path on the local filesystem
        Overwrite existing files by the same name if found
        (may happen if filesystem modified outside the dependency manager,
         for example during an update if the state gets reset but filesystem
         doesn't get cleared)
        """
        try:
            if os.path.exists(dependency_path):
                logger.info('Path %s already exists, overwriting', dependency_path)
                if os.path.isdir(dependency_path):
                    shutil.rmtree(dependency_path)
                else:
                    os.remove(dependency_path)
            if target_type == 'directory':
                un_tar_directory(fileobj, dependency_path, 'gz')
            else:
                with open(dependency_path, 'wb') as f:
                    logger.debug('copying file to %s', dependency_path)
                    shutil.copyfileobj(fileobj, f)
        except Exception:
            raise

    @property
    def all_dependencies(self) -> List[DependencyKey]:
        with self._state_lock:
            dependencies: Dict[DependencyKey, DependencyState] = self._fetch_dependencies(
                default={'dependencies': {}, 'paths': set()}
            )
            return list(dependencies.keys())

    def _transition_from_DOWNLOADING(self, dependency_state: DependencyState):
        """
        Checks if the dependency is downloading or not.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        """
        assert self._state_lock.is_locked

        def download():
            """
            Runs in a separate thread. Only one worker should be running this in a thread at a time.
            """

            def update_state_and_check_killed(bytes_downloaded):
                """
                Callback method for bundle service client updates dependency state and
                raises DownloadAbortedException if download is killed by dep. manager

                Note: This function needs to be fast, since it's called every time fileobj.read is called.
                      Therefore, we keep a copy of the state in memory (self._downloading) and copy over
                      non-critical fields (last_downloading, size_bytes and message) when the download transition
                      function is executed.
                """
                state = self._downloading[dependency_state.dependency_key]['state']
                if state.killed:
                    raise DownloadAbortedException("Aborted by user")
                self._downloading[dependency_state.dependency_key]['state'] = state._replace(
                    last_downloading=time.time(),
                    size_bytes=bytes_downloaded,
                    message=f"Downloading dependency: {str(bytes_downloaded)} downloaded",
                )

            dependency_path = os.path.join(self.dependencies_dir, dependency_state.path)
            logger.debug('Downloading dependency %s', dependency_state.dependency_key)

            attempt = 0
            while attempt < self._download_dependencies_max_retries:
                try:
                    # Start async download to the fileobj
                    target_type = self._bundle_service.get_bundle_info(
                        dependency_state.dependency_key.parent_uuid,
                        dependency_state.dependency_key.parent_path,
                    )["type"]
                    fileobj = self._bundle_service.get_bundle_contents(
                        dependency_state.dependency_key.parent_uuid,
                        dependency_state.dependency_key.parent_path,
                    )
                    with closing(fileobj):
                        # "Bug" the fileobj's read function so that we can keep
                        # track of the number of bytes downloaded so far.
                        original_read_method = fileobj.read
                        bytes_downloaded = [0]

                        def interruptable_read(*args, **kwargs):
                            data = original_read_method(*args, **kwargs)
                            bytes_downloaded[0] += len(data)
                            update_state_and_check_killed(bytes_downloaded[0])
                            return data

                        fileobj.read = interruptable_read

                        # Start copying the fileobj to filesystem dependency path
                        # Note: Overwrites if something already exists at dependency_path, such as when
                        #       another worker partially downloads a dependency and then goes offline.
                        self._store_dependency(dependency_path, fileobj, target_type)

                    logger.debug(
                        'Finished downloading %s dependency %s to %s',
                        target_type,
                        dependency_state.dependency_key,
                        dependency_path,
                    )
                    self._downloading[dependency_state.dependency_key]['success'] = True

                except Exception as e:
                    attempt += 1
                    if attempt >= self._download_dependencies_max_retries:
                        self._downloading[dependency_state.dependency_key]['success'] = False
                        self._downloading[dependency_state.dependency_key][
                            'failure_message'
                        ] = f"Dependency download failed: {e} "
                    else:
                        logger.warning(
                            f'Failed to download {dependency_state.dependency_key} after {attempt} attempt(s) '
                            f'due to {e}. Retrying up to {self._download_dependencies_max_retries} times...',
                            exc_info=True,
                        )
                else:
                    # Break out of the retry loop if no exceptions were thrown
                    break

        # Start downloading if either:
        # 1. No other dependency manager is downloading the dependency
        # 2. There was a dependency manager downloading a dependency, but it has been longer than
        #    DEPENDENCY_DOWNLOAD_TIMEOUT_SECONDS since it last downloaded anything for the particular dependency.
        now = time.time()
        if not dependency_state.downloading_by or (
            dependency_state.downloading_by
            and now - dependency_state.last_downloading
            >= DependencyManager.DEPENDENCY_DOWNLOAD_TIMEOUT_SECONDS
        ):
            if not dependency_state.downloading_by:
                logger.info(
                    f"{self._id} will start downloading dependency: {dependency_state.dependency_key}."
                )
            else:
                logger.info(
                    f"{dependency_state.downloading_by} stopped downloading "
                    f"dependency: {dependency_state.dependency_key}. {self._id} will restart downloading."
                )

            self._downloading.add_if_new(
                dependency_state.dependency_key, threading.Thread(target=download, args=[])
            )
            self._downloading[dependency_state.dependency_key]['state'] = dependency_state
            dependency_state = dependency_state._replace(downloading_by=self._id)

        # If there is already another worker downloading the dependency,
        # just return the dependency state as downloading is in progress.
        if dependency_state.downloading_by != self._id:
            logger.debug(
                f"Waiting for {dependency_state.downloading_by} "
                f"to download dependency: {dependency_state.dependency_key}"
            )
            return dependency_state

        if (
            dependency_state.dependency_key in self._downloading
            and self._downloading[dependency_state.dependency_key].is_alive()
        ):
            logger.debug(
                f"This dependency manager ({dependency_state.downloading_by}) "
                f"is downloading dependency: {dependency_state.dependency_key}"
            )
            state = self._downloading[dependency_state.dependency_key]['state']
            # Copy over the values of the non-critical fields of the state in memory
            # that is being updated by the download thread.
            return dependency_state._replace(
                last_downloading=state.last_downloading,
                size_bytes=state.size_bytes,
                message=state.message,
            )

        # At this point, no thread is downloading the dependency, but the dependency is still
        # assigned to the current worker. Check if the download finished.
        success: bool = self._downloading[dependency_state.dependency_key]['success']
        failure_message: str = self._downloading[dependency_state.dependency_key]['failure_message']

        dependency_state = dependency_state._replace(downloading_by=None)
        self._downloading.remove(dependency_state.dependency_key)
        logger.info(
            f"Download complete. Removing downloading thread for {dependency_state.dependency_key}."
        )

        if success:
            return dependency_state._replace(
                stage=DependencyStage.READY, message="Download complete"
            )
        else:
            self._paths.remove(dependency_state.path)
            logger.error(
                f"Dependency {dependency_state.dependency_key} download failed: {failure_message}"
            )
            return dependency_state._replace(stage=DependencyStage.FAILED, message=failure_message)
Пример #2
0
class DependencyManager(StateTransitioner, BaseDependencyManager):
    """
    This dependency manager downloads dependency bundles from Codalab server
    to the local filesystem. It caches all downloaded dependencies but cleans up the
    old ones if the disk use hits the given threshold

    For this class dependencies are uniquely identified by DependencyKey
    """

    DEPENDENCIES_DIR_NAME = 'dependencies'
    DEPENDENCY_FAILURE_COOLDOWN = 10
    # TODO(bkgoksel): The server writes these to the worker_dependencies table, which stores the dependencies
    # json as a SqlAlchemy LargeBinary, which defaults to MySQL BLOB, which has a size limit of
    # 65K. For now we limit this value to about 58K to avoid any issues but we probably want to do
    # something better (either specify MEDIUMBLOB in the SqlAlchemy definition of the table or change
    # the data format of how we store this)
    MAX_SERIALIZED_LEN = 60000

    def __init__(self, commit_file, bundle_service, worker_dir,
                 max_cache_size_bytes):
        super(DependencyManager, self).__init__()
        self.add_transition(DependencyStage.DOWNLOADING,
                            self._transition_from_DOWNLOADING)
        self.add_terminal(DependencyStage.READY)
        self.add_terminal(DependencyStage.FAILED)

        self._state_committer = JsonStateCommitter(commit_file)
        self._bundle_service = bundle_service
        self._max_cache_size_bytes = max_cache_size_bytes
        self.dependencies_dir = os.path.join(
            worker_dir, DependencyManager.DEPENDENCIES_DIR_NAME)
        if not os.path.exists(self.dependencies_dir):
            logger.info('{} doesn\'t exist, creating.'.format(
                self.dependencies_dir))
            os.makedirs(self.dependencies_dir, 0o770)

        # Locks for concurrency
        self._dependency_locks = dict(
        )  # type: Dict[DependencyKey, threading.RLock]
        self._global_lock = threading.RLock()  # Used for add/remove actions
        self._paths_lock = threading.RLock()  # Used for path name computations

        # File paths that are currently being used to store dependencies. Used to prevent conflicts
        self._paths = set()
        # DependencyKey -> DependencyState
        self._dependencies = dict()
        # DependencyKey -> WorkerThread(thread, success, failure_message)
        self._downloading = ThreadDict(fields={
            'success': False,
            'failure_message': None
        })
        self._load_state()
        # Sync states between dependency-state.json and dependency directories on the local file system.
        self._sync_state()

        self._stop = False
        self._main_thread = None

    def _save_state(self):
        with self._global_lock, self._paths_lock:
            self._state_committer.commit({
                'dependencies': self._dependencies,
                'paths': self._paths
            })

    def _load_state(self):
        """
        Load states from dependencies-state.json, which contains information about bundles (e.g., state, dependencies,
        last used, etc.) and populates values for self._dependencies, self._dependency_locks, and self._paths
        """
        state = self._state_committer.load(default={
            'dependencies': {},
            'paths': set()
        })

        dependencies = {}
        dependency_locks = {}

        for dep, dep_state in state['dependencies'].items():
            dependencies[dep] = dep_state
            dependency_locks[dep] = threading.RLock()

        with self._global_lock, self._paths_lock:
            self._dependencies = dependencies
            self._dependency_locks = dependency_locks
            self._paths = state['paths']

        logger.info('Loaded {} dependencies, {} paths from cache.'.format(
            len(self._dependencies), len(self._paths)))

    def _sync_state(self):
        """
        Synchronize dependency states between dependencies-state.json and the local file system as follows:
        1. self._dependencies, self._dependency_locks, and self._paths: populated from dependencies-state.json
            in function _load_state()
        2. directories on the local file system: the bundle contents
        This function forces the 1 and 2 to be in sync by taking the intersection (e.g., deleting bundles from the
        local file system that don't appear in the dependencies-state.json and vice-versa)
        """
        # Get the paths that exist in dependency state, loaded path and
        # the local file system (the dependency directories under self.dependencies_dir)
        local_directories = set(os.listdir(self.dependencies_dir))
        paths_in_loaded_state = [
            dep_state.path for dep_state in self._dependencies.values()
        ]
        self._paths = self._paths.intersection(
            paths_in_loaded_state).intersection(local_directories)

        # Remove the orphaned dependencies from self._dependencies and
        # self._dependency_locks if they don't exist in self._paths (intersection of paths in dependency state,
        # loaded paths and the paths on the local file system)
        dependencies_to_remove = [
            dep for dep, dep_state in self._dependencies.items()
            if dep_state.path not in self._paths
        ]
        for dep in dependencies_to_remove:
            logger.info(
                "Dependency {} in dependency state but its path {} doesn't exist on the local file system. "
                "Removing it from dependency state.".format(
                    dep,
                    os.path.join(self.dependencies_dir,
                                 self._dependencies[dep].path)))
            del self._dependencies[dep]
            del self._dependency_locks[dep]

        # Remove the orphaned directories from the local file system
        directories_to_remove = local_directories - self._paths
        for dir in directories_to_remove:
            full_path = os.path.join(self.dependencies_dir, dir)
            logger.info(
                "Remove orphaned directory {} from the local file system.".
                format(full_path))
            remove_path(full_path)

        # Save the current synced state back to the state file: dependency-state.json as
        # the current state might have been changed during the state syncing phase
        self._save_state()

    def start(self):
        logger.info('Starting local dependency manager')

        def loop(self):
            while not self._stop:
                try:
                    self._process_dependencies()
                    self._save_state()
                    self._cleanup()
                    self._save_state()
                except Exception:
                    traceback.print_exc()
                time.sleep(1)

        self._main_thread = threading.Thread(target=loop, args=[self])
        self._main_thread.start()

    def stop(self):
        logger.info('Stopping local dependency manager')
        self._stop = True
        self._downloading.stop()
        self._main_thread.join()
        logger.info('Stopped local dependency manager')

    def _process_dependencies(self):
        for dep_key, dep_state in self._dependencies.items():
            with self._dependency_locks[dep_key]:
                self._dependencies[dep_key] = self.transition(dep_state)

    def _prune_failed_dependencies(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds so that further runs
        get to retry the download. Without pruning, any future run depending on a
        failed dependency would automatically fail indefinitely.
        """
        with self._global_lock:
            self._acquire_all_locks()
            failed_deps = {
                dep_key: dep_state
                for dep_key, dep_state in self._dependencies.items()
                if dep_state.stage == DependencyStage.FAILED and time.time() -
                dep_state.last_used >
                DependencyManager.DEPENDENCY_FAILURE_COOLDOWN
            }
            for dep_key, dep_state in failed_deps.items():
                self._delete_dependency(dep_key)
            self._release_all_locks()

    def _cleanup(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds.
        Limit the disk usage of the dependencies (both the bundle files and the serialized state file size)
        Deletes oldest failed dependencies first and then oldest finished dependencies.
        Doesn't touch downloading dependencies.
        """
        self._prune_failed_dependencies()
        # With all the locks (should be fast if no cleanup needed, otherwise make sure nothing is corrupted
        while True:
            with self._global_lock:
                self._acquire_all_locks()
                bytes_used = sum(dep_state.size_bytes
                                 for dep_state in self._dependencies.values())
                serialized_length = len(
                    codalab.worker.pyjson.dumps(self._dependencies))
                if (bytes_used > self._max_cache_size_bytes
                        or serialized_length >
                        DependencyManager.MAX_SERIALIZED_LEN):
                    logger.debug(
                        '%d dependencies in cache, disk usage: %s (max %s), serialized size: %s (max %s)',
                        len(self._dependencies),
                        size_str(bytes_used),
                        size_str(self._max_cache_size_bytes),
                        size_str(serialized_length),
                        DependencyManager.MAX_SERIALIZED_LEN,
                    )
                    ready_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in self._dependencies.items()
                        if dep_state.stage == DependencyStage.READY
                        and not dep_state.dependents
                    }
                    failed_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in self._dependencies.items()
                        if dep_state.stage == DependencyStage.FAILED
                    }
                    if failed_deps:
                        dep_key_to_remove = min(
                            failed_deps.items(),
                            key=lambda dep: dep[1].last_used)[0]
                    elif ready_deps:
                        dep_key_to_remove = min(
                            ready_deps.items(),
                            key=lambda dep: dep[1].last_used)[0]
                    else:
                        logger.info(
                            'Dependency quota full but there are only downloading dependencies, not cleaning up until downloads are over'
                        )
                        self._release_all_locks()
                        break
                    if dep_key_to_remove:
                        self._delete_dependency(dep_key_to_remove)
                    self._release_all_locks()
                else:
                    self._release_all_locks()
                    break

    def _delete_dependency(self, dependency_key):
        """
        Remove the given dependency from the manager's state
        Also delete any known files on the filesystem if any exist
        """
        if self._acquire_if_exists(dependency_key):
            try:
                path_to_remove = self._dependencies[dependency_key].path
                self._paths.remove(path_to_remove)
                remove_path(path_to_remove)
            except Exception:
                pass
            finally:
                del self._dependencies[dependency_key]
                self._dependency_locks[dependency_key].release()

    def has(self, dependency_key):
        """
        Takes a DependencyKey
        Returns true if the manager has processed this dependency
        """
        with self._global_lock:
            return dependency_key in self._dependencies

    def get(self, uuid, dependency_key):
        """
        Request the dependency for the run with uuid, registering uuid as a dependent of this dependency
        """
        now = time.time()
        if not self._acquire_if_exists(
                dependency_key):  # add dependency state if it does not exist
            with self._global_lock:
                self._dependency_locks[dependency_key] = threading.RLock()
                self._dependency_locks[dependency_key].acquire()
                self._dependencies[dependency_key] = DependencyState(
                    stage=DependencyStage.DOWNLOADING,
                    dependency_key=dependency_key,
                    path=self._assign_path(dependency_key),
                    size_bytes=0,
                    dependents=set([uuid]),
                    last_used=now,
                    message="Starting download",
                    killed=False,
                )

        # update last_used as long as it isn't in FAILED
        if self._dependencies[dependency_key].stage != DependencyStage.FAILED:
            self._dependencies[dependency_key].dependents.add(uuid)
            self._dependencies[dependency_key] = self._dependencies[
                dependency_key]._replace(last_used=now)
        self._dependency_locks[dependency_key].release()
        return self._dependencies[dependency_key]

    def release(self, uuid, dependency_key):
        """
        Register that the run with uuid is no longer dependent on this dependency
        If no more runs are dependent on this dependency, kill it
        """
        if self._acquire_if_exists(dependency_key):
            dep_state = self._dependencies[dependency_key]
            if uuid in dep_state.dependents:
                dep_state.dependents.remove(uuid)
            if not dep_state.dependents:
                dep_state = dep_state._replace(killed=True)
                self._dependencies[dependency_key] = dep_state
            self._dependency_locks[dependency_key].release()

    def _acquire_if_exists(self, dependency_key):
        """
        Safely acquires a lock for the given dependency if it exists
        Returns True if depedendency exists, False otherwise
        Callers should remember to release the lock
        """
        with self._global_lock:
            if dependency_key in self._dependencies:
                self._dependency_locks[dependency_key].acquire()
                return True
            else:
                return False

    def _acquire_all_locks(self):
        """
        Acquires all dependency locks in the thread it's called from
        """
        with self._global_lock:
            for dependency, lock in self._dependency_locks.items():
                lock.acquire()

    def _release_all_locks(self):
        """
        Releases all dependency locks in the thread it's called from
        """
        with self._global_lock:
            for dependency, lock in self._dependency_locks.items():
                lock.release()

    def _assign_path(self, dependency_key):
        """
        Normalize the path for the dependency by replacing / with _, avoiding conflicts
        """
        if dependency_key.parent_path:
            path = os.path.join(dependency_key.parent_uuid,
                                dependency_key.parent_path)
        else:
            path = dependency_key.parent_uuid
        path = path.replace(os.path.sep, '_')

        # You could have a conflict between, for example a/b_c and
        # a_b/c. We have to avoid those.
        with self._paths_lock:
            while path in self._paths:
                path = path + '_'
            self._paths.add(path)
        return path

    def _store_dependency(self, dependency_path, fileobj, target_type):
        """
        Copy the dependency fileobj to its path on the local filesystem
        Overwrite existing files by the same name if found
        (may happen if filesystem modified outside the dependency manager,
         for example during an update if the state gets reset but filesystem
         doesn't get cleared)
        """
        try:
            if os.path.exists(dependency_path):
                logger.info('Path %s already exists, overwriting',
                            dependency_path)
                if os.path.isdir(dependency_path):
                    shutil.rmtree(dependency_path)
                else:
                    os.remove(dependency_path)
            if target_type == 'directory':
                un_tar_directory(fileobj, dependency_path, 'gz')
            else:
                with open(dependency_path, 'wb') as f:
                    logger.debug('copying file to %s', dependency_path)
                    shutil.copyfileobj(fileobj, f)
        except Exception:
            raise

    @property
    def all_dependencies(self):
        with self._global_lock:
            return list(self._dependencies.keys())

    def _transition_from_DOWNLOADING(self, dependency_state):
        def download():
            def update_state_and_check_killed(bytes_downloaded):
                """
                Callback method for bundle service client updates dependency state and
                raises DownloadAbortedException if download is killed by dep. manager
                """
                with self._dependency_locks[dependency_state.dependency_key]:
                    state = self._dependencies[dependency_state.dependency_key]
                    if state.killed:
                        raise DownloadAbortedException("Aborted by user")
                    self._dependencies[
                        dependency_state.dependency_key] = state._replace(
                            size_bytes=bytes_downloaded,
                            message="Downloading dependency: %s downloaded" %
                            size_str(bytes_downloaded),
                        )

            dependency_path = os.path.join(self.dependencies_dir,
                                           dependency_state.path)
            logger.debug('Downloading dependency %s',
                         dependency_state.dependency_key)
            try:
                # Start async download to the fileobj
                fileobj, target_type = self._bundle_service.get_bundle_contents(
                    dependency_state.dependency_key.parent_uuid,
                    dependency_state.dependency_key.parent_path,
                )
                with closing(fileobj):
                    # "Bug" the fileobj's read function so that we can keep
                    # track of the number of bytes downloaded so far.
                    old_read_method = fileobj.read
                    bytes_downloaded = [0]

                    def interruptable_read(*args, **kwargs):
                        data = old_read_method(*args, **kwargs)
                        bytes_downloaded[0] += len(data)
                        update_state_and_check_killed(bytes_downloaded[0])
                        return data

                    fileobj.read = interruptable_read

                    # Start copying the fileobj to filesystem dependency path
                    self._store_dependency(dependency_path, fileobj,
                                           target_type)

                logger.debug(
                    'Finished downloading %s dependency %s to %s',
                    target_type,
                    dependency_state.dependency_key,
                    dependency_path,
                )
                with self._dependency_locks[dependency_state.dependency_key]:
                    self._downloading[
                        dependency_state.dependency_key]['success'] = True

            except Exception as e:
                with self._dependency_locks[dependency_state.dependency_key]:
                    self._downloading[
                        dependency_state.dependency_key]['success'] = False
                    self._downloading[dependency_state.dependency_key][
                        'failure_message'] = "Dependency download failed: %s " % str(
                            e)

        self._downloading.add_if_new(
            dependency_state.dependency_key,
            threading.Thread(target=download, args=[]))

        if self._downloading[dependency_state.dependency_key].is_alive():
            return dependency_state

        success = self._downloading[dependency_state.dependency_key]['success']
        failure_message = self._downloading[
            dependency_state.dependency_key]['failure_message']

        self._downloading.remove(dependency_state.dependency_key)
        if success:
            return dependency_state._replace(stage=DependencyStage.READY,
                                             message="Download complete")
        else:
            with self._paths_lock:
                self._paths.remove(dependency_state.path)
            return dependency_state._replace(stage=DependencyStage.FAILED,
                                             message=failure_message)
Пример #3
0
class RunStateMachine(StateTransitioner):
    """
    Manages the state machine of the runs running on the local machine

    Note that in general there are two types of errors:
    - User errors (fault of bundle) - we fail the bundle (move to CLEANING_UP state).
    - System errors (fault of worker) - we freeze this worker (Exception is thrown up).
    It's not always clear where the line is.
    """

    _ROOT = '/'
    _CURRENT_DIRECTORY = '.'
    RESTAGED_REASON = 'The bundle is not in terminal states {READY, FAILED, KILLED} when the worker checks termination'

    def __init__(
            self,
            image_manager,  # Component to request docker images from
            dependency_manager,  # Component to request dependency downloads from
            worker_docker_network,  # Docker network to add all bundles to
            docker_network_internal,  # Docker network to add non-net connected bundles to
            docker_network_external,  # Docker network to add internet connected bundles to
            docker_runtime,  # Docker runtime to use for containers (nvidia or runc)
            upload_bundle_callback,  # Function to call to upload bundle results to the server
            assign_cpu_and_gpu_sets_fn,  # Function to call to assign CPU and GPU resources to each run
            shared_file_system,  # If True, bundle mount is shared with server
    ):
        super(RunStateMachine, self).__init__()
        self.add_transition(RunStage.PREPARING,
                            self._transition_from_PREPARING)
        self.add_transition(RunStage.RUNNING, self._transition_from_RUNNING)
        self.add_transition(RunStage.CLEANING_UP,
                            self._transition_from_CLEANING_UP)
        self.add_transition(RunStage.UPLOADING_RESULTS,
                            self._transition_from_UPLOADING_RESULTS)
        self.add_transition(RunStage.FINALIZING,
                            self._transition_from_FINALIZING)
        self.add_terminal(RunStage.FINISHED)
        self.add_terminal(RunStage.RESTAGED)

        self.dependency_manager = dependency_manager
        self.image_manager = image_manager
        self.worker_docker_network = worker_docker_network
        self.docker_network_external = docker_network_external
        self.docker_network_internal = docker_network_internal
        # todo aditya: docker_runtime will be None if the worker is a singularity worker. handle this.
        self.docker_runtime = docker_runtime
        # bundle.uuid -> {'thread': Thread, 'run_status': str}
        self.uploading = ThreadDict(fields={
            'run_status': 'Upload started',
            'success': False
        })
        # bundle.uuid -> {'thread': Thread, 'disk_utilization': int, 'running': bool}
        self.disk_utilization = ThreadDict(fields={
            'disk_utilization': 0,
            'running': True,
            'lock': None
        })
        self.upload_bundle_callback = upload_bundle_callback
        self.assign_cpu_and_gpu_sets_fn = assign_cpu_and_gpu_sets_fn
        self.shared_file_system = shared_file_system

    def stop(self):
        for uuid in self.disk_utilization.keys():
            self.disk_utilization[uuid]['running'] = False
        self.disk_utilization.stop()
        self.uploading.stop()

    def _transition_from_PREPARING(self, run_state):
        """
        1- Request the docker image from docker image manager
            - if image is failed, move to CLEANING_UP state
        2- Request the dependencies from dependency manager
            - if any are failed, move to CLEANING_UP state
        3- If all dependencies and docker image are ready:
            - Set up the local filesystem for the run
            - Create symlinks to dependencies
            - Allocate resources and prepare the docker container
            - Start the docker container
        4- If all is successful, move to RUNNING state
        """
        def mount_dependency(dependency, shared_file_system):
            if not shared_file_system:
                # Set up symlinks for the content at dependency path
                Path(dependency.child_path).parent.mkdir(parents=True,
                                                         exist_ok=True)
                os.symlink(dependency.docker_path, dependency.child_path)
            # The following will be converted into a Docker volume binding like:
            #   dependency_path:docker_dependency_path:ro
            docker_dependencies.append(
                (dependency.parent_path, dependency.docker_path))

        if run_state.is_killed or run_state.is_restaged:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.CLEANING_UP,
                reason=
                f'the bundle was {"killed" if run_state.is_killed else "restaged"}',
            )
            return run_state._replace(stage=RunStage.CLEANING_UP)

        # Check CPU and GPU availability
        try:
            cpuset, gpuset = self.assign_cpu_and_gpu_sets_fn(
                run_state.resources.cpus, run_state.resources.gpus)
        except Exception as e:
            message = "Unexpectedly unable to assign enough resources to bundle {}: {}".format(
                run_state.bundle.uuid, str(e))
            logger.error(message)
            logger.error(traceback.format_exc())
            return run_state._replace(run_status=message)

        dependencies_ready = True
        status_messages = []

        if not self.shared_file_system:
            # No need to download dependencies if we're in the shared FS,
            # since they're already in our FS
            for dep in run_state.bundle.dependencies:
                dep_key = DependencyKey(dep.parent_uuid, dep.parent_path)
                dependency_state = self.dependency_manager.get(
                    run_state.bundle.uuid, dep_key)
                if dependency_state.stage == DependencyStage.DOWNLOADING:
                    status_messages.append(
                        'Downloading dependency %s: %s done (archived size)' %
                        (dep.child_path, size_str(
                            dependency_state.size_bytes)))
                    dependencies_ready = False
                elif dependency_state.stage == DependencyStage.FAILED:
                    # Failed to download dependency; -> CLEANING_UP
                    log_bundle_transition(
                        bundle_uuid=run_state.bundle.uuid,
                        previous_stage=run_state.stage,
                        next_stage=RunStage.CLEANING_UP,
                        reason=
                        f'Dependency has failed for this bundle. Dependency child uuid: {dep.child_uuid}. Dependency child path: {dep.child_path}',
                    )
                    return run_state._replace(
                        stage=RunStage.CLEANING_UP,
                        failure_message='Failed to download dependency %s: %s'
                        % (dep.child_path, dependency_state.message),
                    )

        # get the docker image
        docker_image = run_state.resources.docker_image
        image_state = self.image_manager.get(docker_image)
        if image_state.stage == DependencyStage.DOWNLOADING:
            status_messages.append('Pulling docker image %s %s' %
                                   (docker_image, image_state.message))
            dependencies_ready = False
        elif image_state.stage == DependencyStage.FAILED:
            # Failed to pull image; -> CLEANING_UP
            message = 'Failed to download Docker image: %s' % image_state.message
            logger.error(message)
            return run_state._replace(stage=RunStage.CLEANING_UP,
                                      failure_message=message)

        # stop proceeding if dependency and image downloads aren't all done
        if not dependencies_ready:
            status_message = status_messages.pop()
            if status_messages:
                status_message += "(and downloading %d other dependencies and docker images)" % len(
                    status_messages)
            logger.info(
                f'bundle is not ready yet. uuid: {run_state.bundle.uuid}. status message: {status_message}'
            )
            return run_state._replace(run_status=status_message)

        # All dependencies ready! Set up directories, symlinks and container. Start container.
        # 1) Set up a directory to store the bundle.
        if self.shared_file_system:
            if not os.path.exists(run_state.bundle_path):
                if run_state.bundle_dir_wait_num_tries == 0:
                    message = (
                        "Bundle directory cannot be found on the shared filesystem. "
                        "Please ensure the shared fileystem between the server and "
                        "your worker is mounted properly or contact your administrators."
                    )
                    log_bundle_transition(
                        bundle_uuid=run_state.bundle.uuid,
                        previous_stage=run_state.stage,
                        next_stage=RunStage.CLEANING_UP,
                        reason=
                        "Bundle directory cannot be found on the shared filesystem.",
                    )
                    return run_state._replace(stage=RunStage.CLEANING_UP,
                                              failure_message=message)
                next_bundle_dir_wait_num_tries = run_state.bundle_dir_wait_num_tries - 1
                logger.info(
                    f'Waiting for bundle directory to be created by the server, uuid: {run_state.bundle.uuid}, bundle_dir_wait_num_tries: {next_bundle_dir_wait_num_tries}'
                )
                return run_state._replace(
                    run_status=
                    "Waiting for bundle directory to be created by the server",
                    bundle_dir_wait_num_tries=next_bundle_dir_wait_num_tries,
                )
        else:
            remove_path(run_state.bundle_path)
            os.makedirs(run_state.bundle_path)

        # 2) Set up symlinks
        docker_dependencies = []
        docker_dependencies_path = (
            RunStateMachine._ROOT + run_state.bundle.uuid +
            ('_dependencies' if not self.shared_file_system else ''))

        for dep in run_state.bundle.dependencies:
            full_child_path = os.path.normpath(
                os.path.join(run_state.bundle_path, dep.child_path))
            to_mount = []
            dependency_path = self._get_dependency_path(run_state, dep)

            if dep.child_path == RunStateMachine._CURRENT_DIRECTORY:
                # Mount all the content of the dependency_path to the top-level of the bundle
                for child in os.listdir(dependency_path):
                    child_path = os.path.normpath(
                        os.path.join(run_state.bundle_path, child))
                    to_mount.append(
                        DependencyToMount(
                            docker_path=os.path.join(docker_dependencies_path,
                                                     child),
                            child_path=child_path,
                            parent_path=os.path.join(dependency_path, child),
                        ))
                    run_state = run_state._replace(
                        paths_to_remove=(run_state.paths_to_remove or []) +
                        [child_path])
            else:
                to_mount.append(
                    DependencyToMount(
                        docker_path=os.path.join(docker_dependencies_path,
                                                 dep.child_path),
                        child_path=full_child_path,
                        parent_path=dependency_path,
                    ))

                first_element_of_path = Path(dep.child_path).parts[0]
                if first_element_of_path == RunStateMachine._ROOT:
                    run_state = run_state._replace(
                        paths_to_remove=(run_state.paths_to_remove or []) +
                        [full_child_path])
                else:
                    # child_path can be a nested path, so later remove everything from the first element of the path
                    path_to_remove = os.path.join(run_state.bundle_path,
                                                  first_element_of_path)
                    run_state = run_state._replace(
                        paths_to_remove=(run_state.paths_to_remove or []) +
                        [path_to_remove])
            for dependency in to_mount:
                try:
                    mount_dependency(dependency, self.shared_file_system)
                except OSError as e:
                    log_bundle_transition(
                        bundle_uuid=run_state.bundle.uuid,
                        previous_stage=run_state.stage,
                        next_stage=RunStage.CLEANING_UP,
                        reason=str(e.__class__),
                        level=logging.ERROR,
                    )
                    return run_state._replace(stage=RunStage.CLEANING_UP,
                                              failure_message=str(e))

        if run_state.resources.network:
            docker_network = self.docker_network_external.name
        else:
            docker_network = self.docker_network_internal.name

        # 3) Start container
        try:
            container = docker_utils.start_bundle_container(
                run_state.bundle_path,
                run_state.bundle.uuid,
                docker_dependencies,
                run_state.bundle.command,
                run_state.resources.docker_image,
                network=docker_network,
                cpuset=cpuset,
                gpuset=gpuset,
                memory_bytes=run_state.resources.memory,
                runtime=self.docker_runtime,
            )
            self.worker_docker_network.connect(container)
        except docker_utils.DockerUserErrorException as e:
            message = 'Cannot start Docker container: {}'.format(e)
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.CLEANING_UP,
                reason='Cannot start Docker container.',
                level=logging.ERROR,
            )
            return run_state._replace(stage=RunStage.CLEANING_UP,
                                      failure_message=message)
        except Exception as e:
            message = 'Cannot start container: {}'.format(e)
            logger.error(message)
            logger.error(traceback.format_exc())
            raise

        return run_state._replace(
            stage=RunStage.RUNNING,
            run_status='Running job in container',
            container_id=container.id,
            container=container,
            docker_image=image_state.digest,
            has_contents=True,
            cpuset=cpuset,
            gpuset=gpuset,
        )

    def _get_dependency_path(self, run_state, dependency):
        if self.shared_file_system:
            # TODO(Ashwin): make this not fs-specific.
            # On a shared FS, we know where the dependency is stored and can get the contents directly
            return os.path.realpath(
                os.path.join(dependency.location, dependency.parent_path))
        else:
            # On a dependency_manager setup, ask the manager where the dependency is
            dep_key = DependencyKey(dependency.parent_uuid,
                                    dependency.parent_path)
            return os.path.join(
                self.dependency_manager.dependencies_dir,
                self.dependency_manager.get(run_state.bundle.uuid,
                                            dep_key).path,
            )

    def _transition_from_RUNNING(self, run_state):
        """
        1- Check run status of the docker container
        2- If run is killed, kill the container
        3- If run is finished, move to CLEANING_UP state
        """
        def check_and_report_finished(run_state):
            try:
                finished, exitcode, failure_msg = docker_utils.check_finished(
                    run_state.container)
            except docker_utils.DockerException:
                logger.error(traceback.format_exc())
                finished, exitcode, failure_msg = False, None, None
            return run_state._replace(finished=finished,
                                      exitcode=exitcode,
                                      failure_message=failure_msg)

        def check_resource_utilization(run_state: RunState):
            logger.info(
                f'Checking resource utilization for bundle. uuid: {run_state.bundle.uuid}'
            )
            cpu_usage, memory_usage = docker_utils.get_container_stats_with_docker_stats(
                run_state.container)
            run_state = run_state._replace(cpu_usage=cpu_usage,
                                           memory_usage=memory_usage)
            run_state = run_state._replace(memory_usage=memory_usage)

            kill_messages = []

            run_stats = docker_utils.get_container_stats(run_state.container)

            run_state = run_state._replace(max_memory=max(
                run_state.max_memory, run_stats.get('memory', 0)))
            run_state = run_state._replace(
                disk_utilization=self.disk_utilization[
                    run_state.bundle.uuid]['disk_utilization'])

            container_time_total = docker_utils.get_container_running_time(
                run_state.container)
            run_state = run_state._replace(
                container_time_total=container_time_total,
                container_time_user=run_stats.get(
                    'container_time_user', run_state.container_time_user),
                container_time_system=run_stats.get(
                    'container_time_system', run_state.container_time_system),
            )

            if run_state.resources.time and container_time_total > run_state.resources.time:
                kill_messages.append(
                    'Time limit exceeded. (Container uptime %s > time limit %s)'
                    % (duration_str(container_time_total),
                       duration_str(run_state.resources.time)))

            if run_state.max_memory > run_state.resources.memory or run_state.exitcode == 137:
                kill_messages.append('Memory limit %s exceeded.' %
                                     size_str(run_state.resources.memory))

            if run_state.resources.disk and run_state.disk_utilization > run_state.resources.disk:
                kill_messages.append('Disk limit %sb exceeded.' %
                                     size_str(run_state.resources.disk))

            if kill_messages:
                run_state = run_state._replace(
                    kill_message=' '.join(kill_messages), is_killed=True)
            return run_state

        def check_disk_utilization():
            logger.info(
                f'Checking disk utilization for bundle. uuid: {run_state.bundle.uuid}'
            )
            running = True
            while running:
                start_time = time.time()
                try:
                    disk_utilization = get_path_size(run_state.bundle_path)
                    self.disk_utilization[run_state.bundle.uuid][
                        'disk_utilization'] = disk_utilization
                    running = self.disk_utilization[
                        run_state.bundle.uuid]['running']
                except Exception:
                    logger.error(traceback.format_exc())
                end_time = time.time()

                # To ensure that we don't hammer the disk for this computation when
                # there are lots of files, we run it at most 10% of the time.
                time.sleep(max((end_time - start_time) * 10, 1.0))

        self.disk_utilization.add_if_new(
            run_state.bundle.uuid,
            threading.Thread(target=check_disk_utilization, args=[]))
        run_state = check_and_report_finished(run_state)
        run_state = check_resource_utilization(run_state)

        if run_state.is_killed or run_state.is_restaged:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.CLEANING_UP,
                reason=
                f'the bundle was {"killed" if run_state.is_killed else "restaged"}',
            )
            if docker_utils.container_exists(run_state.container):
                try:
                    run_state.container.kill()
                except docker.errors.APIError:
                    finished, _, _ = docker_utils.check_finished(
                        run_state.container)
                    if not finished:
                        logger.error(traceback.format_exc())
            self.disk_utilization[run_state.bundle.uuid]['running'] = False
            self.disk_utilization.remove(run_state.bundle.uuid)
            return run_state._replace(stage=RunStage.CLEANING_UP)
        if run_state.finished:
            logger.debug(
                'Finished run with UUID %s, exitcode %s, failure_message %s',
                run_state.bundle.uuid,
                run_state.exitcode,
                run_state.failure_message,
            )
            self.disk_utilization[run_state.bundle.uuid]['running'] = False
            self.disk_utilization.remove(run_state.bundle.uuid)
            return run_state._replace(stage=RunStage.CLEANING_UP,
                                      run_status='Uploading results')
        else:
            return run_state

    def _transition_from_CLEANING_UP(self, run_state):
        """
        1- delete the container if still existent
        2- clean up the dependencies from bundle directory
        3- release the dependencies in dependency manager
        4- If bundle has contents to upload (i.e. was RUNNING at some point),
            move to UPLOADING_RESULTS state
           Otherwise move to FINALIZING state
        """
        def remove_path_no_fail(path):
            try:
                remove_path(path)
            except Exception:
                logger.error(traceback.format_exc())

        if run_state.container_id is not None:
            while docker_utils.container_exists(run_state.container):
                try:
                    finished, _, _ = docker_utils.check_finished(
                        run_state.container)
                    if finished:
                        run_state.container.remove(force=True)
                        run_state = run_state._replace(container=None,
                                                       container_id=None)
                        break
                    else:
                        try:
                            run_state.container.kill()
                        except docker.errors.APIError:
                            logger.error(traceback.format_exc())
                            time.sleep(1)
                except docker.errors.APIError:
                    logger.error(traceback.format_exc())
                    time.sleep(1)

        for dep in run_state.bundle.dependencies:
            if not self.shared_file_system:  # No dependencies if shared fs worker
                dep_key = DependencyKey(dep.parent_uuid, dep.parent_path)
                self.dependency_manager.release(run_state.bundle.uuid, dep_key)

        # Clean up dependencies paths
        for path in run_state.paths_to_remove or []:
            remove_path_no_fail(path)
        run_state = run_state._replace(paths_to_remove=[])

        if run_state.is_restaged:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.RESTAGED,
                reason=self.RESTAGED_REASON,
            )
            return run_state._replace(stage=RunStage.RESTAGED)

        if not self.shared_file_system and run_state.has_contents:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.UPLOADING_RESULTS,
            )
            return run_state._replace(stage=RunStage.UPLOADING_RESULTS,
                                      run_status='Uploading results',
                                      container=None)
        else:
            # No need to upload results since results are directly written to bundle store
            # Delete any files that match the exclude_patterns .
            for exclude_pattern in run_state.bundle.metadata[
                    "exclude_patterns"]:
                full_pattern = os.path.join(run_state.bundle_path,
                                            exclude_pattern)
                for file_path in glob.glob(full_pattern, recursive=True):
                    # Only remove files that are subpaths of run_state.bundle_path, in case
                    # that exclude_pattern is something like "../../../".
                    if path_is_parent(parent_path=run_state.bundle_path,
                                      child_path=file_path):
                        remove_path(file_path)
            return self.finalize_run(run_state)

    def _transition_from_UPLOADING_RESULTS(self, run_state):
        """
        If bundle not already uploading:
            Use the RunManager API to upload contents at bundle_path to the server
            Pass the callback to that API such that if the bundle is killed during the upload,
            the callback returns false, allowing killable uploads.
        If uploading and not finished:
            Update run_status with upload progress
        If uploading and finished:
            Move to FINALIZING state
        """
        if run_state.is_restaged:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.RESTAGED,
                reason=self.RESTAGED_REASON,
            )
            return run_state._replace(stage=RunStage.RESTAGED)

        def upload_results():
            try:
                # Upload results
                logger.debug('Uploading results for run with UUID %s',
                             run_state.bundle.uuid)

                def progress_callback(bytes_uploaded):
                    run_status = 'Uploading results: %s done (archived size)' % size_str(
                        bytes_uploaded)
                    self.uploading[
                        run_state.bundle.uuid]['run_status'] = run_status
                    return True

                self.upload_bundle_callback(
                    run_state.bundle.uuid,
                    run_state.bundle_path,
                    run_state.bundle.metadata["exclude_patterns"],
                    progress_callback,
                )
                self.uploading[run_state.bundle.uuid]['success'] = True
            except Exception as e:
                self.uploading[run_state.bundle.uuid]['run_status'] = (
                    "Error while uploading: %s" % e)
                logger.error(traceback.format_exc())

        self.uploading.add_if_new(
            run_state.bundle.uuid,
            threading.Thread(target=upload_results, args=[]))

        if self.uploading[run_state.bundle.uuid].is_alive():
            return run_state._replace(
                run_status=self.uploading[run_state.bundle.uuid]['run_status'])
        elif not self.uploading[run_state.bundle.uuid]['success']:
            # upload failed
            failure_message = run_state.failure_message
            if failure_message:
                failure_message = (
                    f'{failure_message}. {self.uploading[run_state.bundle.uuid]["run_status"]}'
                )
            else:
                failure_message = self.uploading[
                    run_state.bundle.uuid]['run_status']
            logger.info(
                f'Upload failed. uuid: {run_state.bundle.uuid}. failure message: {failure_message}'
            )
            run_state = run_state._replace(failure_message=failure_message)

        self.uploading.remove(run_state.bundle.uuid)
        return self.finalize_run(run_state)

    def finalize_run(self, run_state):
        """
        Prepare the finalize message to be sent with the next checkin
        """
        if run_state.is_killed:
            # Append kill_message, which contains more useful info on why a run was killed, to the failure message.
            failure_message = ("{}. {}".format(run_state.failure_message,
                                               run_state.kill_message)
                               if run_state.failure_message else
                               run_state.kill_message)
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.FINALIZING,
                reason=
                f'Bundle is killed. uuid: {run_state.bundle.uuid}. failure message: {failure_message}',
            )
            run_state = run_state._replace(failure_message=failure_message)
        else:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.FINALIZING,
            )
        return run_state._replace(stage=RunStage.FINALIZING,
                                  run_status="Finalizing bundle")

    def _transition_from_FINALIZING(self, run_state):
        """
        If a full worker cycle has passed since we got into the FINALIZING state we already reported to
        server, if bundle is going be sent back to the server, move on to the RESTAGED state. Otherwise,
        move on to the FINISHED state. Can also remove bundle_path now.
        """
        if run_state.is_restaged:
            log_bundle_transition(
                bundle_uuid=run_state.bundle.uuid,
                previous_stage=run_state.stage,
                next_stage=RunStage.RESTAGED,
                reason=
                'the bundle is restaged, as `pass-down-termination` is specified for worker',
            )
            return run_state._replace(stage=RunStage.RESTAGED)
        elif run_state.finalized:
            if not self.shared_file_system:
                remove_path(
                    run_state.bundle_path)  # don't remove bundle if shared FS
            return run_state._replace(stage=RunStage.FINISHED,
                                      run_status='Finished')
        else:
            return run_state
class DockerImageManager:

    CACHE_TAG = 'codalab-image-cache/last-used'

    def __init__(self, commit_file, max_image_cache_size, max_image_size):
        """
        Initializes a DockerImageManager
        :param commit_file: String path to where the state file should be committed
        :param max_image_cache_size: Total size in bytes that the image cache can use
        :param max_image_size: Total size in bytes that the image can have
        """
        self._state_committer = JsonStateCommitter(
            commit_file)  # type: JsonStateCommitter
        self._docker = docker.from_env()  # type: DockerClient
        self._downloading = ThreadDict(fields={
            'success': False,
            'status': 'Download starting.'
        },
                                       lock=True)
        self._max_image_cache_size = max_image_cache_size
        self._max_image_size = max_image_size

        self._stop = False
        self._sleep_secs = 10
        self._cleanup_thread = None

    def start(self):
        logger.info("Starting docker image manager")

        if self._max_image_cache_size:

            def cleanup_loop(self):
                while not self._stop:
                    try:
                        self._cleanup()
                    except Exception:
                        traceback.print_exc()
                    time.sleep(self._sleep_secs)

            self._cleanup_thread = threading.Thread(target=cleanup_loop,
                                                    args=[self])
            self._cleanup_thread.start()

    def stop(self):
        logger.info("Stopping docker image manager")
        self._stop = True
        logger.debug(
            "Stopping docker image manager: stop the downloads threads")
        self._downloading.stop()
        if self._cleanup_thread:
            logger.debug(
                "Stopping docker image manager: stop the cleanup thread")
            self._cleanup_thread.join()
        logger.info("Stopped docker image manager")

    def _get_cache_use(self):
        return sum(
            float(image.attrs['VirtualSize'])
            for image in self._docker.images.list(self.CACHE_TAG))

    def _cleanup(self):
        """
        Prunes the image cache for runs.
        1. Only care about images we (this DockerImageManager) downloaded and know about.
        2. We also try to prune any dangling docker images on the system.
        3. We use sum of VirtualSize's, which is an upper bound on the disk use of our images:
            in case no images share any intermediate layers, this will be the real disk use,
            however if images share layers, the virtual size will count that layer's size for each
            image that uses it, even though it's stored only once in the disk. The 'Size' field
            accounts for the marginal size each image adds on top of the shared layers, but summing
            those is not accurate either since the shared base layers need to be counted once to get
            the total size. (i.e. summing marginal sizes would give us a lower bound on the total disk
            use of images). Calling df gives us an accurate disk use of ALL the images on the machine
            but because of (1) we don't want to use that.
        """

        # Sort the image cache in LRU order
        def last_used(image):
            for tag in image.tags:
                if tag.split(":")[0] == self.CACHE_TAG:
                    return float(tag.split(":")[1])

        cache_use = self._get_cache_use()
        if cache_use > self._max_image_cache_size:
            logger.info(
                'Disk use (%s) > max cache size (%s): starting image pruning',
                cache_use,
                self._max_image_cache_size,
            )
            all_images = self._docker.images.list(self.CACHE_TAG)
            all_images_sorted = sorted(all_images, key=last_used)
            logger.info("Cached docker images: {}".format(all_images_sorted))
            for image in all_images_sorted:
                # We re-list all the images to get an updated total size since we may have deleted some
                cache_use = self._get_cache_use()
                if cache_use > self._max_image_cache_size:
                    image_tag = (image.attrs['RepoTags'][-1]
                                 if len(image.attrs['RepoTags']) > 0 else
                                 '<none>')
                    logger.info(
                        'Disk use (%s) > max cache size (%s), pruning image: %s',
                        cache_use,
                        self._max_image_cache_size,
                        image_tag,
                    )
                    try:
                        self._docker.images.remove(image.id, force=True)
                    except docker.errors.APIError as err:
                        # Two types of 409 Client Error can be thrown here:
                        # 1. 409 Client Error: Conflict ("conflict: unable to delete <image_id> (cannot be forced)")
                        #   This happens when an image either has a running container or has multiple child dependents.
                        # 2. 409 Client Error: Conflict ("conflict: unable to delete <image_id> (must be forced)")
                        #   This happens when an image is referenced in multiple repositories.
                        # We can only remove images in 2rd case using force=True, but not the 1st case. So after we
                        # try to remove the image using force=True, if it failed, then this indicates that we were
                        # trying to remove images in 1st case. Since we can't do much for images in 1st case, we
                        # just continue with our lives, hoping it will get deleted once it's no longer in use and
                        # the cache becomes full again
                        logger.warning(
                            "Cannot forcibly remove image %s from cache: %s",
                            image_tag, err)
            logger.debug("Stopping docker image manager cleanup")

    def get(self, image_spec):
        """
        Always request the newest docker image from Dockerhub if it's not in downloading thread and return the current
        downloading status(READY, FAILED, or DOWNLOADING).
        When the requested image in the following states:
        1. If it's not available on the platform, we download the image and return DOWNLOADING status.
        2. If another thread is actively downloading it, we return DOWNLOADING status.
        3. If another thread was downloading it but not active by the time the request was sent, we return the following status:
            * READY if the image was downloaded successfully.
            * FAILED if the image wasn't able to be downloaded due to any reason.
        :param image_spec: Repo image_spec of docker image being requested
        :returns: A DockerAvailabilityState object with the state of the docker image
        """
        def image_availability_state(image_spec, success_message,
                                     failure_message):
            """
            Try to get the image specified by image_spec from host machine.
            Return ImageAvailabilityState.
            """
            try:
                image = self._docker.images.get(image_spec)
                digests = image.attrs.get('RepoDigests', [image_spec])
                digest = digests[0] if len(digests) > 0 else None
                new_timestamp = str(time.time())
                image.tag(self.CACHE_TAG, tag=new_timestamp)
                for tag in image.tags:
                    tag_label, timestamp = tag.split(":")
                    # remove any other timestamp but not the current one
                    if tag_label == self.CACHE_TAG and timestamp != new_timestamp:
                        try:
                            self._docker.images.remove(tag)
                        except docker.errors.NotFound as err:
                            # It's possible that we get a 404 not found error here when removing the image,
                            # since another worker on the same system has already done so. We just
                            # ignore this 404, since any extraneous tags will be removed during the next iteration.
                            logger.warning(
                                "Attempted to remove image %s from cache, but image was not found: %s",
                                tag,
                                err,
                            )

                return ImageAvailabilityState(digest=digest,
                                              stage=DependencyStage.READY,
                                              message=success_message)
            except Exception as ex:
                if using_sentry():
                    capture_exception()
                return ImageAvailabilityState(digest=None,
                                              stage=DependencyStage.FAILED,
                                              message=failure_message % ex)

        if ':' not in image_spec:
            # Both digests and repo:tag kind of specs include the : character. The only case without it is when
            # a repo is specified without a tag (like 'latest')
            # When this is the case, different images API methods act differently:
            # - pull pulls all tags of the image
            # - get tries to get `latest` by default
            # That means if someone requests a docker image without a tag, and the image does not have a latest
            # tag pushed to Dockerhub, pull will succeed since it will pull all other tags, but later get calls
            # will fail since the `latest` tag won't be found on the system.
            # We don't want to assume what tag the user wanted so we want the pull step to fail if no tag is specified
            # and there's no latest tag on dockerhub.
            # Hence, we append the latest tag to the image spec if there's no tag specified otherwise at the very beginning
            image_spec += ':latest'
        try:
            if image_spec in self._downloading:
                with self._downloading[image_spec]['lock']:
                    if self._downloading[image_spec].is_alive():
                        return ImageAvailabilityState(
                            digest=None,
                            stage=DependencyStage.DOWNLOADING,
                            message=self._downloading[image_spec]['status'],
                        )
                    else:
                        if self._downloading[image_spec]['success']:
                            status = image_availability_state(
                                image_spec,
                                success_message='Image ready',
                                failure_message=
                                'Image {} was downloaded successfully, '
                                'but it cannot be found locally due to unhandled error %s'
                                .format(image_spec),
                            )
                        else:
                            status = image_availability_state(
                                image_spec,
                                success_message=
                                'Image {} can not be downloaded from DockerHub '
                                'but it is found locally'.format(image_spec),
                                failure_message=self._downloading[image_spec]
                                ['message'] + ": %s",
                            )
                        self._downloading.remove(image_spec)
                        return status
            else:

                def download():
                    logger.debug('Downloading Docker image %s', image_spec)
                    try:
                        self._docker.images.pull(image_spec)
                        logger.debug('Download for Docker image %s complete',
                                     image_spec)
                        self._downloading[image_spec]['success'] = True
                        self._downloading[image_spec][
                            'message'] = "Downloading image"
                    except (docker.errors.APIError,
                            docker.errors.ImageNotFound) as ex:
                        logger.debug('Download for Docker image %s failed: %s',
                                     image_spec, ex)
                        self._downloading[image_spec]['success'] = False
                        self._downloading[image_spec][
                            'message'] = "Can't download image: {}".format(ex)

                # Check docker image size before pulling from Docker Hub.
                # Do not download images larger than self._max_image_size
                # Download images if size cannot be obtained
                if self._max_image_size:
                    try:
                        image_size_bytes = docker_utils.get_image_size_without_pulling(
                            image_spec)
                        if image_size_bytes is None:
                            failure_msg = (
                                "Unable to find Docker image: {} from Docker HTTP Rest API V2. "
                                "Skipping Docker image size precheck.".format(
                                    image_spec))
                            logger.info(failure_msg)
                        elif image_size_bytes > self._max_image_size:
                            failure_msg = (
                                "The size of " + image_spec +
                                ": {} exceeds the maximum image size allowed {}."
                                .format(size_str(image_size_bytes),
                                        size_str(self._max_image_size)))
                            return ImageAvailabilityState(
                                digest=None,
                                stage=DependencyStage.FAILED,
                                message=failure_msg)
                    except Exception as ex:
                        failure_msg = "Cannot fetch image size before pulling Docker image: {} from Docker Hub: {}.".format(
                            image_spec, ex)
                        logger.error(failure_msg)
                        return ImageAvailabilityState(
                            digest=None,
                            stage=DependencyStage.FAILED,
                            message=failure_msg)

                self._downloading.add_if_new(
                    image_spec, threading.Thread(target=download, args=[]))
                return ImageAvailabilityState(
                    digest=None,
                    stage=DependencyStage.DOWNLOADING,
                    message=self._downloading[image_spec]['status'],
                )
        except Exception as ex:
            return ImageAvailabilityState(digest=None,
                                          stage=DependencyStage.FAILED,
                                          message=str(ex))
class RunStateMachine(StateTransitioner):
    """
    Manages the state machine of the runs running on the local machine

    Note that in general there are two types of errors:
    - User errors (fault of bundle) - we fail the bundle (move to CLEANING_UP state).
    - System errors (fault of worker) - we freeze this worker (Exception is thrown up).
    It's not always clear where the line is.
    """
    def __init__(
            self,
            docker_image_manager,  # Component to request docker images from
            dependency_manager,  # Component to request dependency downloads from
            worker_docker_network,  # Docker network to add all bundles to
            docker_network_internal,  # Docker network to add non-net connected bundles to
            docker_network_external,  # Docker network to add internet connected bundles to
            docker_runtime,  # Docker runtime to use for containers (nvidia or runc)
            upload_bundle_callback,  # Function to call to upload bundle results to the server
            assign_cpu_and_gpu_sets_fn,  # Function to call to assign CPU and GPU resources to each run
            shared_file_system,  # If True, bundle mount is shared with server
    ):
        super(RunStateMachine, self).__init__()
        self.add_transition(RunStage.PREPARING,
                            self._transition_from_PREPARING)
        self.add_transition(RunStage.RUNNING, self._transition_from_RUNNING)
        self.add_transition(RunStage.CLEANING_UP,
                            self._transition_from_CLEANING_UP)
        self.add_transition(RunStage.UPLOADING_RESULTS,
                            self._transition_from_UPLOADING_RESULTS)
        self.add_transition(RunStage.FINALIZING,
                            self._transition_from_FINALIZING)
        self.add_terminal(RunStage.FINISHED)

        self.dependency_manager = dependency_manager
        self.docker_image_manager = docker_image_manager
        self.worker_docker_network = worker_docker_network
        self.docker_network_external = docker_network_external
        self.docker_network_internal = docker_network_internal
        self.docker_runtime = docker_runtime
        # bundle.uuid -> {'thread': Thread, 'run_status': str}
        self.uploading = ThreadDict(fields={
            'run_status': 'Upload started',
            'success': False
        })
        # bundle.uuid -> {'thread': Thread, 'disk_utilization': int, 'running': bool}
        self.disk_utilization = ThreadDict(fields={
            'disk_utilization': 0,
            'running': True,
            'lock': None
        })
        self.upload_bundle_callback = upload_bundle_callback
        self.assign_cpu_and_gpu_sets_fn = assign_cpu_and_gpu_sets_fn
        self.shared_file_system = shared_file_system

    def stop(self):
        for uuid in self.disk_utilization.keys():
            self.disk_utilization[uuid]['running'] = False
        self.disk_utilization.stop()
        self.uploading.stop()

    def _transition_from_PREPARING(self, run_state):
        """
        1- Request the docker image from docker image manager
            - if image is failed, move to CLEANING_UP state
        2- Request the dependencies from dependency manager
            - if any are failed, move to CLEANING_UP state
        3- If all dependencies and docker image are ready:
            - Set up the local filesystem for the run
            - Create symlinks to dependencies
            - Allocate resources and prepare the docker container
            - Start the docker container
        4- If all is successful, move to RUNNING state
        """
        if run_state.is_killed:
            return run_state._replace(stage=RunStage.CLEANING_UP)

        dependencies_ready = True
        status_messages = []

        if not self.shared_file_system:
            # No need to download dependencies if we're in the shared FS since they're already in our FS
            for dep_key, dep in run_state.bundle.dependencies.items():
                dependency_state = self.dependency_manager.get(
                    run_state.bundle.uuid, dep_key)
                if dependency_state.stage == DependencyStage.DOWNLOADING:
                    status_messages.append(
                        'Downloading dependency %s: %s done (archived size)' %
                        (dep.child_path, size_str(
                            dependency_state.size_bytes)))
                    dependencies_ready = False
                elif dependency_state.stage == DependencyStage.FAILED:
                    # Failed to download dependency; -> CLEANING_UP
                    return run_state._replace(
                        stage=RunStage.CLEANING_UP,
                        failure_message='Failed to download dependency %s: %s'
                        % (dep.child_path, dependency_state.message),
                    )

        # get the docker image
        docker_image = run_state.resources.docker_image
        image_state = self.docker_image_manager.get(docker_image)
        if image_state.stage == DependencyStage.DOWNLOADING:
            status_messages.append('Pulling docker image: ' +
                                   (image_state.message or docker_image or ""))
            dependencies_ready = False
        elif image_state.stage == DependencyStage.FAILED:
            # Failed to pull image; -> CLEANING_UP
            message = 'Failed to download Docker image: %s' % image_state.message
            logger.error(message)
            return run_state._replace(stage=RunStage.CLEANING_UP,
                                      failure_message=message)

        # stop proceeding if dependency and image downloads aren't all done
        if not dependencies_ready:
            status_message = status_messages.pop()
            if status_messages:
                status_message += "(and downloading %d other dependencies and docker images)" % len(
                    status_messages)
            return run_state._replace(run_status=status_message)

        # All dependencies ready! Set up directories, symlinks and container. Start container.
        # 1) Set up a directory to store the bundle.
        if self.shared_file_system:
            if not os.path.exists(run_state.bundle_path):
                if run_state.bundle_dir_wait_num_tries == 0:
                    message = (
                        "Bundle directory cannot be found on the shared filesystem. "
                        "Please ensure the shared fileystem between the server and "
                        "your worker is mounted properly or contact your administrators."
                    )
                    logger.error(message)
                    return run_state._replace(stage=RunStage.CLEANING_UP,
                                              failure_message=message)
                return run_state._replace(
                    run_status=
                    "Waiting for bundle directory to be created by the server",
                    bundle_dir_wait_num_tries=run_state.
                    bundle_dir_wait_num_tries - 1,
                )
        else:
            remove_path(run_state.bundle_path)
            os.mkdir(run_state.bundle_path)

        # 2) Set up symlinks
        docker_dependencies = []
        docker_dependencies_path = (
            '/' + run_state.bundle.uuid +
            ('_dependencies' if not self.shared_file_system else ''))
        for dep_key, dep in run_state.bundle.dependencies.items():
            full_child_path = os.path.normpath(
                os.path.join(run_state.bundle_path, dep.child_path))
            if not full_child_path.startswith(run_state.bundle_path):
                # Dependencies should end up in their bundles (ie prevent using relative paths like ..
                # to get out of their parent bundles
                message = 'Invalid key for dependency: %s' % (dep.child_path)
                logger.error(message)
                return run_state._replace(stage=RunStage.CLEANING_UP,
                                          failure_message=message)
            docker_dependency_path = os.path.join(docker_dependencies_path,
                                                  dep.child_path)
            if self.shared_file_system:
                # On a shared FS, we know where the dep is stored and can get the contents directly
                dependency_path = os.path.realpath(
                    os.path.join(dep.location, dep.parent_path))
            else:
                # On a dependency_manager setup ask the manager where the dependency is
                dependency_path = os.path.join(
                    self.dependency_manager.dependencies_dir,
                    self.dependency_manager.get(run_state.bundle.uuid,
                                                dep_key).path,
                )
                os.symlink(docker_dependency_path, full_child_path)
            # These are turned into docker volume bindings like:
            #   dependency_path:docker_dependency_path:ro
            docker_dependencies.append(
                (dependency_path, docker_dependency_path))

        # 3) Set up container
        if run_state.resources.network:
            docker_network = self.docker_network_external.name
        else:
            docker_network = self.docker_network_internal.name

        try:
            cpuset, gpuset = self.assign_cpu_and_gpu_sets_fn(
                run_state.resources.cpus, run_state.resources.gpus)
        except Exception as e:
            message = "Cannot assign enough resources: %s" % str(e)
            logger.error(message)
            logger.error(traceback.format_exc())
            return run_state._replace(run_status=message)

        # 4) Start container
        try:
            container = docker_utils.start_bundle_container(
                run_state.bundle_path,
                run_state.bundle.uuid,
                docker_dependencies,
                run_state.bundle.command,
                run_state.resources.docker_image,
                network=docker_network,
                cpuset=cpuset,
                gpuset=gpuset,
                memory_bytes=run_state.resources.memory,
                runtime=self.docker_runtime,
            )
            self.worker_docker_network.connect(container)
        except Exception as e:
            message = 'Cannot start Docker container: {}'.format(e)
            logger.error(message)
            logger.error(traceback.format_exc())
            raise

        return run_state._replace(
            stage=RunStage.RUNNING,
            run_status='Running job in Docker container',
            container_id=container.id,
            container=container,
            docker_image=image_state.digest,
            has_contents=True,
            cpuset=cpuset,
            gpuset=gpuset,
        )

    def _transition_from_RUNNING(self, run_state):
        """
        1- Check run status of the docker container
        2- If run is killed, kill the container
        3- If run is finished, move to CLEANING_UP state
        """
        def check_and_report_finished(run_state):
            try:
                finished, exitcode, failure_msg = docker_utils.check_finished(
                    run_state.container)
            except docker_utils.DockerException:
                logger.error(traceback.format_exc())
                finished, exitcode, failure_msg = False, None, None
            return run_state._replace(finished=finished,
                                      exitcode=exitcode,
                                      failure_message=failure_msg)

        def check_resource_utilization(run_state):
            kill_messages = []

            run_stats = docker_utils.get_container_stats(run_state.container)

            run_state = run_state._replace(max_memory=max(
                run_state.max_memory, run_stats.get('memory', 0)))
            run_state = run_state._replace(
                disk_utilization=self.disk_utilization[
                    run_state.bundle.uuid]['disk_utilization'])

            container_time_total = docker_utils.get_container_running_time(
                run_state.container)
            run_state = run_state._replace(
                container_time_total=container_time_total,
                container_time_user=run_stats.get(
                    'container_time_user', run_state.container_time_user),
                container_time_system=run_stats.get(
                    'container_time_system', run_state.container_time_system),
            )

            if run_state.resources.time and container_time_total > run_state.resources.time:
                kill_messages.append(
                    'Time limit exceeded. (Container uptime %s > time limit %s)'
                    % (duration_str(container_time_total),
                       duration_str(run_state.resources.time)))

            if run_state.max_memory > run_state.resources.memory or run_state.exitcode == 137:
                kill_messages.append('Memory limit %s exceeded.' %
                                     size_str(run_state.resources.memory))

            if run_state.resources.disk and run_state.disk_utilization > run_state.resources.disk:
                kill_messages.append('Disk limit %sb exceeded.' %
                                     size_str(run_state.resources.disk))

            if kill_messages:
                run_state = run_state._replace(
                    kill_message=' '.join(kill_messages), is_killed=True)
            return run_state

        def check_disk_utilization():
            running = True
            while running:
                start_time = time.time()
                try:
                    disk_utilization = get_path_size(run_state.bundle_path)
                    self.disk_utilization[run_state.bundle.uuid][
                        'disk_utilization'] = disk_utilization
                    running = self.disk_utilization[
                        run_state.bundle.uuid]['running']
                except Exception:
                    logger.error(traceback.format_exc())
                end_time = time.time()

                # To ensure that we don't hammer the disk for this computation when
                # there are lots of files, we run it at most 10% of the time.
                time.sleep(max((end_time - start_time) * 10, 1.0))

        self.disk_utilization.add_if_new(
            run_state.bundle.uuid,
            threading.Thread(target=check_disk_utilization, args=[]))
        run_state = check_and_report_finished(run_state)
        run_state = check_resource_utilization(run_state)

        if run_state.is_killed:
            if docker_utils.container_exists(run_state.container):
                try:
                    run_state.container.kill()
                except docker.errors.APIError:
                    finished, _, _ = docker_utils.check_finished(
                        run_state.container)
                    if not finished:
                        logger.error(traceback.format_exc())
            self.disk_utilization[run_state.bundle.uuid]['running'] = False
            self.disk_utilization.remove(run_state.bundle.uuid)
            return run_state._replace(stage=RunStage.CLEANING_UP)
        if run_state.finished:
            logger.debug(
                'Finished run with UUID %s, exitcode %s, failure_message %s',
                run_state.bundle.uuid,
                run_state.exitcode,
                run_state.failure_message,
            )
            self.disk_utilization[run_state.bundle.uuid]['running'] = False
            self.disk_utilization.remove(run_state.bundle.uuid)
            return run_state._replace(stage=RunStage.CLEANING_UP,
                                      run_status='Uploading results')
        else:
            return run_state

    def _transition_from_CLEANING_UP(self, run_state):
        """
        1- delete the container if still existent
        2- clean up the dependencies from bundle directory
        3- release the dependencies in dependency manager
        4- If bundle has contents to upload (i.e. was RUNNING at some point),
            move to UPLOADING_RESULTS state
           Otherwise move to FINALIZING state
        """
        if run_state.container_id is not None:
            while docker_utils.container_exists(run_state.container):
                try:
                    finished, _, _ = docker_utils.check_finished(
                        run_state.container)
                    if finished:
                        run_state.container.remove(force=True)
                        run_state = run_state._replace(container=None,
                                                       container_id=None)
                        break
                    else:
                        try:
                            run_state.container.kill()
                        except docker.errors.APIError:
                            logger.error(traceback.format_exc())
                            time.sleep(1)
                except docker.errors.APIError:
                    logger.error(traceback.format_exc())
                    time.sleep(1)

        for dep_key, dep in run_state.bundle.dependencies.items():
            if not self.shared_file_system:  # No dependencies if shared fs worker
                self.dependency_manager.release(run_state.bundle.uuid, dep_key)

            child_path = os.path.join(run_state.bundle_path, dep.child_path)
            try:
                remove_path(child_path)
            except Exception:
                logger.error(traceback.format_exc())

        if not self.shared_file_system and run_state.has_contents:
            # No need to upload results since results are directly written to bundle store
            return run_state._replace(stage=RunStage.UPLOADING_RESULTS,
                                      run_status='Uploading results',
                                      container=None)
        else:
            return self.finalize_run(run_state)

    def _transition_from_UPLOADING_RESULTS(self, run_state):
        """
        If bundle not already uploading:
            Use the RunManager API to upload contents at bundle_path to the server
            Pass the callback to that API such that if the bundle is killed during the upload,
            the callback returns false, allowing killable uploads.
        If uploading and not finished:
            Update run_status with upload progress
        If uploading and finished:
            Move to FINALIZING state
        """
        def upload_results():
            try:
                # Upload results
                logger.debug('Uploading results for run with UUID %s',
                             run_state.bundle.uuid)

                def progress_callback(bytes_uploaded):
                    run_status = 'Uploading results: %s done (archived size)' % size_str(
                        bytes_uploaded)
                    self.uploading[
                        run_state.bundle.uuid]['run_status'] = run_status
                    return True

                self.upload_bundle_callback(run_state.bundle.uuid,
                                            run_state.bundle_path,
                                            progress_callback)
                self.uploading[run_state.bundle.uuid]['success'] = True
            except Exception as e:
                self.uploading[run_state.bundle.uuid]['run_status'] = (
                    "Error while uploading: %s" % e)
                logger.error(traceback.format_exc())

        self.uploading.add_if_new(
            run_state.bundle.uuid,
            threading.Thread(target=upload_results, args=[]))

        if self.uploading[run_state.bundle.uuid].is_alive():
            return run_state._replace(
                run_status=self.uploading[run_state.bundle.uuid]['run_status'])
        elif not self.uploading[run_state.bundle.uuid]['success']:
            # upload failed
            failure_message = run_state.failure_message
            if failure_message:
                run_state = run_state._replace(failure_message=(
                    failure_message + '. ' +
                    self.uploading[run_state.bundle.uuid]['run_status']))
            else:
                run_state = run_state._replace(failure_message=self.uploading[
                    run_state.bundle.uuid]['run_status'])

        self.uploading.remove(run_state.bundle.uuid)
        return self.finalize_run(run_state)

    def finalize_run(self, run_state):
        """
        Prepare the finalize message to be sent with the next checkin
        """
        if run_state.is_killed:
            # Append kill_message, which contains more useful info on why a run was killed, to the failure message.
            failure_message = ("{}. {}".format(run_state.failure_message,
                                               run_state.kill_message)
                               if run_state.failure_message else
                               run_state.kill_message)
            run_state = run_state._replace(failure_message=failure_message)
        return run_state._replace(stage=RunStage.FINALIZING,
                                  run_status="Finalizing bundle")

    def _transition_from_FINALIZING(self, run_state):
        """
        If a full worker cycle has passed since we got into FINALIZING we already reported to
        server so can move on to FINISHED. Can also remove bundle_path now
        """
        if run_state.finalized:
            if not self.shared_file_system:
                remove_path(
                    run_state.bundle_path)  # don't remove bundle if shared FS
            return run_state._replace(stage=RunStage.FINISHED,
                                      run_status='Finished')
        else:
            return run_state
Пример #6
0
class ImageManager:
    """
    ImageManager is the base interface for all image managers (docker, singularity, etc.)
    This is an abstract class.
    An Image Manager manages instances of images, not dependent on the container runtime.
    It does this in the start and stop cleanup loops.
    Subclasses need to implement the get and _cleanup methods.
    """
    def __init__(self, max_image_size: int, max_image_cache_size: int):
        """
        Args:
            max_image_size: maximum image size in bytes of any given image
            max_image_cache_size: max number of bytes the image cache will hold at any given time.
        """
        self._max_image_size = max_image_size
        self._max_image_cache_size = max_image_cache_size
        self._stop = False
        self._sleep_secs = 10
        self._downloading = ThreadDict(fields={
            'success': False,
            'status': 'Download starting'
        },
                                       lock=True)
        self._cleanup_thread = None  # type: Optional[threading.Thread]

    def start(self) -> None:
        """
        Start the image manager.
        If the _max_cache_image_size argument is defined, the image manager will
            clean up the cache where images are held.

        Returns: None

        """
        logger.info("Starting image manager")
        if self._max_image_cache_size:

            def cleanup_loop(self):
                while not self._stop:
                    try:
                        self._cleanup()
                    except Exception:
                        traceback.print_exc()
                    time.sleep(self._sleep_secs)

            self._cleanup_thread = threading.Thread(target=cleanup_loop,
                                                    args=[self])
            self._cleanup_thread.start()
        pass

    def stop(self) -> None:
        """
        stop will stop the running cleanup loop, and therefore, stop the image manager.
        Returns: None
        """
        logger.info("Stopping image manager")
        self._stop = True
        logger.debug("Stopping image manager: stop the downloads threads")
        self._downloading.stop()
        if self._cleanup_thread:
            logger.debug("Stopping image manager: stop the cleanup thread")
            self._cleanup_thread.join()
        logger.info("Stopped image manager")
        pass

    def get(self, image_spec: str) -> ImageAvailabilityState:
        """
        Always request the newest image from the cloud if it's not in downloading thread and return the current
        downloading status(READY, FAILED, or DOWNLOADING).
        When the requested image in the following states:
        1. If it's not available on the platform, we download the image and return DOWNLOADING status.
        2. If another thread is actively downloading it, we return DOWNLOADING status.
        3. If another thread was downloading it but not active by the time the request was sent, we return the following status:
            * READY if the image was downloaded successfully.
            * FAILED if the image wasn't able to be downloaded due to any reason.
        Args:
            image_spec: the image that the requester needs.
                The caller will need to determine the type of image they need before calling this function.
                It is usually safe to prefix the image with the type of image.
                For example, the docker image go would be docker://go

        Returns:
            ImageAvailabilityState of the image requested.
        """
        try:
            if image_spec in self._downloading:
                with self._downloading[image_spec]['lock']:
                    if self._downloading[image_spec].is_alive():
                        return ImageAvailabilityState(
                            digest=None,
                            stage=DependencyStage.DOWNLOADING,
                            message=self._downloading[image_spec]['status'],
                        )
                    else:
                        if self._downloading[image_spec]['success']:
                            status = self._image_availability_state(
                                image_spec,
                                success_message='Image ready',
                                failure_message=
                                'Image {} was downloaded successfully, '
                                'but it cannot be found locally due to unhandled error %s'
                                .format(image_spec),
                            )
                        else:
                            status = self._image_availability_state(
                                image_spec,
                                success_message=
                                'Image {} can not be downloaded from the cloud '
                                'but it is found locally'.format(image_spec),
                                failure_message=self._downloading[image_spec]
                                ['message'] + ": %s",
                            )
                        self._downloading.remove(image_spec)
                        return status
            else:
                if self._max_image_size:
                    try:
                        try:
                            image_size_bytes = self._image_size_without_pulling(
                                image_spec)
                        except NotImplementedError:
                            failure_msg = (
                                "Could not query size of {} from container runtime hub. "
                                "Skipping size precheck.".format(image_spec))
                            logger.info(failure_msg)
                            image_size_bytes = 0
                        if image_size_bytes > self._max_image_size:
                            failure_msg = (
                                "The size of " + image_spec +
                                ": {} exceeds the maximum image size allowed {}."
                                .format(size_str(image_size_bytes),
                                        size_str(self._max_image_size)))
                            logger.error(failure_msg)
                            return ImageAvailabilityState(
                                digest=None,
                                stage=DependencyStage.FAILED,
                                message=failure_msg)
                    except Exception as ex:
                        failure_msg = "Cannot fetch image size before pulling Docker image: {} from Docker Hub: {}.".format(
                            image_spec, ex)
                        logger.error(failure_msg)
                        return ImageAvailabilityState(
                            digest=None,
                            stage=DependencyStage.FAILED,
                            message=failure_msg)

            self._downloading.add_if_new(
                image_spec,
                threading.Thread(target=self._download, args=[image_spec]))
            return ImageAvailabilityState(
                digest=None,
                stage=DependencyStage.DOWNLOADING,
                message=self._downloading[image_spec]['status'],
            )
        except Exception as ex:
            logger.error(ex)
            return ImageAvailabilityState(digest=None,
                                          stage=DependencyStage.FAILED,
                                          message=str(ex))

    def _cleanup(self):
        """
        Prune and clean up images in accordance with the image cache and
            how much space the images take up. The logic will vary per container runtime.
            Therefore, this function should be implemented in the subclasses.
        """
        raise NotImplementedError

    def _image_availability_state(
            self, image_spec: str, success_message: str,
            failure_message: str) -> ImageAvailabilityState:
        """
        Try to get the image specified by image_spec from host machine.
        The message field of the ImageAvailabiltyState will contain success_message or failure_message and the respective
        stages of DependencyStage.READY or DependencyStage.FAILED depending on whether or not the image can be found locally.
        If the image can be found locally, the image is considered "ready", and "failed" if not.
        Args:
            image_spec: container-runtime specific spec
            success_message: message to return in the ImageAvailabilityState if the image was successfully found
            failure_message: message to return if there were any issues in getting the image.

        Returns: ImageAvailabilityState

        """
        raise NotImplementedError

    def _download(self, image_spec: str) -> None:
        """
        Download the container image from the cloud to the host machine.
        This function needs to be implemented specific to the container runtime.
        For instance:
            - _download's docker image implementation should pull from DockerHub.
            - _downloads's singularity image implementation shoudl pull from the singularity hub or
                sylab's cloud hub, based on the image scheme.
        This function will update the _downloading ThreadDict with the status and progress of the
            download.
        Args:
            image_spec: container-runtime specific image specification

        Returns: None

        """
        raise NotImplementedError

    def _image_size_without_pulling(self, image_spec: str):
        """
        Attempt to query the requested image's size, based on the container runtime.
        Args:
            image_spec: image specification

        Returns: None if the image size cannot be queried, or the size of the image(float).

        """
        raise NotImplementedError
Пример #7
0
class DockerImageManager:
    def __init__(self, commit_file, max_image_cache_size):
        """
        Initializes a DockerImageManager
        :param commit_file: String path to where the state file should be committed
        :param max_image_cache_size: Total size in bytes that the image cache can use
        """
        self._state_committer = JsonStateCommitter(
            commit_file)  # type: JsonStateCommitter
        self._docker = docker.from_env()  # type: DockerClient
        self._image_cache = {}  # type: Dict[str, ImageCacheEntry]
        self._downloading = ThreadDict(fields={
            'success': False,
            'status': 'Download starting.'
        },
                                       lock=True)
        self._max_image_cache_size = max_image_cache_size
        self._lock = threading.RLock()

        self._stop = False
        self._sleep_secs = 10
        self._cleanup_thread = None

        self._load_state()

    def _save_state(self):
        with self._lock:
            self._state_committer.commit(self._image_cache)

    def _load_state(self):
        with self._lock:
            self._image_cache = self._state_committer.load()

    def start(self):
        logger.info("Starting docker image manager")
        if self._max_image_cache_size:

            def cleanup_loop(self):
                while not self._stop:
                    try:
                        self._cleanup()
                        self._save_state()
                    except Exception:
                        traceback.print_exc()
                    time.sleep(self._sleep_secs)

            self._cleanup_thread = threading.Thread(target=cleanup_loop,
                                                    args=[self])
            self._cleanup_thread.start()

    def stop(self):
        logger.info("Stopping docker image manager")
        self._stop = True
        logger.debug(
            "Stopping docker image manager: stop the downloads threads")
        self._downloading.stop()
        if self._cleanup_thread:
            logger.debug(
                "Stopping docker image manager: stop the cleanup thread")
            self._cleanup_thread.join()
        logger.info("Stopped docker image manager")

    def _cleanup(self):
        """
        Prunes the image cache for runs.
        1. Only care about images we (this DockerImageManager) downloaded and know about
        2. We use sum of VirtualSize's, which is an upper bound on the disk use of our images:
            in case no images share any intermediate layers, this will be the real disk use,
            however if images share layers, the virtual size will count that layer's size for each
            image that uses it, even though it's stored only once in the disk. The 'Size' field
            accounts for the marginal size each image adds on top of the shared layers, but summing
            those is not accurate either since the shared base layers need to be counted once to get
            the total size. (i.e. summing marginal sizes would give us a lower bound on the total disk
            use of images). Calling df gives us an accurate disk use of ALL the images on the machine
            but because of (1) we don't want to use that.
        """
        while not self._stop:
            deletable_entries = set(self._image_cache.values())
            disk_use = sum(cache_entry.virtual_size
                           for cache_entry in deletable_entries)
            while disk_use > self._max_image_cache_size:
                entry_to_remove = min(deletable_entries,
                                      key=lambda entry: entry.last_used)
                logger.info(
                    'Disk use (%s) > max cache size (%s), pruning image: %s',
                    disk_use,
                    self._max_image_cache_size,
                    entry_to_remove.digest,
                )
                try:
                    image_to_delete = self._docker.images.get(
                        entry_to_remove.id)
                    tags_to_delete = image_to_delete.tags
                    for tag in tags_to_delete:
                        self._docker.images.remove(tag)
                    # if we successfully removed the image also remove its cache entry
                    del self._image_cache[entry_to_remove.digest]
                except docker.errors.NotFound:
                    # image doesn't exist anymore for some reason, stop tracking it
                    del self._image_cache[entry_to_remove.digest]
                except docker.errors.APIError as err:
                    # Maybe we can't delete this image because its container is still running
                    # (think a run that takes 4 days so this is the oldest image but still in use)
                    # In that case we just continue with our lives, hoping it will get deleted once
                    # it's no longer in use and the cache becomes full again
                    logger.error("Cannot remove image %s from cache: %s",
                                 entry_to_remove.digest, err)
                deletable_entries.remove(entry_to_remove)
                disk_use = sum(entry.virtual_size
                               for entry in deletable_entries)
        logger.debug("Stopping docker image manager cleanup")

    def get(self, image_spec):
        """
        Request the docker image for the run with uuid, registering uuid as a dependent of this docker image
        :param image_spec: Repo image_spec of docker image being requested
        :returns: A DockerAvailabilityState object with the state of the docker image
        """
        if ':' not in image_spec:
            # Both digests and repo:tag kind of specs include the : character. The only case without it is when
            # a repo is specified without a tag (like 'latest')
            # When this is the case, different images API methods act differently:
            # - pull pulls all tags of the image
            # - get tries to get `latest` by default
            # That means if someone requests a docker image without a tag, and the image does not have a latest
            # tag pushed to Dockerhub, pull will succeed since it will pull all other tags, but later get calls
            # will fail since the `latest` tag won't be found on the system.
            # We don't want to assume what tag the user wanted so we want the pull step to fail if no tag is specified
            # and there's no latest tag on dockerhub.
            # Hence, we append the latest tag to the image spec if there's no tag specified otherwise at the very beginning
            image_spec += ':latest'
        try:
            image = self._docker.images.get(image_spec)
            digests = image.attrs.get('RepoDigests', [image_spec])
            if len(digests) == 0:
                return ImageAvailabilityState(
                    digest=None,
                    stage=DependencyStage.FAILED,
                    message=
                    'No digest available for {}, probably because it was built locally; delete the Docker image on the worker and try again'
                    .format(image_spec),
                )
            digest = digests[0]
            with self._lock:
                self._image_cache[digest] = ImageCacheEntry(
                    id=image.id,
                    digest=digest,
                    last_used=time.time(),
                    virtual_size=image.attrs['VirtualSize'],
                    marginal_size=image.attrs['Size'],
                )
            # We can remove the download thread if it still exists
            if image_spec in self._downloading:
                self._downloading.remove(image_spec)
            return ImageAvailabilityState(digest=digest,
                                          stage=DependencyStage.READY,
                                          message='Image ready')
        except docker.errors.ImageNotFound:
            return self._pull_or_report(
                image_spec)  # type: DockerAvailabilityState
        except Exception as ex:
            return ImageAvailabilityState(digest=None,
                                          stage=DependencyStage.FAILED,
                                          message=str(ex))

    def _pull_or_report(self, image_spec):
        if image_spec in self._downloading:
            with self._downloading[image_spec]['lock']:
                if self._downloading[image_spec].is_alive():
                    return ImageAvailabilityState(
                        digest=None,
                        stage=DependencyStage.DOWNLOADING,
                        message=self._downloading[image_spec]['status'],
                    )
                else:
                    if self._downloading[image_spec]['success']:
                        digest = self._docker.images.get(image_spec).attrs.get(
                            'RepoDigests', [image_spec])[0]
                        status = ImageAvailabilityState(
                            digest=digest,
                            stage=DependencyStage.READY,
                            message=self._downloading[image_spec]['message'],
                        )
                    else:
                        status = ImageAvailabilityState(
                            digest=None,
                            stage=DependencyStage.FAILED,
                            message=self._downloading[image_spec]['message'],
                        )
                    self._downloading.remove(image_spec)
                    return status
        else:

            def download():
                logger.debug('Downloading Docker image %s', image_spec)
                try:
                    self._docker.images.pull(image_spec)
                    logger.debug('Download for Docker image %s complete',
                                 image_spec)
                    self._downloading[image_spec]['success'] = True
                    self._downloading[image_spec][
                        'message'] = "Downloading image"
                except (docker.errors.APIError,
                        docker.errors.ImageNotFound) as ex:
                    logger.debug('Download for Docker image %s failed: %s',
                                 image_spec, ex)
                    self._downloading[image_spec]['success'] = False
                    self._downloading[image_spec][
                        'message'] = "Can't download image: {}".format(ex)

            self._downloading.add_if_new(
                image_spec, threading.Thread(target=download, args=[]))
            return ImageAvailabilityState(
                digest=None,
                stage=DependencyStage.DOWNLOADING,
                message=self._downloading[image_spec]['status'],
            )