コード例 #1
0
    def test_exec_proc_io(self):
        with TemporaryDirectory() as tempdir:
            with open(os.path.join(tempdir, 'test_payload.txt'), 'wb') as f:
                f.write(b'hello, world!')

            args = ['bash', '-c', 'ls && echo error_message >&2 && exit 123']

            # test stdout only
            stdout = io.BytesIO()
            with exec_proc(args, on_stdout=stdout.write, cwd=tempdir) as proc:
                proc.wait()
            self.assertEquals(123, proc.poll())
            self.assertIn(b'test_payload.txt', stdout.getvalue())
            self.assertNotIn(b'error_message', stdout.getvalue())

            # test separated stdout and stderr
            stdout = io.BytesIO()
            stderr = io.BytesIO()
            with exec_proc(args, on_stdout=stdout.write, on_stderr=stderr.write,
                           cwd=tempdir) as proc:
                proc.wait()
            self.assertEquals(123, proc.poll())
            self.assertIn(b'test_payload.txt', stdout.getvalue())
            self.assertNotIn(b'error_message', stdout.getvalue())
            self.assertNotIn(b'test_payload.txt', stderr.getvalue())
            self.assertIn(b'error_message', stderr.getvalue())

            # test redirect stderr to stdout
            stdout = io.BytesIO()
            with exec_proc(args, on_stdout=stdout.write, stderr_to_stdout=True,
                           cwd=tempdir) as proc:
                proc.wait()
            self.assertEquals(123, proc.poll())
            self.assertIn(b'test_payload.txt', stdout.getvalue())
            self.assertIn(b'error_message', stdout.getvalue())
コード例 #2
0
    def test_exec_proc_kill(self):
        interruptable = _strip('''
        |import time
        |try:
        |  while True:
        |    time.sleep(1)
        |except KeyboardInterrupt:
        |  print("kbd interrupt")
        |print("exited")
        ''')
        non_interruptable = _strip('''
        |import time
        |while True:
        |  try:
        |    time.sleep(1)
        |  except KeyboardInterrupt:
        |    print("kbd interrupt")
        |print("exited")
        ''')

        # test interruptable
        stdout = io.BytesIO()
        with exec_proc(
                ['python', '-u', '-c', interruptable],
                on_stdout=stdout.write) as proc:
            timed_wait_proc(proc, 1.)
        self.assertEquals(b'kbd interrupt\nexited\n', stdout.getvalue())
        self.assertEquals(0, proc.poll())

        # test non-interruptable, give up waiting
        stdout = io.BytesIO()
        with exec_proc(
                ['python', '-u', '-c', non_interruptable],
                on_stdout=stdout.write,
                ctrl_c_timeout=1) as proc:
            timed_wait_proc(proc, 1.)
        self.assertEquals(b'kbd interrupt\n', stdout.getvalue())
        self.assertNotEquals(0, proc.poll())
コード例 #3
0
def run_experiment(client_args):
    """
    Run the experiment with specified argument.

    Args:
        client_args: The client arguments.
    """
    # initialize logging and print the arguments to log
    logging.basicConfig(level='DEBUG' if client_args.debug else 'INFO',
                        format='%(asctime)s [%(levelname)s]: %(message)s')
    logger = getLogger(__name__)
    logger.info('MLStorage Runner {}'.format(mlstorage_client.__version__))
    if client_args.debug:
        logger.debug('DEBUGGING FLAG IS SET!')
    if client_args.parent_id:
        logger.info('Launched within a parent experiment: %s',
                    client_args.parent_id)
    log_dict('Environmental variables', client_args.env)
    log_dict('Experiment config', client_args.config)

    # establish connection to the server, and create the experiment
    api = ApiClientV1(client_args.server)
    doc = api.create(client_args.name, get_creation_doc(client_args))
    cleanup_helper = CleanupHelper()
    proc = None
    final_status_set = False

    try:
        # extract information from the returned doc
        id = doc['id']
        storage_dir = doc['storage_dir']
        logger.info('Experiment ID: %s', id)
        logger.info('Work-dir: %s', storage_dir)

        # ensure the shared network file system works, by generating a
        # random file, and try to get from the server
        os.makedirs(storage_dir, exist_ok=True)
        needle_fn = str(uuid.uuid4()) + '.txt'
        needle_path = os.path.join(storage_dir, needle_fn)
        needle_content = str(uuid.uuid4()).encode('utf-8')
        with open(needle_path, 'wb') as f:
            f.write(needle_content)
        remote_content = api.getfile(id, needle_fn)
        if remote_content != needle_content:
            raise ValueError('The content of remote file does not agree '
                             'with the generated content.')
        os.remove(needle_path)

        # construct the env dict
        env = get_environ_dict(client_args, id, storage_dir)

        # further update the doc according to `id` and `storage_dir`
        exc_info = doc.get('exc_info', {})
        exc_info.update({'work_dir': storage_dir, 'env': env})
        doc = api.update(id, {'exc_info': exc_info})

        # prepare for the working directory
        for script_file in client_args.script_files:
            clone_file_or_dir(os.path.join(client_args.cwd, script_file),
                              script_file,
                              storage_dir,
                              symlink=False)
        if not client_args.no_link:
            for data_file in client_args.data_files:
                if data_file not in client_args.script_files:
                    clone_file_or_dir(os.path.join(client_args.cwd, data_file),
                                      data_file,
                                      storage_dir,
                                      symlink=True)
                    cleanup_helper.add(os.path.join(storage_dir, data_file))

        if client_args.config:
            config_json = json.dumps(client_args.config,
                                     cls=JsonEncoder,
                                     sort_keys=True,
                                     indent=2)
            with codecs.open(os.path.join(storage_dir, 'config.json'), 'wb',
                             'utf-8') as f:
                f.write(config_json)

        # scoped class for injecting TensorBoard webui
        class TensorBoardWebUI(object):
            def __init__(self, key='TensorBoard'):
                self.uri = None
                self.key = key

            def postprocess(self, webui):
                if self.uri:
                    webui[self.key] = self.uri
                else:
                    webui.pop(self.key, None)
                return webui

            @contextmanager
            def set_uri(self, uri):
                self.uri = uri
                try:
                    yield
                finally:
                    self.uri = None

        # run the program
        heartbeat_job = HeartbeatJob('send heartbeat', api, doc)
        config_job = CollectJsonDictJob('collect config',
                                        api,
                                        doc,
                                        filename='config.json',
                                        field='config')
        default_config_job = CollectJsonDictJob(
            'collect default config',
            api,
            doc,
            filename='config.defaults.json',
            field='default_config')
        result_job = CollectJsonDictJob('collect result',
                                        api,
                                        doc,
                                        filename='result.json',
                                        field='result')
        tb_webui = TensorBoardWebUI()
        webui_job = CollectJsonDictJob('collect webui',
                                       api,
                                       doc,
                                       filename='webui.json',
                                       field='webui',
                                       postprocess=tb_webui.postprocess)

        try:
            with maybe_run_tensorboard(client_args, api, doc) as tb_uri, \
                    tb_webui.set_uri(tb_uri), \
                    heartbeat_job.run_in_background(), \
                    config_job.run_in_background(), \
                    default_config_job.run_in_background(), \
                    result_job.run_in_background(), \
                    webui_job.run_in_background(), \
                    ConsoleDuplicator(storage_dir, 'console.log') as out_dup, \
                    exec_proc(client_args.args,
                              on_stdout=out_dup.on_output,
                              stderr_to_stdout=True,
                              cwd=storage_dir,
                              env=env) as p:
                proc = p

                # final update the doc, to store pid
                exc_info = doc.get('exc_info', {})
                exc_info.update({'pid': proc.pid})
                doc = api.update(id, {'exc_info': exc_info})

                p.wait()
                logger.debug('Process exited normally.')

        finally:
            # collect the JSON dict for the last time
            retry(lambda: webui_job.run_once(True), webui_job.name)
            retry(lambda: config_job.run_once(True), config_job.name)
            retry(lambda: default_config_job.run_once(True),
                  default_config_job.name)
            retry(lambda: result_job.run_once(True), result_job.name)
            logger.debug('JSON file collected.')

            # cleanup the working directory
            cleanup_helper.cleanup()
            logger.debug('Working directory cleanup finished.')

            # compute the storage size and update the result
            if proc is not None:
                result_dict = {
                    'exit_code': proc.poll(),
                    'storage_size': compute_fs_size(storage_dir)
                }
                retry(lambda: api.set_finished(id, 'COMPLETED', result_dict),
                      'store the experiment result')
                final_status_set = True
                logger.info('Experiment exited with code: %s', proc.poll())

    except Exception as ex:
        if not final_status_set:
            logger.exception('Failed to run the experiment.')
            error_dict = {
                'message': str(ex),
                'traceback':
                ''.join(traceback.format_exception(*sys.exc_info()))
            }
            retry(
                lambda: api.set_finished(doc['id'], 'FAILED',
                                         {'error': error_dict}),
                'store the experiment failure')

    finally:
        cleanup_helper.cleanup()  # ensure to cleanup
        if client_args.debug:
            retry(lambda: api.delete(doc['id']),
                  'cleanup debugging experiment')
            logger.debug('Experiment deleted.')
コード例 #4
0
def run_tensorboard(path, log_file=None, host='0.0.0.0', port=0, timeout=30):
    """
    Run TensorBoard in background.

    Args:
        path (str): The log directory for TensorBoard.
        log_file (str): If specified, will write the log of TensorBoard
            into this file. (default :obj:`None`)
        host (str): Bind TensorBoard to this host. (default "0.0.0.0")
        port (int): Bind TensorBoard to this port. (default 0, any free port)
        timeout (float): Wait the TensorBoard to start for this number of
            seconds. (default 30)

    Yields:
        str: The URI of the launched TensorBoard.
    """
    def capture_output(data, fout, headbuf,
                       pattern=re.compile(br'TensorBoard \S+ at '
                                          br'http://([^:]+):(\d+)')):
        if headbuf:
            headbuf[0] = headbuf[0] + data
            m = pattern.search(headbuf[0])
            if m:
                url_host = m.group(1).decode('utf-8')
                url_port = m.group(2).decode('utf-8')
                if not url_host or (url_host in ('0.0.0.0', '::0')):
                    url_host = socket.gethostbyname(socket.gethostname())
                the_url = 'http://{}:{}'.format(url_host, url_port)
                url_q.put(the_url)
                del headbuf[:]
        if fout is not None:
            fout.write(data)
            fout.flush()

    @contextmanager
    def maybe_open_log():
        if log_file is not None:
            with open(log_file, 'wb') as f:
                yield f
        else:
            yield None

    url_q = Queue()
    args = ['tensorboard',
            '--logdir', path,
            '--host', host,
            '--port', str(port)]
    env = dict(os.environ)
    env['PYTHONUNBUFFERED'] = '1'
    env['CUDA_VISIBLE_DEVICES'] = ''  # force not using GPUs
    with maybe_open_log() as log_f, \
            exec_proc(args,
                      on_stdout=lambda data: capture_output(
                          data, fout=log_f, headbuf=[b'']
                      ),
                      stderr_to_stdout=True,
                      env=env):
        try:
            url = url_q.get(timeout=timeout)
        except Empty:
            raise ValueError('TensorBoard did not report its url in '
                             '{} seconds.'.format(timeout))
        yield url