class Server(object):
    def __init__(self, port):
        self.context = None
        self.thread = None
        self.queue = None
        self.webserver = None
        self.port = port

    def __enter__(self):
        self.context = ExitStack()
        self.context.enter_context(self.event_loop_context())
        self.thread = EventLoopThread([self.webserver.server()])
        self.context.enter_context(self.thread)
        return self

    def __exit__(self, *enc):
        self.context.__exit__(*enc)

    def recv(self):
        result = self.queue.get()
        return result

    def send(self, msg):
        asyncio.run_coroutine_threadsafe(self.webserver.broadcast(msg), self.thread.loop)

    @contextmanager
    def event_loop_context(self):
        with ExitStack() as stack:
            stack.callback(lambda: setattr(self, "queue", None))
            stack.callback(lambda: setattr(self, "webserver", None))
            self.queue = Queue()
            self.webserver = WebsocketServer(self.queue, self.port)
            yield
Exemple #2
0
def logtail(path, offset_path=None, *, dry_run=False):
    """Yield new lines from a logfile.

    :param path: The path to the file to read from
    :param offset_path: The path to the file where offset/inode
                        information will be stored.  If not set,
                        ``<file>.offset`` will be used.
    :param dry_run: If ``True``, the offset file will not be modified
                    or created.
    """
    if offset_path is None:
        offset_path = path + '.offset'

    try:
        logfile = open(path, encoding='utf-8', errors='replace')
    except OSError as exc:
        warning_echo('Could not read: {} ({})'.format(path, exc))
        return

    closer = ExitStack()
    closer.enter_context(logfile)
    with closer:
        line_iter = iter([])
        stat = os.stat(logfile.fileno())
        debug_echo('logfile inode={}, size={}'.format(stat.st_ino, stat.st_size))
        inode, offset = _parse_offset_file(offset_path)
        if inode is not None:
            if stat.st_ino == inode:
                debug_echo('inodes are the same')
                if offset == stat.st_size:
                    debug_echo('offset points to eof')
                    return
                elif offset > stat.st_size:
                    warning_echo('File shrunk since last read: {} ({} < {})'.format(path, stat.st_size, offset))
                    offset = 0
            else:
                debug_echo('inode changed, checking for rotated file')
                rotated_path = _check_rotated_file(path, inode)
                if rotated_path is not None:
                    try:
                        rotated_file = open(rotated_path, encoding='utf-8', errors='replace')
                    except OSError as exc:
                        warning_echo('Could not read rotated file: {} ({})'.format(rotated_path, exc))
                    else:
                        closer.enter_context(rotated_file)
                        rotated_file.seek(offset)
                        line_iter = itertools.chain(line_iter, iter(rotated_file))
                offset = 0
        logfile.seek(offset)
        line_iter = itertools.chain(line_iter, iter(logfile))
        for line in line_iter:
            line = line.strip()
            yield line
        pos = logfile.tell()
        debug_echo('reached end of logfile at {}'.format(pos))
        if not dry_run:
            debug_echo('writing offset file: ' + offset_path)
            _write_offset_file(offset_path, stat.st_ino, pos)
        else:
            debug_echo('dry run - not writing offset file')
Exemple #3
0
def Function_emitInstructions(self):
    self.emitEntry()

    # Exit stack that might contain a selfScope
    self_stack = ExitStack()
    if "self" in self.closed_context:
        assert self.closed_context["self"].llvm_context_index >= 0
        index = self.closed_context["self"].llvm_context_index
        self_value = State.builder.structGEP(self.llvm_context, index, "")
        self_value = State.builder.load(self_value, "")
        self_stack.enter_context(State.selfScope(self_value))

    with self_stack:
        for object in self.closed_context:
            if object.name == "self": continue
            if object.stats.static: continue

            index = object.llvm_context_index
            object.llvm_value = State.builder.structGEP(self.llvm_context, index, "")

        # Allocate Arguments
        for index, arg in enumerate(self.arguments):
            val = self.llvm_value.getParam(index + 1)
            arg.llvm_value = State.builder.alloca(arg.resolveType().emitType(), resolveName(arg))
            State.builder.store(val, arg.llvm_value)

        self.emitPostContext()

        return State.emitInstructions(self.instructions)
Exemple #4
0
    def target(self):
        stack = ExitStack()

        for target, value in self.targets:
            stack.enter_context(target.resolveValue().targetValue(value))

        with stack:
            yield
Exemple #5
0
    def checkCompatibility(self, other, check_cache = None):
        stack = ExitStack()
        stack.enter_context(self.target())
        if isinstance(other, ClosedTarget):
            stack.enter_context(other.target())
            other = other.value

        with stack:
            return self.value.checkCompatibility(other, check_cache)
class EventLoopThread(object):
    def __init__(self, servers_to_start):
        self.context = None
        self.executor = None
        self.loop = None
        self.servers_to_start = servers_to_start
        self.servers = []

    def __enter__(self):
        self.context = ExitStack()
        self.executor = self.context.enter_context(ThreadPoolExecutor(max_workers=1))
        self.context.enter_context(self.event_loop_context())
        return self

    def __exit__(self, *enc):
        self.context.__exit__(*enc)
        self.context = None
        self.executor = None
        self.loop = None

    def start_loop(self, event):
        logger.info("starting eventloop server")
        loop = asyncio.new_event_loop()
        self.loop = loop
        asyncio.set_event_loop(loop)
        for server_starter in self.servers_to_start:
            server = loop.run_until_complete(server_starter)
            self.servers.append(server)
        loop.call_soon(event.set)
        loop.run_forever()

    def stop_loop(self):
        logger.info("stopping eventloop server")
        self.loop.create_task(self._close_connections())

    @contextmanager
    def event_loop_context(self):
        event = Event()
        event.clear()
        self.executor.submit(self.start_loop, event)
        event.wait()
        logger.info("started eventloop")
        try:
            yield
        finally:
            self.loop.call_soon_threadsafe(self.stop_loop)
            logger.info("stopped eventloop")

    @asyncio.coroutine
    def _close_connections(self):
        for server in self.servers:
            server.close()
            yield from server.wait_closed()
        self.loop.stop()
Exemple #7
0
 def mock_sys(self):
     "Mock system environment for InteractiveConsole"
     # use exit stack to match patch context managers to addCleanup
     stack = ExitStack()
     self.addCleanup(stack.close)
     self.infunc = stack.enter_context(mock.patch("code.input", create=True))
     self.stdout = stack.enter_context(mock.patch("code.sys.stdout"))
     self.stderr = stack.enter_context(mock.patch("code.sys.stderr"))
     prepatch = mock.patch("code.sys", wraps=code.sys, spec=code.sys)
     self.sysmod = stack.enter_context(prepatch)
     if sys.excepthook is sys.__excepthook__:
         self.sysmod.excepthook = self.sysmod.__excepthook__
 class StreamDecoder:
     def __init__(self, file):
         self._file = file
         self._crc = 0
         self._pipe = PipeWriter()
         self._cleanup = ExitStack()
         coroutine = self._pipe.coroutine(self._receive())
         self._cleanup.enter_context(coroutine)
     
     def close(self):
         self._pipe.close()
         del self._pipe
         self._cleanup.close()
     
     def feed(self, data):
         self._pipe.write(data)
     
     def _receive(self):
         while True:
             data = self._pipe.buffer
             pos = data.find(b"=")
             if pos >= 0:
                 data = data[:pos]
             data = data.replace(b"\r", b"").replace(b"\n", b"")
             data = data.translate(self.TABLE)
             # TODO: check data size overflow
             self._crc = crc32(data, self._crc)
             self._file.write(data)
             if pos >= 0:  # Escape character (equals sign)
                 self._pipe.buffer = self._pipe.buffer[pos + 1:]
                 while True:
                     byte = yield from self._pipe.read_one()
                     if byte not in b"\r\n":
                         break
                 # TODO: check for size overflow
                 [byte] = byte
                 data = bytes(((byte - 64 - 42) & bitmask(8),))
                 self._crc = crc32(data, self._crc)
                 self._file.write(data)
             else:
                 try:
                     self._pipe.buffer = yield
                 except EOFError:
                     break
     
     def flush(self):
         pass
     
     def getCrc32(self):
         return format(self._crc, "08x")
     
     TABLE = bytes(range(256))
     TABLE = TABLE[-42:] + TABLE[:-42]
    def setUp(self):
        super(TestOpenSSLManager, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.subp = patches.enter_context(
            mock.patch.object(azure_helper.util, 'subp'))
        try:
            self.open = patches.enter_context(
                mock.patch('__builtin__.open'))
        except ImportError:
            self.open = patches.enter_context(
                mock.patch('builtins.open'))
Exemple #10
0
class PulpWritter(object):
    """Use this to create a pulp db."""
    def __init__(self, db_name, msg_dumper=None, idx_dumpers=None):
        self.dir_path = os.path.abspath(db_name)
        self.keys_path = os.path.join(self.dir_path, 'keys')
        if os.path.isdir(self.dir_path):
            shutil.rmtree(self.dir_path)
        os.makedirs(self.dir_path)
        os.makedirs(self.keys_path)

        self.master_table = None
        self.key_tables = {}
        self.table_stack = None
        self.msg_dumper = msg_dumper
        self.idx_dumpers = idx_dumpers
    
    def __enter__(self):
        self.table_stack = ExitStack().__enter__()
        table = MasterTable(self.dir_path, 'w', dumper=self.msg_dumper)
        self.add_table(None, table)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.table_stack is not None:
            self.table_stack.close()

    def add_table(self, key, table):
        if key == None:
            self.master_table = table
        else:
            self.key_tables[key] = table
        self.table_stack.enter_context(table)

    def append(self, data, index_map):
        msg_num = self.master_table.append(data)
        for key, value in index_map.items():
            table = self.key_tables.get(key, None)
            if not table:
                if self.idx_dumpers is not None:
                    dumper = self.idx_dumpers.get(key, None)
                else:
                    dumper = None
                table = KeyTable(self.keys_path, key, 'w', dumper=dumper)
                self.add_table(key, table)

            if isinstance(value, (tuple, list, set)):
                for v in value:
                    table.append(v, msg_num)
            else:
                table.append(value, msg_num)
Exemple #11
0
    def targetValue(self, value):
        value = value.resolveValue()
        old_value = self.value
        self.value = value

        stack = ExitStack()
        if self._static_value_type is not None:
            targets = [(self._static_value_type, value)]
            stack.enter_context(forward.target(targets))

        with stack:
            yield

        self.value = old_value
Exemple #12
0
class TestSocket(unittest.TestCase):
    # Usually the socket will be set up from socket.getaddrinfo() but if that
    # raises socket.gaierror, then it tries to infer the IPv4/IPv6 type from
    # the host name.
    def setUp(self):
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        self._resources.enter_context(patch('aiosmtpd.main.socket.getaddrinfo',
                                            side_effect=socket.gaierror))

    def test_ipv4(self):
        bind = self._resources.enter_context(patch('aiosmtpd.main.bind'))
        mock_sock = setup_sock('host.example.com', 8025)
        bind.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM, 0)
        mock_sock.bind.assert_called_once_with(('host.example.com', 8025))

    def test_ipv6(self):
        bind = self._resources.enter_context(patch('aiosmtpd.main.bind'))
        mock_sock = setup_sock('::1', 8025)
        bind.assert_called_once_with(socket.AF_INET6, socket.SOCK_STREAM, 0)
        mock_sock.bind.assert_called_once_with(('::1', 8025, 0, 0))

    def test_bind_ipv4(self):
        self._resources.enter_context(patch('aiosmtpd.main.socket.socket'))
        mock_sock = setup_sock('host.example.com', 8025)
        mock_sock.setsockopt.assert_called_once_with(
            socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    def test_bind_ipv6(self):
        self._resources.enter_context(patch('aiosmtpd.main.socket.socket'))
        mock_sock = setup_sock('::1', 8025)
        self.assertEqual(mock_sock.setsockopt.call_args_list, [
            call(socket.SOL_SOCKET, socket.SO_REUSEADDR, True),
            call(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False),
            ])
    def setUp(self):
        super(TestAzureEndpointHttpClient, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.read_file_or_url = patches.enter_context(
            mock.patch.object(azure_helper.util, 'read_file_or_url'))
    def setUp(self):
        super(TestFindEndpoint, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.load_file = patches.enter_context(
            mock.patch.object(azure_helper.util, 'load_file'))
    def setUp(self):
        super(TestWalkerHandleHandler, self).setUp()
        tmpdir = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, tmpdir)

        self.data = {
            "handlercount": 0,
            "frequency": "",
            "handlerdir": tmpdir,
            "handlers": helpers.ContentHandlers(),
            "data": None,
        }

        self.expected_module_name = "part-handler-%03d" % (self.data["handlercount"],)
        expected_file_name = "%s.py" % self.expected_module_name
        self.expected_file_fullname = os.path.join(self.data["handlerdir"], expected_file_name)
        self.module_fake = FakeModule()
        self.ctype = None
        self.filename = None
        self.payload = "dummy payload"

        # Mock the write_file() function.  We'll assert that it got called as
        # expected in each of the individual tests.
        resources = ExitStack()
        self.addCleanup(resources.close)
        self.write_file_mock = resources.enter_context(mock.patch("cloudinit.util.write_file"))
Exemple #16
0
class TestZip(unittest.TestCase):
    def setUp(self):
        # Find the path to the example.*.whl so we can add it to the front of
        # sys.path, where we'll then try to find the metadata thereof.
        self.resources = ExitStack()
        self.addCleanup(self.resources.close)
        wheel = self.resources.enter_context(
            path('importlib_metadata.tests.data',
                 'example-21.12-py3-none-any.whl'))
        sys.path.insert(0, str(wheel))
        self.resources.callback(sys.path.pop, 0)

    def test_zip_version(self):
        self.assertEqual(importlib_metadata.version('example'), '21.12')

    def test_zip_entry_points(self):
        parser = importlib_metadata.entry_points('example')
        entry_point = parser.get('console_scripts', 'example')
        self.assertEqual(entry_point, 'example:main')

    def test_missing_metadata(self):
        distribution = importlib_metadata.distribution('example')
        self.assertIsNone(distribution.read_text('does not exist'))

    def test_case_insensitive(self):
        self.assertEqual(importlib_metadata.version('Example'), '21.12')
Exemple #17
0
def determine_context(device_ids: List[int],
                      use_cpu: bool,
                      disable_device_locking: bool,
                      lock_dir: str,
                      exit_stack: ExitStack) -> List[mx.Context]:
    """
    Determine the MXNet context to run on (CPU or GPU).

    :param device_ids: List of device as defined from the CLI.
    :param use_cpu: Whether to use the CPU instead of GPU(s).
    :param disable_device_locking: Disable Sockeye's device locking feature.
    :param lock_dir: Directory to place device lock files in.
    :param exit_stack: An ExitStack from contextlib.
    :return: A list with the context(s) to run on.
    """
    if use_cpu:
        context = [mx.cpu()]
    else:
        num_gpus = get_num_gpus()
        check_condition(num_gpus >= 1,
                        "No GPUs found, consider running on the CPU with --use-cpu "
                        "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                        "binary isn't on the path).")
        if disable_device_locking:
            context = expand_requested_device_ids(device_ids)
        else:
            context = exit_stack.enter_context(acquire_gpus(device_ids, lock_dir=lock_dir))
        context = [mx.gpu(gpu_id) for gpu_id in context]
    return context
Exemple #18
0
def Object_resetEmission(self):
    cache = set()
    stack = ExitStack()
    objects = iter([self])

    while True:
        obj = next(objects, None)
        if obj is None: break
        if obj in cache: continue

        cache.add(obj)
        local_resets = obj.resetLocalEmission()
        if local_resets is not None:
            stack.enter_context(local_resets)
        objects = chain(objects, obj.gatherEmissionResets())

    return stack
class WPRobotBase(object):
    def __init__(self):
        self.context = None
        self.devices = []

    def __enter__(self):
        wiringpi2.wiringPiSetupGpio()
        self.context = ExitStack()
        for device in self.devices:
            self.context.enter_context(device)
        return self

    def __exit__(self, *exc):
        self.context.__exit__(*exc)

    def attach_device(self, device):
        self.devices.append(device)
        return device
def get_folders(directory: str) -> Iterable[Tuple[str, str]]:
    """
    Get folders composing a directory.
    Yield a 2-item tuples composed of both folder name and folder path.

    :param directory: directory path.
    :return: 2-item tuples composed of both folder name and folder path.
    """
    _collection = []  # type: List[Tuple[str, str]]
    stack = ExitStack()
    try:
        stack.enter_context(ChangeLocalCurrentDirectory(directory))
    except PermissionError:
        pass
    else:
        with stack:
            _collection = [(name, os.path.join(os.getcwd(), name)) for name in os.listdir(".")]
    for name, path in _collection:
        yield name, path
class RobotBase(object):
    def __init__(self):
        self.context = None
        self.devices = []

    def __enter__(self):
        GPIO.setmode(GPIO.BCM)
        self.context = ExitStack()
        for device in self.devices:
            self.context.enter_context(device)
        return self

    def __exit__(self, *exc):
        self.context.__exit__(*exc)
        GPIO.cleanup()

    def attach_device(self, device):
        self.devices.append(device)
        return device
Exemple #22
0
def context_env_update(context_list, env):
    es = ExitStack()
    for item in context_list:
        # create context manager and enter
        tmp_name = '__pw_cm'
        cm_code = compile(ast.Expression(item.context_expr), '<context_eval>', 'eval')
        env[tmp_name] = es.enter_context(eval(cm_code, env))

        # assign to its optional_vars in separte dict
        if item.optional_vars:
            code = assign_from_ast(item.optional_vars, tmp_name)
            exec(code, env)

    return es
Exemple #23
0
class TestMain(TestCase):
    def setUp(self):
        super().setUp()
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        # Capture builtin print() output.
        self._stdout = StringIO()
        self._stderr = StringIO()
        self._resources.enter_context(
            patch('argparse._sys.stdout', self._stdout))
        # Capture stderr since this is where argparse will spew to.
        self._resources.enter_context(
            patch('argparse._sys.stderr', self._stderr))

    def test_help(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--help',))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(lines[0].startswith('usage: ubuntu-image'),
                        lines[0])

    def test_debug(self):
        with ExitStack() as resources:
            mock = resources.enter_context(
                patch('ubuntu_image.__main__.logging.basicConfig'))
            code = main(('--debug',))
        self.assertEqual(code, 0)
        mock.assert_called_once_with(level=logging.DEBUG)

    def test_no_debug(self):
        with ExitStack() as resources:
            mock = resources.enter_context(
                patch('ubuntu_image.__main__.logging.basicConfig'))
            code = main(())
        self.assertEqual(code, 0)
        mock.assert_not_called()
class TestMainWithBadGadget(TestCase):
    def setUp(self):
        super().setUp()
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        self.model_assertion = resource_filename(
            'ubuntu_image.tests.data', 'model.assertion')

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_bad_gadget_log(self):
        log = self._resources.enter_context(LogCapture())
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            BadGadgetModelAssertionBuilder))
        main(('snap', '--channel', 'edge',
              '--workdir', workdir,
              self.model_assertion))
        self.assertEqual(log.logs, [
            (logging.ERROR, 'gadget.yaml parse error: '
                            'GUID structure type with non-GPT schema'),
            (logging.ERROR, 'Use --debug for more information')
            ])

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_bad_gadget_debug_log(self):
        log = self._resources.enter_context(LogCapture())
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            BadGadgetModelAssertionBuilder))
        main(('snap', '--debug',
              '--workdir', workdir,
              '--channel', 'edge',
              self.model_assertion))
        self.assertEqual(log.logs, [
            (logging.ERROR, 'uncaught exception in state machine step: '
                            '[3] load_gadget_yaml'),
            'IMAGINE THE TRACEBACK HERE',
            (logging.ERROR, 'gadget.yaml parse error'),
            'IMAGINE THE TRACEBACK HERE',
            ])
    def setUp(self):
        super(TestWALinuxAgentShim, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.AzureEndpointHttpClient = patches.enter_context(
            mock.patch.object(azure_helper, 'AzureEndpointHttpClient'))
        self.find_endpoint = patches.enter_context(
            mock.patch.object(
                azure_helper.WALinuxAgentShim, 'find_endpoint'))
        self.GoalState = patches.enter_context(
            mock.patch.object(azure_helper, 'GoalState'))
        self.iid_from_shared_config_content = patches.enter_context(
            mock.patch.object(azure_helper, 'iid_from_shared_config_content'))
        self.OpenSSLManager = patches.enter_context(
            mock.patch.object(azure_helper, 'OpenSSLManager'))
        patches.enter_context(
            mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
Exemple #26
0
 def setUp(self):
     # We mock out so much of this, is it even worthwhile testing?  Well, it
     # does give us coverage.
     self.loop = asyncio.get_event_loop()
     pfunc = partial(patch.object, self.loop)
     resources = ExitStack()
     self.addCleanup(resources.close)
     self.create_server = resources.enter_context(pfunc('create_server'))
     self.run_until_complete = resources.enter_context(
         pfunc('run_until_complete'))
     self.add_signal_handler = resources.enter_context(
         pfunc('add_signal_handler'))
     resources.enter_context(
         patch.object(logging.getLogger('mail.log'), 'info'))
     self.run_forever = resources.enter_context(pfunc('run_forever'))
Exemple #27
0
    def build_context(self, environ):
        """
        Start a request context.

        :param environ: A WSGI environment.
        :return: A context manager for the request. When the context
            manager exits, the request context variables are destroyed and
            all cleanup hooks are run.

        .. note:: This method is intended for internal use; Findig will
            call this method internally on its own. It is *not* re-entrant
            with a single request.

        """
        self.__run_startup_hooks()

        ctx.app = self
        ctx.url_adapter = adapter = self.url_map.bind_to_environ(environ)
        ctx.request = self.request_class(environ) # ALWAYS set this after adapter

        rule, url_values = adapter.match(return_rule=True)
        dispatcher = self #self.get_dispatcher(rule)

        # Set up context variables
        ctx.url_values = url_values
        ctx.dispatcher = dispatcher
        ctx.resource = dispatcher.get_resource(rule)

        context = ExitStack()
        context.callback(self.__cleanup)
        # Add all the application's context managers to
        # the exit stack. If any of them return a value,
        # we'll add the value to the application context
        # with the function name.
        for hook in self.context_hooks:
            retval = context.enter_context(hook())
            if retval is not None:
                setattr(ctx, hook.__name__, retval)
        return context
Exemple #28
0
def dump(obj, count, json_file):
    usb_tin = obj['usb_tin']

    cleanup = ExitStack()

    if json_file:
        json_file = cleanup.enter_context(json_writer(json_file))

    with open_channel(usb_tin), cleanup:
        num_captured = 0
        while count is None or num_captured < count:

            msg = usb_tin.recv_can_message()

            if json_file:
                data = {'t': time.time(),
                        'msg': (msg.frame.ident,
                                hexlify(msg.frame.data).decode('ascii')), }
                json_file(data)

            click.echo(repr(msg))

            num_captured += 1
Exemple #29
0
 def borrow_with(self, stack: contextlib.ExitStack, task: Task) -> None:
     stack.enter_context(self.function.borrow(task))
     self.data.borrow_with(stack, task)
class GameEngine(LoggingMixin):
    """
    The core component of :mod:`ppb`.

    To use the engine directly, treat it as a context manager: ::

       with GameEngine(BaseScene, **kwargs) as ge:
           ge.run()
    """
    def __init__(self, first_scene: Type, *,
                 basic_systems=(Renderer, Updater, EventPoller, SoundController, AssetLoadingSystem),
                 systems=(), scene_kwargs=None, **kwargs):
        """
        :param first_scene: A :class:`~ppb.BaseScene` type.
        :type first_scene: Type
        :param basic_systems: :class:systemslib.Systems that are considered
           the "default". Includes: :class:`~systems.Renderer`,
           :class:`~systems.Updater`, :class:`~systems.EventPoller`,
           :class:`~systems.SoundController`, :class:`~systems.AssetLoadingSystem`.
        :type basic_systems: Iterable[systemslib.System]
        :param systems: Additional user defined systems.
        :type systems: Iterable[systemslib.System]
        :param scene_kwargs: Keyword arguments passed along to the first scene.
        :type scene_kwargs: Dict[str, Any]
        :param kwargs: Additional keyword arguments. Passed to the systems.

        .. warning::
           Passing in your own ``basic_systems`` can have unintended
           consequences. Consider passing via systems parameter instead.
        """

        super(GameEngine, self).__init__()

        # Engine Configuration
        self.first_scene = first_scene
        self.scene_kwargs = scene_kwargs or {}
        self.kwargs = kwargs

        # Engine State
        self.scenes = []
        self.events = deque()
        self.event_extensions: DefaultDict[Union[Type, _ellipsis], List[Callable[[Any], None]]] = defaultdict(list)
        self.running = False
        self.entered = False
        self._last_idle_time = None

        # Systems
        self.systems_classes = list(chain(basic_systems, systems))
        self.systems = []
        self.exit_stack = ExitStack()

    @property
    def current_scene(self):
        """
        The top of the scene stack.

        :return: The currently running scene.
        :rtype: ppb.BaseScene
        """
        try:
            return self.scenes[-1]
        except IndexError:
            return None

    def __enter__(self):
        self.logger.info("Entering context")
        self.start_systems()
        self.entered = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.logger.info("Exiting context")
        self.entered = False
        self.exit_stack.close()

    def start_systems(self):
        """Initialize and enter the systems."""
        if self.systems:
            return
        for system in self.systems_classes:
            if isinstance(system, type):
                system = system(engine=self, **self.kwargs)
            self.systems.append(system)
            self.exit_stack.enter_context(system)

    def run(self):
        """
        Begin the main loop.

        If you have not entered the :class:`GameEngine`, this function will
        enter it for you before starting.

        Example: ::

           GameEngine(BaseScene, **kwargs).run()
        """
        if not self.entered:
            with self:
                self.start()
                self.main_loop()
        else:
            self.start()
            self.main_loop()

    def start(self):
        """
        Starts the engine.

        Called by :meth:`GameEngine.run` before :meth:`GameEngine.main_loop`.

        You shouldn't call this yourself unless you're embedding :mod:`ppb` in
        another event loop.
        """
        self.running = True
        self._last_idle_time = time.monotonic()
        self.activate({"scene_class": self.first_scene,
                       "kwargs": self.scene_kwargs})

    def main_loop(self):
        """
        Loop forever.

        If you're embedding :mod:`ppb` in an external event loop you should not
        use this method. Call :meth:`GameEngine.loop_once` instead.
        """
        while self.running:
            time.sleep(0)
            self.loop_once()

    def loop_once(self):
        """
        Iterate once.

        If you're embedding :mod:`ppb` in an external event loop call once per
        loop.
        """
        if not self.entered:
            raise ValueError("Cannot run before things have started",
                             self.entered)
        now = time.monotonic()
        self.signal(events.Idle(now - self._last_idle_time))
        self._last_idle_time = now
        while self.events:
            self.publish()

    def activate(self, next_scene: dict):
        """
        Instantiates and sets up a new scene.

        :param next_scene: A dictionary with the keys:

           * "scene_class": A :class:`~ppb.BaseScene` type.
           * "args": A :class:`list` of positional arguments.
           * "kwargs": A :class:`dict` of keyword arguments.
        """
        scene = next_scene["scene_class"]
        if scene is None:
            return
        args = next_scene.get("args", [])
        kwargs = next_scene.get("kwargs", {})
        self._start_scene(scene(*args, **kwargs), None)

    def signal(self, event):
        """
        Add an event to the event queue.

        Thread-safe.

        You will rarely call this directly from a :class:`GameEngine` instance.
        The current :class:`GameEngine` instance will pass it's signal method
        as part of publishing an event.
        """
        self.events.append(event)

    def publish(self):
        """
        Publish the next event to every object in the tree.
        """
        event = self.events.popleft()
        scene = self.current_scene
        event.scene = scene
        extensions = chain(self.event_extensions[type(event)], self.event_extensions[...])

        # Hydrating extensions.
        for callback in extensions:
            callback(event)

        event_handler_name = _get_handler_name(type(event).__name__)
        for obj in self.walk():
            method = getattr(obj, event_handler_name, None)
            if callable(method):
                try:
                    method(event, self.signal)
                except TypeError as ex:
                    from inspect import signature
                    sig = signature(method)
                    try:
                        sig.bind(event, self.signal)
                    except TypeError:
                        raise BadEventHandlerException(obj, event_handler_name, event) from ex
                    else:
                        raise

    def on_start_scene(self, event: events.StartScene, signal: Callable[[Any], None]):
        """
        Start a new scene. The current scene pauses.

        Do not call this method directly. It is called by the GameEngine when a
        :class:`~events.StartScene` event is fired.
        """
        self._pause_scene()
        self._start_scene(event.new_scene, event.kwargs)

    def on_stop_scene(self, event: events.StopScene, signal: Callable[[Any], None]):
        """
        Stop a running scene. If there's a scene on the stack, it resumes.

        Do not call this method directly. It is called by the GameEngine when a
        :class:`~events.StopScene` event is fired.
        """
        self._stop_scene()
        if self.current_scene is not None:
            signal(events.SceneContinued())
        else:
            signal(events.Quit())

    def on_replace_scene(self, event: events.ReplaceScene, signal):
        """
        Replace the running scene with a new one.

        Do not call this method directly. It is called by the GameEngine when a
        :class:`~events.ReplaceScene` event is fired.
        """
        self._stop_scene()
        self._start_scene(event.new_scene, event.kwargs)

    def on_quit(self, quit_event: events.Quit, signal: Callable[[Any], None]):
        """
        Shut down the event loop.

        Do not call this method directly. It is called by the GameEngine when a
        :class:`~events.Quit` event is fired.
        """
        self.running = False

    def _pause_scene(self):
        """Pause the current scene."""
        # Empty the queue before changing scenes.
        self._flush_events()
        self.signal(events.ScenePaused())
        self.publish()

    def _stop_scene(self):
        """Stop the current scene."""
        # Empty the queue before changing scenes.
        self._flush_events()
        self.signal(events.SceneStopped())
        self.publish()
        self.scenes.pop()

    def _start_scene(self, scene, kwargs):
        """Start a scene."""
        if isinstance(scene, type):
            scene = scene(**(kwargs or {}))
        self.scenes.append(scene)
        self.signal(events.SceneStarted())

    def register(self, event_type: Union[Type, _ellipsis], callback: Callable[[], Any]):
        """
        Register a callback to be applied to an event at time of publishing.

        Primarily to be used by subsystems.

        The callback will receive the event. Your code should modify the event
        in place. It does not need to return it.

        :param event_type: The class of an event.
        :param callback: A callable, must accept an event, and return no value.
        :return: None
        """
        if not isinstance(event_type, type) and event_type is not ...:
            raise TypeError(f"{type(self)}.register requires event_type to be a type.")
        if not callable(callback):
            raise TypeError(f"{type(self)}.register requires callback to be callable.")
        self.event_extensions[event_type].append(callback)

    def _flush_events(self):
        """
        Flush the event queue.

        Call before doing anything that will cause signals to be delivered to
        the wrong scene.
        """
        self.events = deque()

    def walk(self):
        """
        Walk the object tree.

        Publication order: The :class:`GameEngine`, the
        :class:`~ppb.systemslib.System` list, the current
        :class:`~ppb.BaseScene`, then finally the :class:`~ppb.Sprite` objects
        in the current scene.
        """
        yield self
        yield from self.systems
        yield self.current_scene
        if self.current_scene is not None:
            yield from self.current_scene
Exemple #31
0
class Configuration:
    def __init__(self, directory=None):
        self._set_defaults()
        # Because the configuration object is a global singleton, it makes for
        # a convenient place to stash information used by widely separate
        # components.  For example, this is a placeholder for rendezvous
        # between the downloader and the D-Bus service.  When running under
        # D-Bus and we get a `paused` signal from the download manager, we need
        # this to plumb through an UpdatePaused signal to our clients.  It
        # rather sucks that we need a global for this, but I can't get the
        # plumbing to work otherwise.  This seems like the least horrible place
        # to stash this global.
        self.dbus_service = None
        # These are used to plumb command line arguments from the main() to
        # other parts of the system.
        self.skip_gpg_verification = False
        self.override_gsm = False
        # Cache.
        self._device = None
        self._build_number = None
        self.build_number_override = False
        self._channel = None
        # This is used only to override the phased percentage via command line
        # and the property setter.
        self._phase_override = None
        self._tempdir = None
        self.config_d = None
        self.ini_files = []
        self.http_base = None
        self.https_base = None
        if directory is not None:
            self.load(directory)
        self._calculate_http_bases()
        self._resources = ExitStack()
        self._stats = DeviceStats()
        atexit.register(self._resources.close)

    def _set_defaults(self):
        self.service = Bag(
            base='system-image.ubports.com',
            http_port=80,
            https_port=443,
            channel='daily',
            build_number=0,
        )
        self.system = Bag(
            timeout=as_timedelta('1h'),
            tempdir='/tmp',
            logfile='/var/log/system-image/client.log',
            loglevel=as_loglevel('info'),
            settings_db='/var/lib/system-image/settings.db',
        )
        self.gpg = Bag(
            archive_master='/usr/share/system-image/archive-master.tar.xz',
            image_master='/var/lib/system-image/keyrings/image-master.tar.xz',
            image_signing='/var/lib/system-image/keyrings/image-signing.tar.xz',
            device_signing=
            '/var/lib/system-image/keyrings/device-signing.tar.xz',
        )
        self.updater = Bag(
            cache_partition='/android/cache/recovery',
            data_partition='/var/lib/system-image',
        )
        self.hooks = Bag(
            device=as_object('systemimage.device.SystemProperty'),
            scorer=as_object('systemimage.scores.WeightedScorer'),
            apply=as_object('systemimage.apply.Reboot'),
        )
        self.dbus = Bag(lifetime=as_timedelta('10m'), )

    def _load_file(self, path):
        parser = SafeConfigParser()
        str_path = str(path)
        parser.read(str_path)
        self.ini_files.append(path)
        self.service.update(converters=dict(
            http_port=as_port,
            https_port=as_port,
            build_number=int,
            device=as_stripped,
        ),
                            **parser['service'])
        self.system.update(converters=dict(timeout=as_timedelta,
                                           loglevel=as_loglevel,
                                           settings_db=expand_path,
                                           tempdir=expand_path),
                           **parser['system'])
        self.gpg.update(**parser['gpg'])
        self.updater.update(**parser['updater'])
        self.hooks.update(converters=dict(device=as_object,
                                          scorer=as_object,
                                          apply=as_object),
                          **parser['hooks'])
        self.dbus.update(converters=dict(lifetime=as_timedelta),
                         **parser['dbus'])

    def load(self, directory):
        """Load up the configuration from a config.d directory."""
        # Look for all the files in the given directory with .ini or .cfg
        # suffixes.  The files must start with a number, and the files are
        # loaded in numeric order.
        if self.config_d is not None:
            raise RuntimeError('Configuration already loaded; use .reload()')
        self.config_d = directory
        if not Path(directory).is_dir():
            raise TypeError(
                '.load() requires a directory: {}'.format(directory))
        candidates = []
        for child in Path(directory).glob('*.ini'):
            order, _, base = child.stem.partition('_')
            # XXX 2014-10-03: The logging system isn't initialized when we get
            # here, so we can't log that these files are being ignored.
            if len(_) == 0:
                continue
            try:
                serial = int(order)
            except ValueError:
                continue
            candidates.append((serial, child))
        for serial, path in sorted(candidates):
            self._load_file(path)
        self._calculate_http_bases()

    def reload(self):
        """Reload the configuration directory."""
        # Reset some cached attributes.
        directory = self.config_d
        self.ini_files = []
        self.config_d = None
        self._build_number = None
        # Now load the defaults, then reload the previous config.d directory.
        self._set_defaults()
        self.load(directory)

    def _calculate_http_bases(self):
        if (self.service.http_port is NO_PORT
                and self.service.https_port is NO_PORT):
            raise ValueError('Cannot disable both http and https ports')
        # Construct the HTTP and HTTPS base urls, which most applications will
        # actually use.  We do this in two steps, in order to support disabling
        # one or the other (but not both) protocols.
        if self.service.http_port == 80:
            http_base = 'http://{}'.format(self.service.base)
        elif self.service.http_port is NO_PORT:
            http_base = None
        else:
            http_base = 'http://{}:{}'.format(self.service.base,
                                              self.service.http_port)
        # HTTPS.
        if self.service.https_port == 443:
            https_base = 'https://{}'.format(self.service.base)
        elif self.service.https_port is NO_PORT:
            https_base = None
        else:
            https_base = 'https://{}:{}'.format(self.service.base,
                                                self.service.https_port)
        # Sanity check and final settings.
        if http_base is None:
            assert https_base is not None
            http_base = https_base
        if https_base is None:
            assert http_base is not None
            https_base = http_base
        self.http_base = http_base
        self.https_base = https_base

    @property
    def build_number(self):
        if self._build_number is None:
            self._build_number = self.service.build_number
        return self._build_number

    @build_number.setter
    def build_number(self, value):
        if not isinstance(value, int):
            raise ValueError('integer is required, got: {}'.format(
                type(value).__name__))
        self._build_number = value
        self.build_number_override = True

    @build_number.deleter
    def build_number(self):
        self._build_number = None

    @property
    def device(self):
        if self._device is None:
            # Start by looking for a [service]device setting.  Use this if it
            # exists, otherwise fall back to calling the hook.
            self._device = getattr(self.service, 'device', None)
            if not self._device:
                self._device = self.hooks.device().get_device()
        return self._device

    @device.setter
    def device(self, value):
        self._device = value

    @property
    def channel(self):
        if self._channel is None:
            self._channel = self.service.channel
        return self._channel

    @channel.setter
    def channel(self, value):
        self._channel = value

    @property
    def phase_override(self):
        return self._phase_override

    @phase_override.setter
    def phase_override(self, value):
        self._phase_override = max(0, min(100, int(value)))

    @phase_override.deleter
    def phase_override(self):
        self._phase_override = None

    @property
    def tempdir(self):
        if self._tempdir is None:
            makedirs(self.system.tempdir)
            self._tempdir = self._resources.enter_context(
                temporary_directory(prefix='system-image-',
                                    dir=self.system.tempdir))
        return self._tempdir

    @property
    def session(self):
        return self._stats.getSessionId()

    @property
    def instance(self):
        return self._stats.getInstanceId()

    @property
    def user_agent(self):
        return USER_AGENT.format(self)
Exemple #32
0
class Workspace(IWorkspace):
    """An IWorkspace that maintains a fixed list of origins, loading repositorylocations
    for all of them on initialization."""
    def __init__(self, workspace_load_target, grpc_server_registry=None):
        self._stack = ExitStack()

        from .cli_target import WorkspaceLoadTarget

        self._workspace_load_target = check.opt_inst_param(
            workspace_load_target, "workspace_load_target",
            WorkspaceLoadTarget)

        if grpc_server_registry:
            self._grpc_server_registry = check.inst_param(
                grpc_server_registry, "grpc_server_registry",
                GrpcServerRegistry)
        else:
            self._grpc_server_registry = self._stack.enter_context(
                ProcessGrpcServerRegistry(reload_interval=0, heartbeat_ttl=30))

        self._load_workspace()

    def _load_workspace(self):
        repository_location_origins = (
            self._workspace_load_target.create_origins()
            if self._workspace_load_target else [])

        self._location_origin_dict = OrderedDict()
        check.list_param(
            repository_location_origins,
            "repository_location_origins",
            of_type=RepositoryLocationOrigin,
        )

        self._location_dict = {}
        self._location_error_dict = {}
        for origin in repository_location_origins:
            check.invariant(
                self._location_origin_dict.get(origin.location_name) is None,
                'Cannot have multiple locations with the same name, got multiple "{name}"'
                .format(name=origin.location_name, ),
            )

            self._location_origin_dict[origin.location_name] = origin
            self._load_location(origin.location_name)

    # Can be overidden in subclasses that need different logic for loading repository
    # locations from origins
    def create_location_from_origin(self, origin):
        if not self._grpc_server_registry.supports_origin(origin):
            return origin.create_location()
        else:
            endpoint = (self._grpc_server_registry.reload_grpc_endpoint(origin)
                        if self._grpc_server_registry.supports_reload else
                        self._grpc_server_registry.get_grpc_endpoint(origin))

            return GrpcServerRepositoryLocation(
                origin=origin,
                server_id=endpoint.server_id,
                port=endpoint.port,
                socket=endpoint.socket,
                host=endpoint.host,
                heartbeat=True,
                watch_server=False,
                grpc_server_registry=self._grpc_server_registry,
            )

    def _load_location(self, location_name):
        if self._location_dict.get(location_name):
            del self._location_dict[location_name]

        if self._location_error_dict.get(location_name):
            del self._location_error_dict[location_name]

        origin = self._location_origin_dict[location_name]
        try:
            location = self.create_location_from_origin(origin)
            self._location_dict[location_name] = location
        except Exception:  # pylint: disable=broad-except
            error_info = serializable_error_info_from_exc_info(sys.exc_info())
            self._location_error_dict[location_name] = error_info
            warnings.warn(
                "Error loading repository location {location_name}:{error_string}"
                .format(location_name=location_name,
                        error_string=error_info.to_string()))

    def create_snapshot(self):
        return WorkspaceSnapshot(self._location_origin_dict.copy(),
                                 self._location_error_dict.copy())

    @property
    def repository_locations(self):
        return list(self._location_dict.values())

    def has_repository_location(self, location_name):
        check.str_param(location_name, "location_name")
        return location_name in self._location_dict

    def get_repository_location(self, location_name):
        check.str_param(location_name, "location_name")
        return self._location_dict[location_name]

    def has_repository_location_error(self, location_name):
        check.str_param(location_name, "location_name")
        return location_name in self._location_error_dict

    def get_repository_location_error(self, location_name):
        check.str_param(location_name, "location_name")
        return self._location_error_dict[location_name]

    def reload_repository_location(self, location_name):
        self._load_location(location_name)

    def reload_workspace(self):
        for location in self.repository_locations:
            location.cleanup()
        self._load_workspace()

    def get_location(self, origin):
        location_name = origin.location_name
        if self.has_repository_location(location_name):
            return self.get_repository_location(location_name)
        elif self.has_repository_location_error(location_name):
            error_info = self.get_repository_location_error(location_name)
            raise DagsterRepositoryLocationLoadError(
                f"Failure loading {location_name}: {error_info.to_string()}",
                load_error_infos=[error_info],
            )
        else:
            raise DagsterInvariantViolationError(
                f"Location {location_name} does not exist in workspace")

    def __enter__(self):
        return self

    def __exit__(self, exception_type, exception_value, traceback):
        for location in self.repository_locations:
            location.cleanup()
        self._stack.close()
Exemple #33
0
class BasePod:
    """A BasePod is a immutable set of peas, which run in parallel. They share the same input and output socket.
    Internally, the peas can run with the process/thread backend. They can be also run in their own containers
    """
    def __init__(self, args: Union['argparse.Namespace', Dict]):
        """

        :param args: arguments parsed from the CLI
        """
        self.peas = []
        self.is_head_router = False
        self.is_tail_router = False
        self.deducted_head = None
        self.deducted_tail = None
        if hasattr(args, 'polling') and args.polling.is_push:
            # ONLY reset when it is push
            args.reducing_yaml_path = '_forward'
        self._args = args
        self.peas_args = self._parse_args(args)

    @property
    def is_idle(self) -> bool:
        """A Pod is idle when all its peas are idle, see also :attr:`jina.peapods.pea.Pea.is_idle`.
        """
        return all(p.is_idle for p in self.peas if p.is_ready.is_set())

    def close_if_idle(self):
        """Check every second if the pod is in idle, if yes, then close the pod"""
        while True:
            if self.is_idle:
                self.close()
                break  # only run once
            time.sleep(1)

    @property
    def name(self) -> str:
        """The name of this :class:`BasePod`. """
        return self.peas_args['peas'][0].name

    @property
    def port_grpc(self) -> int:
        """Get the grpc port number """
        return self.peas_args['peas'][0].port_grpc

    @property
    def host(self) -> str:
        """Get the grpc host name """
        return self.peas_args['peas'][0].host

    def _parse_args(self, args):
        peas_args = {'head': None, 'tail': None, 'peas': []}

        if getattr(args, 'replicas', 1) > 1:
            # reasons to separate head and tail from peas is that they
            # can be deducted based on the previous and next pods
            peas_args['head'] = _copy_to_head_args(args, args.polling.is_push)
            peas_args['tail'] = _copy_to_tail_args(args)
            peas_args['peas'] = _set_peas_args(args, peas_args['head'],
                                               peas_args['tail'])
            self.is_head_router = True
            self.is_tail_router = True
        else:
            peas_args['peas'] = [args]

        # note that peas_args['peas'][0] exist either way and carries the original property
        return peas_args

    @property
    def head_args(self):
        """Get the arguments for the `head` of this BasePod. """
        if self.is_head_router and self.peas_args['head']:
            return self.peas_args['head']
        elif not self.is_head_router and len(self.peas_args['peas']) == 1:
            return self.peas_args['peas'][0]
        elif self.deducted_head:
            return self.deducted_head
        else:
            raise ValueError(
                'ambiguous head node, maybe it is deducted already?')

    @head_args.setter
    def head_args(self, args):
        """Set the arguments for the `head` of this BasePod. """
        if self.is_head_router and self.peas_args['head']:
            self.peas_args['head'] = args
        elif not self.is_head_router and len(self.peas_args['peas']) == 1:
            self.peas_args['peas'][0] = args
        elif self.deducted_head:
            self.deducted_head = args
        else:
            raise ValueError(
                'ambiguous head node, maybe it is deducted already?')

    @property
    def tail_args(self):
        """Get the arguments for the `tail` of this BasePod. """
        if self.is_tail_router and self.peas_args['tail']:
            return self.peas_args['tail']
        elif not self.is_tail_router and len(self.peas_args['peas']) == 1:
            return self.peas_args['peas'][0]
        elif self.deducted_tail:
            return self.deducted_tail
        else:
            raise ValueError(
                'ambiguous tail node, maybe it is deducted already?')

    @tail_args.setter
    def tail_args(self, args):
        """Get the arguments for the `tail` of this BasePod. """
        if self.is_tail_router and self.peas_args['tail']:
            self.peas_args['tail'] = args
        elif not self.is_tail_router and len(self.peas_args['peas']) == 1:
            self.peas_args['peas'][0] = args
        elif self.deducted_tail:
            self.deducted_tail = args
        else:
            raise ValueError(
                'ambiguous tail node, maybe it is deducted already?')

    @property
    def all_args(self):
        """Get all arguments of all Peas in this BasePod. """
        return self.peas_args['peas'] + (
            [self.peas_args['head']] if self.peas_args['head'] else
            []) + ([self.peas_args['tail']] if self.peas_args['tail'] else [])

    @property
    def num_peas(self) -> int:
        """Get the number of running :class:`BasePea`"""
        return len(self.peas)

    def __eq__(self, other: 'BasePod'):
        return self.num_peas == other.num_peas and self.name == other.name

    def set_runtime(self, runtime: str):
        """Set the parallel runtime of this BasePod.

        :param runtime: possible values: process, thread
        """
        for s in self.all_args:
            s.runtime = runtime
            # for thread and process backend which runs locally, host_in and host_out should not be set
            # s.host_in = __default_host__
            # s.host_out = __default_host__

    def start_sentinels(self):
        self.sentinel_threads = []
        if isinstance(self._args, argparse.Namespace) and getattr(
                self._args, 'shutdown_idle', False):
            self.sentinel_threads.append(
                Thread(target=self.close_if_idle,
                       name='sentinel-shutdown-idle',
                       daemon=True))
        for t in self.sentinel_threads:
            t.start()

    def start(self):
        """Start to run all Peas in this BasePod.

        Remember to close the BasePod with :meth:`close`.

        Note that this method has a timeout of ``timeout_ready`` set in CLI,
        which is inherited from :class:`jina.peapods.peas.BasePea`
        """
        self.stack = ExitStack()
        # start head and tail
        if self.peas_args['head']:
            p = BasePea(self.peas_args['head'])
            self.peas.append(p)
            self.stack.enter_context(p)

        if self.peas_args['tail']:
            p = BasePea(self.peas_args['tail'])
            self.peas.append(p)
            self.stack.enter_context(p)

        # start real peas and accumulate the storage id
        if len(self.peas_args['peas']) > 1:
            start_rep_id = 1
            role = PeaRoleType.REPLICA
        else:
            start_rep_id = 0
            role = PeaRoleType.SINGLETON
        for idx, _args in enumerate(self.peas_args['peas'],
                                    start=start_rep_id):
            _args.replica_id = idx
            _args.role = role
            p = Pea(_args, allow_remote=False)
            self.peas.append(p)
            self.stack.enter_context(p)

        self.start_sentinels()
        return self

    @property
    def log_iterator(self):
        """Get the last log using iterator

        The :class:`BasePod` log iterator goes through all peas :attr:`log_iterator` and
        poll them sequentially. If non all them is active anymore, aka :attr:`is_event_loop`
        is False, then the iterator ends.

        .. warning::

            The log may not strictly follow the time order given that we are polling the log
            from all peas in the sequential manner.
        """
        from ..logging.queue import __log_queue__
        while not self.is_shutdown:
            try:
                yield __log_queue__.get_nowait()
            except Empty:
                pass

    @property
    def is_shutdown(self) -> bool:
        return all(not p.is_ready.is_set() for p in self.peas)

    def __enter__(self):
        return self.start()

    @property
    def status(self) -> List:
        """The status of a BasePod is the list of status of all its Peas """
        return [p.status for p in self.peas]

    def is_ready(self) -> bool:
        """Wait till the ready signal of this BasePod.

        The pod is ready only when all the contained Peas returns is_ready
        """
        for p in self.peas:
            p.is_ready.wait()
        return True

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def join(self):
        """Wait until all peas exit"""
        try:
            for s in self.peas:
                s.join()
        except KeyboardInterrupt:
            pass
        finally:
            self.peas.clear()

    def close(self):
        self.stack.close()
Exemple #34
0
class TestMain(TestCase):
    def setUp(self):
        super().setUp()
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        # Capture builtin print() output.
        self._stdout = StringIO()
        self._stderr = StringIO()
        self._resources.enter_context(
            patch('argparse._sys.stdout', self._stdout))
        # Capture stderr since this is where argparse will spew to.
        self._resources.enter_context(
            patch('argparse._sys.stderr', self._stderr))

    def test_help(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--help',))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(lines[0].startswith('Usage'),
                        lines[0])
        self.assertTrue(lines[1].startswith('  ubuntu-image'),
                        lines[1])

    def test_debug(self):
        with ExitStack() as resources:
            mock = resources.enter_context(
                patch('ubuntu_image.__main__.logging.basicConfig'))
            resources.enter_context(patch(
                'ubuntu_image.__main__.ModelAssertionBuilder',
                EarlyExitModelAssertionBuilder))
            # Prevent actual main() from running.
            resources.enter_context(patch('ubuntu_image.__main__.main'))
            code = main(('--debug', 'model.assertion'))
        self.assertEqual(code, 0)
        mock.assert_called_once_with(level=logging.DEBUG)

    def test_no_debug(self):
        with ExitStack() as resources:
            mock = resources.enter_context(
                patch('ubuntu_image.__main__.logging.basicConfig'))
            resources.enter_context(patch(
                'ubuntu_image.__main__.ModelAssertionBuilder',
                EarlyExitModelAssertionBuilder))
            # Prevent actual main() from running.
            resources.enter_context(patch('ubuntu_image.__main__.main'))
            code = main(('model.assertion',))
        self.assertEqual(code, 0)
        mock.assert_not_called()

    def test_state_machine_exception(self):
        with ExitStack() as resources:
            resources.enter_context(patch(
                'ubuntu_image.__main__.ModelAssertionBuilder',
                CrashingModelAssertionBuilder))
            mock = resources.enter_context(patch(
                'ubuntu_image.__main__._logger.exception'))
            code = main(('model.assertion',))
            self.assertEqual(code, 1)
            self.assertEqual(
                mock.call_args_list[-1], call('Crash in state machine'))

    def test_state_machine_snap_command_fails(self):
        # The `snap prepare-image` command fails and main exits with non-zero.
        #
        # This tests needs to run the actual snap() helper function, not
        # the testsuite-wide mock.  This is appropriate since we're
        # mocking it ourselves here.
        if NosePlugin.snap_mocker is not None:
            NosePlugin.snap_mocker.patcher.stop()
            self._resources.callback(NosePlugin.snap_mocker.patcher.start)
        self._resources.enter_context(patch(
            'ubuntu_image.helpers.subprocess_run',
            return_value=SimpleNamespace(
                returncode=1,
                stdout='command stdout',
                stderr='command stderr',
                check_returncode=check_returncode,
                )))
        self._resources.enter_context(LogCapture())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            XXXModelAssertionBuilder))
        workdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(workdir, 'my-disk.img')
        code = main(('--until', 'prepare_filesystems',
                     '--channel', 'edge',
                     '--workdir', workdir,
                     '--output', imgfile,
                     'model.assertion'))
        self.assertEqual(code, 1)

    def test_no_arguments(self):
        with self.assertRaises(SystemExit) as cm:
            main(())
        self.assertEqual(cm.exception.code, 2)
        lines = self._stderr.getvalue().splitlines()
        self.assertTrue(
                lines[0].startswith('Warning: for backwards compatibility'),
                lines[0])
        self.assertTrue(lines[1], 'Usage:')
        self.assertEqual(
                lines[2],
                '  ubuntu-image COMMAND [OPTIONS]...')

    def test_with_none(self):
        with self.assertRaises(SystemExit) as cm:
            main((None))    # code coverage __main__.py 308-309
        self.assertEqual(cm.exception.code, 2)

    def test_snap_subcommand_help(self):
        with self.assertRaises(SystemExit) as cm:
            main(('snap', '--help',))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(
              lines[0].startswith('usage: ubuntu-image snap'),
              lines[0])

    def test_classic_subcommand_help(self):
        with self.assertRaises(SystemExit) as cm:
            main(('classic', '--help',))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(
              lines[0].startswith('usage: ubuntu-image classic'),
              lines[0])
Exemple #35
0
class TestOWCSVFileImport(WidgetTest):
    def setUp(self):
        self._stack = ExitStack().__enter__()
        # patch `_local_settings` to avoid side effects, across tests
        fname = self._stack.enter_context(named_file(""))
        s = QSettings(fname, QSettings.IniFormat)
        self._stack.enter_context(
            mock.patch.object(owcsvimport.OWCSVFileImport, "_local_settings",
                              lambda *a: s))
        self.widget = self.create_widget(owcsvimport.OWCSVFileImport)

    def tearDown(self):
        self.widgets.remove(self.widget)
        self.widget.onDeleteWidget()
        self.widget = None
        self._stack.close()

    def test_basic(self):
        w = self.widget
        w.activate_recent(0)
        w.cancel()

    data_regions_options = owcsvimport.Options(
        encoding="ascii",
        dialect=csv.excel_tab(),
        columntypes=[
            (range(0, 1), ColumnType.Categorical),
            (range(1, 2), ColumnType.Text),
            (range(2, 3), ColumnType.Categorical),
        ],
        rowspec=[
            (range(0, 1), RowSpec.Header),
            (range(1, 3), RowSpec.Skipped),
        ],
    )

    def _check_data_regions(self, table):
        self.assertEqual(len(table), 3)
        self.assertEqual(len(table), 3)
        self.assertTrue(table.domain["id"].is_discrete)
        self.assertTrue(table.domain["continent"].is_discrete)
        self.assertTrue(table.domain["state"].is_string)
        assert_array_equal(table.X, [[0, 1], [1, 1], [2, 0]])
        assert_array_equal(table.metas,
                           np.array([["UK"], ["Russia"], ["Mexico"]], object))

    def test_restore(self):
        dirname = os.path.dirname(__file__)
        path = os.path.join(dirname, "data-regions.tab")

        w = self.create_widget(owcsvimport.OWCSVFileImport,
                               stored_settings={
                                   "_session_items":
                                   [(path, self.data_regions_options.as_dict())
                                    ]
                               })
        item = w.current_item()
        self.assertEqual(item.path(), path)
        self.assertEqual(item.options(), self.data_regions_options)
        out = self.get_output("Data", w)
        self._check_data_regions(out)

    def test_restore_from_local(self):
        dirname = os.path.dirname(__file__)
        path = os.path.join(dirname, "data-regions.tab")
        s = owcsvimport.OWCSVFileImport._local_settings()
        s.clear()
        QSettings_writeArray(
            s, "recent",
            [{
                "path": path,
                "options": json.dumps(self.data_regions_options.as_dict())
            }])
        w = self.create_widget(owcsvimport.OWCSVFileImport, )
        item = w.current_item()
        self.assertEqual(item.path(), path)
        self.assertEqual(item.options(), self.data_regions_options)
        self.assertEqual(
            w._session_items,
            [(path, self.data_regions_options.as_dict())],
            "local settings item must be recorded in _session_items when "
            "activated in __init__",
        )
        self._check_data_regions(self.get_output("Data", w))

    def test_summary(self):
        """Check if status bar is updated when data is received"""
        dirname = os.path.dirname(__file__)
        path = os.path.join(dirname, "data-regions.tab")
        widget = self.create_widget(owcsvimport.OWCSVFileImport,
                                    stored_settings={
                                        "_session_items":
                                        [(path,
                                          self.data_regions_options.as_dict())]
                                    })
        output_sum = widget.info.set_output_summary = mock.Mock()
        widget.commit()
        self.wait_until_finished(widget)
        output = self.get_output("Data", widget)
        output_sum.assert_called_with(len(output),
                                      format_summary_details(output))
Exemple #36
0
def test_simultaneous_ramp_mode_resets_individual_axis_ramp_rates_if_blocking_ramp(
        current_driver, caplog, request):
    ami3d = current_driver

    ami3d.cartesian((0.0, 0.0, 0.0))

    restore_parameters_stack = ExitStack()
    request.addfinalizer(restore_parameters_stack.close)

    restore_parameters_stack.callback(ami3d.cartesian, (0.0, 0.0, 0.0))

    restore_parameters_stack.enter_context(
        ami3d._instrument_x.ramp_rate.restore_at_exit())
    restore_parameters_stack.enter_context(
        ami3d._instrument_y.ramp_rate.restore_at_exit())
    restore_parameters_stack.enter_context(
        ami3d._instrument_z.ramp_rate.restore_at_exit())

    restore_parameters_stack.enter_context(
        ami3d.ramp_mode.set_to("simultaneous"))

    restore_parameters_stack.enter_context(
        ami3d.block_during_ramp.set_to(True))

    with caplog.at_level(logging.DEBUG, logger="qcodes.instrument.base"):

        # Set individual ramp rates to known values
        ami3d._instrument_x.ramp_rate(0.09)
        ami3d._instrument_y.ramp_rate(0.10)
        ami3d._instrument_z.ramp_rate(0.11)

        ami3d.vector_ramp_rate(0.05)

        # Initiate the simultaneous ramp
        ami3d.cartesian((0.5, 0.5, 0.5))

        # Assert the individual axes ramp rates were reverted
        # to the known values set earlier
        assert ami3d._instrument_x.ramp_rate() == 0.09
        assert ami3d._instrument_y.ramp_rate() == 0.10
        assert ami3d._instrument_z.ramp_rate() == 0.11

    messages = [record.message for record in caplog.records]

    expected_log_fragment = "Restoring individual axes ramp rates"
    messages_with_expected_fragment = tuple(
        message for message in messages if expected_log_fragment in message)
    assert (len(messages_with_expected_fragment) == 1
            ), f"found: {messages_with_expected_fragment}"

    expected_log_fragment_2 = "Simultaneous ramp: blocking until ramp is finished"
    messages_with_expected_fragment_2 = tuple(
        message for message in messages if expected_log_fragment_2 in message)
    assert (len(messages_with_expected_fragment_2) == 1
            ), f"found: {messages_with_expected_fragment_2}"

    unexpected_log_fragment = "Simultaneous ramp: not blocking until ramp is finished"
    messages_with_unexpected_fragment = tuple(
        message for message in messages if unexpected_log_fragment in message)
    assert (len(messages_with_unexpected_fragment) == 0
            ), f"found: {messages_with_unexpected_fragment}"
Exemple #37
0
def test_simultaneous_ramp_mode_does_not_reset_individual_axis_ramp_rates_if_nonblocking_ramp(
        current_driver, caplog, request):
    ami3d = current_driver

    ami3d.cartesian((0.0, 0.0, 0.0))

    restore_parameters_stack = ExitStack()
    request.addfinalizer(restore_parameters_stack.close)

    restore_parameters_stack.callback(ami3d.cartesian, (0.0, 0.0, 0.0))

    restore_parameters_stack.enter_context(
        ami3d._instrument_x.ramp_rate.restore_at_exit())
    restore_parameters_stack.enter_context(
        ami3d._instrument_y.ramp_rate.restore_at_exit())
    restore_parameters_stack.enter_context(
        ami3d._instrument_z.ramp_rate.restore_at_exit())

    restore_parameters_stack.enter_context(
        ami3d.ramp_mode.set_to("simultaneous"))

    restore_parameters_stack.enter_context(
        ami3d.block_during_ramp.set_to(False))

    # Set individual ramp rates to known values
    ami3d._instrument_x.ramp_rate(0.09)
    ami3d._instrument_y.ramp_rate(0.10)
    ami3d._instrument_z.ramp_rate(0.11)

    ami3d.vector_ramp_rate(0.05)

    with caplog.at_level(logging.DEBUG, logger="qcodes.instrument.base"):

        # Initiate the simultaneous ramp
        ami3d.cartesian((0.5, 0.5, 0.5))

        # Assert the individual axes ramp rates were changed and not reverted
        # to the known values set earlier
        assert ami3d._instrument_x.ramp_rate() != 0.09
        assert ami3d._instrument_y.ramp_rate() != 0.10
        assert ami3d._instrument_z.ramp_rate() != 0.11

        # Assert the expected values of the ramp rates of the individual axes
        # set by the simultaneous ramp based on the vector_ramp_rate and the
        # setpoint magnetic field
        expected_ramp_rate = pytest.approx(
            0.5 / np.linalg.norm(ami3d.cartesian(), ord=2) *
            ami3d.vector_ramp_rate())
        assert ami3d._instrument_x.ramp_rate() == expected_ramp_rate
        assert ami3d._instrument_y.ramp_rate() == expected_ramp_rate
        assert ami3d._instrument_z.ramp_rate() == expected_ramp_rate

    messages = [record.message for record in caplog.records]

    expected_log_fragment = "Simultaneous ramp: not blocking until ramp is finished"
    messages_with_expected_fragment = tuple(
        message for message in messages if expected_log_fragment in message)
    assert (len(messages_with_expected_fragment) == 1
            ), f"found: {messages_with_expected_fragment}"

    unexpected_log_fragment = "Restoring individual axes ramp rates"
    messages_with_unexpected_fragment = tuple(
        message for message in messages if unexpected_log_fragment in message)
    assert (len(messages_with_unexpected_fragment) == 0
            ), f"found: {messages_with_unexpected_fragment}"

    # However, calling ``wait_while_all_axes_ramping`` DOES restore the
    # individual ramp rates

    with caplog.at_level(logging.DEBUG, logger="qcodes.instrument.base"):
        ami3d.wait_while_all_axes_ramping()

    messages_2 = [record.message for record in caplog.records]

    expected_log_fragment_2 = "Restoring individual axes ramp rates"
    messages_with_expected_fragment_2 = tuple(
        message for message in messages_2
        if expected_log_fragment_2 in message)
    assert (len(messages_with_expected_fragment_2) == 1
            ), f"found: {messages_with_expected_fragment_2}"

    # Assert calling ``wait_while_all_axes_ramping`` is possible

    ami3d.wait_while_all_axes_ramping()
Exemple #38
0
class CleanValueTests(DbTestCase):
    def setUp(self):
        super().setUp()
        self.exit_stack = ExitStack()
        self.basedir = self.exit_stack.enter_context(tempdir_context())

    def tearDown(self):
        self.exit_stack.close()
        super().tearDown()

    def _render_context(
        self,
        *,
        wf_module_id=None,
        input_table=None,
        tab_results={},
        params={},
        exit_stack=None,
    ) -> RenderContext:
        if exit_stack is None:
            exit_stack = self.exit_stack
        return RenderContext(
            wf_module_id=wf_module_id,
            input_table=input_table,
            tab_results=tab_results,
            basedir=self.basedir,
            exit_stack=exit_stack,
            params=params,
        )

    def test_clean_float(self):
        result = clean_value(ParamDType.Float(), 3.0, None)
        self.assertEqual(result, 3.0)
        self.assertIsInstance(result, float)

    def test_clean_float_with_int_value(self):
        # ParamDType.Float can have `int` values (because values come from
        # json.parse(), which only gives Numbers so can give "3" instead of
        # "3.0". We want to pass that as `float` in the `params` dict.
        result = clean_value(ParamDType.Float(), 3, None)
        self.assertEqual(result, 3.0)
        self.assertIsInstance(result, float)

    def test_clean_file_none(self):
        result = clean_value(ParamDType.File(), None, None)
        self.assertEqual(result, None)

    def test_clean_file_happy_path(self):
        workflow = Workflow.create_and_init()
        tab = workflow.tabs.first()
        step = tab.wf_modules.create(module_id_name="uploadfile",
                                     order=0,
                                     slug="step-1")
        id = str(uuid.uuid4())
        key = f"wf-${workflow.id}/wfm-${step.id}/${id}"
        minio.put_bytes(minio.UserFilesBucket, key, b"1234")
        UploadedFile.objects.create(
            wf_module=step,
            name="x.csv.gz",
            size=4,
            uuid=id,
            bucket=minio.UserFilesBucket,
            key=key,
        )
        with ExitStack() as inner_stack:
            context = self._render_context(wf_module_id=step.id,
                                           exit_stack=inner_stack)
            result: Path = clean_value(ParamDType.File(), id, context)
            self.assertIsInstance(result, Path)
            self.assertEqual(result.read_bytes(), b"1234")
            self.assertEqual(result.suffixes, [".csv", ".gz"])

        # Assert that once `exit_stack` goes out of scope, file is deleted
        self.assertFalse(result.exists())

    def test_clean_file_no_uploaded_file(self):
        workflow = Workflow.create_and_init()
        tab = workflow.tabs.first()
        step = tab.wf_modules.create(module_id_name="uploadfile",
                                     order=0,
                                     slug="step-1")
        context = self._render_context(wf_module_id=step.id)
        result = clean_value(ParamDType.File(), str(uuid.uuid4()), context)
        self.assertIsNone(result)
        # Assert that if a temporary file was created to house the download, it
        # no longer exists.
        self.assertListEqual(list(self.basedir.iterdir()), [])

    def test_clean_file_no_minio_file(self):
        workflow = Workflow.create_and_init()
        tab = workflow.tabs.first()
        step = tab.wf_modules.create(module_id_name="uploadfile",
                                     order=0,
                                     slug="step-1")
        step2 = tab.wf_modules.create(module_id_name="uploadfile",
                                      order=1,
                                      slug="step-2")
        id = str(uuid.uuid4())
        key = f"wf-${workflow.id}/wfm-${step.id}/${id}"
        # Oops -- let's _not_ put the file!
        # minio.put_bytes(minio.UserFilesBucket, key, b'1234')
        UploadedFile.objects.create(
            wf_module=step2,
            name="x.csv.gz",
            size=4,
            uuid=id,
            bucket=minio.UserFilesBucket,
            key=key,
        )
        context = self._render_context(wf_module_id=step.id)
        result = clean_value(ParamDType.File(), id, context)
        self.assertIsNone(result)
        # Assert that if a temporary file was created to house the download, it
        # no longer exists.
        self.assertListEqual(list(self.basedir.iterdir()), [])

    def test_clean_file_wrong_wf_module(self):
        workflow = Workflow.create_and_init()
        tab = workflow.tabs.first()
        step = tab.wf_modules.create(module_id_name="uploadfile",
                                     order=0,
                                     slug="step-1")
        step2 = tab.wf_modules.create(module_id_name="uploadfile",
                                      order=1,
                                      slug="step-2")
        id = str(uuid.uuid4())
        key = f"wf-${workflow.id}/wfm-${step.id}/${id}"
        minio.put_bytes(minio.UserFilesBucket, key, b"1234")
        UploadedFile.objects.create(
            wf_module=step2,
            name="x.csv.gz",
            size=4,
            uuid=id,
            bucket=minio.UserFilesBucket,
            key=key,
        )
        context = self._render_context(wf_module_id=step.id)
        result = clean_value(ParamDType.File(), id, context)
        self.assertIsNone(result)
        # Assert that if a temporary file was created to house the download, it
        # no longer exists.
        self.assertListEqual(list(self.basedir.iterdir()), [])

    def test_clean_normal_dict(self):
        context = self._render_context()
        schema = ParamDType.Dict({
            "str": ParamDType.String(),
            "int": ParamDType.Integer()
        })
        value = {"str": "foo", "int": 3}
        expected = dict(value)  # no-op
        result = clean_value(schema, value, context)
        self.assertEqual(result, expected)

    def test_clean_column_valid(self):
        context = self._render_context(input_table=arrow_table({"A": [1]}))
        result = clean_value(ParamDType.Column(), "A", context)
        self.assertEqual(result, "A")

    def test_clean_column_prompting_error_convert_to_text(self):
        # TODO make this _automatic_ instead of quick-fix?
        # Consider Regex. We probably want to pass the module a text Series
        # _separately_ from the input DataFrame. That way Regex can output
        # a new Text column but preserve its input column's data type.
        #
        # ... but for now: prompt for a Quick Fix.
        context = self._render_context(input_table=arrow_table({"A": [1]}))
        with self.assertRaises(PromptingError) as cm:
            clean_value(ParamDType.Column(column_types=frozenset({"text"})),
                        "A", context)

        self.assertEqual(
            cm.exception.errors,
            [PromptingError.WrongColumnType(["A"], None, frozenset({"text"}))],
        )

    def test_clean_column_prompting_error_convert_to_number(self):
        context = self._render_context(input_table=arrow_table({"A": ["1"]}))
        with self.assertRaises(PromptingError) as cm:
            clean_value(ParamDType.Column(column_types=frozenset({"number"})),
                        "A", context)
        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A"], "text",
                                               frozenset({"number"}))
            ],
        )

    def test_list_prompting_error_concatenate_same_type(self):
        context = self._render_context(input_table=arrow_table({
            "A": ["1"],
            "B": ["2"]
        }))
        schema = ParamDType.List(inner_dtype=ParamDType.Column(
            column_types=frozenset({"number"})))
        with self.assertRaises(PromptingError) as cm:
            clean_value(schema, ["A", "B"], context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A", "B"], "text",
                                               frozenset({"number"}))
            ],
        )

    def test_list_prompting_error_concatenate_different_type(self):
        context = self._render_context(input_table=arrow_table({
            "A": ["1"],
            "B":
            pa.array([datetime.now()], pa.timestamp("ns"))
        }))
        schema = ParamDType.List(inner_dtype=ParamDType.Column(
            column_types=frozenset({"number"})))
        with self.assertRaises(PromptingError) as cm:
            clean_value(schema, ["A", "B"], context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A"], "text",
                                               frozenset({"number"})),
                PromptingError.WrongColumnType(["B"], "datetime",
                                               frozenset({"number"})),
            ],
        )

    def test_list_prompting_error_concatenate_different_type_to_text(self):
        context = self._render_context(input_table=arrow_table({
            "A": [1],
            "B":
            pa.array([datetime.now()], pa.timestamp("ns"))
        }))
        schema = ParamDType.List(inner_dtype=ParamDType.Column(
            column_types=frozenset({"text"})))
        with self.assertRaises(PromptingError) as cm:
            clean_value(schema, ["A", "B"], context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A", "B"], None,
                                               frozenset({"text"}))
            ],
        )

    def test_dict_prompting_error(self):
        context = self._render_context(input_table=arrow_table({
            "A": ["a"],
            "B": ["b"]
        }))
        schema = ParamDType.Dict({
            "col1":
            ParamDType.Column(column_types=frozenset({"number"})),
            "col2":
            ParamDType.Column(column_types=frozenset({"datetime"})),
        })
        with self.assertRaises(PromptingError) as cm:
            clean_value(schema, {"col1": "A", "col2": "B"}, context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A"], "text",
                                               frozenset({"number"})),
                PromptingError.WrongColumnType(["B"], "text",
                                               frozenset({"datetime"})),
            ],
        )

    def test_dict_prompting_error_concatenate_same_type(self):
        context = self._render_context(input_table=arrow_table({
            "A": ["1"],
            "B": ["2"]
        }))
        schema = ParamDType.Dict({
            "x":
            ParamDType.Column(column_types=frozenset({"number"})),
            "y":
            ParamDType.Column(column_types=frozenset({"number"})),
        })
        with self.assertRaises(PromptingError) as cm:
            clean_value(schema, {"x": "A", "y": "B"}, context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A", "B"], "text",
                                               frozenset({"number"}))
            ],
        )

    def test_dict_prompting_error_concatenate_different_types(self):
        context = self._render_context(input_table=arrow_table({
            "A": ["1"],
            "B":
            pa.array([datetime.now()], pa.timestamp("ns"))
        }))
        schema = ParamDType.Dict({
            "x":
            ParamDType.Column(column_types=frozenset({"number"})),
            "y":
            ParamDType.Column(column_types=frozenset({"number"})),
        })
        with self.assertRaises(PromptingError) as cm:
            clean_value(schema, {"x": "A", "y": "B"}, context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A"], "text",
                                               frozenset({"number"})),
                PromptingError.WrongColumnType(["B"], "datetime",
                                               frozenset({"number"})),
            ],
        )

    def test_clean_column_missing_becomes_empty_string(self):
        context = self._render_context(input_table=arrow_table({"A": [1]}))
        result = clean_value(ParamDType.Column(), "B", context)
        self.assertEqual(result, "")

    def test_clean_multicolumn_valid(self):
        context = self._render_context(input_table=arrow_table({
            "A": [1],
            "B": [2]
        }))
        result = clean_value(ParamDType.Multicolumn(), ["A", "B"], context)
        self.assertEqual(result, ["A", "B"])

    def test_clean_multicolumn_sort_in_table_order(self):
        context = self._render_context(input_table=arrow_table({
            "B": [1],
            "A": [2]
        }))
        result = clean_value(ParamDType.Multicolumn(), ["A", "B"], context)
        self.assertEqual(result, ["B", "A"])

    def test_clean_multicolumn_prompting_error_convert_to_text(self):
        # TODO make this _automatic_ instead of quick-fix?
        # ... but for now: prompt for a Quick Fix.
        context = self._render_context(input_table=arrow_table({
            "A": [1],
            "B":
            pa.array([datetime.now()], pa.timestamp("ns")),
            "C": ["x"],
        }))
        with self.assertRaises(PromptingError) as cm:
            schema = ParamDType.Multicolumn(column_types=frozenset({"text"}))
            clean_value(schema, ["A", "B"], context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A", "B"], None,
                                               frozenset({"text"}))
            ],
        )

    def test_clean_multicolumn_missing_is_removed(self):
        context = self._render_context(input_table=arrow_table({
            "A": [1],
            "B": [1]
        }))
        result = clean_value(ParamDType.Multicolumn(), ["A", "X", "B"],
                             context)
        self.assertEqual(result, ["A", "B"])

    def test_clean_multichartseries_missing_is_removed(self):
        context = self._render_context(input_table=arrow_table({
            "A": [1],
            "B": [1]
        }))
        value = [
            {
                "column": "A",
                "color": "#aaaaaa"
            },
            {
                "column": "C",
                "color": "#cccccc"
            },
        ]
        result = clean_value(ParamDType.Multichartseries(), value, context)
        self.assertEqual(result, [{"column": "A", "color": "#aaaaaa"}])

    def test_clean_multichartseries_non_number_is_prompting_error(self):
        context = self._render_context(input_table=arrow_table({
            "A": ["a"],
            "B":
            pa.array([datetime.now()], pa.timestamp("ns"))
        }))
        value = [
            {
                "column": "A",
                "color": "#aaaaaa"
            },
            {
                "column": "B",
                "color": "#cccccc"
            },
        ]
        with self.assertRaises(PromptingError) as cm:
            clean_value(ParamDType.Multichartseries(), value, context)

        self.assertEqual(
            cm.exception.errors,
            [
                PromptingError.WrongColumnType(["A"], "text",
                                               frozenset({"number"})),
                PromptingError.WrongColumnType(["B"], "datetime",
                                               frozenset({"number"})),
            ],
        )

    def test_clean_tab_happy_path(self):
        tab = Tab("tab-1", "Tab 1")
        table = arrow_table({"A": [1, 2]})
        context = self._render_context(tab_results={tab: RenderResult(table)})
        result = clean_value(ParamDType.Tab(), "tab-1", context)
        self.assertEqual(result, TabOutput(tab, table))

    def test_clean_multicolumn_from_other_tab(self):
        tab2 = Tab("tab-2", "Tab 2")
        tab2_output_table = arrow_table({"A-from-tab-2": [1, 2]})

        schema = ParamDType.Dict({
            "tab":
            ParamDType.Tab(),
            "columns":
            ParamDType.Multicolumn(tab_parameter="tab"),
        })
        params = {"tab": "tab-2", "columns": ["A-from-tab-1", "A-from-tab-2"]}
        context = self._render_context(
            input_table=arrow_table({"A-from-tab-1": [1]}),
            tab_results={tab2: RenderResult(tab2_output_table)},
            params=params,
        )
        result = clean_value(schema, params, context)
        # result['tab'] is not what we're testing here
        self.assertEqual(result["columns"], ["A-from-tab-2"])

    def test_clean_multicolumn_from_other_tab_that_does_not_exist(self):
        # The other tab would not exist if the user selected and then deleted
        # it.
        schema = ParamDType.Dict({
            "tab":
            ParamDType.Tab(),
            "columns":
            ParamDType.Multicolumn(tab_parameter="tab"),
        })
        params = {"tab": "tab-missing", "columns": ["A-from-tab-1"]}
        context = self._render_context(
            input_table=arrow_table({"A-from-tab-1": [1]}),
            tab_results={},
            params=params,
        )
        result = clean_value(schema, params, context)
        # result['tab'] is not what we're testing here
        self.assertEqual(result["columns"], [])

    def test_clean_tab_no_tab_selected_gives_none(self):
        context = self._render_context(tab_results={})
        result = clean_value(ParamDType.Tab(), "", context)
        self.assertEqual(result, None)

    def test_clean_tab_missing_tab_selected_gives_none(self):
        """
        If the user has selected a nonexistent tab, pretend tab is blank.

        JS sees nonexistent tab slugs. render() doesn't.
        """
        context = self._render_context(tab_results={})
        result = clean_value(ParamDType.Tab(), "tab-XXX", context)
        self.assertEqual(result, None)

    def test_clean_tab_cycle(self):
        tab = Tab("tab-1", "Tab 1")
        context = self._render_context(tab_results={tab: None})
        with self.assertRaises(TabCycleError):
            clean_value(ParamDType.Tab(), "tab-1", context)

    def test_clean_tab_unreachable(self):
        tab = Tab("tab-error", "Buggy Tab")
        context = self._render_context(tab_results={tab: RenderResult()})
        with self.assertRaises(TabOutputUnreachableError):
            clean_value(ParamDType.Tab(), "tab-error", context)

    def test_clean_tabs_happy_path(self):
        tab2 = Tab("tab-2", "Tab 2")
        tab2_output = arrow_table({"B": [1]})
        tab3 = Tab("tab-3", "Tab 3")
        tab3_output = arrow_table({"C": [1]})

        context = self._render_context(tab_results={
            tab2: RenderResult(tab2_output),
            tab3: RenderResult(tab3_output),
        })
        result = clean_value(ParamDType.Multitab(), ["tab-2", "tab-3"],
                             context)
        self.assertEqual(
            result,
            [TabOutput(tab2, tab2_output),
             TabOutput(tab3, tab3_output)])

    def test_clean_tabs_preserve_ordering(self):
        tab2 = Tab("tab-2", "Tab 2")
        tab2_output = arrow_table({"B": [1]})
        tab3 = Tab("tab-3", "Tab 3")
        tab3_output = arrow_table({"C": [1]})

        context = self._render_context(
            # RenderContext's dict ordering determines desired tab order.
            # (Python 3.7 spec: dict is ordered in insertion order. CPython 3.6
            # and PyPy 7 do this, too.)
            tab_results={
                tab3: RenderResult(tab3_output),
                tab2: RenderResult(tab2_output),
            })
        # Supply wrongly-ordered tabs; renderprep should reorder them.
        result = clean_value(ParamDType.Multitab(), ["tab-2", "tab-3"],
                             context)
        self.assertEqual([t.tab.slug for t in result], ["tab-3", "tab-2"])

    def test_clean_tabs_nix_missing_tab(self):
        context = self._render_context(tab_results={})
        result = clean_value(ParamDType.Multitab(), ["tab-missing"], context)
        self.assertEqual(result, [])

    def test_clean_tabs_tab_cycle(self):
        tab = Tab("tab-1", "Tab 1")
        context = self._render_context(tab_results={tab: None})
        with self.assertRaises(TabCycleError):
            clean_value(ParamDType.Multitab(), ["tab-1"], context)

    def test_clean_tabs_tab_unreachable(self):
        tab = Tab("tab-1", "Tab 1")
        context = self._render_context(tab_results={tab: RenderResult()})
        with self.assertRaises(TabOutputUnreachableError):
            clean_value(ParamDType.Multitab(), ["tab-1"], context)
class TestClassicBuilder(TestCase):
    # XXX These tests relies on external resources, namely that the rootfs can
    # actually be downloaded from the ubuntu archive
    # That's a test isolation bug and a potential source of test
    # brittleness. We should fix this.
    #
    # XXX These tests also requires root, because `lb build`
    # currently requires it.

    def setUp(self):
        self._resources = ExitStack()
        # Mock out the check_root_privilege call
        self._resources.enter_context(
            patch('ubuntu_image.classic_builder.check_root_privilege'))
        self.addCleanup(self._resources.close)
        self.gadget_tree = resource_filename(
            'ubuntu_image.tests.data', 'gadget_tree')

    def test_prepare_gadget_tree_locally(self):
        # Run the action classic builder through the steps needed to
        # at least call `snapcraft prime`.
        # To create pc-boot.img and pc-core.img, we need to fetch
        # packages like grub-pc-bin, shim-signed from ubuntu archive,
        # even if gadget tree is placed locally on the machine.
        workdir = self._resources.enter_context(TemporaryDirectory())
        args = SimpleNamespace(
            project='ubuntu-cpc',
            suite='xenial',
            arch='amd64',
            image_format='img',
            output=None,
            subproject=None,
            subarch=None,
            output_dir=None,
            workdir=workdir,
            cloud_init=None,
            with_with_proposed=None,
            extra_ppas=None,
            hooks_directory=[],
            gadget_tree=self.gadget_tree,
            filesystem=None,
            )
        state = self._resources.enter_context(XXXClassicBuilder(args))
        gadget_dir = os.path.join(workdir, 'unpack', 'gadget')
        state.run_thru('prepare_gadget_tree')
        files = [
            '{gadget_dir}/grub-cpc.cfg',
            '{gadget_dir}/grubx64.efi',
            '{gadget_dir}/pc-boot.img',
            '{gadget_dir}/pc-core.img',
            '{gadget_dir}/shim.efi.signed',
            '{gadget_dir}/meta/gadget.yaml',
            ]
        # Check if all needed bootloader bits are in place.
        for filename in files:
            path = filename.format(
                gadget_dir=gadget_dir,
                )
            self.assertTrue(os.path.exists(path), path)

    def test_fs_contents(self):
        # Run the action classic builder through the steps needed to
        # at least call `lb config && lb build`.
        output = self._resources.enter_context(NamedTemporaryFile())
        workdir = self._resources.enter_context(TemporaryDirectory())
        unpackdir = os.path.join(workdir, 'unpack')
        mock = LiveBuildMocker(unpackdir)
        args = SimpleNamespace(
            project='ubuntu-cpc',
            suite='xenial',
            arch='amd64',
            image_format='img',
            output=output.name,
            subproject='subproject',
            subarch='subarch',
            output_dir=None,
            workdir=workdir,
            cloud_init=None,
            with_proposed='1',
            extra_ppas='******',
            hooks_directory=[],
            gadget_tree=self.gadget_tree,
            filesystem=None,
            )
        state = self._resources.enter_context(XXXClassicBuilder(args))
        # Mock out rootfs generation `live_build`
        # and create dummy top-level filesystem layout.
        self._resources.enter_context(
            patch('ubuntu_image.helpers.run', mock.run))
        state.run_thru('populate_bootfs_contents')
        # How does the root and boot file systems look?
        files = [
            '{boot}/EFI/boot/bootx64.efi',
            '{boot}/EFI/boot/grubx64.efi',
            '{boot}/EFI/ubuntu/grub.cfg',
            '{root}/boot/',
            ]
        for filename in files:
            path = filename.format(
                root=state.rootfs,
                boot=state.gadget.volumes['pc'].bootfs,
                )
            self.assertTrue(os.path.exists(path), path)
        # Simply check if all top-level files and folders exist.
        for dirname in DIRS_UNDER_ROOTFS:
            path = os.path.join(state.rootfs, dirname)
            self.assertTrue(os.path.exists(path), path)

    def test_populate_rootfs_contents_fstab_label(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                workdir=workdir,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                cloud_init=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_data,
                filesystem_label='writable',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.grub,
                schema=VolumeSchema.gpt,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Fake some state expected by the method under test.
            state.unpackdir = resources.enter_context(TemporaryDirectory())
            etc_path = os.path.join(state.unpackdir, 'chroot', 'etc')
            os.makedirs(etc_path)
            with open(os.path.join(etc_path, 'fstab'), 'w') as fp:
                fp.write('LABEL=cloudimg-rootfs   /    ext4   defaults    0 0')
            state.rootfs = resources.enter_context(TemporaryDirectory())
            # Jump right to the state method we're trying to test.
            state._next.pop()
            state._next.append(state.populate_rootfs_contents)
            next(state)
            # The seed metadata should exist.
            # And the filesystem label should be modified to 'writable'
            fstab_data = os.path.join(state.rootfs, 'etc', 'fstab')
            with open(fstab_data, 'r', encoding='utf-8') as fp:
                self.assertEqual(fp.read(), 'LABEL=writable   '
                                            '/    ext4   defaults    0 0')

    def test_populate_rootfs_contents_from_filesystem(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            args = SimpleNamespace(
                project=None,
                suite='xenial',
                arch='amd64',
                image_format='img',
                workdir=workdir,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                cloud_init=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_data,
                filesystem_label='writable',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.grub,
                schema=VolumeSchema.gpt,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Fake some state expected by the method under test.
            args.filesystem = resources.enter_context(TemporaryDirectory())
            etc_path = os.path.join(args.filesystem, 'etc')
            os.makedirs(etc_path)
            with open(os.path.join(etc_path, 'fstab'), 'w') as fp:
                fp.write('LABEL=cloudimg-rootfs   /    ext4   defaults    0 0')
            state.rootfs = resources.enter_context(TemporaryDirectory())
            # Jump right to the state method we're trying to test.
            state._next.pop()
            state._next.append(state.populate_rootfs_contents)
            next(state)
            # The seed metadata should exist.
            # And the filesystem label should be modified to 'writable'
            fstab_data = os.path.join(state.rootfs, 'etc', 'fstab')
            with open(fstab_data, 'r', encoding='utf-8') as fp:
                self.assertEqual(fp.read(), 'LABEL=writable   '
                                            '/    ext4   defaults    0 0')

    def test_populate_rootfs_contents_empty_fstab_entry(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                workdir=workdir,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                cloud_init=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_data,
                filesystem_label='writable',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.grub,
                schema=VolumeSchema.gpt,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Fake some state expected by the method under test.
            state.unpackdir = resources.enter_context(TemporaryDirectory())
            etc_path = os.path.join(state.unpackdir, 'chroot', 'etc')
            os.makedirs(etc_path)
            with open(os.path.join(etc_path, 'fstab'), 'w') as fp:
                pass
            state.rootfs = resources.enter_context(TemporaryDirectory())
            # Jump right to the state method we're trying to test.
            state._next.pop()
            state._next.append(state.populate_rootfs_contents)
            next(state)
            # And the filesystem label should be inserted if it doesn't exist.
            fstab_data = os.path.join(state.rootfs, 'etc', 'fstab')
            with open(fstab_data, 'r', encoding='utf-8') as fp:
                self.assertEqual(fp.read(), 'LABEL=writable   '
                                            '/    ext4   defaults    0 0')

    def test_populate_rootfs_contents_without_cloud_init(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            cloud_init = resources.enter_context(
                NamedTemporaryFile('w', encoding='utf-8'))
            print('cloud init user data', end='', flush=True, file=cloud_init)
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                workdir=workdir,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                cloud_init=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_boot,
                filesystem_label='system-boot',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.uboot,
                schema=VolumeSchema.mbr,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Fake some state expected by the method under test.
            state.unpackdir = resources.enter_context(TemporaryDirectory())
            os.makedirs(os.path.join(state.unpackdir, 'chroot'))
            state.rootfs = resources.enter_context(TemporaryDirectory())
            # Jump right to the state method we're trying to test.
            state._next.pop()
            state._next.append(state.populate_rootfs_contents)
            next(state)
            # The user data should not have been written and there should be
            # no metadata either.
            seed_path = os.path.join(
                state.rootfs, 'var', 'lib', 'cloud', 'seed', 'nocloud-net')
            self.assertFalse(os.path.exists(
                os.path.join(seed_path, 'user-data')))
            self.assertFalse(os.path.exists(
                os.path.join(seed_path, 'meta-data')))

    def test_populate_rootfs_contents_with_cloud_init(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            cloud_init = resources.enter_context(
                NamedTemporaryFile('w', encoding='utf-8'))
            print('cloud init user data', end='', flush=True, file=cloud_init)
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                cloud_init=cloud_init.name,
                workdir=workdir,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_boot,
                filesystem_label='system-boot',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.uboot,
                schema=VolumeSchema.mbr,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Fake some state expected by the method under test.
            state.unpackdir = resources.enter_context(TemporaryDirectory())
            os.makedirs(os.path.join(state.unpackdir, 'chroot'))
            state.rootfs = resources.enter_context(TemporaryDirectory())
            # Jump right to the state method we're trying to test.
            state._next.pop()
            state._next.append(state.populate_rootfs_contents)
            next(state)
            # Both the user data and the seed metadata should exist.
            seed_path = os.path.join(
                state.rootfs,
                'var', 'lib', 'cloud', 'seed', 'nocloud-net')
            user_data = os.path.join(seed_path, 'user-data')
            meta_data = os.path.join(seed_path, 'meta-data')
            with open(user_data, 'r', encoding='utf-8') as fp:
                self.assertEqual(fp.read(), 'cloud init user data')
            with open(meta_data, 'r', encoding='utf-8') as fp:
                self.assertEqual(fp.read(), 'instance-id: nocloud-static\n')

    def test_populate_rootfs_contents_grub_boot_remove(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                workdir=workdir,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                cloud_init=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_boot,
                filesystem_label='system-boot',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.uboot,
                schema=VolumeSchema.mbr,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Fake some state expected by the method under test.
            state.unpackdir = resources.enter_context(TemporaryDirectory())
            os.makedirs(os.path.join(state.unpackdir, 'chroot'))
            state.rootfs = resources.enter_context(TemporaryDirectory())
            # Create some dummy files in the grub directory.
            grub_dir = os.path.join(state.rootfs, 'boot', 'grub')
            os.makedirs(grub_dir, exist_ok=True)
            grub_inside_dir = os.path.join(grub_dir, 'dir')
            os.makedirs(grub_inside_dir, exist_ok=True)
            grub_file = os.path.join(grub_dir, 'test')
            open(grub_file, 'wb').close()
            # Jump right to the state method we're trying to test.
            state._next.pop()
            state._next.append(state.populate_rootfs_contents)
            next(state)
            # /boot/grub should persist, but not the files inside
            self.assertTrue(os.path.exists(grub_dir))
            self.assertFalse(os.path.exists(grub_inside_dir))
            self.assertFalse(os.path.exists(grub_file))

    def test_populate_bootfs_contents(self):
        # This test provides coverage for populate_bootfs_contents() when a
        # volume's part is defined as an ext4 or vfat file system type.  In
        # that case, the part's contents are copied to the target directory.
        # There are two paths here: one where the contents are a directory and
        # the other where the contents are a file.  We can test both cases
        # here for full coverage.
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            unpackdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                unpackdir=unpackdir,
                workdir=workdir,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            state._next.pop()
            state._next.append(state.populate_bootfs_contents)
            # Now we have to craft enough of gadget definition to drive the
            # method under test.  The two paths (is-a-file and is-a-directory)
            # are differentiated by whether the source ends in a slash or not.
            # In that case, the target must also end in a slash.
            contents1 = SimpleNamespace(
                source='as.dat',
                target='at.dat',
                )
            contents2 = SimpleNamespace(
                source='bs/',
                target='bt/',
                )
            part = SimpleNamespace(
                role=None,
                filesystem_label='not a boot',
                filesystem=FileSystemType.ext4,
                content=[contents1, contents2],
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.grub,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            # Since we're not running make_temporary_directories(), just set
            # up some additional expected state.
            state.unpackdir = unpackdir
            prep_state(state, workdir)
            # Run the method, the testable effects of which copy all the files
            # in the source directory (i.e. <unpackdir>/gadget/<source>) into
            # the target directory (i.e. <workdir>/part0).  So put some
            # contents into the source locations.
            gadget_dir = os.path.join(unpackdir, 'gadget')
            os.makedirs(gadget_dir)
            src = os.path.join(gadget_dir, 'as.dat')
            with open(src, 'wb') as fp:
                fp.write(b'01234')
            src = os.path.join(gadget_dir, 'bs')
            os.makedirs(src)
            # Put a couple of files and a directory in the source, since
            # directories are copied recursively.
            with open(os.path.join(src, 'c.dat'), 'wb') as fp:
                fp.write(b'56789')
            srcdir = os.path.join(src, 'd')
            os.makedirs(srcdir)
            with open(os.path.join(srcdir, 'e.dat'), 'wb') as fp:
                fp.write(b'0abcd')
            # Run the state machine.
            next(state)
            # Did all the files and directories get copied?
            dstbase = os.path.join(workdir, 'volumes', 'volume1', 'part0')
            with open(os.path.join(dstbase, 'at.dat'), 'rb') as fp:
                self.assertEqual(fp.read(), b'01234')
            with open(os.path.join(dstbase, 'bt', 'c.dat'), 'rb') as fp:
                self.assertEqual(fp.read(), b'56789')
            with open(os.path.join(dstbase, 'bt', 'd', 'e.dat'), 'rb') as fp:
                self.assertEqual(fp.read(), b'0abcd')

    def test_bootloader_options_invalid(self):
        # This test provides coverage for populate_bootfs_contents() when the
        # bootloader has a bogus value.
        #
        # We don't want to run the entire state machine just for this test, so
        # we start by setting up enough of the environment for the method
        # under test to function.
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                workdir=workdir,
                debug=None,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            state._next.pop()
            state._next.append(state.populate_bootfs_contents)
            # Now we have to craft enough of gadget definition to drive the
            # method under test.
            part = SimpleNamespace(
                role=StructureRole.system_boot,
                filesystem_label='system-boot',
                filesystem=FileSystemType.none,
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader='bogus',
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir)
            # Don't blat to stderr.
            resources.enter_context(patch('ubuntu_image.state.log'))
            with self.assertRaises(ValueError) as cm:
                next(state)
            self.assertEqual(
                str(cm.exception),
                'Unsupported volume bootloader value: bogus')

    def test_populate_bootfs_contents_content_mismatch(self):
        # If a content source ends in a slash, so must the target.
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            unpackdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                unpackdir=unpackdir,
                workdir=workdir,
                debug=None,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            state = resources.enter_context(XXXClassicBuilder(args))
            state._next.pop()
            state._next.append(state.populate_bootfs_contents)
            # Now we have to craft enough of gadget definition to drive the
            # method under test.  The two paths (is-a-file and is-a-directory)
            # are differentiated by whether the source ends in a slash or not.
            # In that case, the target must also end in a slash.
            content1 = SimpleNamespace(
                source='bs/',
                # No slash!
                target='bt',
                )
            part = SimpleNamespace(
                role=StructureRole.system_boot,
                filesystem=FileSystemType.ext4,
                content=[content1],
                )
            volume = SimpleNamespace(
                structures=[part],
                bootloader=BootLoader.grub,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            # Since we're not running make_temporary_directories(), just set
            # up some additional expected state.
            state.unpackdir = unpackdir
            prep_state(state, workdir)
            # Run the state machine.  Don't blat to stderr.
            resources.enter_context(patch('ubuntu_image.state.log'))
            with self.assertRaises(ValueError) as cm:
                next(state)
            self.assertEqual(
                str(cm.exception), 'target must end in a slash: bt')

    def test_populate_filesystems_none_type(self):
        # We do a bit-wise copy when the file system has no type.
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            unpackdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                unpackdir=unpackdir,
                workdir=workdir,
                debug=None,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            # Jump right to the method under test.
            state = resources.enter_context(XXXClassicBuilder(args))
            state._next.pop()
            state._next.append(state.populate_filesystems)
            # Set up expected state.
            state.unpackdir = unpackdir
            state.images = os.path.join(workdir, '.images')
            os.makedirs(state.images)
            part0_img = os.path.join(state.images, 'part0.img')
            # Craft a gadget specification.
            contents1 = SimpleNamespace(
                image='image1.img',
                size=None,
                offset=None,
                )
            contents2 = SimpleNamespace(
                image='image2.img',
                size=23,
                offset=None,
                )
            contents3 = SimpleNamespace(
                image='image3.img',
                size=None,
                offset=None,
                )
            contents4 = SimpleNamespace(
                image='image4.img',
                size=None,
                offset=127,
                )
            part = SimpleNamespace(
                role=None,
                filesystem=FileSystemType.none,
                content=[contents1, contents2, contents3, contents4],
                size=150,
                )
            volume = SimpleNamespace(
                structures=[part],
                schema=VolumeSchema.gpt,
                )
            state.gadget = SimpleNamespace(
                volumes=dict(volume1=volume),
                )
            prep_state(state, workdir, [part0_img])
            # The source image.
            gadget_dir = os.path.join(unpackdir, 'gadget')
            os.makedirs(gadget_dir)
            with open(os.path.join(gadget_dir, 'image1.img'),
                      'wb') as fp:
                fp.write(b'\1' * 47)
            with open(os.path.join(gadget_dir, 'image2.img'),
                      'wb') as fp:
                fp.write(b'\2' * 19)
            with open(os.path.join(gadget_dir, 'image3.img'),
                      'wb') as fp:
                fp.write(b'\3' * 51)
            with open(os.path.join(gadget_dir, 'image4.img'),
                      'wb') as fp:
                fp.write(b'\4' * 11)
            # Mock out the mkfs.ext4 call, and we'll just test the contents
            # directory (i.e. what would go in the ext4 file system).
            resources.enter_context(
                patch('ubuntu_image.common_builder.mkfs_ext4'))
            next(state)
            # Check the contents of the part0 image file.
            with open(part0_img, 'rb') as fp:
                data = fp.read()
            self.assertEqual(
                data,
                b'\1' * 47 +
                b'\2' * 19 +
                # 23 (specified size) - 19 (actual size).
                b'\0' * 4 +
                b'\3' * 51 +
                # 127 (offset) - 121 (written byte count)
                b'\0' * 6 +
                b'\4' * 11 +
                # 150 (image size) - 138 (written byte count)
                b'\0' * 12)

    def test_live_build_command_fails(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            unpackdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                unpackdir=unpackdir,
                workdir=workdir,
                debug=False,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            # Jump right to the method under test.
            state = resources.enter_context(XXXClassicBuilder(args))
            state.unpackdir = unpackdir
            state._next.pop()
            state._next.append(state.prepare_image)
            resources.enter_context(patch(
                'ubuntu_image.helpers.subprocess_run',
                return_value=SimpleNamespace(
                    returncode=1,
                    stdout='command stdout',
                    stderr='command stderr',
                    check_returncode=check_returncode,
                    )))
            log_capture = resources.enter_context(LogCapture())
            next(state)
            self.assertEqual(state.exitcode, 1)
            # Note that there is no traceback in the output.
            self.assertEqual(log_capture.logs, [
                (logging.ERROR,
                 'COMMAND FAILED: dpkg -L livecd-rootfs | grep "auto$"'),
                (logging.ERROR, 'command stdout'),
                (logging.ERROR, 'command stderr'),
                ])

    def test_live_build_command_fails_debug(self):
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            unpackdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                unpackdir=unpackdir,
                workdir=workdir,
                debug=True,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=None,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            # Jump right to the method under test.
            state = resources.enter_context(XXXClassicBuilder(args))
            state.unpackdir = unpackdir
            state._next.pop()
            state._next.append(state.prepare_image)
            resources.enter_context(patch(
                'ubuntu_image.helpers.subprocess_run',
                return_value=SimpleNamespace(
                    returncode=1,
                    stdout='command stdout',
                    stderr='command stderr',
                    check_returncode=check_returncode,
                    )))
            log_capture = resources.enter_context(LogCapture())
            next(state)
            self.assertEqual(state.exitcode, 1)
            # Note that there is traceback in the output now.
            self.assertEqual(log_capture.logs, [
                (logging.ERROR,
                 'COMMAND FAILED: dpkg -L livecd-rootfs | grep "auto$"'),
                (logging.ERROR, 'command stdout'),
                (logging.ERROR, 'command stderr'),
                (logging.ERROR, 'Full debug traceback follows'),
                ('IMAGINE THE TRACEBACK HERE'),
                ])

    def test_live_build_pass_arguments(self):
        with ExitStack() as resources:
            argstoenv = {
                'project': 'PROJECT',
                'suite': 'SUITE',
                'arch': 'ARCH',
                'subproject': 'SUBPROJECT',
                'subarch': 'SUBARCH',
                'with_proposed': 'PROPOSED',
                }
            kwargs_skel = {
                'workdir': '/tmp',
                'output_dir': '/tmp',
                'hooks_directory': '/tmp',
                'output': None,
                'cloud_init': None,
                'gadget_tree': None,
                'unpackdir': None,
                'debug': None,
                'project': None,
                'suite': None,
                'arch': None,
                'subproject': None,
                'subarch': None,
                'with_proposed': None,
                'extra_ppas': None,
                'filesystem': None,
                }
            for arg, env in argstoenv.items():
                kwargs = dict(kwargs_skel)
                kwargs[arg] = 'test' if arg != 'with_proposed' else True
                args = SimpleNamespace(**kwargs)
                # Jump right to the method under test.
                state = resources.enter_context(XXXClassicBuilder(args))
                state._next.pop()
                state._next.append(state.prepare_image)
                mock = resources.enter_context(patch(
                    'ubuntu_image.classic_builder.live_build'))
                next(state)
                self.assertEqual(len(mock.call_args_list), 1)
                posargs, kwargs = mock.call_args_list[0]
                self.assertIn(env, posargs[1])
                self.assertEqual(
                    posargs[1][env],
                    'test' if arg != 'with_proposed' else '1')
            # The extra_ppas argument is actually a list, so it needs a
            # separate test-case.
            outputtoinput = {
                'foo/bar': ['foo/bar'],
                'foo/bar foo/baz': ['foo/bar', 'foo/baz'],
            }
            for outputarg, inputarg in outputtoinput.items():
                kwargs = dict(kwargs_skel)
                kwargs['extra_ppas'] = inputarg
                args = SimpleNamespace(**kwargs)
                # Jump right to the method under test.
                state = resources.enter_context(XXXClassicBuilder(args))
                state._next.pop()
                state._next.append(state.prepare_image)
                mock = resources.enter_context(patch(
                    'ubuntu_image.classic_builder.live_build'))
                next(state)
                self.assertEqual(len(mock.call_args_list), 1)
                posargs, kwargs = mock.call_args_list[0]
                self.assertIn('EXTRA_PPAS', posargs[1])
                self.assertEqual(posargs[1]['EXTRA_PPAS'], outputarg)

    def test_filesystem_no_live_build_call(self):
        with ExitStack() as resources:
            argstoenv = {
                'project': 'PROJECT',
                'suite': 'SUITE',
                'arch': 'ARCH',
                'subproject': 'SUBPROJECT',
                'subarch': 'SUBARCH',
                'with_proposed': 'PROPOSED',
                'extra_ppas': '******',
                }
            kwargs_skel = {
                'workdir': '/tmp',
                'output_dir': '/tmp',
                'hooks_directory': '/tmp',
                'output': None,
                'cloud_init': None,
                'gadget_tree': None,
                'unpackdir': None,
                'debug': None,
                'project': None,
                'suite': None,
                'arch': None,
                'subproject': None,
                'subarch': None,
                'with_proposed': None,
                'extra_ppas': None,
                'filesystem': '/tmp/fs',
                }
            for arg, env in argstoenv.items():
                kwargs = dict(kwargs_skel)
                kwargs[arg] = 'test'
                args = SimpleNamespace(**kwargs)
                # Jump right to the method under test.
                state = resources.enter_context(XXXClassicBuilder(args))
                state._next.pop()
                state._next.append(state.prepare_image)
                mock = resources.enter_context(patch(
                    'ubuntu_image.classic_builder.live_build'))
                next(state)
                self.assertEqual(len(mock.call_args_list), 0)

    def test_generate_manifests_exclude(self):
        # This is not a full test of the manifest generation process as this
        # requires more preparation.  Here we try to see if deprecated words
        # are being removed from the manifest.
        with ExitStack() as resources:
            workdir = resources.enter_context(TemporaryDirectory())
            unpackdir = resources.enter_context(TemporaryDirectory())
            outputdir = resources.enter_context(TemporaryDirectory())
            # Fast forward a state machine to the method under test.
            args = SimpleNamespace(
                project='ubuntu-cpc',
                suite='xenial',
                arch='amd64',
                image_format='img',
                unpackdir=unpackdir,
                workdir=workdir,
                debug=True,
                cloud_init=None,
                output=None,
                subproject=None,
                subarch=None,
                output_dir=outputdir,
                with_proposed=None,
                extra_ppas=None,
                hooks_directory=[],
                gadget_tree=self.gadget_tree,
                filesystem=None,
                )
            # Jump right to the method under test.
            state = resources.enter_context(XXXClassicBuilder(args))
            state._next.pop()
            state._next.append(state.generate_manifests)
            # Set up expected state.
            state.rootfs = os.path.join(workdir, 'root')
            test_output = dedent("""\
                                 foo 1.1
                                 bar 3.12.3-0ubuntu1
                                 ubiquity 17.10.8
                                 baz 2.3
                                 casper 1.384
                                 """)

            def run_script(command, *, check=True, **args):
                stdout = args.pop('stdout', PIPE)
                stdout.write(test_output)
                stdout.flush()
            resources.enter_context(patch(
                'ubuntu_image.classic_builder.run',
                side_effect=run_script))
            next(state)
            manifest_path = os.path.join(outputdir, 'filesystem.manifest')
            self.assertTrue(os.path.exists(manifest_path))
            with open(manifest_path) as f:
                self.assertEqual(
                    f.read(),
                    dedent("""\
                           foo 1.1
                           bar 3.12.3-0ubuntu1
                           baz 2.3
                           """))
class SensorEvaluationContext:
    """Sensor execution context.

    An instance of this class is made available as the first argument to the evaluation function
    on SensorDefinition.

    Attributes:
        instance_ref (Optional[InstanceRef]): The serialized instance configured to run the schedule
        cursor (Optional[str]): The cursor, passed back from the last sensor evaluation via
            the cursor attribute of SkipReason and RunRequest
        last_completion_time (float): DEPRECATED The last time that the sensor was evaluated (UTC).
        last_run_key (str): DEPRECATED The run key of the RunRequest most recently created by this
            sensor. Use the preferred `cursor` attribute instead.
        repository_name (Optional[str]): The name of the repository that the sensor belongs to.
        instance (Optional[DagsterInstance]): The deserialized instance can also be passed in
            directly (primarily useful in testing contexts).
    """
    def __init__(
        self,
        instance_ref: Optional[InstanceRef],
        last_completion_time: Optional[float],
        last_run_key: Optional[str],
        cursor: Optional[str],
        repository_name: Optional[str],
        instance: Optional[DagsterInstance] = None,
    ):
        self._exit_stack = ExitStack()
        self._instance_ref = check.opt_inst_param(instance_ref, "instance_ref",
                                                  InstanceRef)
        self._last_completion_time = check.opt_float_param(
            last_completion_time, "last_completion_time")
        self._last_run_key = check.opt_str_param(last_run_key, "last_run_key")
        self._cursor = check.opt_str_param(cursor, "cursor")
        self._repository_name = check.opt_str_param(repository_name,
                                                    "repository_name")
        self._instance = check.opt_inst_param(instance, "instance",
                                              DagsterInstance)

    def __enter__(self):
        return self

    def __exit__(self, _exception_type, _exception_value, _traceback):
        self._exit_stack.close()

    @property
    def instance(self) -> DagsterInstance:
        # self._instance_ref should only ever be None when this SensorEvaluationContext was
        # constructed under test.
        if not self._instance:
            if not self._instance_ref:
                raise DagsterInvariantViolationError(
                    "Attempted to initialize dagster instance, but no instance reference was provided."
                )
            self._instance = self._exit_stack.enter_context(
                DagsterInstance.from_ref(self._instance_ref))
        return cast(DagsterInstance, self._instance)

    @property
    def last_completion_time(self) -> Optional[float]:
        return self._last_completion_time

    @property
    def last_run_key(self) -> Optional[str]:
        return self._last_run_key

    @property
    def cursor(self) -> Optional[str]:
        """The cursor value for this sensor, which was set in an earlier sensor evaluation."""
        return self._cursor

    def update_cursor(self, cursor: Optional[str]) -> None:
        """Updates the cursor value for this sensor, which will be provided on the context for the
        next sensor evaluation.

        This can be used to keep track of progress and avoid duplicate work across sensor
        evaluations.

        Args:
            cursor (Optional[str]):
        """
        self._cursor = check.opt_str_param(cursor, "cursor")

    @property
    def repository_name(self) -> Optional[str]:
        return self._repository_name
Exemple #41
0
    class OpenAIRemoteEnv(gym.Env):
        '''Base class for remote OpenAI gym compatible environments.

        By inherting from this class you can provide almost all of the 
        code necessary to register a remote Blender environment to 
        OpenAI gym.

        See the `examples/control/cartpole_gym` for details.

        Params
        ------
        version : str
            Version of this environment.
        '''
        metadata = {'render.modes': ['rgb_array', 'human']}

        def __init__(self,  version='0.0.1'):
            self.__version__ = version
            self._es = ExitStack()
            self._env = None

        def launch(self, scene, script, background=False, **kwargs):
            '''Launch the remote environment.

            Params
            ------
            scene: path, str
                Blender scene file
            script: path, str
                Python script containing environment implementation.
            background: bool
                Whether or not this environment can run in Blender background mode.
            kwargs: dict
                Any keyword arguments passes as command-line arguments
                to the remote environment. See `btt.env.launch_env` for
                details.
            '''
            assert not self._env, 'Environment already running.'
            self._env = self._es.enter_context(
                launch_env(
                    scene=scene,
                    script=script,
                    background=background,
                    **kwargs
                )
            )

        def step(self, action):
            '''Run one timestep of the environment's dynamics. When end of
            episode is reached, you are responsible for calling `reset()`
            to reset this environment's state.

            Accepts an action and returns a tuple (observation, reward, done, info).
            Note, this methods documentation is a 1:1 copy of OpenAI `gym.Env`.

            Params
            ------
            action: object
                An action provided by the agent

            Returns
            -------
            observation: object
                Agent's observation of the current environment
            reward: float
                Amount of reward returned after previous action
            done: bool
                Whether the episode has ended, in which case further step() calls will return undefined results
            info: (dict)
                Contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
            '''
            assert self._env, 'Environment not running.'
            obs, reward, done, info = self._env.step(action)
            return obs, reward, done, info

        def reset(self):
            '''Resets the state of the environment and returns an initial observation.

            Note, this methods documentation is a 1:1 copy of OpenAI `gym.Env`.

            Returns
            -------
            observation: object
                The initial observation.
            '''
            assert self._env, 'Environment not running.'
            obs, info = self._env.reset()
            return obs

        def seed(self, seed):
            ''''Sets the seed for this env's random number generator(s).'''
            raise NotImplementedError()

        def render(self, mode='human'):
            '''Renders the environment.

            Note, we consider Blender itself the main vehicle to view
            and manipulate the current environment state. Calling
            this method will usually render a specific camera view 
            in Blender, transmit its image and visualize it. This will
            only work, if the remote environment supports such an operation.
            '''
            assert self._env, 'Environment not running.'
            return self._env.render(mode=mode)

        @property
        def env_time(self):
            '''Returns the remote environment time.'''
            return self._env.env_time

        def close(self):
            '''Close the environment.'''
            if self._es:
                self._es.close()
                self._es = None
                self._env = None

        def __del__(self):
            self.close()
Exemple #42
0
from contextlib import ExitStack
from typing_extensions import assert_type


# See issue #7961
class Thing(ExitStack):
    pass


stack = ExitStack()
thing = Thing()
assert_type(stack.enter_context(Thing()), Thing)
assert_type(thing.enter_context(ExitStack()), ExitStack)

with stack as cm:
    assert_type(cm, ExitStack)
with thing as cm2:
    assert_type(cm2, Thing)
Exemple #43
0
class AbstractWebcamFilterApp(ABC):
    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.bodypix_model = None
        self.output_sink = None
        self.image_source = None
        self.image_iterator = None
        self.timer = LoggingTimer()
        self.masks: List[np.ndarray] = []
        self.exit_stack = ExitStack()
        self.bodypix_result_cache_time = None
        self.bodypix_result_cache = None

    @abstractmethod
    def get_output_image(self, image_array: np.ndarray) -> np.ndarray:
        pass

    def get_mask(self, *args, **kwargs):
        return get_mask(
            *args, masks=self.masks, timer=self.timer, args=self.args, **kwargs
        )

    def get_bodypix_result(self, image_array: np.ndarray) -> BodyPixResultWrapper:
        assert self.bodypix_model is not None
        current_time = time()
        if (
            self.bodypix_result_cache is not None
            and current_time < self.bodypix_result_cache_time + self.args.mask_cache_time
        ):
            return self.bodypix_result_cache
        self.bodypix_result_cache = self.bodypix_model.predict_single(image_array)
        self.bodypix_result_cache_time = current_time
        return self.bodypix_result_cache

    def __enter__(self):
        self.exit_stack.__enter__()
        self.bodypix_model = load_bodypix_model(self.args)
        self.output_sink = self.exit_stack.enter_context(get_output_sink(self.args))
        self.image_source = self.exit_stack.enter_context(get_image_source_for_args(self.args))
        self.image_iterator = iter(self.image_source)
        return self

    def __exit__(self, *args, **kwargs):
        self.exit_stack.__exit__(*args, **kwargs)

    def next_frame(self):
        self.timer.on_frame_start(initial_step_name='in')
        try:
            image_array = next(self.image_iterator)
        except StopIteration:
            return False
        self.timer.on_step_start('model')
        output_image = self.get_output_image(image_array)
        self.timer.on_step_start('out')
        self.output_sink(output_image)
        self.timer.on_frame_end()
        return True

    def run(self):
        try:
            self.timer.start()
            while self.next_frame():
                pass
            if self.args.show_output:
                LOGGER.info('waiting for window to be closed')
                while not self.output_sink.is_closed:
                    sleep(0.5)
        except KeyboardInterrupt:
            LOGGER.info('exiting')
Exemple #44
0
class ServiceManager:
    def __init__(self, service_cls, args):
        self.logger = set_logger(self.__class__.__name__, args.verbose)

        self.services = []  # type: List['BaseService']
        if args.num_parallel > 1:
            from .router import RouterService
            _head_router = copy.deepcopy(args)
            _head_router.port_ctrl = self._get_random_port()
            port_out = self._get_random_port()
            _head_router.port_out = port_out

            _tail_router = copy.deepcopy(args)
            port_in = self._get_random_port()
            _tail_router.port_in = port_in
            _tail_router.port_ctrl = self._get_random_port()

            _tail_router.socket_in = SocketType.PULL_BIND

            if args.parallel_type.is_push:
                _head_router.socket_out = SocketType.PUSH_BIND
            else:
                _head_router.socket_out = SocketType.PUB_BIND
                _head_router.yaml_path = resolve_yaml_path(
                    '!PublishRouter {parameters: {num_part: %d}}' % args.num_parallel)

            if args.parallel_type.is_block:
                _tail_router.yaml_path = resolve_yaml_path('BaseReduceRouter')
                _tail_router.num_part = args.num_parallel

            self.services.append(RouterService(_head_router))
            self.services.append(RouterService(_tail_router))

            for _ in range(args.num_parallel):
                _args = copy.deepcopy(args)
                _args.port_in = port_out
                _args.port_out = port_in
                _args.port_ctrl = self._get_random_port()
                _args.socket_out = SocketType.PUSH_CONNECT
                if args.parallel_type.is_push:
                    _args.socket_in = SocketType.PULL_CONNECT
                else:
                    _args.socket_in = SocketType.SUB_CONNECT
                self.services.append(service_cls(_args))
            self.logger.info('num_parallel=%d, add a router with port_in=%d and a router with port_out=%d' % (
                args.num_parallel, _head_router.port_in, _tail_router.port_out))
        else:
            self.services.append(service_cls(args))

    @staticmethod
    def _get_random_port(min_port: int = 49152, max_port: int = 65536) -> int:
        return random.randrange(min_port, max_port)

    def __enter__(self):
        self.stack = ExitStack()
        for s in self.services:
            self.stack.enter_context(s)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stack.close()

    def join(self):
        for s in self.services:
            s.join()
Exemple #45
0
class TestMainWithGadget(TestCase):
    def setUp(self):
        super().setUp()
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        # Capture builtin print() output.
        self._stdout = StringIO()
        self._stderr = StringIO()
        self._resources.enter_context(
            patch('argparse._sys.stdout', self._stdout))
        # Capture stderr since this is where argparse will spew to.
        self._resources.enter_context(
            patch('argparse._sys.stderr', self._stderr))
        # Set up a few other useful things for these tests.
        self._resources.enter_context(
            patch('ubuntu_image.__main__.logging.basicConfig'))
        self.model_assertion = resource_filename(
            'ubuntu_image.tests.data', 'model.assertion')
        self.classic_gadget_tree = resource_filename(
            'ubuntu_image.tests.data', 'gadget_tree')

    def test_output_without_subcommand(self):
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(tmpdir, 'my-disk.img')
        self.assertFalse(os.path.exists(imgfile))
        main(('--output', imgfile, self.model_assertion))
        self.assertTrue(os.path.exists(imgfile))

    def test_output(self):
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(tmpdir, 'my-disk.img')
        self.assertFalse(os.path.exists(imgfile))
        main(('snap', '--output', imgfile, self.model_assertion))
        self.assertTrue(os.path.exists(imgfile))

    def test_output_directory(self):
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        outputdir = os.path.join(tmpdir, 'images')
        main(('snap', '--output-dir', outputdir, self.model_assertion))
        self.assertTrue(os.path.exists(os.path.join(outputdir, 'pc.img')))

    def test_output_directory_multiple_images(self):
        class Builder(DoNothingBuilder):
            gadget_yaml = 'gadget-multi.yaml'
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            Builder))
        # Quiet the test suite.
        self._resources.enter_context(patch(
            'ubuntu_image.parser._logger.warning'))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        outputdir = os.path.join(tmpdir, 'images')
        main(('snap', '-O', outputdir, self.model_assertion))
        for name in ('first', 'second', 'third', 'fourth'):
            image_path = os.path.join(outputdir, '{}.img'.format(name))
            self.assertTrue(os.path.exists(image_path))

    def test_output_directory_multiple_images_image_file_list(self):
        class Builder(DoNothingBuilder):
            gadget_yaml = 'gadget-multi.yaml'
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            Builder))
        # Quiet the test suite.
        self._resources.enter_context(patch(
            'ubuntu_image.parser._logger.warning'))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        outputdir = os.path.join(tmpdir, 'images')
        image_file_list = os.path.join(tmpdir, 'ifl.txt')
        main(('snap', '-O', outputdir,
              '--image-file-list', image_file_list,
              self.model_assertion))
        with open(image_file_list, 'r', encoding='utf-8') as fp:
            img_files = set(line.rstrip() for line in fp.readlines())
        self.assertEqual(
            img_files,
            set(os.path.join(outputdir, '{}.img'.format(filename))
                for filename in ('first', 'second', 'third', 'fourth'))
            )

    def test_output_image_file_list(self):
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        # Quiet the test suite.
        self._resources.enter_context(patch(
            'ubuntu_image.parser._logger.warning'))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        output = os.path.join(tmpdir, 'pc.img')
        image_file_list = os.path.join(tmpdir, 'ifl.txt')
        main(('snap', '-o', output,
              '--image-file-list', image_file_list,
              self.model_assertion))
        with open(image_file_list, 'r', encoding='utf-8') as fp:
            img_files = set(line.rstrip() for line in fp.readlines())
        self.assertEqual(img_files, {output})

    def test_tmp_okay_for_classic_snap(self):
        # For reference see:
        # http://snapcraft.io/docs/reference/env
        self._resources.enter_context(envar('SNAP_NAME', 'crack-pop'))
        self._resources.enter_context(chdir('/tmp'))
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        code = main(('snap', '--output-dir', '/tmp/images',
                     '--extra-snaps', '/tmp/extra.snap',
                     '/tmp/model.assertion'))
        self.assertEqual(code, 0)
        self.assertTrue(os.path.exists('/tmp/images/pc.img'))

    def test_resume_and_model_assertion(self):
        with self.assertRaises(SystemExit) as cm:
            main(('snap', '--resume', self.model_assertion))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_and_model_assertion_without_subcommand(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--resume', self.model_assertion))
        self.assertEqual(cm.exception.code, 2)

    def test_no_resume_and_no_model_assertion(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--until', 'whatever'))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_without_workdir(self):
        with self.assertRaises(SystemExit) as cm:
            main(('snap', '--resume'))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_without_workdir_without_subcommand(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--resume',))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_and_gadget_tree(self):
        with self.assertRaises(SystemExit) as cm:
            main(('classic', '--resume', self.classic_gadget_tree))
        self.assertEqual(cm.exception.code, 2)

    def test_no_resume_and_no_gadget_tree(self):
        with self.assertRaises(SystemExit) as cm:
            main(('classic', '--until', 'whatever'))
        self.assertEqual(cm.exception.code, 2)

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_save_resume(self):
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            XXXModelAssertionBuilder))
        workdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(workdir, 'my-disk.img')
        main(('--until', 'prepare_filesystems',
              '--channel', 'edge',
              '--workdir', workdir,
              '--output', imgfile,
              self.model_assertion))
        self.assertTrue(os.path.exists(os.path.join(
            workdir, '.ubuntu-image.pck')))
        self.assertFalse(os.path.exists(imgfile))
        main(('snap', '--resume', '--workdir', workdir))
        self.assertTrue(os.path.exists(imgfile))

    def test_until(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        main(('snap', '--until', 'populate_rootfs_contents',
              '--channel', 'edge',
              '--workdir', workdir,
              self.model_assertion))
        # The pickle file will tell us how far the state machine got.
        with open(os.path.join(workdir, '.ubuntu-image.pck'), 'rb') as fp:
            pickle_state = load(fp).__getstate__()
        # This is the *next* state to execute.
        self.assertEqual(pickle_state['state'], ['populate_rootfs_contents'])

    def test_thru(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        main(('snap', '--thru', 'populate_rootfs_contents',
              '--workdir', workdir,
              '--channel', 'edge',
              self.model_assertion))
        # The pickle file will tell us how far the state machine got.
        with open(os.path.join(workdir, '.ubuntu-image.pck'), 'rb') as fp:
            pickle_state = load(fp).__getstate__()
        # This is the *next* state to execute.
        self.assertEqual(
            pickle_state['state'], ['populate_rootfs_contents_hooks'])

    def test_resume_loads_pickle_snap(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            EarlyExitLeaveATraceAssertionBuilder))
        main(('snap', '--until', 'prepare_image',
              '--workdir', workdir,
              self.model_assertion))
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        main(('--workdir', workdir, '--resume'))
        self.assertTrue(os.path.exists(os.path.join(workdir, 'success')))

    def test_resume_loads_pickle_classic(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ClassicBuilder',
            EarlyExitLeaveATraceClassicBuilder))
        self._resources.enter_context(
            patch('ubuntu_image.classic_builder.check_root_privilege'))
        main(('classic', '--until', 'prepare_image',
              '--workdir', workdir,
              '--project', 'ubuntu-cpc',
              self.classic_gadget_tree))
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        main(('--workdir', workdir, '--resume'))
        self.assertTrue(os.path.exists(os.path.join(workdir, 'success')))

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_does_not_fit(self):
        # The contents of a structure is too large for the image size.
        workdir = self._resources.enter_context(TemporaryDirectory())
        # See LP: #1666580
        main(('snap', '--workdir', workdir,
              '--thru', 'load_gadget_yaml',
              self.model_assertion))
        # Make the gadget's mbr contents too big.
        path = os.path.join(workdir, 'unpack', 'gadget', 'pc-boot.img')
        os.truncate(path, 512)
        mock = self._resources.enter_context(patch(
            'ubuntu_image.__main__._logger.error'))
        code = main(('snap', '--workdir', workdir, '--resume'))
        self.assertEqual(code, 1)
        self.assertEqual(
            mock.call_args_list[-1],
            call('Volume contents do not fit (72B over): '
                 'volumes:<pc>:structure:<mbr> [#0]'))

    def test_classic_not_privileged(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ClassicBuilder',
            EarlyExitLeaveATraceClassicBuilder))
        self._resources.enter_context(
            patch('os.geteuid', return_value=1))
        self._resources.enter_context(
            patch('pwd.getpwuid', return_value=['test']))
        mock = self._resources.enter_context(patch(
            'ubuntu_image.__main__._logger.error'))
        code = main(('classic', '--workdir', workdir,
                     '--project', 'ubuntu-cpc',
                     self.classic_gadget_tree))
        self.assertEqual(code, 1)
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        self.assertEqual(
            mock.call_args_list[-1],
            call('Current user(test) does not have root privilege to build '
                 'classic image. Please run ubuntu-image as root.'))

    def test_classic_cross_build_no_static(self):
        # We need to check that a DependencyError is raised when
        # find_executable does not find the qemu-<ARCH>-static binary in
        # PATH (and no path env is set)
        workdir = self._resources.enter_context(TemporaryDirectory())
        livecd_rootfs = self._resources.enter_context(TemporaryDirectory())
        auto = os.path.join(livecd_rootfs, 'auto')
        os.mkdir(auto)
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ClassicBuilder',
            CallLBLeaveATraceClassicBuilder))
        self._resources.enter_context(
            envar('UBUNTU_IMAGE_LIVECD_ROOTFS_AUTO_PATH', auto))
        self._resources.enter_context(
            patch('ubuntu_image.helpers.run', return_value=None))
        self._resources.enter_context(
            patch('ubuntu_image.helpers.find_executable', return_value=None))
        self._resources.enter_context(
            patch('ubuntu_image.helpers.get_host_arch',
                  return_value='amd64'))
        self._resources.enter_context(
            patch('ubuntu_image.__main__.get_host_distro',
                  return_value='bionic'))
        self._resources.enter_context(
            patch('ubuntu_image.classic_builder.check_root_privilege',
                  return_value=None))
        mock = self._resources.enter_context(patch(
            'ubuntu_image.__main__._logger.error'))
        code = main(('classic', '--workdir', workdir,
                     '--project', 'ubuntu-cpc', '--arch', 'armhf',
                     self.classic_gadget_tree))
        self.assertEqual(code, 1)
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        self.assertEqual(
            mock.call_args_list[-1],
            call('Required dependency qemu-arm-static seems to be missing. '
                 'Use UBUNTU_IMAGE_QEMU_USER_STATIC_PATH in case of '
                 'non-standard archs or custom paths.'))

    def test_hook_fired(self):
        # For the purpose of testing, we will be using the post-populate-rootfs
        # hook as we made sure it's still executed as part of of the
        # DoNothingBuilder.
        hookdir = self._resources.enter_context(TemporaryDirectory())
        hookfile = os.path.join(hookdir, 'post-populate-rootfs')
        # Let's make sure that, with the use of post-populate-rootfs, we can
        # modify the rootfs contents.
        with open(hookfile, 'w') as fp:
            fp.write("""\
#!/bin/sh
echo "[MAGIC_STRING_FOR_U-I_HOOKS]" > $UBUNTU_IMAGE_HOOK_ROOTFS/foo
""")
        os.chmod(hookfile, 0o744)
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        code = main(('--hooks-directory', hookdir,
                     '--workdir', workdir,
                     '--output-dir', workdir,
                     self.model_assertion))
        self.assertEqual(code, 0)
        self.assertTrue(os.path.exists(os.path.join(workdir, 'root', 'foo')))
        imagefile = os.path.join(workdir, 'pc.img')
        self.assertTrue(os.path.exists(imagefile))
        # Map the image and grep through it to see if our hook change actually
        # landed in the final image.
        with open(imagefile, 'r+b') as fp:
            m = self._resources.enter_context(mmap(fp.fileno(), 0))
            self.assertGreaterEqual(m.find(b'[MAGIC_STRING_FOR_U-I_HOOKS]'), 0)

    def test_hook_error(self):
        # For the purpose of testing, we will be using the post-populate-rootfs
        # hook as we made sure it's still executed as part of of the
        # DoNothingBuilder.
        hookdir = self._resources.enter_context(TemporaryDirectory())
        hookfile = os.path.join(hookdir, 'post-populate-rootfs')
        with open(hookfile, 'w') as fp:
            fp.write("""\
#!/bin/sh
echo -n "Failed" 1>&2
return 1
""")
        os.chmod(hookfile, 0o744)
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        mock = self._resources.enter_context(patch(
            'ubuntu_image.__main__._logger.error'))
        code = main(('--hooks-directory', hookdir, self.model_assertion))
        self.assertEqual(code, 1)
        self.assertEqual(
            mock.call_args_list[-1],
            call('Hook script in path {} failed for the post-populate-rootfs '
                 'hook with return code 1. Output of stderr:\nFailed'.format(
                    hookfile)))

    def test_hook_fired_after_resume(self):
        # For the purpose of testing, we will be using the post-populate-rootfs
        # hook as we made sure it's still executed as part of of the
        # DoNothingBuilder.
        hookdir = self._resources.enter_context(TemporaryDirectory())
        hookfile = os.path.join(hookdir, 'post-populate-rootfs')
        with open(hookfile, 'w') as fp:
            fp.write("""\
#!/bin/sh
touch {}/success
""".format(hookdir))
        os.chmod(hookfile, 0o744)
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            DoNothingBuilder))
        main(('--until', 'prepare_image',
              '--hooks-directory', hookdir,
              '--workdir', workdir,
              self.model_assertion))
        self.assertFalse(os.path.exists(os.path.join(hookdir, 'success')))
        # Check if after a resume the hook path is still correct and the hooks
        # are fired as expected.
        code = main(('--workdir', workdir, '--resume'))
        self.assertEqual(code, 0)
        self.assertTrue(os.path.exists(os.path.join(hookdir, 'success')))

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_hook_official_support(self):
        # This test is responsible for checking if all the officially declared
        # hooks are called as intended, making sure none get dropped by
        # accident.
        self._resources.enter_context(patch(
            'ubuntu_image.__main__.ModelAssertionBuilder',
            XXXModelAssertionBuilder))
        fire_mock = self._resources.enter_context(patch(
            'ubuntu_image.hooks.HookManager.fire'))
        code = main(('--channel', 'edge', self.model_assertion))
        self.assertEqual(code, 0)
        called_hooks = [x[0][0] for x in fire_mock.call_args_list]
        self.assertListEqual(called_hooks, supported_hooks)
Exemple #46
0
class NosePlugin(Plugin):
    configSection = 'ubuntu-image'
    snap_mocker = None

    def __init__(self):
        super().__init__()
        self.patterns = []
        self.addArgument(self.patterns, 'P', 'pattern',
                         'Add a test matching pattern')

    def getTestCaseNames(self, event):
        if len(self.patterns) == 0:
            # No filter patterns, so everything should be tested.
            return
        # Does the pattern match the fully qualified class name?
        for pattern in self.patterns:
            full_class_name = '{}.{}'.format(event.testCase.__module__,
                                             event.testCase.__name__)
            if re.search(pattern, full_class_name):
                # Don't suppress this test class.
                return
        names = filter(event.isTestMethod, dir(event.testCase))
        for name in names:
            full_test_name = '{}.{}.{}'.format(event.testCase.__module__,
                                               event.testCase.__name__, name)
            for pattern in self.patterns:
                if re.search(pattern, full_test_name):
                    break
            else:
                event.excludedNames.append(name)

    def handleFile(self, event):
        path = event.path[len(TOPDIR) + 1:]
        if len(self.patterns) > 0:
            for pattern in self.patterns:
                if re.search(pattern, path):
                    break
            else:
                # Skip this doctest.
                return
        base, ext = os.path.splitext(path)
        if ext != '.rst':
            return
        test = doctest.DocFileTest(path,
                                   package='ubuntu_image',
                                   optionflags=FLAGS,
                                   setUp=setup,
                                   tearDown=teardown)
        # Suppress the extra "Doctest: ..." line.
        test.shortDescription = lambda: None
        event.extraTests.append(test)

    def startTestRun(self, event):
        # Create a mock for the `sudo snap prepare-image` command.  This is an
        # expensive command which hits the actual snap store.  We want this to
        # run at least once so we know our tests are valid.  We can cache the
        # results in a test-suite-wide temporary directory and simulate future
        # calls by just recursively copying the contents to the specified
        # directories.
        #
        # It's a bit more complicated than that though, because it's possible
        # that the channel and model.assertion will be different, so we need
        # to make the cache dependent on those values.
        #
        # Finally, to enable full end-to-end tests, check an environment
        # variable to see if the mocking should even be done.  This way, we
        # can make our Travis-CI job do at least one real end-to-end test.
        self.resources = ExitStack()
        # How should we mock `snap prepare-image`?  If set to 'always' (case
        # insensitive), then use the sample data in the .zip file.  Any other
        # truthy value says to use a second-and-onward mock.
        should_we_mock = os.environ.get('UBUNTU_IMAGE_MOCK_SNAP', 'yes')
        if should_we_mock.lower() == 'always':
            mock_class = AlwaysMock
        elif as_bool(should_we_mock):
            mock_class = SecondAndOnwardMock
        else:
            mock_class = None
        if mock_class is not None:
            tmpdir = self.resources.enter_context(TemporaryDirectory())
            # Record the actual snap mocker on the class so that other tests
            # can temporarily disable it.  Some tests need to run the actual
            # snap() helper function.
            self.__class__.snap_mocker = self.resources.enter_context(
                mock_class(tmpdir))

    def stopTestRun(self, event):
        self.resources.close()
Exemple #47
0
class FlowPod(BasePod):
    """A :class:`FlowPod` is like a :class:`BasePod`, but it exposes more interfaces for tweaking its connections with
    other Pods, which comes in handy when used in the Flow API
    """
    def __init__(self,
                 kwargs: Dict,
                 needs: Set[str] = None,
                 parser: Callable = set_pod_parser):
        """

        :param kwargs: unparsed argument in dict, if given the
        :param needs: a list of names this BasePod needs to receive message from
        """
        _parser = parser()
        self.cli_args, self._args, self.unk_args = get_parsed_args(
            kwargs, _parser, 'FlowPod')
        super().__init__(self._args)
        self.needs = needs if needs else set(
        )  #: used in the :class:`jina.flow.Flow` to build the graph
        self._kwargs = get_non_defaults_args(self._args, _parser)

    def to_cli_command(self):
        if isinstance(self, GatewayPod):
            cmd = 'jina gateway'
        else:
            cmd = 'jina pod'

        return f'{cmd} {" ".join(self.cli_args)}'

    @staticmethod
    def connect(first: 'BasePod', second: 'BasePod',
                first_socket_type: 'SocketType'):
        """Connect two Pods

        :param first: the first BasePod
        :param second: the second BasePod
        :param first_socket_type: socket type of the first BasePod, availables are PUSH_BIND, PUSH_CONNECT, PUB_BIND
        """
        if first_socket_type == SocketType.PUSH_BIND:
            first.tail_args.socket_out = SocketType.PUSH_BIND
            second.head_args.socket_in = SocketType.PULL_CONNECT

            first.tail_args.host_out = __default_host__
            second.head_args.host_in = _fill_in_host(
                bind_args=first.tail_args, connect_args=second.head_args)
            second.head_args.port_in = first.tail_args.port_out
        elif first_socket_type == SocketType.PUSH_CONNECT:
            first.tail_args.socket_out = SocketType.PUSH_CONNECT
            second.head_args.socket_in = SocketType.PULL_BIND

            first.tail_args.host_out = _fill_in_host(
                connect_args=first.tail_args, bind_args=second.head_args)
            second.head_args.host_in = __default_host__
            first.tail_args.port_out = second.head_args.port_in
        elif first_socket_type == SocketType.PUB_BIND:
            first.tail_args.socket_out = SocketType.PUB_BIND
            first.tail_args.num_part += 1
            first.tail_args.yaml_path = '- !!PublishDriver | {num_part: %d}' % first.tail_args.num_part
            second.head_args.socket_in = SocketType.SUB_CONNECT

            first.tail_args.host_out = __default_host__  # bind always get default 0.0.0.0
            second.head_args.host_in = _fill_in_host(
                bind_args=first.tail_args,
                connect_args=second.head_args)  # the hostname of s_pod
            second.head_args.port_in = first.tail_args.port_out
        else:
            raise NotImplementedError(
                f'{first_socket_type!r} is not supported here')

    def connect_to_tail_of(self, pod: 'BasePod'):
        """Eliminate the head node by connecting prev_args node directly to peas """
        if self._args.replicas > 1 and self.is_head_router:
            # keep the port_in and socket_in of prev_args
            # only reset its output
            pod.tail_args = _copy_to_head_args(pod.tail_args,
                                               self._args.polling.is_push,
                                               as_router=False)
            # update peas to receive from it
            self.peas_args['peas'] = _set_peas_args(self._args, pod.tail_args,
                                                    self.tail_args)
            # remove the head node
            self.peas_args['head'] = None
            # head is no longer a router anymore
            self.is_head_router = False
            self.deducted_head = pod.tail_args
        else:
            raise ValueError(
                'the current pod has no head router, deduct the head is confusing'
            )

    def connect_to_head_of(self, pod: 'BasePod'):
        """Eliminate the tail node by connecting next_args node directly to peas """
        if self._args.replicas > 1 and self.is_tail_router:
            # keep the port_out and socket_out of next_arts
            # only reset its input
            pod.head_args = _copy_to_tail_args(pod.head_args, as_router=False)
            # update peas to receive from it
            self.peas_args['peas'] = _set_peas_args(self._args, self.head_args,
                                                    pod.head_args)
            # remove the head node
            self.peas_args['tail'] = None
            # head is no longer a router anymore
            self.is_tail_router = False
            self.deducted_tail = pod.head_args
        else:
            raise ValueError(
                'the current pod has no tail router, deduct the tail is confusing'
            )

    def start(self):
        if self._args.host == __default_host__:
            return super().start()
        else:
            from .remote import RemoteMutablePod
            _remote_pod = RemoteMutablePod(self.peas_args)
            self.stack = ExitStack()
            self.stack.enter_context(_remote_pod)
            self.start_sentinels()
            return self
Exemple #48
0
class FilesystemMockingTestCase(ResourceUsingTestCase):

    def setUp(self):
        super(FilesystemMockingTestCase, self).setUp()
        self.patched_funcs = ExitStack()

    def tearDown(self):
        self.patched_funcs.close()
        ResourceUsingTestCase.tearDown(self)

    def replicateTestRoot(self, example_root, target_root):
        real_root = self.resourceLocation()
        real_root = os.path.join(real_root, 'roots', example_root)
        for (dir_path, _dirnames, filenames) in os.walk(real_root):
            real_path = dir_path
            make_path = rebase_path(real_path[len(real_root):], target_root)
            util.ensure_dir(make_path)
            for f in filenames:
                real_path = util.abs_join(real_path, f)
                make_path = util.abs_join(make_path, f)
                shutil.copy(real_path, make_path)

    def patchUtils(self, new_root):
        patch_funcs = {
            util: [('write_file', 1),
                   ('append_file', 1),
                   ('load_file', 1),
                   ('ensure_dir', 1),
                   ('chmod', 1),
                   ('delete_dir_contents', 1),
                   ('del_file', 1),
                   ('sym_link', -1),
                   ('copy', -1)],
        }
        for (mod, funcs) in patch_funcs.items():
            for (f, am) in funcs:
                func = getattr(mod, f)
                trap_func = retarget_many_wrapper(new_root, am, func)
                self.patched_funcs.enter_context(
                    mock.patch.object(mod, f, trap_func))

        # Handle subprocess calls
        func = getattr(util, 'subp')

        def nsubp(*_args, **_kwargs):
            return ('', '')

        self.patched_funcs.enter_context(
            mock.patch.object(util, 'subp', nsubp))

        def null_func(*_args, **_kwargs):
            return None

        for f in ['chownbyid', 'chownbyname']:
            self.patched_funcs.enter_context(
                mock.patch.object(util, f, null_func))

    def patchOS(self, new_root):
        patch_funcs = {
            os.path: [('isfile', 1), ('exists', 1),
                      ('islink', 1), ('isdir', 1)],
            os: [('listdir', 1), ('mkdir', 1),
                 ('lstat', 1), ('symlink', 2)],
        }
        for (mod, funcs) in patch_funcs.items():
            for f, nargs in funcs:
                func = getattr(mod, f)
                trap_func = retarget_many_wrapper(new_root, nargs, func)
                self.patched_funcs.enter_context(
                    mock.patch.object(mod, f, trap_func))

    def patchOpen(self, new_root):
        trap_func = retarget_many_wrapper(new_root, 1, open)
        name = 'builtins.open' if PY3 else '__builtin__.open'
        self.patched_funcs.enter_context(mock.patch(name, trap_func))

    def patchStdoutAndStderr(self, stdout=None, stderr=None):
        if stdout is not None:
            self.patched_funcs.enter_context(
                mock.patch.object(sys, 'stdout', stdout))
        if stderr is not None:
            self.patched_funcs.enter_context(
                mock.patch.object(sys, 'stderr', stderr))

    def reRoot(self, root=None):
        if root is None:
            root = self.tmp_dir()
        self.patchUtils(root)
        self.patchOS(root)
        return root
Exemple #49
0
def early_exit_if(builder, stack: ExitStack, cond):
    then, otherwise = stack.enter_context(builder.if_else(cond, likely=False))
    with then:
        yield
    stack.enter_context(otherwise)
Exemple #50
0
class GTiffSingleFileOutputWriter(GTiffOutputReaderFunctions,
                                  base.SingleFileOutputWriter):

    write_in_parent_process = True

    def __init__(self, output_params, **kwargs):
        """Initialize."""
        logger.debug("output is single file")
        self.dst = None
        super().__init__(output_params, **kwargs)
        self._set_attributes(output_params)
        if len(self.output_params["delimiters"]["zoom"]) != 1:
            raise ValueError(
                "single file output only works with one zoom level")
        self.zoom = output_params["delimiters"]["zoom"][0]
        self.cog = output_params.get("cog", False)
        if self.cog or "overviews" in output_params:
            self.overviews = True
            self.overviews_resampling = output_params.get(
                "overviews_resampling", "nearest")
            self.overviews_levels = output_params.get(
                "overviews_levels", [2**i for i in range(1, self.zoom + 1)])
        else:
            self.overviews = False
        self.in_memory = output_params.get("in_memory", True)
        _bucket = self.path.split("/")[2] if self.path.startswith(
            "s3://") else None
        self._bucket_resource = get_boto3_bucket(_bucket) if _bucket else None

    def prepare(self, process_area=None, **kwargs):
        bounds = snap_bounds(
            bounds=Bounds(*process_area.intersection(
                box(*self.output_params["delimiters"]
                    ["effective_bounds"])).bounds),
            pyramid=self.pyramid,
            zoom=self.zoom) if process_area else self.output_params[
                "delimiters"]["effective_bounds"]
        height = math.ceil((bounds.top - bounds.bottom) /
                           self.pyramid.pixel_x_size(self.zoom))
        width = math.ceil((bounds.right - bounds.left) /
                          self.pyramid.pixel_x_size(self.zoom))
        logger.debug("output raster bounds: %s", bounds)
        logger.debug("output raster shape: %s, %s", height, width)
        self._profile = dict(
            GTIFF_DEFAULT_PROFILE,
            driver="GTiff",
            transform=Affine(self.pyramid.pixel_x_size(self.zoom), 0,
                             bounds.left, 0,
                             -self.pyramid.pixel_y_size(self.zoom),
                             bounds.top),
            height=height,
            width=width,
            count=self.output_params["bands"],
            crs=self.pyramid.crs,
            **{
                k: self.output_params.get(k, GTIFF_DEFAULT_PROFILE[k])
                for k in GTIFF_DEFAULT_PROFILE.keys()
            },
            bigtiff=self.output_params.get("bigtiff", "NO"))
        logger.debug("single GTiff profile: %s", self._profile)
        self.in_memory = (self.in_memory if self.in_memory is False else
                          height * width < IN_MEMORY_THRESHOLD)
        # set up rasterio
        if path_exists(self.path):
            if self.output_params["mode"] != "overwrite":
                raise MapcheteConfigError(
                    "single GTiff file already exists, use overwrite mode to replace"
                )
            else:
                logger.debug("remove existing file: %s", self.path)
                os.remove(self.path)
        # create output directory if necessary
        makedirs(os.path.dirname(self.path))
        logger.debug("open output file: %s", self.path)
        self._ctx = ExitStack()
        # (1) use memfile if output is remote or COG
        if self.cog or path_is_remote(self.path):
            if self.in_memory:
                self._memfile = self._ctx.enter_context(MemoryFile())
                self.dst = self._ctx.enter_context(
                    self._memfile.open(**self._profile))
            else:
                # in case output raster is too big, use tempfile on disk
                self._tempfile = self._ctx.enter_context(NamedTemporaryFile())
                self.dst = self._ctx.enter_context(
                    rasterio.open(self._tempfile.name, "w+", **self._profile))
        else:
            self.dst = self._ctx.enter_context(
                rasterio.open(self.path, "w+", **self._profile))

    def read(self, output_tile, **kwargs):
        """
        Read existing process output.

        Parameters
        ----------
        output_tile : ``BufferedTile``
            must be member of output ``TilePyramid``

        Returns
        -------
        NumPy array
        """
        return read_raster_window(self.dst, output_tile)

    def get_path(self, tile=None):
        """
        Determine target file path.

        Parameters
        ----------
        tile : ``BufferedTile``
            must be member of output ``TilePyramid``

        Returns
        -------
        path : string
        """
        return self.path

    def tiles_exist(self, process_tile=None, output_tile=None):
        """
        Check whether output tiles of a tile (either process or output) exists.

        Parameters
        ----------
        process_tile : ``BufferedTile``
            must be member of process ``TilePyramid``
        output_tile : ``BufferedTile``
            must be member of output ``TilePyramid``

        Returns
        -------
        exists : bool
        """
        if process_tile and output_tile:
            raise ValueError(
                "just one of 'process_tile' and 'output_tile' allowed")
        if process_tile:
            return any(not self.read(tile).mask.all()
                       for tile in self.pyramid.intersecting(process_tile))
        if output_tile:
            return not self.read(output_tile).mask.all()

    def write(self, process_tile, data):
        """
        Write data from process tiles into GeoTIFF file(s).

        Parameters
        ----------
        process_tile : ``BufferedTile``
            must be member of process ``TilePyramid``
        """
        data = prepare_array(data,
                             masked=True,
                             nodata=self.output_params["nodata"],
                             dtype=self.profile(process_tile)["dtype"])

        if data.mask.all():
            logger.debug("data empty, nothing to write")
        else:
            # Convert from process_tile to output_tiles and write
            for tile in self.pyramid.intersecting(process_tile):
                out_tile = BufferedTile(tile, self.pixelbuffer)
                write_window = from_bounds(
                    *out_tile.bounds,
                    transform=self.dst.transform,
                    height=self.dst.height,
                    width=self.dst.width).round_lengths(
                        pixel_precision=0).round_offsets(pixel_precision=0)
                if _window_in_out_file(write_window, self.dst):
                    logger.debug("write data to window: %s", write_window)
                    self.dst.write(
                        extract_from_array(in_raster=data,
                                           in_affine=process_tile.affine,
                                           out_tile=out_tile)
                        if process_tile != out_tile else data,
                        window=write_window,
                    )

    def profile(self, tile=None):
        """
        Create a metadata dictionary for rasterio.

        Returns
        -------
        metadata : dictionary
            output profile dictionary used for rasterio.
        """
        return self._profile

    def close(self, exc_type=None, exc_value=None, exc_traceback=None):
        """Build overviews and write file."""
        try:
            # only in case no Exception was raised
            if not exc_type:
                # build overviews
                if self.overviews and self.dst is not None:
                    logger.debug(
                        "build overviews using %s resampling and levels %s",
                        self.overviews_resampling, self.overviews_levels)
                    self.dst.build_overviews(
                        self.overviews_levels,
                        Resampling[self.overviews_resampling])
                    self.dst.update_tags(ns='rio_overview',
                                         resampling=self.overviews_resampling)
                # write
                if self.cog:
                    if path_is_remote(self.path):
                        # remote COG: copy to tempfile and upload to destination
                        logger.debug("upload to %s", self.path)
                        # TODO this writes a memoryfile to disk and uploads the file,
                        # this is inefficient but until we find a solution to copy
                        # from one memoryfile to another the rasterio way (rasterio needs
                        # to rearrange the data so the overviews are at the beginning of
                        # the GTiff in order to be a valid COG).
                        with NamedTemporaryFile() as tmp_dst:
                            copy(self.dst,
                                 tmp_dst.name,
                                 copy_src_overviews=True,
                                 **self._profile)
                            self._bucket_resource.upload_file(
                                Filename=tmp_dst.name,
                                Key="/".join(self.path.split("/")[3:]),
                            )
                    else:
                        # local COG: copy to destination
                        logger.debug("write to %s", self.path)
                        copy(self.dst,
                             self.path,
                             copy_src_overviews=True,
                             **self._profile)
                else:
                    if path_is_remote(self.path):
                        # remote GTiff: upload memfile or tempfile to destination
                        logger.debug("upload to %s", self.path)
                        if self.in_memory:
                            self._bucket_resource.put_object(
                                Body=self._memfile,
                                Key="/".join(self.path.split("/")[3:]),
                            )
                        else:
                            self._bucket_resource.upload_file(
                                Filename=self._tempfile.name,
                                Key="/".join(self.path.split("/")[3:]),
                            )
                    else:
                        # local GTiff: already written, do nothing
                        pass

        finally:
            self._ctx.close()
Exemple #51
0
class TestConfig(TestCase):
    def setUp(self):
        super(TestConfig, self).setUp()
        self.name = "ca-certs"
        distro = self._fetch_distro('ubuntu')
        self.paths = None
        self.cloud = cloud.Cloud(None, self.paths, None, distro, None)
        self.log = logging.getLogger("TestNoConfig")
        self.args = []

        self.mocks = ExitStack()
        self.addCleanup(self.mocks.close)

        # Mock out the functions that actually modify the system
        self.mock_add = self.mocks.enter_context(
            mock.patch.object(cc_ca_certs, 'add_ca_certs'))
        self.mock_update = self.mocks.enter_context(
            mock.patch.object(cc_ca_certs, 'update_ca_certs'))
        self.mock_remove = self.mocks.enter_context(
            mock.patch.object(cc_ca_certs, 'remove_default_ca_certs'))

    def _fetch_distro(self, kind):
        cls = distros.fetch(kind)
        paths = helpers.Paths({})
        return cls(kind, {}, paths)

    def test_no_trusted_list(self):
        """
        Test that no certificates are written if the 'trusted' key is not
        present.
        """
        config = {"ca-certs": {}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.assertEqual(self.mock_add.call_count, 0)
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 0)

    def test_empty_trusted_list(self):
        """Test that no certificate are written if 'trusted' list is empty."""
        config = {"ca-certs": {"trusted": []}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.assertEqual(self.mock_add.call_count, 0)
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 0)

    def test_single_trusted(self):
        """Test that a single cert gets passed to add_ca_certs."""
        config = {"ca-certs": {"trusted": ["CERT1"]}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.mock_add.assert_called_once_with(['CERT1'])
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 0)

    def test_multiple_trusted(self):
        """Test that multiple certs get passed to add_ca_certs."""
        config = {"ca-certs": {"trusted": ["CERT1", "CERT2"]}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.mock_add.assert_called_once_with(['CERT1', 'CERT2'])
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 0)

    def test_remove_default_ca_certs(self):
        """Test remove_defaults works as expected."""
        config = {"ca-certs": {"remove-defaults": True}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.assertEqual(self.mock_add.call_count, 0)
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 1)

    def test_no_remove_defaults_if_false(self):
        """Test remove_defaults is not called when config value is False."""
        config = {"ca-certs": {"remove-defaults": False}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.assertEqual(self.mock_add.call_count, 0)
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 0)

    def test_correct_order_for_remove_then_add(self):
        """Test remove_defaults is not called when config value is False."""
        config = {"ca-certs": {"remove-defaults": True, "trusted": ["CERT1"]}}

        cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args)

        self.mock_add.assert_called_once_with(['CERT1'])
        self.assertEqual(self.mock_update.call_count, 1)
        self.assertEqual(self.mock_remove.call_count, 1)
Exemple #52
0
async def run_nodes(
    repo_root_dir: Union[str, pathlib.Path],
    docs_root_dir: Union[str, pathlib.Path],
    stack: contextlib.ExitStack,
    nodes: List[Dict[str, Any]],
    *,
    setup: Optional[List[ConsoletestCommand]] = None,
) -> None:
    # Ensure pathlib objects
    repo_root_dir = pathlib.Path(repo_root_dir).resolve()
    docs_root_dir = pathlib.Path(docs_root_dir).resolve()
    # Create an async exit stack
    async with contextlib.AsyncExitStack() as astack:
        tempdir = stack.enter_context(tempfile.TemporaryDirectory())

        ctx = {
            "root": str(repo_root_dir),
            "docs": str(docs_root_dir),
            "cwd": tempdir,
            "stack": stack,
            "astack": astack,
            "daemons": {},
            # Items in this context that must are not serializable
            "no_serialize": {"stack", "astack", "daemons"},
        }

        # Create a virtualenv for every document
        if setup is not None:
            await setup(ctx)

        for node in nodes:  # type: Element
            if node["consoletestnodetype"] == "consoletest-literalinclude":
                lines = node.get("lines", None)
                if lines is not None:
                    lines = tuple(map(int, lines.split("-")))

                # Handle navigating out of the docs_root_dir
                if node["source"].startswith("/.."):
                    node["source"] = node["source"][1:]

                src = os.path.join(str(docs_root_dir), node["source"])
                dst = os.path.join(ctx["cwd"], *node["filepath"])

                print()
                print("Copying", ctx, src, dst, lines)

                copyfile(src, dst, lines=lines)
                print(pathlib.Path(dst).read_text(), end="")
                print()
            elif node["consoletestnodetype"] == "consoletest-file":
                print()
                filepath = pathlib.Path(ctx["cwd"], *node["filepath"])

                if not filepath.parent.is_dir():
                    filepath.parent.mkdir(parents=True)

                if node["overwrite"] and filepath.is_file():
                    print("Overwriting", ctx, filepath)
                    mode = "wt"
                else:
                    print("Writing", ctx, filepath)
                    mode = "at"

                with open(filepath, mode) as outfile:
                    outfile.seek(0, io.SEEK_END)
                    outfile.write("\n".join(node["content"]) + "\n")

                print(filepath.read_text(), end="")
                print()
            elif node["consoletestnodetype"] == "consoletest":
                if node["consoletest_commands_replace"] is not None:
                    for command, new_cmd in zip(
                        node["consoletest_commands"],
                        call_replace(
                            node["consoletest_commands_replace"],
                            list(
                                map(
                                    lambda command: command.cmd
                                    if isinstance(command, ConsoleCommand)
                                    else [],
                                    node["consoletest_commands"],
                                )
                            ),
                            ctx,
                        ),
                    ):
                        if isinstance(command, ConsoleCommand):
                            command.cmd = new_cmd
                for command in node["consoletest_commands"]:
                    print()
                    print("Running", ctx, command)
                    print()
                    await astack.enter_async_context(command)
                    await command.run(ctx)
Exemple #53
0
    barriers = random_barriers(range(BARRIERS), world_area)

    empty = RenderEmpty.SPACE if world_size_auto and barriers else RenderEmpty.DOT

    builder = Builder(world_area, population, barriers)
    roster = builder.roster
    renderer = Renderer(roster, barriers, empty=empty)

    ticks = islice(each_interval(TICK), MAX_AGE)

    tracing_context = ExitStack()

    trace_file_name = environ.get("TRACEFILE")
    if trace_file_name:
        trace_file = open(trace_file_name, mode="w")
        tracing_context.enter_context(trace_file)
        tracing.init(trace_file)

    with tracing_context:
        try:
            for _ in ticks:
                clear()
                for line in renderer.lines:
                    print(line)

                with tracing.span("tick"):
                    old_roster, roster = roster, Tick(roster, barriers).next()

                if old_roster == roster:
                    break
                renderer = Renderer(roster, barriers, empty=empty)
Exemple #54
0
def __create_file_tuple(path: str, es: ExitStack):
    filename = basename(path)
    file_handle = es.enter_context(open(path, 'rb'))
    mime_type = mimetypes.guess_type(path)[0]
    return filename, (filename, file_handle, mime_type)
Exemple #55
0
class TestAzureDataSource(TestCase):

    def setUp(self):
        super(TestAzureDataSource, self).setUp()
        self.tmp = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, self.tmp)

        # patch cloud_dir, so our 'seed_dir' is guaranteed empty
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')

        self.patches = ExitStack()
        self.addCleanup(self.patches.close)

        super(TestAzureDataSource, self).setUp()

    def apply_patches(self, patches):
        for module, name, new in patches:
            self.patches.enter_context(mock.patch.object(module, name, new))

    def _get_ds(self, data):

        def dsdevs():
            return data.get('dsdevs', [])

        def _invoke_agent(cmd):
            data['agent_invoked'] = cmd

        def _wait_for_files(flist, _maxwait=None, _naplen=None):
            data['waited'] = flist
            return []

        def _pubkeys_from_crt_files(flist):
            data['pubkey_files'] = flist
            return ["pubkey_from: %s" % f for f in flist]

        if data.get('ovfcontent') is not None:
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
                         {'ovf-env.xml': data['ovfcontent']})

        mod = DataSourceAzure
        mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d

        self.get_metadata_from_fabric = mock.MagicMock(return_value={
            'public-keys': [],
        })

        self.instance_id = 'test-instance-id'

        self.apply_patches([
            (mod, 'list_possible_azure_ds_devs', dsdevs),
            (mod, 'invoke_agent', _invoke_agent),
            (mod, 'wait_for_files', _wait_for_files),
            (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),
            (mod, 'perform_hostname_bounce', mock.MagicMock()),
            (mod, 'get_hostname', mock.MagicMock()),
            (mod, 'set_hostname', mock.MagicMock()),
            (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric),
            (mod.util, 'read_dmi_data', mock.MagicMock(
                return_value=self.instance_id)),
        ])

        dsrc = mod.DataSourceAzureNet(
            data.get('sys_cfg', {}), distro=None, paths=self.paths)

        return dsrc

    def xml_equals(self, oxml, nxml):
        """Compare two sets of XML to make sure they are equal"""

        def create_tag_index(xml):
            et = ET.fromstring(xml)
            ret = {}
            for x in et.iter():
                ret[x.tag] = x
            return ret

        def tags_exists(x, y):
            for tag in x.keys():
                self.assertIn(tag, y)
            for tag in y.keys():
                self.assertIn(tag, x)

        def tags_equal(x, y):
            for x_tag, x_val in x.items():
                y_val = y.get(x_val.tag)
                self.assertEquals(x_val.text, y_val.text)

        old_cnt = create_tag_index(oxml)
        new_cnt = create_tag_index(nxml)
        tags_exists(old_cnt, new_cnt)
        tags_equal(old_cnt, new_cnt)

    def xml_notequals(self, oxml, nxml):
        try:
            self.xml_equals(oxml, nxml)
        except AssertionError:
            return
        raise AssertionError("XML is the same")

    def test_basic_seed_dir(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata),
                'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, "")
        self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName'])
        self.assertTrue(os.path.isfile(
            os.path.join(self.waagent_d, 'ovf-env.xml')))

    def test_waagent_d_has_0700_perms(self):
        # we expect /var/lib/waagent to be created 0700
        dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue(os.path.isdir(self.waagent_d))
        self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0o700)

    def test_user_cfg_set_agent_command_plain(self):
        # set dscfg in via plaintext
        # we must have friendly-to-xml formatted plaintext in yaml_cfg
        # not all plaintext is expected to work.
        yaml_cfg = "{agent_command: my_command}\n"
        cfg = yaml.safe_load(yaml_cfg)
        odata = {'HostName': "myhost", 'UserName': "******",
                 'dscfg': {'text': yaml_cfg, 'encoding': 'plain'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])

    def test_user_cfg_set_agent_command(self):
        # set dscfg in via base64 encoded yaml
        cfg = {'agent_command': "my_command"}
        odata = {'HostName': "myhost", 'UserName': "******",
                 'dscfg': {'text': b64e(yaml.dump(cfg)),
                           'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])

    def test_sys_cfg_set_agent_command(self):
        sys_cfg = {'datasource': {'Azure': {'agent_command': '_COMMAND'}}}
        data = {'ovfcontent': construct_valid_ovf_env(data={}),
                'sys_cfg': sys_cfg}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], '_COMMAND')

    def test_username_used(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.cfg['system_info']['default_user']['name'],
                         "myuser")

    def test_password_given(self):
        odata = {'HostName': "myhost", 'UserName': "******",
                 'UserPassword': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue('default_user' in dsrc.cfg['system_info'])
        defuser = dsrc.cfg['system_info']['default_user']

        # default user should be updated username and should not be locked.
        self.assertEqual(defuser['name'], odata['UserName'])
        self.assertFalse(defuser['lock_passwd'])
        # passwd is crypt formated string $id$salt$encrypted
        # encrypting plaintext with salt value of everything up to final '$'
        # should equal that after the '$'
        pos = defuser['passwd'].rfind("$") + 1
        self.assertEqual(defuser['passwd'],
                         crypt.crypt(odata['UserPassword'],
                                     defuser['passwd'][0:pos]))

    def test_userdata_plain(self):
        mydata = "FOOBAR"
        odata = {'UserData': {'text': mydata, 'encoding': 'plain'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(decode_binary(dsrc.userdata_raw), mydata)

    def test_userdata_found(self):
        mydata = "FOOBAR"
        odata = {'UserData': {'text': b64e(mydata), 'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8'))

    def test_no_datasource_expected(self):
        # no source should be found if no seed_dir and no devs
        data = {}
        dsrc = self._get_ds({})
        ret = dsrc.get_data()
        self.assertFalse(ret)
        self.assertFalse('agent_invoked' in data)

    def test_cfg_has_pubkeys_fingerprint(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}]
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
                                                      pubkeys=pubkeys)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        for mypk in mypklist:
            self.assertIn(mypk, dsrc.cfg['_pubkeys'])
            self.assertIn('pubkey_from', dsrc.metadata['public-keys'][-1])

    def test_cfg_has_pubkeys_value(self):
        # make sure that provided key is used over fingerprint
        odata = {'HostName': "myhost", 'UserName': "******"}
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': 'value1'}]
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
                                                      pubkeys=pubkeys)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)

        for mypk in mypklist:
            self.assertIn(mypk, dsrc.cfg['_pubkeys'])
            self.assertIn(mypk['value'], dsrc.metadata['public-keys'])

    def test_cfg_has_no_fingerprint_has_value(self):
        # test value is used when fingerprint not provided
        odata = {'HostName': "myhost", 'UserName': "******"}
        mypklist = [{'fingerprint': None, 'path': 'path1', 'value': 'value1'}]
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
                                                      pubkeys=pubkeys)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)

        for mypk in mypklist:
            self.assertIn(mypk['value'], dsrc.metadata['public-keys'])

    def test_default_ephemeral(self):
        # make sure the ephemeral device works
        odata = {}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata),
                'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        cfg = dsrc.get_config_obj()

        self.assertEquals(dsrc.device_name_to_device("ephemeral0"),
                          "/dev/sdb")
        assert 'disk_setup' in cfg
        assert 'fs_setup' in cfg
        self.assertIsInstance(cfg['disk_setup'], dict)
        self.assertIsInstance(cfg['fs_setup'], list)

    def test_provide_disk_aliases(self):
        # Make sure that user can affect disk aliases
        dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}}
        odata = {'HostName': "myhost", 'UserName': "******",
                 'dscfg': {'text': b64e(yaml.dump(dscfg)),
                           'encoding': 'base64'}}
        usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'},
                                  'ephemeral0': False}}
        userdata = '#cloud-config' + yaml.dump(usercfg) + "\n"

        ovfcontent = construct_valid_ovf_env(data=odata, userdata=userdata)
        data = {'ovfcontent': ovfcontent, 'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        cfg = dsrc.get_config_obj()
        self.assertTrue(cfg)

    def test_userdata_arrives(self):
        userdata = "This is my user-data"
        xml = construct_valid_ovf_env(data={}, userdata=userdata)
        data = {'ovfcontent': xml}
        dsrc = self._get_ds(data)
        dsrc.get_data()

        self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw)

    def test_password_redacted_in_ovf(self):
        odata = {'HostName': "myhost", 'UserName': "******",
                 'UserPassword': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
        dsrc = self._get_ds(data)
        ret = dsrc.get_data()

        self.assertTrue(ret)
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')

        # The XML should not be same since the user password is redacted
        on_disk_ovf = load_file(ovf_env_path)
        self.xml_notequals(data['ovfcontent'], on_disk_ovf)

        # Make sure that the redacted password on disk is not used by CI
        self.assertNotEquals(dsrc.cfg.get('password'),
                             DataSourceAzure.DEF_PASSWD_REDACTION)

        # Make sure that the password was really encrypted
        et = ET.fromstring(on_disk_ovf)
        for elem in et.iter():
            if 'UserPassword' in elem.tag:
                self.assertEquals(DataSourceAzure.DEF_PASSWD_REDACTION,
                                  elem.text)

    def test_ovf_env_arrives_in_waagent_dir(self):
        xml = construct_valid_ovf_env(data={}, userdata="FOODATA")
        dsrc = self._get_ds({'ovfcontent': xml})
        dsrc.get_data()

        # 'data_dir' is '/var/lib/waagent' (walinux-agent's state dir)
        # we expect that the ovf-env.xml file is copied there.
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')
        self.assertTrue(os.path.exists(ovf_env_path))
        self.xml_equals(xml, load_file(ovf_env_path))

    def test_ovf_can_include_unicode(self):
        xml = construct_valid_ovf_env(data={})
        xml = u'\ufeff{0}'.format(xml)
        dsrc = self._get_ds({'ovfcontent': xml})
        dsrc.get_data()

    def test_exception_fetching_fabric_data_doesnt_propagate(self):
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ds.ds_cfg['agent_command'] = '__builtin__'
        self.get_metadata_from_fabric.side_effect = Exception
        self.assertFalse(ds.get_data())

    def test_fabric_data_included_in_metadata(self):
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ds.ds_cfg['agent_command'] = '__builtin__'
        self.get_metadata_from_fabric.return_value = {'test': 'value'}
        ret = ds.get_data()
        self.assertTrue(ret)
        self.assertEqual('value', ds.metadata['test'])

    def test_instance_id_from_dmidecode_used(self):
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ds.get_data()
        self.assertEqual(self.instance_id, ds.metadata['instance-id'])

    def test_instance_id_from_dmidecode_used_for_builtin(self):
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ds.ds_cfg['agent_command'] = '__builtin__'
        ds.get_data()
        self.assertEqual(self.instance_id, ds.metadata['instance-id'])
Exemple #56
0
class TestAzureDataSource(TestCase):
    def setUp(self):
        super(TestAzureDataSource, self).setUp()
        self.tmp = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, self.tmp)

        # patch cloud_dir, so our 'seed_dir' is guaranteed empty
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')

        self.patches = ExitStack()
        self.addCleanup(self.patches.close)

        super(TestAzureDataSource, self).setUp()

    def apply_patches(self, patches):
        for module, name, new in patches:
            self.patches.enter_context(mock.patch.object(module, name, new))

    def _get_ds(self, data):
        def dsdevs():
            return data.get('dsdevs', [])

        def _invoke_agent(cmd):
            data['agent_invoked'] = cmd

        def _wait_for_files(flist, _maxwait=None, _naplen=None):
            data['waited'] = flist
            return []

        def _pubkeys_from_crt_files(flist):
            data['pubkey_files'] = flist
            return ["pubkey_from: %s" % f for f in flist]

        def _iid_from_shared_config(path):
            data['iid_from_shared_cfg'] = path
            return 'i-my-azure-id'

        if data.get('ovfcontent') is not None:
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
                         {'ovf-env.xml': data['ovfcontent']})

        mod = DataSourceAzure
        mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d

        self.get_metadata_from_fabric = mock.MagicMock(return_value={
            'instance-id': 'i-my-azure-id',
            'public-keys': [],
        })

        self.apply_patches([
            (mod, 'list_possible_azure_ds_devs', dsdevs),
            (mod, 'invoke_agent', _invoke_agent),
            (mod, 'wait_for_files', _wait_for_files),
            (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),
            (mod, 'iid_from_shared_config', _iid_from_shared_config),
            (mod, 'perform_hostname_bounce', mock.MagicMock()),
            (mod, 'get_hostname', mock.MagicMock()),
            (mod, 'set_hostname', mock.MagicMock()),
            (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric),
        ])

        dsrc = mod.DataSourceAzureNet(data.get('sys_cfg', {}),
                                      distro=None,
                                      paths=self.paths)

        return dsrc

    def xml_equals(self, oxml, nxml):
        """Compare two sets of XML to make sure they are equal"""
        def create_tag_index(xml):
            et = ET.fromstring(xml)
            ret = {}
            for x in et.iter():
                ret[x.tag] = x
            return ret

        def tags_exists(x, y):
            for tag in x.keys():
                self.assertIn(tag, y)
            for tag in y.keys():
                self.assertIn(tag, x)

        def tags_equal(x, y):
            for x_tag, x_val in x.items():
                y_val = y.get(x_val.tag)
                self.assertEquals(x_val.text, y_val.text)

        old_cnt = create_tag_index(oxml)
        new_cnt = create_tag_index(nxml)
        tags_exists(old_cnt, new_cnt)
        tags_equal(old_cnt, new_cnt)

    def xml_notequals(self, oxml, nxml):
        try:
            self.xml_equals(oxml, nxml)
        except AssertionError:
            return
        raise AssertionError("XML is the same")

    def test_basic_seed_dir(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        data = {
            'ovfcontent': construct_valid_ovf_env(data=odata),
            'sys_cfg': {}
        }

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, "")
        self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName'])
        self.assertTrue(
            os.path.isfile(os.path.join(self.waagent_d, 'ovf-env.xml')))
        self.assertEqual(dsrc.metadata['instance-id'], 'i-my-azure-id')

    def test_waagent_d_has_0700_perms(self):
        # we expect /var/lib/waagent to be created 0700
        dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue(os.path.isdir(self.waagent_d))
        self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0o700)

    def test_user_cfg_set_agent_command_plain(self):
        # set dscfg in via plaintext
        # we must have friendly-to-xml formatted plaintext in yaml_cfg
        # not all plaintext is expected to work.
        yaml_cfg = "{agent_command: my_command}\n"
        cfg = yaml.safe_load(yaml_cfg)
        odata = {
            'HostName': "myhost",
            'UserName': "******",
            'dscfg': {
                'text': yaml_cfg,
                'encoding': 'plain'
            }
        }
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])

    def test_user_cfg_set_agent_command(self):
        # set dscfg in via base64 encoded yaml
        cfg = {'agent_command': "my_command"}
        odata = {
            'HostName': "myhost",
            'UserName': "******",
            'dscfg': {
                'text': b64e(yaml.dump(cfg)),
                'encoding': 'base64'
            }
        }
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])

    def test_sys_cfg_set_agent_command(self):
        sys_cfg = {'datasource': {'Azure': {'agent_command': '_COMMAND'}}}
        data = {
            'ovfcontent': construct_valid_ovf_env(data={}),
            'sys_cfg': sys_cfg
        }

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], '_COMMAND')

    def test_username_used(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.cfg['system_info']['default_user']['name'],
                         "myuser")

    def test_password_given(self):
        odata = {
            'HostName': "myhost",
            'UserName': "******",
            'UserPassword': "******"
        }
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue('default_user' in dsrc.cfg['system_info'])
        defuser = dsrc.cfg['system_info']['default_user']

        # default user should be updated username and should not be locked.
        self.assertEqual(defuser['name'], odata['UserName'])
        self.assertFalse(defuser['lock_passwd'])
        # passwd is crypt formated string $id$salt$encrypted
        # encrypting plaintext with salt value of everything up to final '$'
        # should equal that after the '$'
        pos = defuser['passwd'].rfind("$") + 1
        self.assertEqual(
            defuser['passwd'],
            crypt.crypt(odata['UserPassword'], defuser['passwd'][0:pos]))

    def test_userdata_plain(self):
        mydata = "FOOBAR"
        odata = {'UserData': {'text': mydata, 'encoding': 'plain'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(decode_binary(dsrc.userdata_raw), mydata)

    def test_userdata_found(self):
        mydata = "FOOBAR"
        odata = {'UserData': {'text': b64e(mydata), 'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8'))

    def test_no_datasource_expected(self):
        # no source should be found if no seed_dir and no devs
        data = {}
        dsrc = self._get_ds({})
        ret = dsrc.get_data()
        self.assertFalse(ret)
        self.assertFalse('agent_invoked' in data)

    def test_cfg_has_pubkeys(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1'}]
        pubkeys = [(x['fingerprint'], x['path']) for x in mypklist]
        data = {
            'ovfcontent': construct_valid_ovf_env(data=odata, pubkeys=pubkeys)
        }

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        for mypk in mypklist:
            self.assertIn(mypk, dsrc.cfg['_pubkeys'])

    def test_default_ephemeral(self):
        # make sure the ephemeral device works
        odata = {}
        data = {
            'ovfcontent': construct_valid_ovf_env(data=odata),
            'sys_cfg': {}
        }

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        cfg = dsrc.get_config_obj()

        self.assertEquals(dsrc.device_name_to_device("ephemeral0"), "/dev/sdb")
        assert 'disk_setup' in cfg
        assert 'fs_setup' in cfg
        self.assertIsInstance(cfg['disk_setup'], dict)
        self.assertIsInstance(cfg['fs_setup'], list)

    def test_provide_disk_aliases(self):
        # Make sure that user can affect disk aliases
        dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}}
        odata = {
            'HostName': "myhost",
            'UserName': "******",
            'dscfg': {
                'text': b64e(yaml.dump(dscfg)),
                'encoding': 'base64'
            }
        }
        usercfg = {
            'disk_setup': {
                '/dev/sdc': {
                    'something': '...'
                },
                'ephemeral0': False
            }
        }
        userdata = '#cloud-config' + yaml.dump(usercfg) + "\n"

        ovfcontent = construct_valid_ovf_env(data=odata, userdata=userdata)
        data = {'ovfcontent': ovfcontent, 'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        cfg = dsrc.get_config_obj()
        self.assertTrue(cfg)

    def test_userdata_arrives(self):
        userdata = "This is my user-data"
        xml = construct_valid_ovf_env(data={}, userdata=userdata)
        data = {'ovfcontent': xml}
        dsrc = self._get_ds(data)
        dsrc.get_data()

        self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw)

    def test_password_redacted_in_ovf(self):
        odata = {
            'HostName': "myhost",
            'UserName': "******",
            'UserPassword': "******"
        }
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
        dsrc = self._get_ds(data)
        ret = dsrc.get_data()

        self.assertTrue(ret)
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')

        # The XML should not be same since the user password is redacted
        on_disk_ovf = load_file(ovf_env_path)
        self.xml_notequals(data['ovfcontent'], on_disk_ovf)

        # Make sure that the redacted password on disk is not used by CI
        self.assertNotEquals(dsrc.cfg.get('password'),
                             DataSourceAzure.DEF_PASSWD_REDACTION)

        # Make sure that the password was really encrypted
        et = ET.fromstring(on_disk_ovf)
        for elem in et.iter():
            if 'UserPassword' in elem.tag:
                self.assertEquals(DataSourceAzure.DEF_PASSWD_REDACTION,
                                  elem.text)

    def test_ovf_env_arrives_in_waagent_dir(self):
        xml = construct_valid_ovf_env(data={}, userdata="FOODATA")
        dsrc = self._get_ds({'ovfcontent': xml})
        dsrc.get_data()

        # 'data_dir' is '/var/lib/waagent' (walinux-agent's state dir)
        # we expect that the ovf-env.xml file is copied there.
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')
        self.assertTrue(os.path.exists(ovf_env_path))
        self.xml_equals(xml, load_file(ovf_env_path))

    def test_ovf_can_include_unicode(self):
        xml = construct_valid_ovf_env(data={})
        xml = u'\ufeff{0}'.format(xml)
        dsrc = self._get_ds({'ovfcontent': xml})
        dsrc.get_data()

    def test_existing_ovf_same(self):
        # waagent/SharedConfig left alone if found ovf-env.xml same as cached
        odata = {'UserData': b64e("SOMEUSERDATA")}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        populate_dir(
            self.waagent_d, {
                'ovf-env.xml': data['ovfcontent'],
                'otherfile': 'otherfile-content',
                'SharedConfig.xml': 'mysharedconfig'
            })

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue(
            os.path.exists(os.path.join(self.waagent_d, 'ovf-env.xml')))
        self.assertTrue(
            os.path.exists(os.path.join(self.waagent_d, 'otherfile')))
        self.assertTrue(
            os.path.exists(os.path.join(self.waagent_d, 'SharedConfig.xml')))

    def test_existing_ovf_diff(self):
        # waagent/SharedConfig must be removed if ovfenv is found elsewhere

        # 'get_data' should remove SharedConfig.xml in /var/lib/waagent
        # if ovf-env.xml differs.
        cached_ovfenv = construct_valid_ovf_env(
            {'userdata': b64e("FOO_USERDATA")})
        new_ovfenv = construct_valid_ovf_env(
            {'userdata': b64e("NEW_USERDATA")})

        populate_dir(
            self.waagent_d, {
                'ovf-env.xml': cached_ovfenv,
                'SharedConfig.xml': "mysharedconfigxml",
                'otherfile': 'otherfilecontent'
            })

        dsrc = self._get_ds({'ovfcontent': new_ovfenv})
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, b"NEW_USERDATA")
        self.assertTrue(
            os.path.exists(os.path.join(self.waagent_d, 'otherfile')))
        self.assertFalse(
            os.path.exists(os.path.join(self.waagent_d, 'SharedConfig.xml')))
        self.assertTrue(
            os.path.exists(os.path.join(self.waagent_d, 'ovf-env.xml')))
        new_xml = load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))
        self.xml_equals(new_ovfenv, new_xml)

    def test_exception_fetching_fabric_data_doesnt_propagate(self):
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ds.ds_cfg['agent_command'] = '__builtin__'
        self.get_metadata_from_fabric.side_effect = Exception
        self.assertFalse(ds.get_data())

    def test_fabric_data_included_in_metadata(self):
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ds.ds_cfg['agent_command'] = '__builtin__'
        self.get_metadata_from_fabric.return_value = {'test': 'value'}
        ret = ds.get_data()
        self.assertTrue(ret)
        self.assertEqual('value', ds.metadata['test'])
class TestAzureDataSource(TestCase):

    def setUp(self):
        super(TestAzureDataSource, self).setUp()
        self.tmp = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, self.tmp)

        # patch cloud_dir, so our 'seed_dir' is guaranteed empty
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')

        self.patches = ExitStack()
        self.addCleanup(self.patches.close)

        super(TestAzureDataSource, self).setUp()

    def apply_patches(self, patches):
        for module, name, new in patches:
            self.patches.enter_context(mock.patch.object(module, name, new))

    def _get_ds(self, data):

        def dsdevs():
            return data.get('dsdevs', [])

        def _invoke_agent(cmd):
            data['agent_invoked'] = cmd

        def _wait_for_files(flist, _maxwait=None, _naplen=None):
            data['waited'] = flist
            return []

        def _pubkeys_from_crt_files(flist):
            data['pubkey_files'] = flist
            return ["pubkey_from: %s" % f for f in flist]

        def _iid_from_shared_config(path):
            data['iid_from_shared_cfg'] = path
            return 'i-my-azure-id'

        def _apply_hostname_bounce(**kwargs):
            data['apply_hostname_bounce'] = kwargs

        if data.get('ovfcontent') is not None:
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
                         {'ovf-env.xml': data['ovfcontent']})

        mod = DataSourceAzure
        mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d

        self.apply_patches([
            (mod, 'list_possible_azure_ds_devs', dsdevs),
            (mod, 'invoke_agent', _invoke_agent),
            (mod, 'wait_for_files', _wait_for_files),
            (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),
            (mod, 'iid_from_shared_config', _iid_from_shared_config),
            (mod, 'apply_hostname_bounce', _apply_hostname_bounce),
            ])

        dsrc = mod.DataSourceAzureNet(
            data.get('sys_cfg', {}), distro=None, paths=self.paths)

        return dsrc

    def test_basic_seed_dir(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata),
                'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, "")
        self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName'])
        self.assertTrue(os.path.isfile(
            os.path.join(self.waagent_d, 'ovf-env.xml')))
        self.assertEqual(dsrc.metadata['instance-id'], 'i-my-azure-id')

    def test_waagent_d_has_0700_perms(self):
        # we expect /var/lib/waagent to be created 0700
        dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue(os.path.isdir(self.waagent_d))
        self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0o700)

    def test_user_cfg_set_agent_command_plain(self):
        # set dscfg in via plaintext
        # we must have friendly-to-xml formatted plaintext in yaml_cfg
        # not all plaintext is expected to work.
        yaml_cfg = "{agent_command: my_command}\n"
        cfg = yaml.safe_load(yaml_cfg)
        odata = {'HostName': "myhost", 'UserName': "******",
                'dscfg': {'text': yaml_cfg, 'encoding': 'plain'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])

    def test_user_cfg_set_agent_command(self):
        # set dscfg in via base64 encoded yaml
        cfg = {'agent_command': "my_command"}
        odata = {'HostName': "myhost", 'UserName': "******",
                'dscfg': {'text': b64e(yaml.dump(cfg)),
                          'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])

    def test_sys_cfg_set_agent_command(self):
        sys_cfg = {'datasource': {'Azure': {'agent_command': '_COMMAND'}}}
        data = {'ovfcontent': construct_valid_ovf_env(data={}),
                'sys_cfg': sys_cfg}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(data['agent_invoked'], '_COMMAND')

    def test_username_used(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.cfg['system_info']['default_user']['name'],
                         "myuser")

    def test_password_given(self):
        odata = {'HostName': "myhost", 'UserName': "******",
                 'UserPassword': "******"}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue('default_user' in dsrc.cfg['system_info'])
        defuser = dsrc.cfg['system_info']['default_user']

        # default user should be updated username and should not be locked.
        self.assertEqual(defuser['name'], odata['UserName'])
        self.assertFalse(defuser['lock_passwd'])
        # passwd is crypt formated string $id$salt$encrypted
        # encrypting plaintext with salt value of everything up to final '$'
        # should equal that after the '$'
        pos = defuser['passwd'].rfind("$") + 1
        self.assertEqual(defuser['passwd'],
            crypt.crypt(odata['UserPassword'], defuser['passwd'][0:pos]))

    def test_userdata_plain(self):
        mydata = "FOOBAR"
        odata = {'UserData': {'text': mydata, 'encoding': 'plain'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(decode_binary(dsrc.userdata_raw), mydata)

    def test_userdata_found(self):
        mydata = "FOOBAR"
        odata = {'UserData': {'text': b64e(mydata), 'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8'))

    def test_no_datasource_expected(self):
        # no source should be found if no seed_dir and no devs
        data = {}
        dsrc = self._get_ds({})
        ret = dsrc.get_data()
        self.assertFalse(ret)
        self.assertFalse('agent_invoked' in data)

    def test_cfg_has_pubkeys(self):
        odata = {'HostName': "myhost", 'UserName': "******"}
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1'}]
        pubkeys = [(x['fingerprint'], x['path']) for x in mypklist]
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
                                                      pubkeys=pubkeys)}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        for mypk in mypklist:
            self.assertIn(mypk, dsrc.cfg['_pubkeys'])

    def test_disabled_bounce(self):
        pass

    def test_apply_bounce_call_1(self):
        # hostname needs to get through to apply_hostname_bounce
        odata = {'HostName': 'my-random-hostname'}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        self._get_ds(data).get_data()
        self.assertIn('hostname', data['apply_hostname_bounce'])
        self.assertEqual(data['apply_hostname_bounce']['hostname'],
                         odata['HostName'])

    def test_apply_bounce_call_configurable(self):
        # hostname_bounce should be configurable in datasource cfg
        cfg = {'hostname_bounce': {'interface': 'eth1', 'policy': 'off',
                                   'command': 'my-bounce-command',
                                   'hostname_command': 'my-hostname-command'}}
        odata = {'HostName': "xhost",
                'dscfg': {'text': b64e(yaml.dump(cfg)),
                          'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
        self._get_ds(data).get_data()

        for k in cfg['hostname_bounce']:
            self.assertIn(k, data['apply_hostname_bounce'])

        for k, v in cfg['hostname_bounce'].items():
            self.assertEqual(data['apply_hostname_bounce'][k], v)

    def test_set_hostname_disabled(self):
        # config specifying set_hostname off should not bounce
        cfg = {'set_hostname': False}
        odata = {'HostName': "xhost",
                'dscfg': {'text': b64e(yaml.dump(cfg)),
                          'encoding': 'base64'}}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
        self._get_ds(data).get_data()

        self.assertEqual(data.get('apply_hostname_bounce', "N/A"), "N/A")

    def test_default_ephemeral(self):
        # make sure the ephemeral device works
        odata = {}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata),
                'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        cfg = dsrc.get_config_obj()

        self.assertEquals(dsrc.device_name_to_device("ephemeral0"),
                          "/dev/sdb")
        assert 'disk_setup' in cfg
        assert 'fs_setup' in cfg
        self.assertIsInstance(cfg['disk_setup'], dict)
        self.assertIsInstance(cfg['fs_setup'], list)

    def test_provide_disk_aliases(self):
        # Make sure that user can affect disk aliases
        dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}}
        odata = {'HostName': "myhost", 'UserName': "******",
                'dscfg': {'text': b64e(yaml.dump(dscfg)),
                          'encoding': 'base64'}}
        usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'},
                                  'ephemeral0': False}}
        userdata = '#cloud-config' + yaml.dump(usercfg) + "\n"

        ovfcontent = construct_valid_ovf_env(data=odata, userdata=userdata)
        data = {'ovfcontent': ovfcontent, 'sys_cfg': {}}

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        cfg = dsrc.get_config_obj()
        self.assertTrue(cfg)

    def test_userdata_arrives(self):
        userdata = "This is my user-data"
        xml = construct_valid_ovf_env(data={}, userdata=userdata)
        data = {'ovfcontent': xml}
        dsrc = self._get_ds(data)
        dsrc.get_data()

        self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw)

    def test_ovf_env_arrives_in_waagent_dir(self):
        xml = construct_valid_ovf_env(data={}, userdata="FOODATA")
        dsrc = self._get_ds({'ovfcontent': xml})
        dsrc.get_data()

        # 'data_dir' is '/var/lib/waagent' (walinux-agent's state dir)
        # we expect that the ovf-env.xml file is copied there.
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')
        self.assertTrue(os.path.exists(ovf_env_path))
        self.assertEqual(xml, load_file(ovf_env_path))

    def test_ovf_can_include_unicode(self):
        xml = construct_valid_ovf_env(data={})
        xml = u'\ufeff{0}'.format(xml)
        dsrc = self._get_ds({'ovfcontent': xml})
        dsrc.get_data()

    def test_existing_ovf_same(self):
        # waagent/SharedConfig left alone if found ovf-env.xml same as cached
        odata = {'UserData': b64e("SOMEUSERDATA")}
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}

        populate_dir(self.waagent_d,
            {'ovf-env.xml': data['ovfcontent'],
             'otherfile': 'otherfile-content',
             'SharedConfig.xml': 'mysharedconfig'})

        dsrc = self._get_ds(data)
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertTrue(os.path.exists(
            os.path.join(self.waagent_d, 'ovf-env.xml')))
        self.assertTrue(os.path.exists(
            os.path.join(self.waagent_d, 'otherfile')))
        self.assertTrue(os.path.exists(
            os.path.join(self.waagent_d, 'SharedConfig.xml')))

    def test_existing_ovf_diff(self):
        # waagent/SharedConfig must be removed if ovfenv is found elsewhere

        # 'get_data' should remove SharedConfig.xml in /var/lib/waagent
        # if ovf-env.xml differs.
        cached_ovfenv = construct_valid_ovf_env(
            {'userdata': b64e("FOO_USERDATA")})
        new_ovfenv = construct_valid_ovf_env(
            {'userdata': b64e("NEW_USERDATA")})

        populate_dir(self.waagent_d,
            {'ovf-env.xml': cached_ovfenv,
             'SharedConfig.xml': "mysharedconfigxml",
             'otherfile': 'otherfilecontent'})

        dsrc = self._get_ds({'ovfcontent': new_ovfenv})
        ret = dsrc.get_data()
        self.assertTrue(ret)
        self.assertEqual(dsrc.userdata_raw, b"NEW_USERDATA")
        self.assertTrue(os.path.exists(
            os.path.join(self.waagent_d, 'otherfile')))
        self.assertFalse(
            os.path.exists(os.path.join(self.waagent_d, 'SharedConfig.xml')))
        self.assertTrue(
            os.path.exists(os.path.join(self.waagent_d, 'ovf-env.xml')))
        self.assertEqual(new_ovfenv,
            load_file(os.path.join(self.waagent_d, 'ovf-env.xml')))
Exemple #58
0
class TestAzureBounce(TestCase):
    def mock_out_azure_moving_parts(self):
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'invoke_agent'))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'wait_for_files'))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'iid_from_shared_config',
                              mock.MagicMock(return_value='i-my-azure-id')))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs',
                              mock.MagicMock(return_value=[])))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure,
                              'find_fabric_formatted_ephemeral_disk',
                              mock.MagicMock(return_value=None)))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure,
                              'find_fabric_formatted_ephemeral_part',
                              mock.MagicMock(return_value=None)))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric',
                              mock.MagicMock(return_value={})))

    def setUp(self):
        super(TestAzureBounce, self).setUp()
        self.tmp = tempfile.mkdtemp()
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
        self.addCleanup(shutil.rmtree, self.tmp)
        DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
        self.patches = ExitStack()
        self.mock_out_azure_moving_parts()
        self.get_hostname = self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'get_hostname'))
        self.set_hostname = self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'set_hostname'))
        self.subp = self.patches.enter_context(
            mock.patch('cloudinit.sources.DataSourceAzure.util.subp'))

    def tearDown(self):
        self.patches.close()

    def _get_ds(self, ovfcontent=None):
        if ovfcontent is not None:
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
                         {'ovf-env.xml': ovfcontent})
        return DataSourceAzure.DataSourceAzureNet({},
                                                  distro=None,
                                                  paths=self.paths)

    def get_ovf_env_with_dscfg(self, hostname, cfg):
        odata = {
            'HostName': hostname,
            'dscfg': {
                'text': b64e(yaml.dump(cfg)),
                'encoding': 'base64'
            }
        }
        return construct_valid_ovf_env(data=odata)

    def test_disabled_bounce_does_not_change_hostname(self):
        cfg = {'hostname_bounce': {'policy': 'off'}}
        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
        self.assertEqual(0, self.set_hostname.call_count)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_disabled_bounce_does_not_perform_bounce(self,
                                                     perform_hostname_bounce):
        cfg = {'hostname_bounce': {'policy': 'off'}}
        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
        self.assertEqual(0, perform_hostname_bounce.call_count)

    def test_same_hostname_does_not_change_hostname(self):
        host_name = 'unchanged-host-name'
        self.get_hostname.return_value = host_name
        cfg = {'hostname_bounce': {'policy': 'yes'}}
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
        self.assertEqual(0, self.set_hostname.call_count)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_unchanged_hostname_does_not_perform_bounce(
            self, perform_hostname_bounce):
        host_name = 'unchanged-host-name'
        self.get_hostname.return_value = host_name
        cfg = {'hostname_bounce': {'policy': 'yes'}}
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
        self.assertEqual(0, perform_hostname_bounce.call_count)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_force_performs_bounce_regardless(self, perform_hostname_bounce):
        host_name = 'unchanged-host-name'
        self.get_hostname.return_value = host_name
        cfg = {'hostname_bounce': {'policy': 'force'}}
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
        self.assertEqual(1, perform_hostname_bounce.call_count)

    def test_different_hostnames_sets_hostname(self):
        expected_hostname = 'azure-expected-host-name'
        self.get_hostname.return_value = 'default-host-name'
        self._get_ds(self.get_ovf_env_with_dscfg(expected_hostname,
                                                 {})).get_data()
        self.assertEqual(expected_hostname,
                         self.set_hostname.call_args_list[0][0][0])

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_different_hostnames_performs_bounce(self,
                                                 perform_hostname_bounce):
        expected_hostname = 'azure-expected-host-name'
        self.get_hostname.return_value = 'default-host-name'
        self._get_ds(self.get_ovf_env_with_dscfg(expected_hostname,
                                                 {})).get_data()
        self.assertEqual(1, perform_hostname_bounce.call_count)

    def test_different_hostnames_sets_hostname_back(self):
        initial_host_name = 'default-host-name'
        self.get_hostname.return_value = initial_host_name
        self._get_ds(self.get_ovf_env_with_dscfg('some-host-name',
                                                 {})).get_data()
        self.assertEqual(initial_host_name,
                         self.set_hostname.call_args_list[-1][0][0])

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_failure_in_bounce_still_resets_host_name(self,
                                                      perform_hostname_bounce):
        perform_hostname_bounce.side_effect = Exception
        initial_host_name = 'default-host-name'
        self.get_hostname.return_value = initial_host_name
        self._get_ds(self.get_ovf_env_with_dscfg('some-host-name',
                                                 {})).get_data()
        self.assertEqual(initial_host_name,
                         self.set_hostname.call_args_list[-1][0][0])

    def test_environment_correct_for_bounce_command(self):
        interface = 'int0'
        hostname = 'my-new-host'
        old_hostname = 'my-old-host'
        self.get_hostname.return_value = old_hostname
        cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg(hostname, cfg)
        self._get_ds(data).get_data()
        self.assertEqual(1, self.subp.call_count)
        bounce_env = self.subp.call_args[1]['env']
        self.assertEqual(interface, bounce_env['interface'])
        self.assertEqual(hostname, bounce_env['hostname'])
        self.assertEqual(old_hostname, bounce_env['old_hostname'])

    def test_default_bounce_command_used_by_default(self):
        cmd = 'default-bounce-command'
        DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd
        cfg = {'hostname_bounce': {'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
        self._get_ds(data).get_data()
        self.assertEqual(1, self.subp.call_count)
        bounce_args = self.subp.call_args[1]['args']
        self.assertEqual(cmd, bounce_args)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_set_hostname_option_can_disable_bounce(self,
                                                    perform_hostname_bounce):
        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
        self._get_ds(data).get_data()

        self.assertEqual(0, perform_hostname_bounce.call_count)

    def test_set_hostname_option_can_disable_hostname_set(self):
        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
        self._get_ds(data).get_data()

        self.assertEqual(0, self.set_hostname.call_count)
Exemple #59
0
class TestNoCloudDataSource(TestCase):
    def setUp(self):
        super(TestNoCloudDataSource, self).setUp()
        self.tmp = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, self.tmp)
        self.paths = helpers.Paths({'cloud_dir': self.tmp})

        self.cmdline = "root=TESTCMDLINE"

        self.mocks = ExitStack()
        self.addCleanup(self.mocks.close)

        self.mocks.enter_context(
            mock.patch.object(util, 'get_cmdline', return_value=self.cmdline))

    def test_nocloud_seed_dir(self):
        md = {'instance-id': 'IID', 'dsmode': 'local'}
        ud = b"USER_DATA_HERE"
        populate_dir(os.path.join(self.paths.seed_dir, "nocloud"), {
            'user-data': ud,
            'meta-data': yaml.safe_dump(md)
        })

        sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}}

        ds = DataSourceNoCloud.DataSourceNoCloud

        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        ret = dsrc.get_data()
        self.assertEqual(dsrc.userdata_raw, ud)
        self.assertEqual(dsrc.metadata, md)
        self.assertTrue(ret)

    def test_fs_label(self):
        # find_devs_with should not be called ff fs_label is None
        ds = DataSourceNoCloud.DataSourceNoCloud

        class PsuedoException(Exception):
            pass

        def my_find_devs_with(*args, **kwargs):
            raise PsuedoException

        self.mocks.enter_context(
            mock.patch.object(util,
                              'find_devs_with',
                              side_effect=PsuedoException))

        # by default, NoCloud should search for filesystems by label
        sys_cfg = {'datasource': {'NoCloud': {}}}
        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        self.assertRaises(PsuedoException, dsrc.get_data)

        # but disabling searching should just end up with None found
        sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}}
        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        ret = dsrc.get_data()
        self.assertFalse(ret)

    def test_no_datasource_expected(self):
        # no source should be found if no cmdline, config, and fs_label=None
        sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}}

        ds = DataSourceNoCloud.DataSourceNoCloud
        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        self.assertFalse(dsrc.get_data())

    def test_seed_in_config(self):
        ds = DataSourceNoCloud.DataSourceNoCloud

        data = {
            'fs_label': None,
            'meta-data': yaml.safe_dump({'instance-id': 'IID'}),
            'user-data': b"USER_DATA_RAW",
        }

        sys_cfg = {'datasource': {'NoCloud': data}}
        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        ret = dsrc.get_data()
        self.assertEqual(dsrc.userdata_raw, b"USER_DATA_RAW")
        self.assertEqual(dsrc.metadata.get('instance-id'), 'IID')
        self.assertTrue(ret)

    def test_nocloud_seed_with_vendordata(self):
        md = {'instance-id': 'IID', 'dsmode': 'local'}
        ud = b"USER_DATA_HERE"
        vd = b"THIS IS MY VENDOR_DATA"

        populate_dir(os.path.join(self.paths.seed_dir, "nocloud"), {
            'user-data': ud,
            'meta-data': yaml.safe_dump(md),
            'vendor-data': vd
        })

        sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}}

        ds = DataSourceNoCloud.DataSourceNoCloud

        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        ret = dsrc.get_data()
        self.assertEqual(dsrc.userdata_raw, ud)
        self.assertEqual(dsrc.metadata, md)
        self.assertEqual(dsrc.vendordata_raw, vd)
        self.assertTrue(ret)

    def test_nocloud_no_vendordata(self):
        populate_dir(os.path.join(self.paths.seed_dir, "nocloud"), {
            'user-data': b"ud",
            'meta-data': "instance-id: IID\n"
        })

        sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}}

        ds = DataSourceNoCloud.DataSourceNoCloud

        dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
        ret = dsrc.get_data()
        self.assertEqual(dsrc.userdata_raw, b"ud")
        self.assertFalse(dsrc.vendordata)
        self.assertTrue(ret)
Exemple #60
0
class Flow:
    def __init__(self, args: 'argparse.Namespace' = None, **kwargs):
        """Initialize a flow object

        :param kwargs: other keyword arguments that will be shared by all pods in this flow


        More explain on ``optimize_level``:

        As an example, the following flow will generate 6 Peas,

        .. highlight:: python
        .. code-block:: python

            f = Flow(optimize_level=FlowOptimizeLevel.NONE).add(uses='forward', parallel=3)

        The optimized version, i.e. :code:`Flow(optimize_level=FlowOptimizeLevel.FULL)`
        will generate 4 Peas, but it will force the :class:`GatewayPea` to take BIND role,
        as the head and tail routers are removed.
        
        """
        self.logger = get_logger(self.__class__.__name__)
        self._pod_nodes = OrderedDict()  # type: Dict[str, 'FlowPod']
        self._build_level = FlowBuildLevel.EMPTY
        self._pod_name_counter = 0
        self._last_changed_pod = [
            'gateway'
        ]  #: default first pod is gateway, will add when build()

        self._update_args(args, **kwargs)

    def _update_args(self, args, **kwargs):
        from ..main.parser import set_flow_parser
        _flow_parser = set_flow_parser()
        if args is None:
            from ..helper import get_parsed_args
            _, args, _ = get_parsed_args(kwargs, _flow_parser, 'Flow')

        self.args = args
        if kwargs and self.args.logserver and 'log_sse' not in kwargs:
            kwargs['log_sse'] = True
        self._common_kwargs = kwargs
        self._kwargs = get_non_defaults_args(args,
                                             _flow_parser)  #: for yaml dump

    @classmethod
    def to_yaml(cls, representer, data):
        """Required by :mod:`ruamel.yaml.constructor` """
        tmp = data._dump_instance_to_yaml(data)
        representer.sort_base_mapping_type_on_output = False
        return representer.represent_mapping('!' + cls.__name__, tmp)

    @staticmethod
    def _dump_instance_to_yaml(data):
        # note: we only save non-default property for the sake of clarity
        r = {}

        if data._kwargs:
            r['with'] = data._kwargs

        if data._pod_nodes:
            r['pods'] = {}

        if 'gateway' in data._pod_nodes:
            # always dump gateway as the first pod, if exist
            r['pods']['gateway'] = {}

        for k, v in data._pod_nodes.items():
            if k == 'gateway':
                continue

            kwargs = {'needs': list(v.needs)} if v.needs else {}
            kwargs.update(v._kwargs)

            if 'name' in kwargs:
                kwargs.pop('name')

            r['pods'][k] = kwargs
        return r

    @classmethod
    def from_yaml(cls, constructor, node):
        """Required by :mod:`ruamel.yaml.constructor` """
        return cls._get_instance_from_yaml(constructor, node)[0]

    def save_config(self, filename: str = None) -> bool:
        """
        Serialize the object to a yaml file

        :param filename: file path of the yaml file, if not given then :attr:`config_abspath` is used
        :return: successfully dumped or not
        """
        f = filename
        if not f:
            f = tempfile.NamedTemporaryFile('w',
                                            delete=False,
                                            dir=os.environ.get(
                                                'JINA_EXECUTOR_WORKDIR',
                                                None)).name
        yaml.register_class(Flow)
        # yaml.sort_base_mapping_type_on_output = False
        # yaml.representer.add_representer(OrderedDict, yaml.Representer.represent_dict)

        with open(f, 'w', encoding='utf8') as fp:
            yaml.dump(self, fp)
        self.logger.info(f'{self}\'s yaml config is save to %s' % f)
        return True

    @property
    def yaml_spec(self):
        yaml.register_class(Flow)
        stream = StringIO()
        yaml.dump(self, stream)
        return stream.getvalue().strip()

    @classmethod
    def load_config(cls: Type['Flow'], filename: Union[str, TextIO]) -> 'Flow':
        """Build an executor from a YAML file.

        :param filename: the file path of the YAML file or a ``TextIO`` stream to be loaded from
        :return: an executor object
        """
        yaml.register_class(Flow)
        if not filename: raise FileNotFoundError
        if isinstance(filename, str):
            # deserialize from the yaml
            with open(filename, encoding='utf8') as fp:
                return yaml.load(fp)
        else:
            with filename:
                return yaml.load(filename)

    @classmethod
    def _get_instance_from_yaml(cls, constructor, node):

        data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
            constructor, node, deep=True)

        p = data.get('with', {})  # type: Dict[str, Any]
        a = p.pop('args') if 'args' in p else ()
        k = p.pop('kwargs') if 'kwargs' in p else {}
        # maybe there are some hanging kwargs in "parameters"
        tmp_a = (expand_env_var(v) for v in a)
        tmp_p = {kk: expand_env_var(vv) for kk, vv in {**k, **p}.items()}
        obj = cls(*tmp_a, **tmp_p)

        pp = data.get('pods', {})
        for pod_name, pod_attr in pp.items():
            p_pod_attr = {
                kk: expand_env_var(vv)
                for kk, vv in pod_attr.items()
            }
            if pod_name != 'gateway':
                # ignore gateway when reading, it will be added during build()
                obj.add(name=pod_name, **p_pod_attr, copy_flow=False)

        obj.logger.success(
            f'successfully built {cls.__name__} from a yaml config')

        # if node.tag in {'!CompoundExecutor'}:
        #     os.environ['JINA_WARN_UNNAMED'] = 'YES'

        return obj, data

    @staticmethod
    def _parse_endpoints(op_flow,
                         pod_name,
                         endpoint,
                         connect_to_last_pod=False) -> Set:
        # parsing needs
        if isinstance(endpoint, str):
            endpoint = [endpoint]
        elif not endpoint:
            if op_flow._last_changed_pod and connect_to_last_pod:
                endpoint = [op_flow._last_changed_pod[-1]]
            else:
                endpoint = []

        if isinstance(endpoint, list) or isinstance(endpoint, tuple):
            for idx, s in enumerate(endpoint):
                if s == pod_name:
                    raise FlowTopologyError(
                        'the income/output of a pod can not be itself')
        else:
            raise ValueError(f'endpoint={endpoint} is not parsable')
        return set(endpoint)

    def set_last_pod(self, name: str, copy_flow: bool = True) -> 'Flow':
        """
        Set a pod as the last pod in the flow, useful when modifying the flow.

        :param name: the name of the existing pod
        :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification
        :return: a (new) flow object with modification
        """
        op_flow = copy.deepcopy(self) if copy_flow else self

        if name not in op_flow._pod_nodes:
            raise FlowMissingPodError(f'{name} can not be found in this Flow')

        if op_flow._last_changed_pod and name == op_flow._last_changed_pod[-1]:
            pass
        else:
            op_flow._last_changed_pod.append(name)

        # graph is now changed so we need to
        # reset the build level to the lowest
        op_flow._build_level = FlowBuildLevel.EMPTY

        return op_flow

    def _add_gateway(self, needs, **kwargs):
        pod_name = 'gateway'

        kwargs.update(self._common_kwargs)
        kwargs['name'] = 'gateway'
        self._pod_nodes[pod_name] = GatewayFlowPod(kwargs, needs)

    def join(self, needs: Union[Tuple[str], List[str]], *args,
             **kwargs) -> 'Flow':
        """
        Add a blocker to the flow, wait until all peas defined in **needs** completed.

        :param needs: list of service names to wait
        :return: the modified flow
        """
        if len(needs) <= 1:
            raise FlowTopologyError(
                'no need to wait for a single service, need len(needs) > 1')
        return self.add(name='joiner',
                        uses='_merge',
                        needs=needs,
                        *args,
                        **kwargs)

    def add(self,
            needs: Union[str, Tuple[str], List[str]] = None,
            copy_flow: bool = True,
            **kwargs) -> 'Flow':
        """
        Add a pod to the current flow object and return the new modified flow object.
        The attribute of the pod can be later changed with :py:meth:`set` or deleted with :py:meth:`remove`

        Note there are shortcut versions of this method.
        Recommend to use :py:meth:`add_encoder`, :py:meth:`add_preprocessor`,
        :py:meth:`add_router`, :py:meth:`add_indexer` whenever possible.

        :param needs: the name of the pod(s) that this pod receives data from.
                           One can also use 'pod.Gateway' to indicate the connection with the gateway.
        :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification
        :param kwargs: other keyword-value arguments that the pod CLI supports
        :return: a (new) flow object with modification
        """

        op_flow = copy.deepcopy(self) if copy_flow else self

        pod_name = kwargs.get('name', None)

        if pod_name in op_flow._pod_nodes:
            raise FlowTopologyError(
                f'name: {pod_name} is used in this Flow already!')

        if not pod_name:
            pod_name = '%s%d' % ('pod', op_flow._pod_name_counter)
            op_flow._pod_name_counter += 1

        if not pod_name.isidentifier():
            # hyphen - can not be used in the name
            raise ValueError(
                f'name: {pod_name} is invalid, please follow the python variable name conventions'
            )

        needs = op_flow._parse_endpoints(op_flow,
                                         pod_name,
                                         needs,
                                         connect_to_last_pod=True)

        kwargs.update(op_flow._common_kwargs)
        kwargs['name'] = pod_name
        op_flow._pod_nodes[pod_name] = FlowPod(kwargs=kwargs, needs=needs)

        op_flow.set_last_pod(pod_name, False)

        return op_flow

    def build(self, copy_flow: bool = False) -> 'Flow':
        """
        Build the current flow and make it ready to use

        .. note::

            No need to manually call it since 0.0.8. When using flow with the
            context manager, or using :meth:`start`, :meth:`build` will be invoked.

        :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification
        :return: the current flow (by default)

        .. note::
            ``copy_flow=True`` is recommended if you are building the same flow multiple times in a row. e.g.

            .. highlight:: python
            .. code-block:: python

                f = Flow()
                with f:
                    f.index()

                with f.build(copy_flow=True) as fl:
                    fl.search()

        """

        op_flow = copy.deepcopy(self) if copy_flow else self

        _pod_edges = set()

        if 'gateway' not in op_flow._pod_nodes:
            op_flow._add_gateway(needs={op_flow._last_changed_pod[-1]})

        # construct a map with a key a start node and values an array of its end nodes
        _outgoing_map = defaultdict(list)
        for end, pod in op_flow._pod_nodes.items():
            for start in pod.needs:
                if start not in op_flow._pod_nodes:
                    raise FlowMissingPodError(
                        f'{start} is not in this flow, misspelled name?')
                _outgoing_map[start].append(end)
                _pod_edges.add((start, end))

        op_flow = _build_flow(op_flow, _outgoing_map)
        op_flow = _optimize_flow(op_flow, _outgoing_map, _pod_edges)
        op_flow._build_level = FlowBuildLevel.GRAPH
        return op_flow

    def __call__(self, *args, **kwargs):
        return self.build(*args, **kwargs)

    def __enter__(self):
        return self.start()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _start_log_server(self):
        try:
            import urllib.request
            import flask, flask_cors
            self._sse_logger = threading.Thread(
                name='sentinel-sse-logger',
                target=start_sse_logger,
                daemon=True,
                args=(self.args.logserver_config, self.yaml_spec))
            self._sse_logger.start()
            time.sleep(1)
            urllib.request.urlopen(JINA_GLOBAL.logserver.ready, timeout=5)
            self.logger.success(
                f'logserver is started and available at {JINA_GLOBAL.logserver.address}'
            )
        except ModuleNotFoundError:
            self.logger.error(
                f'sse logserver can not start because of "flask" and "flask_cors" are missing, '
                f'use pip install "jina[http]" (with double quotes) to install the dependencies'
            )
        except:
            self.logger.error('logserver fails to start')

    def start(self):
        """Start to run all Pods in this Flow.

        Remember to close the Flow with :meth:`close`.

        Note that this method has a timeout of ``timeout_ready`` set in CLI,
        which is inherited all the way from :class:`jina.peapods.peas.BasePea`
        """

        if self._build_level.value < FlowBuildLevel.GRAPH.value:
            self.build(copy_flow=False)

        if self.args.logserver:
            self.logger.info('start logserver...')
            self._start_log_server()

        self._pod_stack = ExitStack()
        for v in self._pod_nodes.values():
            self._pod_stack.enter_context(v)

        self.logger.info('%d Pods (i.e. %d Peas) are running in this Flow' %
                         (self.num_pods, self.num_peas))

        self.logger.success(
            f'flow is now ready for use, current build_level is {self._build_level}'
        )

        return self

    @property
    def num_pods(self) -> int:
        """Get the number of pods in this flow"""
        return len(self._pod_nodes)

    @property
    def num_peas(self) -> int:
        """Get the number of peas (parallel count) in this flow"""
        return sum(v.num_peas for v in self._pod_nodes.values())

    def close(self):
        """Close the flow and release all resources associated to it. """
        if hasattr(self, '_pod_stack'):
            self._pod_stack.close()
        self._build_level = FlowBuildLevel.EMPTY
        self.logger.success(
            f'flow is closed and all resources should be released already, current build level is {self._build_level}'
        )

    def __eq__(self, other: 'Flow'):
        """
        Comparing the topology of a flow with another flow.
        Identification is defined by whether two flows share the same set of edges.

        :param other: the second flow object
        """

        if self._build_level.value < FlowBuildLevel.GRAPH.value:
            a = self.build()
        else:
            a = self

        if other._build_level.value < FlowBuildLevel.GRAPH.value:
            b = other.build()
        else:
            b = other

        return a._pod_nodes == b._pod_nodes

    @build_required(FlowBuildLevel.GRAPH)
    def _get_client(self, **kwargs):
        kwargs.update(self._common_kwargs)
        from ..clients import py_client
        if 'port_expose' not in kwargs:
            kwargs['port_expose'] = self.port_expose
        if 'host' not in kwargs:
            kwargs['host'] = self.host
        return py_client(**kwargs)

    @deprecated_alias(buffer='input_fn', callback='output_fn')
    def train(self,
              input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes],
                              Callable] = None,
              output_fn: Callable[['jina_pb2.Message'], None] = None,
              **kwargs):
        """Do training on the current flow

        It will start a :py:class:`CLIClient` and call :py:func:`train`.

        Example,

        .. highlight:: python
        .. code-block:: python

            with f:
                f.train(input_fn)
                ...


        This will call the pre-built reader to read files into an iterator of bytes and feed to the flow.

        One may also build a reader/generator on your own.

        Example,

        .. highlight:: python
        .. code-block:: python

            def my_reader():
                for _ in range(10):
                    yield b'abcdfeg'   # each yield generates a document for training

            with f.build(runtime='thread') as flow:
                flow.train(bytes_gen=my_reader())

        :param input_fn: An iterator of bytes. If not given, then you have to specify it in **kwargs**.
        :param output_fn: the callback function to invoke after training
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        self._get_client(**kwargs).train(input_fn, output_fn)

    def index_ndarray(self,
                      array: 'np.ndarray',
                      axis: int = 0,
                      size: int = None,
                      shuffle: bool = False,
                      output_fn: Callable[['jina_pb2.Message'], None] = None,
                      **kwargs):
        """Using numpy ndarray as the index source for the current flow

        :param array: the numpy ndarray data source
        :param axis: iterate over that axis
        :param size: the maximum number of the sub arrays
        :param shuffle: shuffle the the numpy data source beforehand
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        from ..clients.python.io import input_numpy
        self._get_client(**kwargs).index(
            input_numpy(array, axis, size, shuffle), output_fn)

    def search_ndarray(self,
                       array: 'np.ndarray',
                       axis: int = 0,
                       size: int = None,
                       shuffle: bool = False,
                       output_fn: Callable[['jina_pb2.Message'], None] = None,
                       **kwargs):
        """Use a numpy ndarray as the query source for searching on the current flow

        :param array: the numpy ndarray data source
        :param axis: iterate over that axis
        :param size: the maximum number of the sub arrays
        :param shuffle: shuffle the the numpy data source beforehand
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        from ..clients.python.io import input_numpy
        self._get_client(**kwargs).search(
            input_numpy(array, axis, size, shuffle), output_fn)

    def index_lines(self,
                    lines: Iterator[str] = None,
                    filepath: str = None,
                    size: int = None,
                    sampling_rate: float = None,
                    read_mode='r',
                    output_fn: Callable[['jina_pb2.Message'], None] = None,
                    **kwargs):
        """ Use a list of lines as the index source for indexing on the current flow

        :param lines: a list of strings, each is considered as d document
        :param filepath: a text file that each line contains a document
        :param size: the maximum number of the documents
        :param sampling_rate: the sampling rate between [0, 1]
        :param read_mode: specifies the mode in which the file
                is opened. 'r' for reading in text mode, 'rb' for reading in binary
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        from ..clients.python.io import input_lines
        self._get_client(**kwargs).index(
            input_lines(lines, filepath, size, sampling_rate, read_mode),
            output_fn)

    def index_files(self,
                    patterns: Union[str, List[str]],
                    recursive: bool = True,
                    size: int = None,
                    sampling_rate: float = None,
                    read_mode: str = None,
                    output_fn: Callable[['jina_pb2.Message'], None] = None,
                    **kwargs):
        """ Use a set of files as the index source for indexing on the current flow

        :param patterns: The pattern may contain simple shell-style wildcards, e.g. '\*.py', '[\*.zip, \*.gz]'
        :param recursive: If recursive is true, the pattern '**' will match any files and
                    zero or more directories and subdirectories.
        :param size: the maximum number of the files
        :param sampling_rate: the sampling rate between [0, 1]
        :param read_mode: specifies the mode in which the file
                is opened. 'r' for reading in text mode, 'rb' for reading in binary mode
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        from ..clients.python.io import input_files
        self._get_client(**kwargs).index(
            input_files(patterns, recursive, size, sampling_rate, read_mode),
            output_fn)

    def search_files(self,
                     patterns: Union[str, List[str]],
                     recursive: bool = True,
                     size: int = None,
                     sampling_rate: float = None,
                     read_mode: str = None,
                     output_fn: Callable[['jina_pb2.Message'], None] = None,
                     **kwargs):
        """ Use a set of files as the query source for searching on the current flow

        :param patterns: The pattern may contain simple shell-style wildcards, e.g. '\*.py', '[\*.zip, \*.gz]'
        :param recursive: If recursive is true, the pattern '**' will match any files and
                    zero or more directories and subdirectories.
        :param size: the maximum number of the files
        :param sampling_rate: the sampling rate between [0, 1]
        :param read_mode: specifies the mode in which the file
                is opened. 'r' for reading in text mode, 'rb' for reading in
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        from ..clients.python.io import input_files
        self._get_client(**kwargs).search(
            input_files(patterns, recursive, size, sampling_rate, read_mode),
            output_fn)

    def search_lines(self,
                     filepath: str = None,
                     lines: Iterator[str] = None,
                     size: int = None,
                     sampling_rate: float = None,
                     read_mode='r',
                     output_fn: Callable[['jina_pb2.Message'], None] = None,
                     **kwargs):
        """ Use a list of files as the query source for searching on the current flow

        :param filepath: a text file that each line contains a document
        :param lines: a list of strings, each is considered as d document
        :param size: the maximum number of the documents
        :param sampling_rate: the sampling rate between [0, 1]
        :param read_mode: specifies the mode in which the file
                is opened. 'r' for reading in text mode, 'rb' for reading in binary
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        from ..clients.python.io import input_lines
        self._get_client(**kwargs).search(
            input_lines(lines, filepath, size, sampling_rate, read_mode),
            output_fn)

    @deprecated_alias(buffer='input_fn', callback='output_fn')
    def index(self,
              input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes],
                              Callable] = None,
              output_fn: Callable[['jina_pb2.Message'], None] = None,
              **kwargs):
        """Do indexing on the current flow

        Example,

        .. highlight:: python
        .. code-block:: python

            with f:
                f.index(input_fn)
                ...


        This will call the pre-built reader to read files into an iterator of bytes and feed to the flow.

        One may also build a reader/generator on your own.

        Example,

        .. highlight:: python
        .. code-block:: python

            def my_reader():
                for _ in range(10):
                    yield b'abcdfeg'  # each yield generates a document to index

            with f.build(runtime='thread') as flow:
                flow.index(bytes_gen=my_reader())

        It will start a :py:class:`CLIClient` and call :py:func:`index`.

        :param input_fn: An iterator of bytes. If not given, then you have to specify it in **kwargs**.
        :param output_fn: the callback function to invoke after indexing
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        self._get_client(**kwargs).index(input_fn, output_fn)

    @deprecated_alias(buffer='input_fn', callback='output_fn')
    def search(self,
               input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes],
                               Callable] = None,
               output_fn: Callable[['jina_pb2.Message'], None] = None,
               **kwargs):
        """Do searching on the current flow

        It will start a :py:class:`CLIClient` and call :py:func:`search`.


        Example,

        .. highlight:: python
        .. code-block:: python

            with f:
                f.search(input_fn)
                ...


        This will call the pre-built reader to read files into an iterator of bytes and feed to the flow.

        One may also build a reader/generator on your own.

        Example,

        .. highlight:: python
        .. code-block:: python

            def my_reader():
                for _ in range(10):
                    yield b'abcdfeg'   # each yield generates a query for searching

            with f.build(runtime='thread') as flow:
                flow.search(bytes_gen=my_reader())

        :param input_fn: An iterator of bytes. If not given, then you have to specify it in **kwargs**.
        :param output_fn: the callback function to invoke after searching
        :param kwargs: accepts all keyword arguments of `jina client` CLI
        """
        self._get_client(**kwargs).search(input_fn, output_fn)

    def dry_run(self, **kwargs):
        """Send a DRYRUN request to this flow, passing through all pods in this flow,
        useful for testing connectivity and debugging"""
        self.logger.warning(
            'calling dry_run() on a flow is depreciated, it will be removed in the future version. '
            'calling index(), search(), train() will trigger a dry_run()')

    @build_required(FlowBuildLevel.GRAPH)
    def to_swarm_yaml(self, path: TextIO):
        """
        Generate the docker swarm YAML compose file

        :param path: the output yaml path
        """
        swarm_yml = {'version': '3.4', 'services': {}}

        for k, v in self._pod_nodes.items():
            swarm_yml['services'][k] = {
                'command': v.to_cli_command(),
                'deploy': {
                    'parallel': 1
                }
            }

        yaml.dump(swarm_yml, path)

    @property
    @build_required(FlowBuildLevel.GRAPH)
    def port_expose(self):
        return self._pod_nodes['gateway'].port_expose

    @property
    @build_required(FlowBuildLevel.GRAPH)
    def host(self):
        return self._pod_nodes['gateway'].host

    def __iter__(self):
        return self._pod_nodes.values().__iter__()

    def block(self):
        """Block the process until user hits KeyboardInterrupt """
        try:
            self.logger.success(
                f'flow is started at {self.host}:{self.port_expose}, '
                f'you can now use client to send request!')
            threading.Event().wait()
        except KeyboardInterrupt:
            pass

    def use_grpc_gateway(self):
        """Change to use gRPC gateway for IO """
        self._common_kwargs['rest_api'] = False

    def use_rest_gateway(self):
        """Change to use REST gateway for IO """
        self._common_kwargs['rest_api'] = True