def _monitor_tasks(batch_client, job_id, numtasks):
    # type: (batch.BatchServiceClient, str, int) -> None
    """Monitor tasks for completion
    :param batch_client: The batch client to use.
    :type batch_client: `azure.batch.batch_service_client.BatchServiceClient`
    :param str job_id: job to add to
    :param int numtasks: number of tasks
    """
    i = 0
    j = 0
    while True:
        try:
            task_counts = batch_client.job.get_task_counts(job_id=job_id)
        except batchmodels.batch_error.BatchErrorException as ex:
            logger.exception(ex)
        else:
            if (task_counts.validation_status ==
                    batchmodels.TaskCountValidationStatus.validated):
                j = 0
                if task_counts.completed == numtasks:
                    logger.info(task_counts)
                    logger.info('all {} tasks completed'.format(numtasks))
                    break
            else:
                # unvalidated, perform manual list tasks
                j += 1
                if j % 10 == 0:
                    j = 0
                    try:
                        tasks = batch_client.task.list(
                            job_id=job_id,
                            task_list_options=batchmodels.TaskListOptions(
                                select='id,state'))
                        states = [task.state for task in tasks]
                    except batchmodels.batch_error.BatchErrorException as ex:
                        logger.exception(ex)
                    else:
                        if (states.count(
                                batchmodels.TaskState.completed) == numtasks):
                            logger.info(
                                'all {} tasks completed'.format(numtasks))
                            break
            i += 1
            if i % 15 == 0:
                i = 0
                logger.debug(task_counts)
        time.sleep(2)
示例#2
0
    def _update_task_state(self, latest_transition_time=None):
        """ Query azure and update task state. """

        task_filter = "state eq 'completed'"

        if self.most_recent_transition_time is not None:
            assert self.most_recent_transition_time.tzname() == 'UTC'
            filter_time_string = self.most_recent_transition_time.strftime(
                "%Y-%m-%dT%H:%M:%SZ")
            task_filter += " and stateTransitionTime gt DateTime'{}'".format(
                filter_time_string)

        if latest_transition_time is not None:
            assert latest_transition_time.tzname() == 'UTC'
            filter_time_string = latest_transition_time.strftime(
                "%Y-%m-%dT%H:%M:%SZ")
            task_filter += " and stateTransitionTime lt DateTime'{}'".format(
                filter_time_string)

        list_max_results = 1000
        list_options = batchmodels.TaskListOptions(
            filter=task_filter, max_results=list_max_results)

        tasks = []
        for _, job_id in self.pool_job_map.iteritems():
            tasks += list(
                self.batch_client.task.list(job_id,
                                            task_list_options=list_options))
        sorted_transition_times = sorted(
            [task.state_transition_time for task in tasks])

        if len(tasks) >= list_max_results:
            mid_transition_time = sorted_transition_times[
                len(sorted_transition_times) / 2]
            return self._update_task_state(
                latest_transition_time=mid_transition_time)

        elif len(tasks) == 0:
            return False

        self.most_recent_transition_time = sorted_transition_times[-1]

        num_completed_before = len(self.completed_task_ids)
        self.completed_task_ids.update([task.id for task in tasks])
        return len(self.completed_task_ids) > num_completed_before
示例#3
0
    def wait(self, immediate=False):
        """ Wait for a job to finish.

        KwArgs:
            immediate (bool): do not wait if no job has finished

        Returns:
            str: job name

        """

        timeout = datetime.timedelta(minutes=10)

        while True:
            timeout_expiration = datetime.datetime.now() + timeout

            while datetime.datetime.now() < timeout_expiration:
                for task_id in self.running_task_ids.intersection(
                        self.completed_task_ids):
                    return self.job_names[task_id]

                if not self._update_task_state():
                    time.sleep(20)

            self.logger.warn(
                "Tasks did not reach 'Completed' state within timeout period of "
                + str(timeout))

            self.logger.info("Most recent transition: {}".format(
                self.most_recent_transition_time))
            task_filter = "state eq 'completed'"
            list_options = batchmodels.TaskListOptions(filter=task_filter)

            for pool_id, job_id in self.pool_job_map.iteritems():
                check_pool_for_failed_nodes(self.batch_client, pool_id,
                                            self.logger)
                tasks = list(
                    self.batch_client.task.list(
                        job_id, task_list_options=list_options))
                self.logger.info("Received total {} tasks".format(len(tasks)))
                for task in tasks:
                    if task.id not in self.completed_task_ids:
                        self.logger.info("Missed completed task: {}".format(
                            task.serialize()))
                        self.completed_task_ids.add(task.id)
示例#4
0
    def get_task_list(self, job_id, task_filter=None, list_max_results=None):
        """
        return tasks, optionally can filter them as well
        :param job_id: lists tasks under this Job Id
        :type job_id: str
        :param task_filter: filter string for list query
        :type task_filter:  str
        :param list_max_results: number of results to return
        :type list_max_results: int or None
        :return: list of tasks
        :rtype: list
        """
        if task_filter:
            list_options = batchmodels.TaskListOptions(filter=task_filter)
        else:
            list_options = None

        return self.batch_client.task.list(
            job_id, task_list_options=list_options, list_max_results=list_max_results
        )
示例#5
0
def tunnel_tensorboard(batch_client, config, jobid, taskid, logdir, image):
    # type: (batchsc.BatchServiceClient, dict, str, str, str, str) -> None
    """Action: Misc Tensorboard
    :param azure.batch.batch_service_client.BatchServiceClient batch_client:
        batch client
    :param dict config: configuration dict
    :param str jobid: job id to list
    :param str taskid: task id to list
    :param str logdir: log dir
    :param str image: tensorflow image to use
    """
    # ensure pool ssh private key exists
    pool = settings.pool_settings(config)
    ssh_priv_key = pool.ssh.ssh_private_key
    if ssh_priv_key is None:
        ssh_priv_key = pathlib.Path(pool.ssh.generated_file_export_path,
                                    crypto.get_ssh_key_prefix())
    if not ssh_priv_key.exists():
        raise RuntimeError(
            ('cannot tunnel to remote Tensorboard with non-existant RSA '
             'private key: {}').format(ssh_priv_key))
    if not crypto.check_ssh_private_key_filemode(ssh_priv_key):
        logger.warning('SSH private key filemode is too permissive: {}'.format(
            ssh_priv_key))
    # populate jobid if empty
    if util.is_none_or_empty(jobid):
        jobspecs = settings.job_specifications(config)
        jobid = settings.job_id(jobspecs[0])
    # get the last task for this job
    if util.is_none_or_empty(taskid):
        tasks = batch_client.task.list(
            jobid, task_list_options=batchmodels.TaskListOptions(select='id'))
        taskid = sorted([x.id for x in tasks])[-1]
    # wait for task to be running or completed
    logger.debug('waiting for task {} in job {} to reach a valid state'.format(
        taskid, jobid))
    while True:
        task = batch_client.task.get(jobid, taskid)
        if (task.state == batchmodels.TaskState.running
                or task.state == batchmodels.TaskState.completed):
            break
        logger.debug('waiting for task to enter running or completed state')
        time.sleep(1)
    # parse "--logdir" from task commandline
    if util.is_none_or_empty(logdir):
        for arg in _TENSORBOARD_LOG_ARGS:
            try:
                _tmp = task.command_line.index(arg)
            except ValueError:
                pass
            else:
                _tmp = task.command_line[_tmp + len(arg) + 1:]
                logdir = _tmp.split()[0].rstrip(';').rstrip('"').rstrip('\'')
                if not util.confirm_action(
                        config, 'use auto-detected logdir: {}'.format(logdir)):
                    logdir = None
                else:
                    logger.debug(
                        'using auto-detected logdir: {}'.format(logdir))
                    break
    if util.is_none_or_empty(logdir):
        raise RuntimeError(
            ('cannot automatically determine logdir for task {} in '
             'job {}, please retry command with explicit --logdir '
             'parameter').format(taskid, jobid))
    # construct absolute logpath
    logpath = pathlib.Path(
        settings.temp_disk_mountpoint(config)) / 'batch' / 'tasks'
    if logdir.startswith('$AZ_BATCH'):
        _tmp = logdir.index('/')
        _var = logdir[:_tmp]
        # shift off var
        logdir = logdir[_tmp + 1:]
        if _var == '$AZ_BATCH_NODE_ROOT_DIR':
            pass
        elif _var == '$AZ_BATCH_NODE_SHARED_DIR':
            logpath = logpath / 'shared'
        elif _var == '$AZ_BATCH_NODE_STARTUP_DIR':
            logpath = logpath / 'startup'
        elif _var == '$AZ_BATCH_TASK_WORKING_DIR':
            logpath = logpath / 'workitems' / jobid / 'job-1' / taskid / 'wd'
        else:
            raise RuntimeError(
                ('cannot automatically translate variable {} to absolute '
                 'path, please retry with an absolute path for '
                 '--logdir').format(_var))
    elif not logdir.startswith('/'):
        # default to task working directory
        logpath = logpath / 'workitems' / jobid / 'job-1' / taskid / 'wd'
    logpath = logpath / logdir
    if util.on_windows():
        logpath = str(logpath).replace('\\', '/')
    logger.debug('using logpath: {}'.format(logpath))
    # if logdir still has vars raise error
    if '$AZ_BATCH' in logdir:
        raise RuntimeError(
            ('cannot determine absolute logdir path for task {} in job {}, '
             'please retry with an absolute path for --logdir').format(
                 taskid, jobid))
    # determine tensorflow image to use
    tb = settings.get_tensorboard_docker_image()
    if util.is_none_or_empty(image):
        di = settings.global_resources_docker_images(config)
        di = [x for x in di if 'tensorflow' in x]
        if util.is_not_empty(di):
            image = di[0]
            if not util.confirm_action(
                    config,
                    'use auto-detected Docker image: {}'.format(image)):
                image = None
            else:
                logger.debug(
                    'using auto-detected Docker image: {}'.format(image))
        del di
    if util.is_none_or_empty(image):
        logger.warning(
            'no pre-loaded tensorflow Docker image detected on pool, '
            'using: {}'.format(tb[0]))
        image = tb[0]
    # get node remote login settings
    rls = batch_client.compute_node.get_remote_login_settings(
        pool.id, task.node_info.node_id)
    # set up tensorboard command
    if settings.is_gpu_pool(pool.vm_size):
        exe = 'nvidia-docker'
    else:
        exe = 'docker'
    name = str(uuid.uuid4()).split('-')[0]
    # map both ports (jupyter and tensorboard) to different host ports
    # to avoid conflicts
    host_port = 56006
    tb_ssh_args = [
        'ssh', '-o', 'StrictHostKeyChecking=no', '-o',
        'UserKnownHostsFile={}'.format(os.devnull), '-i',
        str(ssh_priv_key), '-p',
        str(rls.remote_login_port), '-t',
        '{}@{}'.format(pool.ssh.username, rls.remote_login_ip_address),
        ('sudo /bin/bash -c "{exe} run --rm --name={name} -p 58888:8888 '
         '-p {hostport}:{contport} -v {logdir}:/{jobid}.{taskid} {image} '
         'python {tbpy} --port={contport} --logdir=/{jobid}.{taskid}"').format(
             exe=exe,
             name=name,
             hostport=host_port,
             contport=tb[2],
             image=image,
             tbpy=tb[1],
             logdir=str(logpath),
             jobid=jobid,
             taskid=taskid)
    ]
    # set up ssh tunnel command
    tunnel_ssh_args = [
        'ssh', '-o', 'StrictHostKeyChecking=no', '-o',
        'UserKnownHostsFile={}'.format(os.devnull), '-i',
        str(ssh_priv_key), '-p',
        str(rls.remote_login_port), '-N', '-L',
        '{port}:localhost:{hostport}'.format(port=tb[2], hostport=host_port),
        '{}@{}'.format(pool.ssh.username, rls.remote_login_ip_address)
    ]
    # execute command and then tunnel
    tb_proc = None
    tunnel_proc = None
    try:
        tb_proc = util.subprocess_nowait_pipe_stdout(tb_ssh_args, shell=False)
        tunnel_proc = util.subprocess_nowait_pipe_stdout(tunnel_ssh_args,
                                                         shell=False)
        logger.info(
            ('\n\n>> Please connect to Tensorboard at http://localhost:{}/'
             '\n\n>> Note that Tensorboard may take a while to start if the '
             'Docker is'
             '\n>> not present. Please keep retrying the URL every few '
             'seconds.'
             '\n\n>> Terminate your session with CTRL+C'
             '\n\n>> If you cannot terminate your session cleanly, run:'
             '\n     shipyard pool ssh --nodeid {} '
             'sudo docker kill {}\n').format(tb[2], task.node_info.node_id,
                                             name))
        tb_proc.wait()
    finally:
        logger.debug(
            'attempting clean up of Tensorboard instance and SSH tunnel')
        try:
            if tunnel_proc is not None:
                tunnel_proc.poll()
                if tunnel_proc.returncode is None:
                    tunnel_proc.kill()
        except Exception as e:
            logger.exception(e)
        if tb_proc is not None:
            tb_proc.poll()
            if tb_proc.returncode is None:
                tb_proc.kill()