Exemple #1
0
def fetch_remote_index():
    url = REMOTE_INDEX_URL
    logger.info('Retrieving remote index... ')
    db = requests.get(url).json()
    logger.info('done.')
    save_ds_db(db)
    return db
Exemple #2
0
 def after_run(self):
     self.output_log.close()
     self.progress = 1.
     # register to bucket!
     bucket_name = os.path.basename(self.download_path)
     bucket.add(self.username, bucket_name)
     logger.info("done.")
Exemple #3
0
 def delete(self):
     if self.status == RUN or self.status == WAIT:
         self.aborted.set()
         self.status = ABORT
     logger.info('Set abort {} and delete'.format(self.id))
     time.sleep(1)
     if os.path.exists(self._dir):
         shutil.rmtree(self._dir)
Exemple #4
0
def download_remote_checkpoint(checkpoint, username):
    # Check if output directory doesn't exist already to fail early.
    output = get_checkpoint_path(checkpoint['id'])
    if os.path.exists(output):
        shutil.rmtree(output)
        logger.info("delete existing checkpoint {}.".format(output))
    logger.info("create CPDownloader")
    CPDownloader(checkpoint['url'], output, checkpoint['id'], username)
Exemple #5
0
 def pull(self, tag):
     tags = tag.rsplit(':', 1)
     if len(tags) < 2:
         tags.append("latest")
     image_name = ":".join(tags)
     logger.info("create ImagePuller for {}".format(image_name))
     ImagePuller(image_name)
     # self.client.images.pull(tags[0], tags[1])
     return {'status': 'ok'}
Exemple #6
0
def download(id_or_alias, username):
    logger.info("download checkpoint {}".format(id_or_alias))
    db = read_checkpoint_db()
    checkpoint = get_checkpoint(db, id_or_alias)
    if not checkpoint:
        raise ValueError(
            "Checkpoint '{}' not found in index.".format(id_or_alias))

    if checkpoint['source'] != 'remote':
        raise ValueError(
            "Checkpoint is not remote. If you intended to download a remote "
            "checkpoint and used an alias, try using the id directly.")

    if checkpoint['status'] != 'NOT_DOWNLOADED':
        raise ValueError("Checkpoint is already downloaded.")
    logger.info("checkpoint: {}".format(checkpoint))
    download_remote_checkpoint(checkpoint, username)
Exemple #7
0
def create(username,
           container,
           num_gpu,
           dataset,
           project,
           port_list=[],
           user_args=None):
    args = []
    bucket.check_exist(dataset)
    Image().get(container)
    if user_args:
        args = user_args.split()
    logger.info("create container: {}".format(
        [username, container, num_gpu, dataset, project, port_list,
         user_args]))
    ser = Work(username, container, port_list, num_gpu, dataset, project, args)
    id = ser.id
    return id
Exemple #8
0
    def run(self, resources):
        self.before_run()

        host, repo, image = self.parse_tag(self.image_tag)
        logger.info("Pulling {}... ".format(self.image_tag))
        try:
            cli = docker.APIClient(base_url='unix://var/run/docker.sock')
            with app.app_context():
                registry = get_registry(host)
                registry.login(cli)

            for line in cli.pull(self.image_tag, stream=True, decode=True):
                if "progressDetail" in line:
                    progress = line["progressDetail"]
                    if "current" in progress:
                        percentag = float(progress['current']) / float(
                            progress['total'])
                        self.progress = percentag
        except Exception as e:
            logger.warning("failed to pull image, {}".format(e))
        logger.info("done.")
        self.after_run()
        self.status = DONE
Exemple #9
0
def get_checkpoint(db, id_or_alias):
    """Returns checkpoint entry in `db` indicated by `id_or_alias`.

    First tries to match an ID, then an alias. For the case of repeated
    aliases, will first match local checkpoints and then remotes. In both
    cases, matching will be newest first.
    """
    # Go through the checkpoints ordered by creation date. There shouldn't be
    # repeated aliases, but if there are, prioritize the newest one.
    local_checkpoints = sorted(
        [c for c in db['checkpoints'] if c['source'] == 'local'],
        key=lambda c: c['created_at'],
        reverse=True)
    remote_checkpoints = sorted(
        [c for c in db['checkpoints'] if c['source'] == 'remote'],
        key=lambda c: c['created_at'],
        reverse=True)

    selected = []
    for cp in local_checkpoints:
        if cp['id'] == id_or_alias or cp['alias'] == id_or_alias:
            selected.append(cp)

    for cp in remote_checkpoints:
        if cp['id'] == id_or_alias or cp['alias'] == id_or_alias:
            selected.append(cp)

    if len(selected) < 1:
        return None

    if len(selected) > 1:
        logger.info(
            "Multiple checkpoints found for '{}' ({}). Returning '{}'.".format(
                id_or_alias, len(selected), selected[0]['id']))

    return selected[0]
Exemple #10
0
    def run(self, resources):
        self.before_run()
        tempdir = tempfile.mkdtemp()
        tmpfile = os.path.join(tempdir, 'temp.tar')
        logger.info("Downloading {}... ".format(self.url))
        with requests.get(self.url, stream=True) as r:
            r.raise_for_status()
            length = int(r.headers.get('Content-Length'))
            dl = .0
            with open(tmpfile, 'wb') as f:
                for chunk in r.iter_content(chunk_size=16 * 1024):
                    dl += len(chunk)
                    f.write(chunk)
                    self.progress = dl / length
        logger.info("Importing {}... ".format(self.download_path))
        with tarfile.open(tmpfile) as f:
            members = [m for m in f.getmembers()]
            f.extractall(self.download_path, members)
        logger.info("done.")

        self.after_run()
        self.status = DONE
Exemple #11
0
 def do_train(self):
     src_to = None
     repo_name = None
     if self.repo_path and (self.repo_path.startswith("ssh://") or self.repo_path.startswith("git@")):
         ## copy pkg from gogs
         logger.info("pkg clone from gogs or git")
         gogsop.repo_clone(self.username, self.repo_path, self._dir)
         repo_name = _repo_name(self.repo_path)
         src_to = os.path.join(self._dir, repo_name)
     elif self.repo_path and os.path.exists(self.repo_path):
         ## copy from another job
         logger.info("pkg clone from path {}".format(self.repo_path))
         repo_name = _repo_name(self.repo_path)
         src_to = os.path.join(self._dir, repo_name)
         system_copy(self.repo_path, src_to)
     elif self.repo_path and os.path.exists(os.path.join("/home", self.username, self.repo_path)):
         ## copy from workspace
         workspace = os.path.join("/home",self.username, self.repo_path)
         logger.info("pkg clone from workspace {}".format(workspace))
         repo_name = _repo_name(workspace)
         src_to = os.path.join(self._dir, repo_name)
         system_copy(workspace, src_to)
     elif self.repo_path:
         ## copy from CLI already
         logger.info("pkg from CLI {}".format(self.repo_path))
         repo_name = _repo_name(self.repo_path)
         if os.path.exists(os.path.join(self._dir, repo_name)):
             src_to = os.path.join(self._dir, repo_name)
     elif self.repo_path == '' and self.parent:
         ## copy from another job
         logger.info("replicate job {}".format(self.parent))
         job_dir = os.path.join('/data', 'dataset', self.parent, '')
         if os.path.exists(job_dir):
             rsync_copy(job_dir, self._dir)
     if src_to and os.path.isdir(src_to):
         self._workspace = "/workspace/{}".format(repo_name)
     logger.info('add instance {} to scheduler'.format(self.id))
     scheduler.add_instance(self)
Exemple #12
0
    def run(self, resources):
        self.before_run()
        env = os.environ.copy()
        env['PYTHONPATH'] = os.pathsep.join(['.', self._dir, env.get('PYTHONPATH', '')] + sys.path)
        gpus = [ i for (i, _) in resources['gpus'] ]
        env['CUDA_VISIBLE_DEVICES'] = ','.join(str(g) for g in gpus)
        env['NV_GPU'] = ','.join(str(g) for g in gpus)
        container_id = None
        if self.parameters:
            with open(os.path.join(self._dir, self.PARAMS), 'w') as f:
                f.write(yaml.dump(dict(self.parameters)))
        user_args = copy.copy(self.user_args)
        user_uid = pwd.getpwnam(self.username).pw_uid
        # prepare docker job parameters
        job_real_dir = get_real_path(self._dir)
        args = ['/usr/bin/docker', 'create', '--runtime=nvidia', '--rm']
        args.extend(['-e', 'NVIDIA_VISIBLE_DEVICES='+env['NV_GPU']])
        args.extend(['-e', 'JOB_DIR=/workspace'])
        metrics_path = os.path.join("/workspace", self.METRICS)
        args.extend(['-e', 'METRICS_PATH={}'.format(metrics_path)])
        #args.extend(['-u', '{}:{}'.format(user_uid, user_uid)])
        if self.dataset_path:
            dataset_host_path = get_real_path(self.dataset_path)
            args.extend(['-v', dataset_host_path+':/dataset'])
        # check for user arguments
        for n, arg in enumerate(user_args):
            if "JOB_DIR" in arg:
                user_args[n] = arg.replace("JOB_DIR", "/workspace")
            if self.dataset_path:
                if "DATASET_DIR" in arg:
                    user_args[n] = arg.replace("DATASET_DIR", "/dataset")
        with open(os.path.join(self._dir, "user_args.sh"), 'w') as f:
            f.write(' '.join(user_args))
        args.extend(['-v', job_real_dir+':/workspace','-w', self._workspace, self.image_tag, "bash", "-x", "/workspace/user_args.sh"])
        run_args = ' '.join(args)
        logger.info("Run {}".format(run_args))
        try:
            output = subprocess.check_output(run_args, stderr=subprocess.STDOUT,
                                             shell=True, universal_newlines=True)
        except subprocess.CalledProcessError as exc:
            for line in exc.output.split('\n'):
                self.output_log.write('%s\n' % line)
            self.after_run()
            raise exc
        logger.info("Run output: {}".format(output))
        container_id = re.findall(r"^\w+", output)[0][:6]
        args = ["/usr/bin/docker", "start", "-i", container_id]
        run_args = ' '.join(args)
        # End of docker start
        logger.info("Run {}".format(run_args))
        self.p = subprocess.Popen(run_args,
                                  shell=True,
                                  stdout=subprocess.PIPE,
                                  stderr=subprocess.PIPE,
                                  cwd=self._dir,
                                  close_fds=True,
                                  env=env,
                                  )
        try:
            sigterm_time = None  # When was the SIGTERM signal sent
            sigterm_timeout = 120  # When should the SIGKILL signal be sent
            while self.p.poll() is None:
                for line in utils.nonblocking_readlines_p(self.p):
                    if self.aborted.is_set():
                        if sigterm_time is None:
                            if container_id:
                                subprocess.check_output("/usr/bin/docker stop {}".format(container_id), shell=True)
                            # Attempt graceful shutdown
                            self.p.send_signal(signal.SIGTERM)
                            sigterm_time = time.time()
                            self.status = ABORT
                        break
                    try:
                        subprocess.check_output("/usr/bin/docker ps |grep {}".format(container_id), shell=True)
                    except:
                        ## Bug, docker start may hang while container not exist.
                        break
                        #self.abort()
                    if line is not None:
                        # Remove whitespace
                        line = line.strip().rstrip()

                    if line:
                        self.output_log.write('%s\n' % line.encode("utf-8"))
                        self.output_log.flush()
                    else:
                        time.sleep(0.05)
                if sigterm_time is not None and (time.time() - sigterm_time > sigterm_timeout):
                    self.p.send_signal(signal.SIGKILL)
                    logger.debug("Sent SIGKILL to task {}".format(self.name))
                time.sleep(0.01)
        except Exception as e:
            logger.debug("exception, {}, {}".format(e, traceback.format_exc()))
            self.p.terminate()
            self.after_run()
            raise e

        self.after_run()
        if self.status != RUN:
            return False
        if self.p.returncode != 0:
            self.returncode = self.p.returncode
            self.status = ERROR
        else:
            self.status = DONE
        return True
Exemple #13
0
    def run(self, resources):
        logger.info("Run worker!")
        self.before_run()
        env = os.environ.copy()
        gpus = [i for (i, _) in resources['gpus']]
        ports = [i for (i, _) in resources['ports']]
        env['NV_GPU'] = ','.join(str(g) for g in gpus)
        args = [
            '/usr/bin/docker', 'create', '--runtime=nvidia', '--rm', '--name',
            self.id
        ]
        args.extend(['-e', 'NVIDIA_VISIBLE_DEVICES=' + env['NV_GPU']])
        #user_uid = pwd.getpwnam(self.username).pw_uid
        #args.extend(['-u', '{}:{}'.format(user_uid, user_uid)])
        if ports:
            for i, port in enumerate(ports):
                args.extend(['-p', "{}:{}".format(port, self.port_list[i])])
        if self.dataset_path:
            dataset_host_path = get_real_path(self.dataset_path)
            args.extend(['-v', dataset_host_path + ':/dataset'])
        if self.user_args:
            for n, arg in enumerate(self.user_args):
                if "JOB_ID" in arg:
                    self.user_args[n] = arg.replace("JOB_ID", self.id)
        home_dir = os.path.join('/home/', self.username)
        home_real_path = get_real_path(home_dir)
        args.extend(['-v', home_real_path + ':/workspace', '-w', '/workspace'])
        args.extend([self.container] + self.user_args)
        logger.info("Run {}".format(' '.join(args)))
        try:
            output = subprocess.check_output(args)
        except subprocess.CalledProcessError as exc:
            for line in exc.output.split('\n'):
                self.output_log.write('%s\n' % line)
            self.after_run()
            logger.debug("docker create: {}".format(exc))
            raise exc
        logger.info("Run output: {}".format(output))
        container_id = re.findall(r"^\w+", output)[0][:6]
        args = ["/usr/bin/docker", "start", "-i", container_id]
        self.p = subprocess.Popen(
            args,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            cwd=self._dir,
            close_fds=True,
            env=env,
        )
        start_time = time.time()
        self.idle_sec = 0
        try:
            sigterm_time = None  # When was the SIGTERM signal sent
            sigterm_timeout = 120  # When should the SIGKILL signal be sent
            while self.p.poll() is None:
                for line in utils.nonblocking_readlines_p(self.p):
                    if self.aborted.is_set():
                        if sigterm_time is None:
                            if container_id:
                                subprocess.check_output(
                                    "/usr/bin/docker stop {}".format(
                                        container_id),
                                    shell=True)
                            # Attempt graceful shutdown
                            self.p.send_signal(signal.SIGTERM)
                            sigterm_time = time.time()
                            self.status = ABORT
                        break
                    try:
                        # check for connection & timeout
                        netst = subprocess.check_output(
                            "/usr/bin/docker exec {} netstat -nat 8888".format(
                                container_id),
                            shell=True)
                        if 'ESTABLISHED' in netst:
                            start_time = time.time()
                        idle_sec = time.time() - start_time
                        if idle_sec > 10:
                            self.idle_sec = idle_sec
                        if idle_sec > IDLE_TIMEOUT:
                            self.abort()
                            logger.info(
                                "jupyter {} timeout, terminating..".format(
                                    self.id))
                    except:
                        pass
                    if line is not None:
                        # Remove whitespace
                        line = line.strip().rstrip()

                    if line:
                        self.output_log.write('%s\n' % line.encode("utf-8"))
                        self.output_log.flush()
                    else:
                        time.sleep(0.05)
                if sigterm_time is not None and (time.time() - sigterm_time >
                                                 sigterm_timeout):
                    self.p.send_signal(signal.SIGKILL)
                    logger.debug("Sent SIGKILL to task {}".format(self.name))
                time.sleep(0.01)
        except Exception as e:
            logger.debug("exception, {}, {}".format(e, traceback.format_exc()))
            self.p.terminate()
            self.after_run()
            raise e

        self.after_run()
        if self.status != RUN:
            return False
        if self.p.returncode != 0:
            self.returncode = self.p.returncode
            self.status = ERROR
        else:
            self.status = DONE
        return True
Exemple #14
0
def merge_index(local_index, remote_index):
    """Merge the `remote_index` into `local_index`.

    The merging process is only applied over the checkpoints in `local_index`
    marked as ``remote``.
    """

    non_remotes_in_local = [
        c for c in local_index['checkpoints'] if c['source'] != 'remote'
    ]
    remotes_in_local = {
        c['id']: c
        for c in local_index['checkpoints'] if c['source'] == 'remote'
    }

    to_add = []
    seen_ids = set()
    for checkpoint in remote_index['checkpoints']:
        seen_ids.add(checkpoint['id'])
        local = remotes_in_local.get(checkpoint['id'])
        if local:
            # Checkpoint is in local index. Overwrite all the fields.
            local.update(**checkpoint)
        elif not local:
            # Checkpoint not found, it's an addition. Transform into our schema
            # before appending to `to_add`. (The remote index schema is exactly
            # the same except for the ``source`` and ``status`` keys.)
            checkpoint['source'] = 'remote'
            checkpoint['status'] = 'NOT_DOWNLOADED'
            to_add.append(checkpoint)

    # Out of the removed checkpoints, only keep those with status
    # ``DOWNLOADED`` and turn them into local checkpoints.
    missing_ids = set(remotes_in_local.keys()) - seen_ids
    already_downloaded = [
        c for c in remotes_in_local.values()
        if c['id'] in missing_ids and c['status'] == 'DOWNLOADED'
    ]
    for checkpoint in already_downloaded:
        checkpoint['status'] = 'LOCAL'
        checkpoint['source'] = 'local'

    new_remotes = [
        c for c in remotes_in_local.values()
        if not c['id'] in missing_ids  # Checkpoints to remove.
    ] + to_add + already_downloaded

    if len(to_add):
        logger.info('{} new remote checkpoints added.'.format(len(to_add)))
    if len(missing_ids):
        if len(already_downloaded):
            logger.info('{} remote checkpoints transformed to local.'.format(
                len(already_downloaded)))
        logger.info('{} remote checkpoints removed.'.format(
            len(missing_ids) - len(already_downloaded)))
    if not len(to_add) and not len(missing_ids):
        logger.info('No changes in remote index.')

    local_index['checkpoints'] = non_remotes_in_local + new_remotes

    return local_index