Ejemplo n.º 1
0
    def __init__(self, commit_file, bundle_service, worker_dir, max_cache_size_bytes):
        super(DependencyManager, self).__init__()
        self.add_transition(DependencyStage.DOWNLOADING, self._transition_from_DOWNLOADING)
        self.add_terminal(DependencyStage.READY)
        self.add_terminal(DependencyStage.FAILED)

        self._state_committer = JsonStateCommitter(commit_file)
        self._bundle_service = bundle_service
        self._max_cache_size_bytes = max_cache_size_bytes
        self.dependencies_dir = os.path.join(worker_dir, DependencyManager.DEPENDENCIES_DIR_NAME)
        if not os.path.exists(self.dependencies_dir):
            logger.info('{} doesn\'t exist, creating.'.format(self.dependencies_dir))
            os.makedirs(self.dependencies_dir, 0o770)

        # Locks for concurrency
        self._dependency_locks = dict()  # type: Dict[DependencyKey, threading.RLock]
        self._global_lock = threading.RLock()  # Used for add/remove actions
        self._paths_lock = threading.RLock()  # Used for path name computations

        # File paths that are currently being used to store dependencies. Used to prevent conflicts
        self._paths = set()
        # DependencyKey -> DependencyState
        self._dependencies = dict()
        # DependencyKey -> WorkerThread(thread, success, failure_message)
        self._downloading = ThreadDict(fields={'success': False, 'failure_message': None})
        self._load_state()
        # Sync states between dependency-state.json and dependency directories on the local file system.
        self._sync_state()

        self._stop = False
        self._main_thread = None
    def __init__(
        self,
        commit_file: str,
        bundle_service: BundleServiceClient,
        worker_dir: str,
        max_cache_size_bytes: int,
        download_dependencies_max_retries: int,
    ):
        super(DependencyManager, self).__init__()
        self.add_transition(DependencyStage.DOWNLOADING, self._transition_from_DOWNLOADING)
        self.add_terminal(DependencyStage.READY)
        self.add_terminal(DependencyStage.FAILED)

        self._id: str = "worker-dependency-manager-{}".format(uuid.uuid4().hex[:8])
        self._state_committer = JsonStateCommitter(commit_file)
        self._bundle_service = bundle_service
        self._max_cache_size_bytes = max_cache_size_bytes
        self.dependencies_dir = os.path.join(worker_dir, DependencyManager.DEPENDENCIES_DIR_NAME)
        self._download_dependencies_max_retries = download_dependencies_max_retries
        if not os.path.exists(self.dependencies_dir):
            logger.info('{} doesn\'t exist, creating.'.format(self.dependencies_dir))
            os.makedirs(self.dependencies_dir, 0o770)

        # Create a lock for concurrency over NFS
        # Create a separate locks directory to hold the lock files.
        # Each lock file is created when a process tries to claim the main lock.
        locks_claims_dir: str = os.path.join(worker_dir, 'locks_claims')
        try:
            os.makedirs(locks_claims_dir)
        except FileExistsError:
            logger.info(f"A locks directory at {locks_claims_dir} already exists.")
        self._state_lock = NFSLock(os.path.join(locks_claims_dir, 'state.lock'))

        # File paths that are currently being used to store dependencies. Used to prevent conflicts
        self._paths: Set[str] = set()
        # DependencyKey -> WorkerThread(thread, success, failure_message)
        self._downloading = ThreadDict(fields={'success': False, 'failure_message': None})
        # Sync states between dependency-state.json and dependency directories on the local file system.
        self._sync_state()

        self._stop = False
        self._main_thread = None
        logger.info(f"Initialized Dependency Manager with ID: {self._id}")
Ejemplo n.º 3
0
class JsonStateCommitterTest(unittest.TestCase):
    def setUp(self):
        self.test_dir = tempfile.mkdtemp()
        self.state_file = 'test-state.json'
        self.state_path = os.path.join(self.test_dir, self.state_file)
        self.committer = JsonStateCommitter(self.state_path)

    def tearDown(self):
        try:
            os.remove(self.state_path)
        except OSError:
            pass
        os.rmdir(self.test_dir)

    def test_path_parsing(self):
        """ Simple test to ensure we don't mess up the state file path"""
        self.assertEqual(self.committer._state_file, self.state_path)

    def test_commit(self):
        """Make sure state is committed correctly"""
        test_state = {'state': 'value'}
        test_state_json_str = '{\"state\": \"value\"}'
        self.committer.commit(test_state)
        with open(self.state_path) as f:
            self.assertEqual(test_state_json_str, f.read())
        self.assertFalse(os.path.exists(self.committer.temp_file))

    def test_load(self):
        """ Make sure load loads the state file if it exists """
        test_state = {'state': 'value'}
        test_state_json_str = '{\"state\": \"value\"}'
        with open(self.state_path, 'w') as f:
            f.write(test_state_json_str)
        loaded_state = self.committer.load()
        self.assertDictEqual(test_state, loaded_state)

    def test_default(self):
        """ Make sure load with a default works if state file doesn't exist """
        default_state = {'state': 'value'}
        loaded_state = self.committer.load(default=default_state)
        self.assertDictEqual(default_state, loaded_state)
    def __init__(
        self,
        worker,  # type: Worker
        image_manager,  # type: DockerImageManager
        dependency_manager,  # type: LocalFileSystemDependencyManager
        commit_file,  # type: str
        cpuset,  # type: Set[str]
        gpuset,  # type: Set[str]
        work_dir,  # type: str
        docker_runtime=docker_utils.DEFAULT_RUNTIME,  # type: str
        docker_network_prefix='codalab_worker_network',  # type: str
    ):
        self._worker = worker
        self._state_committer = JsonStateCommitter(commit_file)
        self._reader = LocalReader()
        self._docker = docker.from_env()
        self._bundles_dir = os.path.join(work_dir, LocalRunManager.BUNDLES_DIR_NAME)
        if not os.path.exists(self._bundles_dir):
            logger.info('{} doesn\'t exist, creating.'.format(self._bundles_dir))
            os.makedirs(self._bundles_dir, 0o770)

        self._image_manager = image_manager
        self._dependency_manager = dependency_manager
        self._cpuset = cpuset
        self._gpuset = gpuset
        self._stop = False
        self._work_dir = work_dir

        self._runs = {}  # bundle_uuid -> LocalRunState
        self._lock = threading.RLock()
        self._init_docker_networks(docker_network_prefix)
        self._run_state_manager = LocalRunStateMachine(
            docker_image_manager=self._image_manager,
            dependency_manager=self._dependency_manager,
            worker_docker_network=self.worker_docker_network,
            docker_network_internal=self.docker_network_internal,
            docker_network_external=self.docker_network_external,
            docker_runtime=docker_runtime,
            upload_bundle_callback=self._worker.upload_bundle_contents,
            assign_cpu_and_gpu_sets_fn=self.assign_cpu_and_gpu_sets,
        )
 def __init__(self, commit_file: str, max_image_cache_size: int,
              max_image_size: int):
     """
     Initializes a DockerImageManager
     :param commit_file: String path to where the state file should be committed
     :param max_image_cache_size: Total size in bytes that the image cache can use
     :param max_image_size: Total size in bytes that the image can have
     """
     super().__init__(max_image_size, max_image_cache_size)
     self._state_committer = JsonStateCommitter(
         commit_file)  # type: JsonStateCommitter
     self._docker = docker.from_env(
         timeout=DEFAULT_DOCKER_TIMEOUT)  # type: DockerClient
Ejemplo n.º 6
0
    def __init__(self, commit_file, max_image_cache_size):
        """
        Initializes a DockerImageManager
        :param commit_file: String path to where the state file should be committed
        :param max_image_cache_size: Total size in bytes that the image cache can use
        """
        self._state_committer = JsonStateCommitter(
            commit_file)  # type: JsonStateCommitter
        self._docker = docker.from_env()  # type: DockerClient
        self._image_cache = {}  # type: Dict[str, ImageCacheEntry]
        self._downloading = ThreadDict(fields={
            'success': False,
            'status': 'Download starting.'
        },
                                       lock=True)
        self._max_image_cache_size = max_image_cache_size
        self._lock = threading.RLock()

        self._stop = False
        self._sleep_secs = 10
        self._cleanup_thread = None

        self._load_state()
    def __init__(self, commit_file, max_image_cache_size, max_image_size):
        """
        Initializes a DockerImageManager
        :param commit_file: String path to where the state file should be committed
        :param max_image_cache_size: Total size in bytes that the image cache can use
        :param max_image_size: Total size in bytes that the image can have
        """
        self._state_committer = JsonStateCommitter(commit_file)  # type: JsonStateCommitter
        self._docker = docker.from_env(timeout=DEFAULT_DOCKER_TIMEOUT)  # type: DockerClient
        self._downloading = ThreadDict(
            fields={'success': False, 'status': 'Download starting'}, lock=True
        )
        self._max_image_cache_size = max_image_cache_size
        self._max_image_size = max_image_size

        self._stop = False
        self._sleep_secs = 10
        self._cleanup_thread = None
Ejemplo n.º 8
0
 def setUp(self):
     self.test_dir = tempfile.mkdtemp()
     self.state_file = 'test-state.json'
     self.state_path = os.path.join(self.test_dir, self.state_file)
     self.committer = JsonStateCommitter(self.state_path)
Ejemplo n.º 9
0
class DependencyManager(StateTransitioner, BaseDependencyManager):
    """
    This dependency manager downloads dependency bundles from Codalab server
    to the local filesystem. It caches all downloaded dependencies but cleans up the
    old ones if the disk use hits the given threshold

    For this class dependencies are uniquely identified by DependencyKey
    """

    DEPENDENCIES_DIR_NAME = 'dependencies'
    DEPENDENCY_FAILURE_COOLDOWN = 10
    # TODO(bkgoksel): The server writes these to the worker_dependencies table, which stores the dependencies
    # json as a SqlAlchemy LargeBinary, which defaults to MySQL BLOB, which has a size limit of
    # 65K. For now we limit this value to about 58K to avoid any issues but we probably want to do
    # something better (either specify MEDIUMBLOB in the SqlAlchemy definition of the table or change
    # the data format of how we store this)
    MAX_SERIALIZED_LEN = 60000

    def __init__(self, commit_file, bundle_service, worker_dir,
                 max_cache_size_bytes):
        super(DependencyManager, self).__init__()
        self.add_transition(DependencyStage.DOWNLOADING,
                            self._transition_from_DOWNLOADING)
        self.add_terminal(DependencyStage.READY)
        self.add_terminal(DependencyStage.FAILED)

        self._state_committer = JsonStateCommitter(commit_file)
        self._bundle_service = bundle_service
        self._max_cache_size_bytes = max_cache_size_bytes
        self.dependencies_dir = os.path.join(
            worker_dir, DependencyManager.DEPENDENCIES_DIR_NAME)
        if not os.path.exists(self.dependencies_dir):
            logger.info('{} doesn\'t exist, creating.'.format(
                self.dependencies_dir))
            os.makedirs(self.dependencies_dir, 0o770)

        # Locks for concurrency
        self._dependency_locks = dict(
        )  # type: Dict[DependencyKey, threading.RLock]
        self._global_lock = threading.RLock()  # Used for add/remove actions
        self._paths_lock = threading.RLock()  # Used for path name computations

        # File paths that are currently being used to store dependencies. Used to prevent conflicts
        self._paths = set()
        # DependencyKey -> DependencyState
        self._dependencies = dict()
        # DependencyKey -> WorkerThread(thread, success, failure_message)
        self._downloading = ThreadDict(fields={
            'success': False,
            'failure_message': None
        })
        self._load_state()
        # Sync states between dependency-state.json and dependency directories on the local file system.
        self._sync_state()

        self._stop = False
        self._main_thread = None

    def _save_state(self):
        with self._global_lock, self._paths_lock:
            self._state_committer.commit({
                'dependencies': self._dependencies,
                'paths': self._paths
            })

    def _load_state(self):
        """
        Load states from dependencies-state.json, which contains information about bundles (e.g., state, dependencies,
        last used, etc.) and populates values for self._dependencies, self._dependency_locks, and self._paths
        """
        state = self._state_committer.load(default={
            'dependencies': {},
            'paths': set()
        })

        dependencies = {}
        dependency_locks = {}

        for dep, dep_state in state['dependencies'].items():
            dependencies[dep] = dep_state
            dependency_locks[dep] = threading.RLock()

        with self._global_lock, self._paths_lock:
            self._dependencies = dependencies
            self._dependency_locks = dependency_locks
            self._paths = state['paths']

        logger.info('Loaded {} dependencies, {} paths from cache.'.format(
            len(self._dependencies), len(self._paths)))

    def _sync_state(self):
        """
        Synchronize dependency states between dependencies-state.json and the local file system as follows:
        1. self._dependencies, self._dependency_locks, and self._paths: populated from dependencies-state.json
            in function _load_state()
        2. directories on the local file system: the bundle contents
        This function forces the 1 and 2 to be in sync by taking the intersection (e.g., deleting bundles from the
        local file system that don't appear in the dependencies-state.json and vice-versa)
        """
        # Get the paths that exist in dependency state, loaded path and
        # the local file system (the dependency directories under self.dependencies_dir)
        local_directories = set(os.listdir(self.dependencies_dir))
        paths_in_loaded_state = [
            dep_state.path for dep_state in self._dependencies.values()
        ]
        self._paths = self._paths.intersection(
            paths_in_loaded_state).intersection(local_directories)

        # Remove the orphaned dependencies from self._dependencies and
        # self._dependency_locks if they don't exist in self._paths (intersection of paths in dependency state,
        # loaded paths and the paths on the local file system)
        dependencies_to_remove = [
            dep for dep, dep_state in self._dependencies.items()
            if dep_state.path not in self._paths
        ]
        for dep in dependencies_to_remove:
            logger.info(
                "Dependency {} in dependency state but its path {} doesn't exist on the local file system. "
                "Removing it from dependency state.".format(
                    dep,
                    os.path.join(self.dependencies_dir,
                                 self._dependencies[dep].path)))
            del self._dependencies[dep]
            del self._dependency_locks[dep]

        # Remove the orphaned directories from the local file system
        directories_to_remove = local_directories - self._paths
        for dir in directories_to_remove:
            full_path = os.path.join(self.dependencies_dir, dir)
            logger.info(
                "Remove orphaned directory {} from the local file system.".
                format(full_path))
            remove_path(full_path)

        # Save the current synced state back to the state file: dependency-state.json as
        # the current state might have been changed during the state syncing phase
        self._save_state()

    def start(self):
        logger.info('Starting local dependency manager')

        def loop(self):
            while not self._stop:
                try:
                    self._process_dependencies()
                    self._save_state()
                    self._cleanup()
                    self._save_state()
                except Exception:
                    traceback.print_exc()
                time.sleep(1)

        self._main_thread = threading.Thread(target=loop, args=[self])
        self._main_thread.start()

    def stop(self):
        logger.info('Stopping local dependency manager')
        self._stop = True
        self._downloading.stop()
        self._main_thread.join()
        logger.info('Stopped local dependency manager')

    def _process_dependencies(self):
        for dep_key, dep_state in self._dependencies.items():
            with self._dependency_locks[dep_key]:
                self._dependencies[dep_key] = self.transition(dep_state)

    def _prune_failed_dependencies(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds so that further runs
        get to retry the download. Without pruning, any future run depending on a
        failed dependency would automatically fail indefinitely.
        """
        with self._global_lock:
            self._acquire_all_locks()
            failed_deps = {
                dep_key: dep_state
                for dep_key, dep_state in self._dependencies.items()
                if dep_state.stage == DependencyStage.FAILED and time.time() -
                dep_state.last_used >
                DependencyManager.DEPENDENCY_FAILURE_COOLDOWN
            }
            for dep_key, dep_state in failed_deps.items():
                self._delete_dependency(dep_key)
            self._release_all_locks()

    def _cleanup(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds.
        Limit the disk usage of the dependencies (both the bundle files and the serialized state file size)
        Deletes oldest failed dependencies first and then oldest finished dependencies.
        Doesn't touch downloading dependencies.
        """
        self._prune_failed_dependencies()
        # With all the locks (should be fast if no cleanup needed, otherwise make sure nothing is corrupted
        while True:
            with self._global_lock:
                self._acquire_all_locks()
                bytes_used = sum(dep_state.size_bytes
                                 for dep_state in self._dependencies.values())
                serialized_length = len(
                    codalab.worker.pyjson.dumps(self._dependencies))
                if (bytes_used > self._max_cache_size_bytes
                        or serialized_length >
                        DependencyManager.MAX_SERIALIZED_LEN):
                    logger.debug(
                        '%d dependencies in cache, disk usage: %s (max %s), serialized size: %s (max %s)',
                        len(self._dependencies),
                        size_str(bytes_used),
                        size_str(self._max_cache_size_bytes),
                        size_str(serialized_length),
                        DependencyManager.MAX_SERIALIZED_LEN,
                    )
                    ready_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in self._dependencies.items()
                        if dep_state.stage == DependencyStage.READY
                        and not dep_state.dependents
                    }
                    failed_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in self._dependencies.items()
                        if dep_state.stage == DependencyStage.FAILED
                    }
                    if failed_deps:
                        dep_key_to_remove = min(
                            failed_deps.items(),
                            key=lambda dep: dep[1].last_used)[0]
                    elif ready_deps:
                        dep_key_to_remove = min(
                            ready_deps.items(),
                            key=lambda dep: dep[1].last_used)[0]
                    else:
                        logger.info(
                            'Dependency quota full but there are only downloading dependencies, not cleaning up until downloads are over'
                        )
                        self._release_all_locks()
                        break
                    if dep_key_to_remove:
                        self._delete_dependency(dep_key_to_remove)
                    self._release_all_locks()
                else:
                    self._release_all_locks()
                    break

    def _delete_dependency(self, dependency_key):
        """
        Remove the given dependency from the manager's state
        Also delete any known files on the filesystem if any exist
        """
        if self._acquire_if_exists(dependency_key):
            try:
                path_to_remove = self._dependencies[dependency_key].path
                self._paths.remove(path_to_remove)
                remove_path(path_to_remove)
            except Exception:
                pass
            finally:
                del self._dependencies[dependency_key]
                self._dependency_locks[dependency_key].release()

    def has(self, dependency_key):
        """
        Takes a DependencyKey
        Returns true if the manager has processed this dependency
        """
        with self._global_lock:
            return dependency_key in self._dependencies

    def get(self, uuid, dependency_key):
        """
        Request the dependency for the run with uuid, registering uuid as a dependent of this dependency
        """
        now = time.time()
        if not self._acquire_if_exists(
                dependency_key):  # add dependency state if it does not exist
            with self._global_lock:
                self._dependency_locks[dependency_key] = threading.RLock()
                self._dependency_locks[dependency_key].acquire()
                self._dependencies[dependency_key] = DependencyState(
                    stage=DependencyStage.DOWNLOADING,
                    dependency_key=dependency_key,
                    path=self._assign_path(dependency_key),
                    size_bytes=0,
                    dependents=set([uuid]),
                    last_used=now,
                    message="Starting download",
                    killed=False,
                )

        # update last_used as long as it isn't in FAILED
        if self._dependencies[dependency_key].stage != DependencyStage.FAILED:
            self._dependencies[dependency_key].dependents.add(uuid)
            self._dependencies[dependency_key] = self._dependencies[
                dependency_key]._replace(last_used=now)
        self._dependency_locks[dependency_key].release()
        return self._dependencies[dependency_key]

    def release(self, uuid, dependency_key):
        """
        Register that the run with uuid is no longer dependent on this dependency
        If no more runs are dependent on this dependency, kill it
        """
        if self._acquire_if_exists(dependency_key):
            dep_state = self._dependencies[dependency_key]
            if uuid in dep_state.dependents:
                dep_state.dependents.remove(uuid)
            if not dep_state.dependents:
                dep_state = dep_state._replace(killed=True)
                self._dependencies[dependency_key] = dep_state
            self._dependency_locks[dependency_key].release()

    def _acquire_if_exists(self, dependency_key):
        """
        Safely acquires a lock for the given dependency if it exists
        Returns True if depedendency exists, False otherwise
        Callers should remember to release the lock
        """
        with self._global_lock:
            if dependency_key in self._dependencies:
                self._dependency_locks[dependency_key].acquire()
                return True
            else:
                return False

    def _acquire_all_locks(self):
        """
        Acquires all dependency locks in the thread it's called from
        """
        with self._global_lock:
            for dependency, lock in self._dependency_locks.items():
                lock.acquire()

    def _release_all_locks(self):
        """
        Releases all dependency locks in the thread it's called from
        """
        with self._global_lock:
            for dependency, lock in self._dependency_locks.items():
                lock.release()

    def _assign_path(self, dependency_key):
        """
        Normalize the path for the dependency by replacing / with _, avoiding conflicts
        """
        if dependency_key.parent_path:
            path = os.path.join(dependency_key.parent_uuid,
                                dependency_key.parent_path)
        else:
            path = dependency_key.parent_uuid
        path = path.replace(os.path.sep, '_')

        # You could have a conflict between, for example a/b_c and
        # a_b/c. We have to avoid those.
        with self._paths_lock:
            while path in self._paths:
                path = path + '_'
            self._paths.add(path)
        return path

    def _store_dependency(self, dependency_path, fileobj, target_type):
        """
        Copy the dependency fileobj to its path on the local filesystem
        Overwrite existing files by the same name if found
        (may happen if filesystem modified outside the dependency manager,
         for example during an update if the state gets reset but filesystem
         doesn't get cleared)
        """
        try:
            if os.path.exists(dependency_path):
                logger.info('Path %s already exists, overwriting',
                            dependency_path)
                if os.path.isdir(dependency_path):
                    shutil.rmtree(dependency_path)
                else:
                    os.remove(dependency_path)
            if target_type == 'directory':
                un_tar_directory(fileobj, dependency_path, 'gz')
            else:
                with open(dependency_path, 'wb') as f:
                    logger.debug('copying file to %s', dependency_path)
                    shutil.copyfileobj(fileobj, f)
        except Exception:
            raise

    @property
    def all_dependencies(self):
        with self._global_lock:
            return list(self._dependencies.keys())

    def _transition_from_DOWNLOADING(self, dependency_state):
        def download():
            def update_state_and_check_killed(bytes_downloaded):
                """
                Callback method for bundle service client updates dependency state and
                raises DownloadAbortedException if download is killed by dep. manager
                """
                with self._dependency_locks[dependency_state.dependency_key]:
                    state = self._dependencies[dependency_state.dependency_key]
                    if state.killed:
                        raise DownloadAbortedException("Aborted by user")
                    self._dependencies[
                        dependency_state.dependency_key] = state._replace(
                            size_bytes=bytes_downloaded,
                            message="Downloading dependency: %s downloaded" %
                            size_str(bytes_downloaded),
                        )

            dependency_path = os.path.join(self.dependencies_dir,
                                           dependency_state.path)
            logger.debug('Downloading dependency %s',
                         dependency_state.dependency_key)
            try:
                # Start async download to the fileobj
                fileobj, target_type = self._bundle_service.get_bundle_contents(
                    dependency_state.dependency_key.parent_uuid,
                    dependency_state.dependency_key.parent_path,
                )
                with closing(fileobj):
                    # "Bug" the fileobj's read function so that we can keep
                    # track of the number of bytes downloaded so far.
                    old_read_method = fileobj.read
                    bytes_downloaded = [0]

                    def interruptable_read(*args, **kwargs):
                        data = old_read_method(*args, **kwargs)
                        bytes_downloaded[0] += len(data)
                        update_state_and_check_killed(bytes_downloaded[0])
                        return data

                    fileobj.read = interruptable_read

                    # Start copying the fileobj to filesystem dependency path
                    self._store_dependency(dependency_path, fileobj,
                                           target_type)

                logger.debug(
                    'Finished downloading %s dependency %s to %s',
                    target_type,
                    dependency_state.dependency_key,
                    dependency_path,
                )
                with self._dependency_locks[dependency_state.dependency_key]:
                    self._downloading[
                        dependency_state.dependency_key]['success'] = True

            except Exception as e:
                with self._dependency_locks[dependency_state.dependency_key]:
                    self._downloading[
                        dependency_state.dependency_key]['success'] = False
                    self._downloading[dependency_state.dependency_key][
                        'failure_message'] = "Dependency download failed: %s " % str(
                            e)

        self._downloading.add_if_new(
            dependency_state.dependency_key,
            threading.Thread(target=download, args=[]))

        if self._downloading[dependency_state.dependency_key].is_alive():
            return dependency_state

        success = self._downloading[dependency_state.dependency_key]['success']
        failure_message = self._downloading[
            dependency_state.dependency_key]['failure_message']

        self._downloading.remove(dependency_state.dependency_key)
        if success:
            return dependency_state._replace(stage=DependencyStage.READY,
                                             message="Download complete")
        else:
            with self._paths_lock:
                self._paths.remove(dependency_state.path)
            return dependency_state._replace(stage=DependencyStage.FAILED,
                                             message=failure_message)
class DependencyManager(StateTransitioner, BaseDependencyManager):
    """
    This dependency manager downloads dependency bundles from Codalab server
    to the local filesystem. It caches all downloaded dependencies but cleans up the
    old ones if the disk use hits the given threshold. It's also NFS-safe.
    In this class, dependencies are uniquely identified by DependencyKey.
    """

    DEPENDENCIES_DIR_NAME = 'dependencies'
    DEPENDENCY_FAILURE_COOLDOWN = 10
    # TODO(bkgoksel): The server writes these to the worker_dependencies table, which stores the dependencies
    # json as a SqlAlchemy LargeBinary, which defaults to MySQL BLOB, which has a size limit of
    # 65K. For now we limit this value to about 58K to avoid any issues but we probably want to do
    # something better (either specify MEDIUMBLOB in the SqlAlchemy definition of the table or change
    # the data format of how we store this)
    MAX_SERIALIZED_LEN = 60000

    # If it has been this long since a worker has downloaded anything, another worker will take over downloading.
    DEPENDENCY_DOWNLOAD_TIMEOUT_SECONDS = 5 * 60

    def __init__(
        self,
        commit_file: str,
        bundle_service: BundleServiceClient,
        worker_dir: str,
        max_cache_size_bytes: int,
        download_dependencies_max_retries: int,
    ):
        super(DependencyManager, self).__init__()
        self.add_transition(DependencyStage.DOWNLOADING, self._transition_from_DOWNLOADING)
        self.add_terminal(DependencyStage.READY)
        self.add_terminal(DependencyStage.FAILED)

        self._id: str = "worker-dependency-manager-{}".format(uuid.uuid4().hex[:8])
        self._state_committer = JsonStateCommitter(commit_file)
        self._bundle_service = bundle_service
        self._max_cache_size_bytes = max_cache_size_bytes
        self.dependencies_dir = os.path.join(worker_dir, DependencyManager.DEPENDENCIES_DIR_NAME)
        self._download_dependencies_max_retries = download_dependencies_max_retries
        if not os.path.exists(self.dependencies_dir):
            logger.info('{} doesn\'t exist, creating.'.format(self.dependencies_dir))
            os.makedirs(self.dependencies_dir, 0o770)

        # Create a lock for concurrency over NFS
        # Create a separate locks directory to hold the lock files.
        # Each lock file is created when a process tries to claim the main lock.
        locks_claims_dir: str = os.path.join(worker_dir, 'locks_claims')
        try:
            os.makedirs(locks_claims_dir)
        except FileExistsError:
            logger.info(f"A locks directory at {locks_claims_dir} already exists.")
        self._state_lock = NFSLock(os.path.join(locks_claims_dir, 'state.lock'))

        # File paths that are currently being used to store dependencies. Used to prevent conflicts
        self._paths: Set[str] = set()
        # DependencyKey -> WorkerThread(thread, success, failure_message)
        self._downloading = ThreadDict(fields={'success': False, 'failure_message': None})
        # Sync states between dependency-state.json and dependency directories on the local file system.
        self._sync_state()

        self._stop = False
        self._main_thread = None
        logger.info(f"Initialized Dependency Manager with ID: {self._id}")

    def _sync_state(self):
        """
        Synchronize dependency states between dependencies-state.json and the local file system as follows:
        1. dependencies and paths: populated from dependencies-state.json
        2. directories on the local file system: the bundle contents
        This function forces the 1 and 2 to be in sync by taking the intersection (e.g., deleting bundles from the
        local file system that don't appear in the dependencies-state.json and vice-versa)
        """
        with self._state_lock:
            # Load states from dependencies-state.json, which contains information about bundles (e.g., state,
            # dependencies, last used, etc.).
            if self._state_committer.state_file_exists:
                # If the state file exists, do not pass in a default. It's critical that we read the contents
                # of the state file, as this method prunes dependencies. If we can't read the contents of the
                # state file, fail immediately.
                dependencies, paths = self._fetch_state()
                logger.info(
                    'Found {} dependencies, {} paths from cache.'.format(
                        len(dependencies), len(paths)
                    )
                )
            else:
                dependencies: Dict[DependencyKey, DependencyState] = dict()
                paths: Set[str] = set()
                logger.info(
                    f'State file did not exist. Will create one at path {self._state_committer.path}.'
                )

            # Get the paths that exist in dependency state, loaded path and
            # the local file system (the dependency directories under self.dependencies_dir)
            local_directories = set(os.listdir(self.dependencies_dir))
            paths_in_loaded_state = [dep_state.path for dep_state in dependencies.values()]
            paths = paths.intersection(paths_in_loaded_state).intersection(local_directories)

            # Remove the orphaned dependencies if they don't exist in paths
            # (intersection of paths in dependency state, loaded paths and the paths on the local file system)
            dependencies_to_remove = [
                dep for dep, dep_state in dependencies.items() if dep_state.path not in paths
            ]
            for dep in dependencies_to_remove:
                logger.info(
                    "Dependency {} in dependency state but its path {} doesn't exist on the local file system. "
                    "Removing it from dependency state.".format(
                        dep, os.path.join(self.dependencies_dir, dependencies[dep].path)
                    )
                )
                del dependencies[dep]

            # Remove the orphaned directories from the local file system
            directories_to_remove = local_directories - paths
            for directory in directories_to_remove:
                full_path = os.path.join(self.dependencies_dir, directory)
                if os.path.exists(full_path):
                    logger.info(
                        "Remove orphaned directory {} from the local file system.".format(full_path)
                    )
                    remove_path(full_path)

            # Save the current synced state back to the state file: dependency-state.json as
            # the current state might have been changed during the state syncing phase
            self._commit_state(dependencies, paths)

    def _fetch_state(self, default=None):
        """
        Fetch state from JSON file stored on disk.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        WARNING: If a value for `default` is specified, errors will be silently handled.
        """
        assert self._state_lock.is_locked
        state: DependencyManagerState = self._state_committer.load(default)
        dependencies: Dict[DependencyKey, DependencyState] = state['dependencies']
        paths: Set[str] = state['paths']
        return dependencies, paths

    def _fetch_dependencies(self, default=None) -> Dict[DependencyKey, DependencyState]:
        """
        Fetch dependencies from JSON file stored on disk.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        WARNING: If a value for `default` is specified, errors will be silently handled.
        """
        assert self._state_lock.is_locked
        dependencies, _ = self._fetch_state(default)
        return dependencies

    def _commit_state(self, dependencies: Dict[DependencyKey, DependencyState], paths: Set[str]):
        """
        Update state in dependencies JSON file stored on disk.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        """
        assert self._state_lock.is_locked
        state: DependencyManagerState = {'dependencies': dependencies, 'paths': paths}
        self._state_committer.commit(state)

    def start(self):
        logger.info('Starting local dependency manager...')

        def loop(self):
            while not self._stop:
                try:
                    self._transition_dependencies()
                    self._cleanup()
                except Exception:
                    traceback.print_exc()
                time.sleep(1)

        self._main_thread = threading.Thread(target=loop, args=[self])
        self._main_thread.start()

    def stop(self):
        logger.info('Stopping local dependency manager...')
        self._stop = True
        self._downloading.stop()
        self._main_thread.join()
        self._state_lock.release()
        logger.info('Stopped local dependency manager.')

    def _transition_dependencies(self):
        with self._state_lock:
            try:
                dependencies, paths = self._fetch_state()

                # Update the class variable _paths as the transition function may update it
                self._paths = paths
                for dep_key, dep_state in dependencies.items():
                    dependencies[dep_key] = self.transition(dep_state)
                self._commit_state(dependencies, self._paths)
            except (ValueError, EnvironmentError):
                # Do nothing if an error is thrown while reading from the state file
                logging.exception("Error reading from state file while transitioning dependencies")
                pass

    def _prune_failed_dependencies(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds so that further runs
        get to retry the download. Without pruning, any future run depending on a
        failed dependency would automatically fail indefinitely.
        """
        with self._state_lock:
            try:
                dependencies, paths = self._fetch_state()
                failed_deps: Dict[DependencyKey, DependencyState] = {
                    dep_key: dep_state
                    for dep_key, dep_state in dependencies.items()
                    if dep_state.stage == DependencyStage.FAILED
                    and time.time() - dep_state.last_used
                    > DependencyManager.DEPENDENCY_FAILURE_COOLDOWN
                }
                if len(failed_deps) == 0:
                    return

                for dep_key, dep_state in failed_deps.items():
                    self._delete_dependency(dep_key, dependencies, paths)
                self._commit_state(dependencies, paths)
            except (ValueError, EnvironmentError):
                # Do nothing if an error is thrown while reading from the state file
                logging.exception(
                    "Error reading from state file while pruning failed dependencies."
                )
                pass

    def _cleanup(self):
        """
        Prune failed dependencies older than DEPENDENCY_FAILURE_COOLDOWN seconds.
        Limit the disk usage of the dependencies (both the bundle files and the serialized state file size)
        Deletes oldest failed dependencies first and then oldest finished dependencies.
        Doesn't touch downloading dependencies.
        """
        self._prune_failed_dependencies()

        while True:
            with self._state_lock:
                try:
                    dependencies, paths = self._fetch_state()
                except (ValueError, EnvironmentError):
                    # Do nothing if an error is thrown while reading from the state file
                    logging.exception(
                        "Error reading from state file when cleaning up dependencies. Trying again..."
                    )
                    continue

                bytes_used = sum(dep_state.size_bytes for dep_state in dependencies.values())
                serialized_length = len(codalab.worker.pyjson.dumps(dependencies))
                if (
                    bytes_used > self._max_cache_size_bytes
                    or serialized_length > DependencyManager.MAX_SERIALIZED_LEN
                ):
                    logger.debug(
                        '%d dependencies, disk usage: %s (max %s), serialized size: %s (max %s)',
                        len(dependencies),
                        size_str(bytes_used),
                        size_str(self._max_cache_size_bytes),
                        size_str(serialized_length),
                        DependencyManager.MAX_SERIALIZED_LEN,
                    )
                    ready_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in dependencies.items()
                        if dep_state.stage == DependencyStage.READY and not dep_state.dependents
                    }
                    failed_deps = {
                        dep_key: dep_state
                        for dep_key, dep_state in dependencies.items()
                        if dep_state.stage == DependencyStage.FAILED
                    }

                    if failed_deps:
                        dep_key_to_remove = min(
                            failed_deps.items(), key=lambda dep: dep[1].last_used
                        )[0]
                    elif ready_deps:
                        dep_key_to_remove = min(
                            ready_deps.items(), key=lambda dep: dep[1].last_used
                        )[0]
                    else:
                        logger.info(
                            'Dependency quota full but there are only downloading dependencies, not cleaning up '
                            'until downloads are over.'
                        )
                        break
                    if dep_key_to_remove:
                        self._delete_dependency(dep_key_to_remove, dependencies, paths)
                        self._commit_state(dependencies, paths)
                else:
                    break

    def _delete_dependency(self, dep_key, dependencies, paths):
        """
        Remove the given dependency from the manager's state
        Modifies `dependencies` and `paths` that are passed in.
        Also deletes any known files on the filesystem if any exist.

        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        """
        assert self._state_lock.is_locked

        if dep_key in dependencies:
            try:
                path_to_remove = dependencies[dep_key].path
                paths.remove(path_to_remove)
                # Deletes dependency content from disk
                remove_path(path_to_remove)
            except Exception:
                pass
            finally:
                del dependencies[dep_key]
                logger.info(f"Deleted dependency {dep_key}.")

    def has(self, dependency_key):
        """
        Takes a DependencyKey and returns true if the manager has processed this dependency
        """
        with self._state_lock:
            dependencies: Dict[DependencyKey, DependencyState] = self._fetch_dependencies()
            return dependency_key in dependencies

    def get(self, uuid: str, dependency_key: DependencyKey) -> DependencyState:
        """
        Request the dependency for the run with uuid, registering uuid as a dependent of this dependency
        """
        with self._state_lock:
            dependencies, paths = self._fetch_state()

            now = time.time()
            # Add dependency state if it does not exist
            if dependency_key not in dependencies:
                dependencies[dependency_key] = DependencyState(
                    stage=DependencyStage.DOWNLOADING,
                    downloading_by=None,
                    dependency_key=dependency_key,
                    path=self._assign_path(paths, dependency_key),
                    size_bytes=0,
                    dependents={uuid},
                    last_used=now,
                    last_downloading=now,
                    message="Starting download",
                    killed=False,
                )

            # Update last_used as long as it isn't in a FAILED stage
            if dependencies[dependency_key].stage != DependencyStage.FAILED:
                dependencies[dependency_key].dependents.add(uuid)
                dependencies[dependency_key] = dependencies[dependency_key]._replace(last_used=now)

            self._commit_state(dependencies, paths)
            return dependencies[dependency_key]

    def release(self, uuid, dependency_key):
        """
        Register that the run with uuid is no longer dependent on this dependency
        If no more runs are dependent on this dependency, kill it.
        """
        with self._state_lock:
            dependencies, paths = self._fetch_state()

            if dependency_key in dependencies:
                dep_state = dependencies[dependency_key]
                if uuid in dep_state.dependents:
                    dep_state.dependents.remove(uuid)
                if not dep_state.dependents:
                    dep_state = dep_state._replace(killed=True)
                    dependencies[dependency_key] = dep_state
                self._commit_state(dependencies, paths)

    def _assign_path(self, paths: Set[str], dependency_key: DependencyKey) -> str:
        """
        Checks the current path against `paths`.
        Normalize the path for the dependency by replacing / with _, avoiding conflicts.
        Adds the new path to `paths`.
        """
        path: str = (
            os.path.join(dependency_key.parent_uuid, dependency_key.parent_path)
            if dependency_key.parent_path
            else dependency_key.parent_uuid
        )
        path = path.replace(os.path.sep, '_')

        # You could have a conflict between, for example a/b_c and a_b/c
        while path in paths:
            path = path + '_'

        paths.add(path)
        return path

    def _store_dependency(self, dependency_path, fileobj, target_type):
        """
        Copy the dependency fileobj to its path on the local filesystem
        Overwrite existing files by the same name if found
        (may happen if filesystem modified outside the dependency manager,
         for example during an update if the state gets reset but filesystem
         doesn't get cleared)
        """
        try:
            if os.path.exists(dependency_path):
                logger.info('Path %s already exists, overwriting', dependency_path)
                if os.path.isdir(dependency_path):
                    shutil.rmtree(dependency_path)
                else:
                    os.remove(dependency_path)
            if target_type == 'directory':
                un_tar_directory(fileobj, dependency_path, 'gz')
            else:
                with open(dependency_path, 'wb') as f:
                    logger.debug('copying file to %s', dependency_path)
                    shutil.copyfileobj(fileobj, f)
        except Exception:
            raise

    @property
    def all_dependencies(self) -> List[DependencyKey]:
        with self._state_lock:
            dependencies: Dict[DependencyKey, DependencyState] = self._fetch_dependencies(
                default={'dependencies': {}, 'paths': set()}
            )
            return list(dependencies.keys())

    def _transition_from_DOWNLOADING(self, dependency_state: DependencyState):
        """
        Checks if the dependency is downloading or not.
        NOT NFS-SAFE - Caller should acquire self._state_lock before calling this method.
        """
        assert self._state_lock.is_locked

        def download():
            """
            Runs in a separate thread. Only one worker should be running this in a thread at a time.
            """

            def update_state_and_check_killed(bytes_downloaded):
                """
                Callback method for bundle service client updates dependency state and
                raises DownloadAbortedException if download is killed by dep. manager

                Note: This function needs to be fast, since it's called every time fileobj.read is called.
                      Therefore, we keep a copy of the state in memory (self._downloading) and copy over
                      non-critical fields (last_downloading, size_bytes and message) when the download transition
                      function is executed.
                """
                state = self._downloading[dependency_state.dependency_key]['state']
                if state.killed:
                    raise DownloadAbortedException("Aborted by user")
                self._downloading[dependency_state.dependency_key]['state'] = state._replace(
                    last_downloading=time.time(),
                    size_bytes=bytes_downloaded,
                    message=f"Downloading dependency: {str(bytes_downloaded)} downloaded",
                )

            dependency_path = os.path.join(self.dependencies_dir, dependency_state.path)
            logger.debug('Downloading dependency %s', dependency_state.dependency_key)

            attempt = 0
            while attempt < self._download_dependencies_max_retries:
                try:
                    # Start async download to the fileobj
                    target_type = self._bundle_service.get_bundle_info(
                        dependency_state.dependency_key.parent_uuid,
                        dependency_state.dependency_key.parent_path,
                    )["type"]
                    fileobj = self._bundle_service.get_bundle_contents(
                        dependency_state.dependency_key.parent_uuid,
                        dependency_state.dependency_key.parent_path,
                    )
                    with closing(fileobj):
                        # "Bug" the fileobj's read function so that we can keep
                        # track of the number of bytes downloaded so far.
                        original_read_method = fileobj.read
                        bytes_downloaded = [0]

                        def interruptable_read(*args, **kwargs):
                            data = original_read_method(*args, **kwargs)
                            bytes_downloaded[0] += len(data)
                            update_state_and_check_killed(bytes_downloaded[0])
                            return data

                        fileobj.read = interruptable_read

                        # Start copying the fileobj to filesystem dependency path
                        # Note: Overwrites if something already exists at dependency_path, such as when
                        #       another worker partially downloads a dependency and then goes offline.
                        self._store_dependency(dependency_path, fileobj, target_type)

                    logger.debug(
                        'Finished downloading %s dependency %s to %s',
                        target_type,
                        dependency_state.dependency_key,
                        dependency_path,
                    )
                    self._downloading[dependency_state.dependency_key]['success'] = True

                except Exception as e:
                    attempt += 1
                    if attempt >= self._download_dependencies_max_retries:
                        self._downloading[dependency_state.dependency_key]['success'] = False
                        self._downloading[dependency_state.dependency_key][
                            'failure_message'
                        ] = f"Dependency download failed: {e} "
                    else:
                        logger.warning(
                            f'Failed to download {dependency_state.dependency_key} after {attempt} attempt(s) '
                            f'due to {e}. Retrying up to {self._download_dependencies_max_retries} times...',
                            exc_info=True,
                        )
                else:
                    # Break out of the retry loop if no exceptions were thrown
                    break

        # Start downloading if either:
        # 1. No other dependency manager is downloading the dependency
        # 2. There was a dependency manager downloading a dependency, but it has been longer than
        #    DEPENDENCY_DOWNLOAD_TIMEOUT_SECONDS since it last downloaded anything for the particular dependency.
        now = time.time()
        if not dependency_state.downloading_by or (
            dependency_state.downloading_by
            and now - dependency_state.last_downloading
            >= DependencyManager.DEPENDENCY_DOWNLOAD_TIMEOUT_SECONDS
        ):
            if not dependency_state.downloading_by:
                logger.info(
                    f"{self._id} will start downloading dependency: {dependency_state.dependency_key}."
                )
            else:
                logger.info(
                    f"{dependency_state.downloading_by} stopped downloading "
                    f"dependency: {dependency_state.dependency_key}. {self._id} will restart downloading."
                )

            self._downloading.add_if_new(
                dependency_state.dependency_key, threading.Thread(target=download, args=[])
            )
            self._downloading[dependency_state.dependency_key]['state'] = dependency_state
            dependency_state = dependency_state._replace(downloading_by=self._id)

        # If there is already another worker downloading the dependency,
        # just return the dependency state as downloading is in progress.
        if dependency_state.downloading_by != self._id:
            logger.debug(
                f"Waiting for {dependency_state.downloading_by} "
                f"to download dependency: {dependency_state.dependency_key}"
            )
            return dependency_state

        if (
            dependency_state.dependency_key in self._downloading
            and self._downloading[dependency_state.dependency_key].is_alive()
        ):
            logger.debug(
                f"This dependency manager ({dependency_state.downloading_by}) "
                f"is downloading dependency: {dependency_state.dependency_key}"
            )
            state = self._downloading[dependency_state.dependency_key]['state']
            # Copy over the values of the non-critical fields of the state in memory
            # that is being updated by the download thread.
            return dependency_state._replace(
                last_downloading=state.last_downloading,
                size_bytes=state.size_bytes,
                message=state.message,
            )

        # At this point, no thread is downloading the dependency, but the dependency is still
        # assigned to the current worker. Check if the download finished.
        success: bool = self._downloading[dependency_state.dependency_key]['success']
        failure_message: str = self._downloading[dependency_state.dependency_key]['failure_message']

        dependency_state = dependency_state._replace(downloading_by=None)
        self._downloading.remove(dependency_state.dependency_key)
        logger.info(
            f"Download complete. Removing downloading thread for {dependency_state.dependency_key}."
        )

        if success:
            return dependency_state._replace(
                stage=DependencyStage.READY, message="Download complete"
            )
        else:
            self._paths.remove(dependency_state.path)
            logger.error(
                f"Dependency {dependency_state.dependency_key} download failed: {failure_message}"
            )
            return dependency_state._replace(stage=DependencyStage.FAILED, message=failure_message)
class LocalRunManager(BaseRunManager):
    """
    LocalRunManager executes the runs locally, each one in its own Docker
    container. It manages its cache of local Docker images and its own local
    Docker network.
    """

    # Network buffer size to use while proxying with netcat
    NETCAT_BUFFER_SIZE = 4096
    # Number of seconds to wait for bundle kills to propagate before forcing kill
    KILL_TIMEOUT = 100
    # Directory name to store running bundles in worker filesystem
    BUNDLES_DIR_NAME = 'runs'

    def __init__(
        self,
        worker,  # type: Worker
        image_manager,  # type: DockerImageManager
        dependency_manager,  # type: LocalFileSystemDependencyManager
        commit_file,  # type: str
        cpuset,  # type: Set[str]
        gpuset,  # type: Set[str]
        work_dir,  # type: str
        docker_runtime=docker_utils.DEFAULT_RUNTIME,  # type: str
        docker_network_prefix='codalab_worker_network',  # type: str
    ):
        self._worker = worker
        self._state_committer = JsonStateCommitter(commit_file)
        self._reader = LocalReader()
        self._docker = docker.from_env()
        self._bundles_dir = os.path.join(work_dir, LocalRunManager.BUNDLES_DIR_NAME)
        if not os.path.exists(self._bundles_dir):
            logger.info('{} doesn\'t exist, creating.'.format(self._bundles_dir))
            os.makedirs(self._bundles_dir, 0o770)

        self._image_manager = image_manager
        self._dependency_manager = dependency_manager
        self._cpuset = cpuset
        self._gpuset = gpuset
        self._stop = False
        self._work_dir = work_dir

        self._runs = {}  # bundle_uuid -> LocalRunState
        self._lock = threading.RLock()
        self._init_docker_networks(docker_network_prefix)
        self._run_state_manager = LocalRunStateMachine(
            docker_image_manager=self._image_manager,
            dependency_manager=self._dependency_manager,
            worker_docker_network=self.worker_docker_network,
            docker_network_internal=self.docker_network_internal,
            docker_network_external=self.docker_network_external,
            docker_runtime=docker_runtime,
            upload_bundle_callback=self._worker.upload_bundle_contents,
            assign_cpu_and_gpu_sets_fn=self.assign_cpu_and_gpu_sets,
        )

    def _init_docker_networks(self, docker_network_prefix):
        """
        Set up docker networks for runs: one with external network access and one without
        """

        def create_or_get_network(name, internal):
            try:
                logger.debug('Creating docker network %s', name)
                return self._docker.networks.create(name, internal=internal, check_duplicate=True)
            except docker.errors.APIError:
                logger.debug('Network %s already exists, reusing', name)
                return self._docker.networks.list(names=[name])[0]

        self.worker_docker_network = create_or_get_network(docker_network_prefix, True)
        self.docker_network_external = create_or_get_network(docker_network_prefix + "_ext", False)
        self.docker_network_internal = create_or_get_network(docker_network_prefix + "_int", True)

    def save_state(self):
        # Remove complex container objects from state before serializing, these can be retrieved
        simple_runs = {uuid: state._replace(container=None) for uuid, state in self._runs.items()}
        self._state_committer.commit(simple_runs)

    def load_state(self):
        runs = self._state_committer.load()
        # Retrieve the complex container objects from the Docker API
        for uuid, run_state in runs.items():
            if run_state.container_id:
                try:
                    run_state = run_state._replace(
                        container=self._docker.containers.get(run_state.container_id)
                    )
                except docker.errors.NotFound as ex:
                    logger.debug('Error getting the container for the run: %s', ex)
                    run_state = run_state._replace(container_id=None)
                finally:
                    self._runs[uuid] = run_state

    def start(self):
        """
        Load your state from disk, and start your sub-managers
        """
        self.load_state()
        self._image_manager.start()
        self._dependency_manager.start()

    def stop(self):
        """
        Starts any necessary cleanup and propagates to its other managers
        Blocks until cleanup is complete and it is safe to quit
        """
        logger.info("Stopping Local Run Manager")
        self._stop = True
        self._image_manager.stop()
        self._dependency_manager.stop()
        self._run_state_manager.stop()
        self.save_state()
        try:
            self.docker_network_internal.remove()
            self.docker_network_external.remove()
        except docker.errors.APIError as e:
            logger.error("Cannot clear docker networks: {}".format(str(e)))

        logger.info("Stopped Local Run Manager. Exiting")

    def kill_all(self):
        """
        Kills all runs
        """
        logger.debug("Killing all bundles")
        # Set all bundle statuses to killed
        with self._lock:
            for uuid in self._runs.keys():
                run_state = self._runs[uuid]
                run_state.info['kill_message'] = 'Worker stopped'
                run_state = run_state._replace(info=run_state.info, is_killed=True)
                self._runs[uuid] = run_state
        # Wait until all runs finished or KILL_TIMEOUT seconds pas
        for attempt in range(LocalRunManager.KILL_TIMEOUT):
            with self._lock:
                self._runs = {
                    k: v for k, v in self._runs.items() if v.stage != LocalRunStage.FINISHED
                }
                if len(self._runs) > 0:
                    logger.debug(
                        "Waiting for {} more bundles. {} seconds until force quit.".format(
                            len(self._runs), LocalRunManager.KILL_TIMEOUT - attempt
                        )
                    )
            time.sleep(1)

    def process_runs(self):
        """ Transition each run then filter out finished runs """
        with self._lock:
            # transition all runs
            for bundle_uuid in self._runs.keys():
                run_state = self._runs[bundle_uuid]
                self._runs[bundle_uuid] = self._run_state_manager.transition(run_state)

            # filter out finished runs
            finished_container_ids = [
                run.container
                for run in self._runs.values()
                if (run.stage == LocalRunStage.FINISHED or run.stage == LocalRunStage.FINALIZING)
                and run.container_id is not None
            ]
            for container_id in finished_container_ids:
                try:
                    container = self._docker.containers.get(container_id)
                    container.remove(force=True)
                except (docker.errors.NotFound, docker.errors.NullResource):
                    pass
            self._runs = {k: v for k, v in self._runs.items() if v.stage != LocalRunStage.FINISHED}

    def create_run(self, bundle, resources):
        """
        Creates and starts processing a new run with the given bundle and
        resources
        """
        if self._stop:
            # Run Manager stopped, refuse more runs
            return
        bundle_uuid = bundle['uuid']
        bundle_path = os.path.join(self._bundles_dir, bundle_uuid)
        now = time.time()
        run_state = LocalRunState(
            stage=LocalRunStage.PREPARING,
            run_status='',
            bundle=bundle,
            bundle_path=os.path.realpath(bundle_path),
            resources=resources,
            start_time=now,
            container_id=None,
            container=None,
            docker_image=None,
            is_killed=False,
            has_contents=False,
            cpuset=None,
            gpuset=None,
            time_used=0,
            max_memory=0,
            disk_utilization=0,
            info={},
        )
        with self._lock:
            self._runs[bundle_uuid] = run_state

    def assign_cpu_and_gpu_sets(self, request_cpus, request_gpus):
        """
        Propose a cpuset and gpuset to a bundle based on given requested resources.
        Note: no side effects (this is important: we don't want to maintain more state than necessary)

        Arguments:
            request_cpus: integer
            request_gpus: integer

        Returns a 2-tuple:
            cpuset: assigned cpuset (str indices).
            gpuset: assigned gpuset (str indices).

        Throws an exception if unsuccessful.
        """
        cpuset, gpuset = set(self._cpuset), set(self._gpuset)

        with self._lock:
            for run_state in self._runs.values():
                if run_state.stage == LocalRunStage.RUNNING:
                    cpuset -= run_state.cpuset
                    gpuset -= run_state.gpuset

        if len(cpuset) < request_cpus or len(gpuset) < request_gpus:
            raise Exception("Not enough cpus or gpus to assign!")

        def propose_set(resource_set, request_count):
            return set(str(el) for el in list(resource_set)[:request_count])

        return propose_set(cpuset, request_cpus), propose_set(gpuset, request_gpus)

    def get_run(self, uuid):
        """
        Returns the state of the run with the given UUID if it is managed
        by this RunManager, returns None otherwise
        """
        with self._lock:
            return self._runs.get(uuid, None)

    def mark_finalized(self, uuid):
        """
        Marks the run as finalized server-side so it can be discarded
        """
        if uuid in self._runs:
            with self._lock:
                self._runs[uuid].info['finalized'] = True

    def read(self, run_state, path, dep_paths, args, reply):
        """
        Use your Reader helper to invoke the given read command
        """
        self._reader.read(run_state, path, dep_paths, args, reply)

    def write(self, run_state, path, dep_paths, string):
        """
        Write `string` (string) to path in bundle with uuid.
        """
        if os.path.normpath(path) in dep_paths:
            return
        with open(os.path.join(run_state.bundle_path, path), 'w') as f:
            f.write(string)

    def netcat(self, run_state, port, message, reply):
        """
        Write `message` (string) to port of bundle with uuid and read the response.
        Returns a stream with the response contents (bytes).
        """
        # TODO: handle this in a thread since this could take a while
        container_ip = docker_utils.get_container_ip(
            self.worker_docker_network.name, run_state.container
        )
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((container_ip, port))
        s.sendall(message.encode())

        total_data = []
        while True:
            data = s.recv(LocalRunManager.NETCAT_BUFFER_SIZE)
            if not data:
                break
            total_data.append(data)
        s.close()
        reply(None, {}, b''.join(total_data))

    def kill(self, run_state):
        """
        Kill bundle with uuid
        """
        with self._lock:
            run_state.info['kill_message'] = 'Kill requested'
            run_state = run_state._replace(info=run_state.info, is_killed=True)
            self._runs[run_state.bundle['uuid']] = run_state

    @property
    def all_runs(self):
        """
        Returns a list of all the runs managed by this RunManager
        """
        with self._lock:
            result = {
                bundle_uuid: {
                    'run_status': run_state.run_status,
                    'start_time': run_state.start_time,
                    'docker_image': run_state.docker_image,
                    'info': run_state.info,
                    'state': LocalRunStage.WORKER_STATE_TO_SERVER_STATE[run_state.stage],
                    'remote': self._worker.id,
                }
                for bundle_uuid, run_state in self._runs.items()
            }
            return result

    @property
    def all_dependencies(self):
        """
        Returns a list of all dependencies available in this RunManager
        """
        return self._dependency_manager.all_dependencies

    @property
    def cpus(self):
        """
        Total number of CPUs this RunManager has
        """
        return len(self._cpuset)

    @property
    def gpus(self):
        """
        Total number of GPUs this RunManager has
        """
        return len(self._gpuset)

    @property
    def memory_bytes(self):
        """
        Total installed memory of this RunManager
        """
        try:
            return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
        except ValueError:
            # Fallback to sysctl when os.sysconf('SC_PHYS_PAGES') fails on OS X
            return int(check_output(['sysctl', '-n', 'hw.memsize']).strip())

    @property
    def free_disk_bytes(self):
        """
        Available disk space by bytes of this RunManager.
        """
        error_msg = "Failed to run command {}".format("df " + self._work_dir)
        try:
            p = Popen(["df", self._work_dir], stdout=PIPE)
            output, error = p.communicate()
            # Return None when there is an error.
            if error:
                logger.error(error.strip())
                return None

            if output:
                lines = output.decode().split("\n")
                index = lines[0].split().index("Available")
                # We convert the original result from df command in unit of 1KB blocks into bytes.
                return int(lines[1].split()[index]) * 1024

        except Exception as e:
            logger.error("{}: {}".format(error_msg, str(e)))
            return None
Ejemplo n.º 12
0
class DockerImageManager:
    def __init__(self, commit_file, max_image_cache_size):
        """
        Initializes a DockerImageManager
        :param commit_file: String path to where the state file should be committed
        :param max_image_cache_size: Total size in bytes that the image cache can use
        """
        self._state_committer = JsonStateCommitter(
            commit_file)  # type: JsonStateCommitter
        self._docker = docker.from_env()  # type: DockerClient
        self._image_cache = {}  # type: Dict[str, ImageCacheEntry]
        self._downloading = ThreadDict(fields={
            'success': False,
            'status': 'Download starting.'
        },
                                       lock=True)
        self._max_image_cache_size = max_image_cache_size
        self._lock = threading.RLock()

        self._stop = False
        self._sleep_secs = 10
        self._cleanup_thread = None

        self._load_state()

    def _save_state(self):
        with self._lock:
            self._state_committer.commit(self._image_cache)

    def _load_state(self):
        with self._lock:
            self._image_cache = self._state_committer.load()

    def start(self):
        logger.info("Starting docker image manager")
        if self._max_image_cache_size:

            def cleanup_loop(self):
                while not self._stop:
                    try:
                        self._cleanup()
                        self._save_state()
                    except Exception:
                        traceback.print_exc()
                    time.sleep(self._sleep_secs)

            self._cleanup_thread = threading.Thread(target=cleanup_loop,
                                                    args=[self])
            self._cleanup_thread.start()

    def stop(self):
        logger.info("Stopping docker image manager")
        self._stop = True
        logger.debug(
            "Stopping docker image manager: stop the downloads threads")
        self._downloading.stop()
        if self._cleanup_thread:
            logger.debug(
                "Stopping docker image manager: stop the cleanup thread")
            self._cleanup_thread.join()
        logger.info("Stopped docker image manager")

    def _cleanup(self):
        """
        Prunes the image cache for runs.
        1. Only care about images we (this DockerImageManager) downloaded and know about
        2. We use sum of VirtualSize's, which is an upper bound on the disk use of our images:
            in case no images share any intermediate layers, this will be the real disk use,
            however if images share layers, the virtual size will count that layer's size for each
            image that uses it, even though it's stored only once in the disk. The 'Size' field
            accounts for the marginal size each image adds on top of the shared layers, but summing
            those is not accurate either since the shared base layers need to be counted once to get
            the total size. (i.e. summing marginal sizes would give us a lower bound on the total disk
            use of images). Calling df gives us an accurate disk use of ALL the images on the machine
            but because of (1) we don't want to use that.
        """
        while not self._stop:
            deletable_entries = set(self._image_cache.values())
            disk_use = sum(cache_entry.virtual_size
                           for cache_entry in deletable_entries)
            while disk_use > self._max_image_cache_size:
                entry_to_remove = min(deletable_entries,
                                      key=lambda entry: entry.last_used)
                logger.info(
                    'Disk use (%s) > max cache size (%s), pruning image: %s',
                    disk_use,
                    self._max_image_cache_size,
                    entry_to_remove.digest,
                )
                try:
                    image_to_delete = self._docker.images.get(
                        entry_to_remove.id)
                    tags_to_delete = image_to_delete.tags
                    for tag in tags_to_delete:
                        self._docker.images.remove(tag)
                    # if we successfully removed the image also remove its cache entry
                    del self._image_cache[entry_to_remove.digest]
                except docker.errors.NotFound:
                    # image doesn't exist anymore for some reason, stop tracking it
                    del self._image_cache[entry_to_remove.digest]
                except docker.errors.APIError as err:
                    # Maybe we can't delete this image because its container is still running
                    # (think a run that takes 4 days so this is the oldest image but still in use)
                    # In that case we just continue with our lives, hoping it will get deleted once
                    # it's no longer in use and the cache becomes full again
                    logger.error("Cannot remove image %s from cache: %s",
                                 entry_to_remove.digest, err)
                deletable_entries.remove(entry_to_remove)
                disk_use = sum(entry.virtual_size
                               for entry in deletable_entries)
        logger.debug("Stopping docker image manager cleanup")

    def get(self, image_spec):
        """
        Request the docker image for the run with uuid, registering uuid as a dependent of this docker image
        :param image_spec: Repo image_spec of docker image being requested
        :returns: A DockerAvailabilityState object with the state of the docker image
        """
        if ':' not in image_spec:
            # Both digests and repo:tag kind of specs include the : character. The only case without it is when
            # a repo is specified without a tag (like 'latest')
            # When this is the case, different images API methods act differently:
            # - pull pulls all tags of the image
            # - get tries to get `latest` by default
            # That means if someone requests a docker image without a tag, and the image does not have a latest
            # tag pushed to Dockerhub, pull will succeed since it will pull all other tags, but later get calls
            # will fail since the `latest` tag won't be found on the system.
            # We don't want to assume what tag the user wanted so we want the pull step to fail if no tag is specified
            # and there's no latest tag on dockerhub.
            # Hence, we append the latest tag to the image spec if there's no tag specified otherwise at the very beginning
            image_spec += ':latest'
        try:
            image = self._docker.images.get(image_spec)
            digests = image.attrs.get('RepoDigests', [image_spec])
            if len(digests) == 0:
                return ImageAvailabilityState(
                    digest=None,
                    stage=DependencyStage.FAILED,
                    message=
                    'No digest available for {}, probably because it was built locally; delete the Docker image on the worker and try again'
                    .format(image_spec),
                )
            digest = digests[0]
            with self._lock:
                self._image_cache[digest] = ImageCacheEntry(
                    id=image.id,
                    digest=digest,
                    last_used=time.time(),
                    virtual_size=image.attrs['VirtualSize'],
                    marginal_size=image.attrs['Size'],
                )
            # We can remove the download thread if it still exists
            if image_spec in self._downloading:
                self._downloading.remove(image_spec)
            return ImageAvailabilityState(digest=digest,
                                          stage=DependencyStage.READY,
                                          message='Image ready')
        except docker.errors.ImageNotFound:
            return self._pull_or_report(
                image_spec)  # type: DockerAvailabilityState
        except Exception as ex:
            return ImageAvailabilityState(digest=None,
                                          stage=DependencyStage.FAILED,
                                          message=str(ex))

    def _pull_or_report(self, image_spec):
        if image_spec in self._downloading:
            with self._downloading[image_spec]['lock']:
                if self._downloading[image_spec].is_alive():
                    return ImageAvailabilityState(
                        digest=None,
                        stage=DependencyStage.DOWNLOADING,
                        message=self._downloading[image_spec]['status'],
                    )
                else:
                    if self._downloading[image_spec]['success']:
                        digest = self._docker.images.get(image_spec).attrs.get(
                            'RepoDigests', [image_spec])[0]
                        status = ImageAvailabilityState(
                            digest=digest,
                            stage=DependencyStage.READY,
                            message=self._downloading[image_spec]['message'],
                        )
                    else:
                        status = ImageAvailabilityState(
                            digest=None,
                            stage=DependencyStage.FAILED,
                            message=self._downloading[image_spec]['message'],
                        )
                    self._downloading.remove(image_spec)
                    return status
        else:

            def download():
                logger.debug('Downloading Docker image %s', image_spec)
                try:
                    self._docker.images.pull(image_spec)
                    logger.debug('Download for Docker image %s complete',
                                 image_spec)
                    self._downloading[image_spec]['success'] = True
                    self._downloading[image_spec][
                        'message'] = "Downloading image"
                except (docker.errors.APIError,
                        docker.errors.ImageNotFound) as ex:
                    logger.debug('Download for Docker image %s failed: %s',
                                 image_spec, ex)
                    self._downloading[image_spec]['success'] = False
                    self._downloading[image_spec][
                        'message'] = "Can't download image: {}".format(ex)

            self._downloading.add_if_new(
                image_spec, threading.Thread(target=download, args=[]))
            return ImageAvailabilityState(
                digest=None,
                stage=DependencyStage.DOWNLOADING,
                message=self._downloading[image_spec]['status'],
            )