Пример #1
0
class DeviceScope(object):
    def __init__(self,
                 builder,
                 pipeline=False,
                 virtualGraph=None,
                 pipelineStage=None,
                 nameScope=None,
                 additional_scopes=None):
        self.builder = builder
        self.pipeline = pipeline
        self.virtualGraph = virtualGraph
        self.pipelineStage = pipelineStage
        self.nameScope = nameScope
        self.additional_scopes = additional_scopes or []

    def __enter__(self):
        self.stack = ExitStack()
        if self.virtualGraph is not None:
            self.stack.enter_context(
                self.builder.virtualGraph(self.virtualGraph))

        if self.pipeline and self.pipelineStage is not None:
            self.stack.enter_context(
                self.builder.pipelineStage(self.pipelineStage))

        if self.nameScope is not None:
            self.stack.enter_context(self.builder.nameScope(self.nameScope))
        for scope in self.additional_scopes:
            self.stack.enter_context(scope)
        return self

    def __exit__(self, *exp):
        self.stack.close()
        return False
Пример #2
0
 def _logged_func(self):
     stack = ExitStack()
     self.exc = None
     self.timer = Timer()
     stack.callback(self.timer.stop)
     stack.callback(self.stop)
     try:
         if not self.console_logging:
             stack.enter_context(_logger.suppressed())
         _logger.debug("%s - starting", self)
         while True:
             self._result = self.func(*self.args, **self.kwargs)
             if not self.loop:
                 return
             if self.wait(self.sleep):
                 _logger.debug("%s - stopped", self)
                 return
     except ProcessExiting as exc:
         _logger.debug(exc)
         raise
     except KeyboardInterrupt as exc:
         _logger.silent_exception("KeyboardInterrupt in thread running %s:",
                                  self.func)
         self.exc = apply_timestamp(exc)
         if IS_GEVENT:
             raise  # in gevent we should let this exception propagate to the main greenlet
     except Exception as exc:
         _logger.silent_exception(
             "Exception in thread running %s: %s (traceback can be found in debug-level logs)",
             self.func, type(exc))
         self.exc = apply_timestamp(exc)
     finally:
         stack.close()
Пример #3
0
def _write_video(
    stack: ExitStack,
    max_threads_semaphore: Semaphore,
    video: VideoClip,
    filename: str,
    codec: str,
    fps: float,
    ext: str,
):
    max_threads_semaphore.acquire()
    try:
        video.write_videofile(
            filename,
            fps=fps,
            codec=codec,
            preset=FFMPEG_PRESET,
        )
    except Exception as e:
        print("\r\nVideo (%s) failed to write: maybe not enough memory/disk" %
              filename)
        print(e)
        if os.path.exists(filename):
            os.remove(filename)
        temp_mp3_filename = filename[:-4] + "TEMP_MPY_wvf_snd.mp3"
        if os.path.exists(temp_mp3_filename):
            os.remove(temp_mp3_filename)
        filename = None
    finally:
        stack.close()  # close all video files for this round
        max_threads_semaphore.release()
    return filename
Пример #4
0
class DeviceScope(object):
    def __init__(self,
                 builder,
                 virtualGraph=None,
                 pipelineStage=None,
                 nameScope=None):
        self.builder = builder
        self.virtualGraph = virtualGraph
        self.pipelineStage = pipelineStage
        self.nameScope = nameScope

    def __enter__(self):
        self.stack = ExitStack()
        if self.virtualGraph is not None:
            self.stack.enter_context(
                self.builder.virtualGraph(self.virtualGraph))
        if self.pipelineStage is not None:
            self.stack.enter_context(
                self.builder.pipelineStage(self.pipelineStage))
        if self.nameScope is not None:
            self.stack.enter_context(self.builder.nameScope(self.nameScope))
        return self

    def __exit__(self, *exp):
        self.stack.close()
        return False
Пример #5
0
 def _logged_func(self):
     stack = ExitStack()
     self.exc = None
     self.timer = Timer()
     stack.callback(self.timer.stop)
     stack.callback(self.stop)
     try:
         if not self.console_logging:
             stack.enter_context(_logger.suppressed())
         _logger.debug("%s - starting", self)
         while True:
             self._result = self.func(*self.args, **self.kwargs)
             if not self.loop:
                 return
             if self.wait(self.sleep):
                 _logger.debug("%s - stopped", self)
                 return
     except ProcessExiting as exc:
         _logger.debug(exc)
         raise
     except (KeyboardInterrupt, Exception) as exc:
         _logger.silent_exception(
             "Exception in thread running %s: %s (traceback can be found in debug-level logs)",
             self.func, type(exc))
         self.exc = exc
         try:
             exc.timestamp = time.time()
         except Exception:
             pass
     finally:
         stack.close()
Пример #6
0
def combine_contexts(*managers: ItemOrList[ContextManager]):
    managers: List[ContextManager] = list(iter_list_items(*managers))
    stack = ExitStack()
    for manager in managers:
        stack.enter_context(manager)
    yield stack
    stack.close()
Пример #7
0
class TestLoadChannelOverHTTPS(unittest.TestCase):
    """channels.json MUST be downloaded over HTTPS.

    Start an HTTP server, no HTTPS server to show the download fails.
    """
    @classmethod
    def setUpClass(cls):
        SystemImagePlugin.controller.set_mode(cert_pem='cert.pem')

    def setUp(self):
        self._stack = ExitStack()
        try:
            self._serverdir = self._stack.enter_context(temporary_directory())
            copy('channel.channels_01.json', self._serverdir, 'channels.json')
            sign(os.path.join(self._serverdir, 'channels.json'),
                 'image-signing.gpg')
        except:
            self._stack.close()
            raise

    def tearDown(self):
        self._stack.close()

    @configuration
    def test_load_channel_over_https_port_with_http_fails(self):
        # We maliciously put an HTTP server on the HTTPS port.
        setup_keyrings()
        state = State()
        # Try to get the blacklist.  This will fail silently since it's okay
        # not to find a blacklist.
        state.run_thru('get_blacklist_1')
        # This will fail to get the channels.json file.
        with make_http_server(self._serverdir, 8943):
            self.assertRaises(FileNotFoundError, next, state)
Пример #8
0
class TestAdapter(Adapter[M]):
    # Tell pytest not to try and collect this class
    __test__ = False

    def __init__(self, processor: "Processor[M]"):
        self.processor = processor
        self.acked: List[M] = []
        self.nacked: List[M] = []
        self.stack = ExitStack()
        self.ctx: Optional[ProcessingContext[M]] = None

    def ack(self, message: M) -> None:
        self.acked.append(message)

    def nack(self, message: M) -> None:
        self.nacked.append(message)

    def send(self, message: M) -> None:
        if self.ctx is None:
            self.ctx = self.stack.enter_context(
                self.processor.context(type(message), self))
        self.ctx.handle(message)

    def close(self) -> None:
        self.stack.close()
Пример #9
0
class DeviceScope:
    """This class have integrated virtualGraph and pipelineStage together
    """

    # record the number if ipu to used for model
    IPUCount = 0

    def __init__(self, pattern):
        self.builder = bF.get_builder()
        self.pattern = pattern
        self.stack = ExitStack()

    def __enter__(self):
        strs = self.pattern.split("_")
        if len(strs) == 2:
            ipu_id, pipeline_id = strs
        elif len(strs) == 1:
            ipu_id, pipeline_id = strs + ['0']
        else:
            raise RuntimeError('unknown input')
        ipu_id, pipeline_id = int(ipu_id), int(pipeline_id)
        DeviceScope.IPUCount = ipu_id if ipu_id > DeviceScope.IPUCount\
            else DeviceScope.IPUCount

        self.stack.enter_context(self.builder.virtualGraph(int(ipu_id)))
        self.stack.enter_context(self.builder.pipelineStage(int(pipeline_id)))
        return self

    def __exit__(self, *p):
        self.stack.close()
        return False
Пример #10
0
def get_temporary_directory() -> Generator[Path, None, None]:
    stack = ExitStack()
    directory = stack.enter_context(TemporaryDirectory())
    path = Path(directory)
    try:
        yield path
    finally:
        stack.close()
Пример #11
0
class BaseGethChain(BaseChain):
    stack = None
    geth = None

    def __init__(self, *args, **kwargs):
        warnings.simplefilter('always', DeprecationWarning)
        warn_msg = (
            "Support for this chain will be dropped in the next populus version"
            "Populus will not run the chains, and will use the better and more robust"
            "Web3.py providers directly, as ExternalChain."
            "Please configure your chains as ExternalChain.")
        warnings.warn(warn_msg, DeprecationWarning)
        warnings.resetwarnings()
        super(BaseGethChain, self).__init__(*args, **kwargs)

    def initialize_chain(self):
        # context manager shenanigans
        self.stack = ExitStack()
        self.geth = self.get_geth_process_instance()

    def get_web3_config(self):
        base_config = super(BaseGethChain, self).get_web3_config()
        config = copy.deepcopy(base_config)
        if not config.get('provider.settings'):
            if issubclass(base_config.provider_class, IPCProvider):
                config['provider.settings.ipc_path'] = self.geth.ipc_path
            elif issubclass(base_config.provider_class, HTTPProvider):
                config[
                    'provider.settings.endpoint_uri'] = "http://127.0.0.1:{0}".format(
                        self.geth.rpc_port, )
            else:
                raise ValueError("Unknown provider type")
        return config

    @property
    def geth_kwargs(self):
        return self.config.get('chain.settings', {})

    def get_geth_process_instance(self):
        raise NotImplementedError("Must be implemented by subclasses")

    def __enter__(self, *args, **kwargs):
        self.stack.enter_context(self.geth)

        if self.geth.is_mining:
            self.geth.wait_for_dag(600)
        if self.geth.ipc_enabled:
            self.geth.wait_for_ipc(60)
        if self.geth.rpc_enabled:
            self.geth.wait_for_rpc(60)

        self._running = True

        return self

    def __exit__(self, *exc_info):
        self.stack.close()
        self._running = False
Пример #12
0
class Fixture(object):

    def __init__(self, goal, workspace):
        self.fixtures = ExitStack()
        self.goal = goal
        self.workspace = workspace

    def __exit__(self, *args):
        self.fixtures.close()
Пример #13
0
 class StreamDecoder:
     def __init__(self, file):
         self._file = file
         self._crc = 0
         self._pipe = PipeWriter()
         self._cleanup = ExitStack()
         coroutine = self._pipe.coroutine(self._receive())
         self._cleanup.enter_context(coroutine)
     
     def close(self):
         self._pipe.close()
         del self._pipe
         self._cleanup.close()
     
     def feed(self, data):
         self._pipe.write(data)
     
     def _receive(self):
         while True:
             data = self._pipe.buffer
             pos = data.find(b"=")
             if pos >= 0:
                 data = data[:pos]
             data = data.replace(b"\r", b"").replace(b"\n", b"")
             data = data.translate(self.TABLE)
             # TODO: check data size overflow
             self._crc = crc32(data, self._crc)
             self._file.write(data)
             if pos >= 0:  # Escape character (equals sign)
                 self._pipe.buffer = self._pipe.buffer[pos + 1:]
                 while True:
                     byte = yield from self._pipe.read_one()
                     if byte not in b"\r\n":
                         break
                 # TODO: check for size overflow
                 [byte] = byte
                 data = bytes(((byte - 64 - 42) & bitmask(8),))
                 self._crc = crc32(data, self._crc)
                 self._file.write(data)
             else:
                 try:
                     self._pipe.buffer = yield
                 except EOFError:
                     break
     
     def flush(self):
         pass
     
     def getCrc32(self):
         return format(self._crc, "08x")
     
     TABLE = bytes(range(256))
     TABLE = TABLE[-42:] + TABLE[:-42]
Пример #14
0
class ScheduleEvaluationContext:
    """Schedule-specific execution context.

    An instance of this class is made available as the first argument to various ScheduleDefinition
    functions. It is passed as the first argument to ``run_config_fn``, ``tags_fn``,
    and ``should_execute``.

    Attributes:
        instance_ref (Optional[InstanceRef]): The serialized instance configured to run the schedule
        scheduled_execution_time (datetime):
            The time in which the execution was scheduled to happen. May differ slightly
            from both the actual execution time and the time at which the run config is computed.
            Not available in all schedulers - currently only set in deployments using
            DagsterDaemonScheduler.
    """

    __slots__ = [
        "_instance_ref", "_scheduled_execution_time", "_exit_stack",
        "_instance"
    ]

    def __init__(self, instance_ref: Optional[InstanceRef],
                 scheduled_execution_time: Optional[datetime]):
        self._exit_stack = ExitStack()
        self._instance = None

        self._instance_ref = check.opt_inst_param(instance_ref, "instance_ref",
                                                  InstanceRef)
        self._scheduled_execution_time = check.opt_inst_param(
            scheduled_execution_time, "scheduled_execution_time", datetime)

    def __enter__(self):
        return self

    def __exit__(self, _exception_type, _exception_value, _traceback):
        self._exit_stack.close()

    @property
    def instance(self) -> "DagsterInstance":
        # self._instance_ref should only ever be None when this ScheduleEvaluationContext was
        # constructed under test.
        if not self._instance_ref:
            raise DagsterInvariantViolationError(
                "Attempted to initialize dagster instance, but no instance reference was provided."
            )
        if not self._instance:
            self._instance = self._exit_stack.enter_context(
                DagsterInstance.from_ref(self._instance_ref))
        return cast(DagsterInstance, self._instance)

    @property
    def scheduled_execution_time(self) -> Optional[datetime]:
        return self._scheduled_execution_time
Пример #15
0
    class StreamDecoder:
        def __init__(self, file):
            self._file = file
            self._crc = 0
            self._pipe = PipeWriter()
            self._cleanup = ExitStack()
            coroutine = self._pipe.coroutine(self._receive())
            self._cleanup.enter_context(coroutine)

        def close(self):
            self._pipe.close()
            del self._pipe
            self._cleanup.close()

        def feed(self, data):
            self._pipe.write(data)

        def _receive(self):
            while True:
                data = self._pipe.buffer
                pos = data.find(b"=")
                if pos >= 0:
                    data = data[:pos]
                data = data.replace(b"\r", b"").replace(b"\n", b"")
                data = data.translate(self.TABLE)
                # TODO: check data size overflow
                self._crc = crc32(data, self._crc)
                self._file.write(data)
                if pos >= 0:  # Escape character (equals sign)
                    self._pipe.buffer = self._pipe.buffer[pos + 1:]
                    while True:
                        byte = yield from self._pipe.read_one()
                        if byte not in b"\r\n":
                            break
                    # TODO: check for size overflow
                    [byte] = byte
                    data = bytes(((byte - 64 - 42) & bitmask(8), ))
                    self._crc = crc32(data, self._crc)
                    self._file.write(data)
                else:
                    try:
                        self._pipe.buffer = yield
                    except EOFError:
                        break

        def flush(self):
            pass

        def getCrc32(self):
            return format(self._crc, "08x")

        TABLE = bytes(range(256))
        TABLE = TABLE[-42:] + TABLE[:-42]
Пример #16
0
class SandboxProcessServer:
    def __init__(self, *, sandbox_dir, executable):
        self.executable = executable

        self.boundary = PipeBoundary(sandbox_dir)
        self.boundary.create_channel(SANDBOX_PROCESS_CHANNEL)
        self.boundary.create_queue(SANDBOX_REQUEST_QUEUE)

        self.done = False
        self.process = None

        self.process_exit_stack = ExitStack()

    def run(self):
        logger.debug("starting process...")

        with self.boundary.open_channel(SANDBOX_PROCESS_CHANNEL,
                                        PipeBoundarySide.SERVER) as pipes:
            connection = SandboxProcessConnection(**pipes)
            self.process = self.process_exit_stack.enter_context(
                self.executable.run(connection))

        logger.debug("process started")

        while not self.done:
            logger.debug("handling requests...")
            self.boundary.handle_request(
                SANDBOX_REQUEST_QUEUE,
                self.handle_request,
            )

    def handle_request(self, *, wait):
        assert not self.done
        assert wait in ("0", "1")

        time_usage = self.executable.get_time_usage(self.process)
        memory_usage = self.executable.get_memory_usage(self.process)

        message = stacktrace = ""
        if wait == "1":
            self.done = True
            self.process = None
            try:
                self.process_exit_stack.close()
            except AlgorithmRuntimeError as e:
                message, stacktrace = e.args

        return {
            "error": message,
            "stacktrace": stacktrace,
            "time_usage": str(time_usage),
            "memory_usage": str(memory_usage),
        }
Пример #17
0
class SensorExecutionContext:
    """Sensor execution context.

    An instance of this class is made available as the first argument to the evaluation function
    on SensorDefinition.

    Attributes:
        instance_ref (InstanceRef): The serialized instance configured to run the schedule
        last_completion_time (float): The last time that the sensor was evaluated (UTC).
        last_run_key (str): The run key of the RunRequest most recently created by this sensor.
    """

    __slots__ = [
        "_instance_ref",
        "_last_completion_time",
        "_last_run_key",
        "_exit_stack",
        "_instance",
    ]

    def __init__(self, instance_ref, last_completion_time, last_run_key):
        self._exit_stack = ExitStack()
        self._instance = None

        self._instance_ref = check.inst_param(instance_ref, "instance_ref",
                                              InstanceRef)
        self._last_completion_time = check.opt_float_param(
            last_completion_time, "last_completion_time")
        self._last_run_key = check.opt_str_param(last_run_key, "last_run_key")

        self._instance = None

    def __enter__(self):
        return self

    def __exit__(self, _exception_type, _exception_value, _traceback):
        self._exit_stack.close()

    @property
    def instance(self):
        if not self._instance:
            self._instance = self._exit_stack.enter_context(
                DagsterInstance.from_ref(self._instance_ref))
        return self._instance

    @property
    def last_completion_time(self):
        return self._last_completion_time

    @property
    def last_run_key(self):
        return self._last_run_key
Пример #18
0
class PulpWritter(object):
    """Use this to create a pulp db."""
    def __init__(self, db_name, msg_dumper=None, idx_dumpers=None):
        self.dir_path = os.path.abspath(db_name)
        self.keys_path = os.path.join(self.dir_path, 'keys')
        if os.path.isdir(self.dir_path):
            shutil.rmtree(self.dir_path)
        os.makedirs(self.dir_path)
        os.makedirs(self.keys_path)

        self.master_table = None
        self.key_tables = {}
        self.table_stack = None
        self.msg_dumper = msg_dumper
        self.idx_dumpers = idx_dumpers
    
    def __enter__(self):
        self.table_stack = ExitStack().__enter__()
        table = MasterTable(self.dir_path, 'w', dumper=self.msg_dumper)
        self.add_table(None, table)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.table_stack is not None:
            self.table_stack.close()

    def add_table(self, key, table):
        if key == None:
            self.master_table = table
        else:
            self.key_tables[key] = table
        self.table_stack.enter_context(table)

    def append(self, data, index_map):
        msg_num = self.master_table.append(data)
        for key, value in index_map.items():
            table = self.key_tables.get(key, None)
            if not table:
                if self.idx_dumpers is not None:
                    dumper = self.idx_dumpers.get(key, None)
                else:
                    dumper = None
                table = KeyTable(self.keys_path, key, 'w', dumper=dumper)
                self.add_table(key, table)

            if isinstance(value, (tuple, list, set)):
                for v in value:
                    table.append(v, msg_num)
            else:
                table.append(value, msg_num)
Пример #19
0
class ZMQComm:
    """
    What is hidden inside ZMQComm?...
    """
    def __init__(self, addr):
        self.addr = addr

    def __enter__(self):
        self.exit_stack = ExitStack()

        self.ctx = self.exit_stack.enter_context(zmq_context())
        self.recv_socket = self.exit_stack.enter_context(
            closing(self.ctx.socket(zmq.PULL)))
        # self.recv_socket.setsockopt(zmq.SUBSCRIBE,b'')
        self.recv_socket.bind(self.addr)

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.exit_stack.close()

    def connect(self, addr):
        send_socket = self.exit_stack.enter_context(
            closing(self.ctx.socket(zmq.PUSH)))
        send_socket.setsockopt(zmq.LINGER, 0)
        send_socket.connect(addr)
        return send_socket

    def disconnect(self, socket):
        pass
        # TODO:???
        # socket.close()

    def send(self, socket, data):
        # TODO: allow multipart messages
        socket.send_multipart([data])

    # TODO: remove once client has its own thread
    async def poll(self):
        return await self.recv_socket.poll(timeout=0)

    async def recv(self):
        # TODO: make 2 milliseconds a config parameter
        if (await self.recv_socket.poll(timeout=2)) == 0:
            return None
        full_msg = await self.recv_socket.recv_multipart()
        msg = full_msg[0]
        if len(msg) == 0:
            return None
        return msg
Пример #20
0
class BaseGethChain(Chain):
    stack = None
    geth = None

    def initialize_chain(self):
        # context manager shenanigans
        self.stack = ExitStack()
        self.geth = self.get_geth_process_instance()

    @property
    def geth_kwargs(self):
        return self.config.get('chain.settings', {})

    @property
    def has_registrar(self):
        return 'registrar' in self.config

    @cached_property
    def registrar(self):
        if not self.has_registrar:
            raise KeyError(
                "The configuration for the {0} chain does not include a "
                "registrar.  Please set this value to the address of the "
                "deployed registrar contract.".format(self.chain_name))
        return get_registrar(
            self.web3,
            address=self.config['registrar'],
        )

    def get_geth_process_instance(self):
        raise NotImplementedError("Must be implemented by subclasses")

    def __enter__(self, *args, **kwargs):
        self.stack.enter_context(self.geth)

        if self.geth.is_mining:
            self.geth.wait_for_dag(600)
        if self.geth.ipc_enabled:
            self.geth.wait_for_ipc(60)
        if self.geth.rpc_enabled:
            self.geth.wait_for_rpc(60)

        self._running = True

        return self

    def __exit__(self, *exc_info):
        self.stack.close()
        self._running = False
Пример #21
0
class BaseGethChain(BaseChain):
    stack = None
    geth = None

    def initialize_chain(self):
        # context manager shenanigans
        self.stack = ExitStack()
        self.geth = self.get_geth_process_instance()

    def get_web3_config(self):
        base_config = super(BaseGethChain, self).get_web3_config()
        config = copy.deepcopy(base_config)
        if not config.get('provider.settings'):
            if issubclass(base_config.provider_class, IPCProvider):
                config['provider.settings.ipc_path'] = self.geth.ipc_path
            elif issubclass(base_config.provider_class, HTTPProvider):
                config[
                    'provider.settings.endpoint_uri'] = "http://127.0.0.1:{0}".format(
                        self.geth.rpc_port, )
            else:
                raise ValueError("Unknown provider type")
        return config

    @property
    def geth_kwargs(self):
        return self.config.get('chain.settings', {})

    def get_geth_process_instance(self):
        raise NotImplementedError("Must be implemented by subclasses")

    def __enter__(self, *args, **kwargs):
        self.stack.enter_context(self.geth)

        if self.geth.is_mining:
            self.geth.wait_for_dag(600)
        if self.geth.ipc_enabled:
            self.geth.wait_for_ipc(60)
        if self.geth.rpc_enabled:
            self.geth.wait_for_rpc(60)

        self._running = True

        return self

    def __exit__(self, *exc_info):
        self.stack.close()
        self._running = False
Пример #22
0
class MockServer:
    def __init__(self, server):
        # type: (_MockServer) -> None
        self._server = server
        self._running = False
        self.context = ExitStack()

    @property
    def port(self):
        return self._server.port

    @property
    def host(self):
        return self._server.host

    def set_responses(self, responses):
        # type: (Iterable[Responder]) -> None
        assert not self._running, "responses cannot be set on running server"
        self._server.mock.side_effect = responses

    def start(self):
        # type: () -> None
        assert not self._running, "running server cannot be started"
        self.context.enter_context(server_running(self._server))
        self.context.enter_context(self._set_running())

    @contextmanager
    def _set_running(self):
        self._running = True
        try:
            yield
        finally:
            self._running = False

    def stop(self):
        # type: () -> None
        assert self._running, "idle server cannot be stopped"
        self.context.close()

    def get_requests(self):
        # type: () -> Dict[str, str]
        """Get environ for each received request.
        """
        assert not self._running, "cannot get mock from running server"
        # Legacy: replace call[0][0] with call.args[0]
        # when pip drops support for python3.7
        return [call[0][0] for call in self._server.mock.call_args_list]
Пример #23
0
class DeviceScope(object):
    def __init__(self, index, stage, name_scope=None):
        self.index = index
        self.stage = stage
        self.name_scope = name_scope

    def __enter__(self):
        self.stack = ExitStack()
        self.stack.enter_context(
            paddle.static.ipu_shard_guard(index=self.index, stage=self.stage))
        if self.name_scope is not None:
            self.stack.enter_context(paddle.static.name_scope(self.name_scope))
        return self

    def __exit__(self, *exp):
        self.stack.close()
        return False
Пример #24
0
class Progress:
    def __init__(self, case):
        self._case       = case
        self._current    = None 
        self._processes  = []
        self._contexts   = []
        self._extensions = []
        self._stack      = ExitStack()
        
    @property
    def case(self):
        return self._case
    
    @property
    def n_proc(self):
        return len(self._processes)

    # decorator
    def add_proc(self, cm_func=None, *args, **kwargs):
        self._contexts.append((cm_func, args, kwargs))
        def _add_proc(proc_func):
            # proc_func: (Progress) -> Job, checker
            # これをテストで定義すべし
            def wrapper(progress):
                job, checker = proc_func(progress)
                if self._current is None:
                    state = checker.check(job)
                else:
                    state = self._current.execute(job, checker)
                self._current = state
            self._processes.append(wrapper)
        return _add_proc
    
    # process関連の処理をProcessクラスへ分離すべきか
    def advance(self):
        for _ in iter(self._case):
            for i, proc in enumerate(self._processes):
                cm_func, args, kwargs = self._contexts[i]
                if cm_func is not None:
                    self._stack.enter_context(cm_func(*args, **kwargs))
                proc(self)
            self._current = None
            self._stack.close()

    def extend(self, extension):
        self._extensions.append(extension)
Пример #25
0
class ScheduleExecutionContext:
    """Schedule-specific execution context.

    An instance of this class is made available as the first argument to various ScheduleDefinition
    functions. It is passed as the first argument to ``run_config_fn``, ``tags_fn``,
    and ``should_execute``.

    Attributes:
        instance_ref (InstanceRef): The serialized instance configured to run the schedule
        scheduled_execution_time (datetime):
            The time in which the execution was scheduled to happen. May differ slightly
            from both the actual execution time and the time at which the run config is computed.
            Not available in all schedulers - currently only set in deployments using
            DagsterDaemonScheduler.
    """

    __slots__ = [
        "_instance_ref", "_scheduled_execution_time", "_exit_stack",
        "_instance"
    ]

    def __init__(self, instance_ref, scheduled_execution_time):
        self._exit_stack = ExitStack()
        self._instance = None

        self._instance_ref = check.inst_param(instance_ref, "instance_ref",
                                              InstanceRef)
        self._scheduled_execution_time = check.opt_inst_param(
            scheduled_execution_time, "scheduled_execution_time", datetime)

    def __enter__(self):
        return self

    def __exit__(self, _exception_type, _exception_value, _traceback):
        self._exit_stack.close()

    @property
    def instance(self):
        if not self._instance:
            self._instance = self._exit_stack.enter_context(
                DagsterInstance.from_ref(self._instance_ref))
        return self._instance

    @property
    def scheduled_execution_time(self):
        return self._scheduled_execution_time
Пример #26
0
def download_to_memory(
        url: str,
        show_progress: bool = False) -> Generator[BytesIO, None, None]:
    stack = ExitStack()
    try:
        response = stack.enter_context(urllib.request.urlopen(url))
        try:
            total = int(response.getheader("Content-Length"))
        except (TypeError, ValueError):
            total = None
        if show_progress:
            response = stack.enter_context(
                tqdm.wrapattr(response, "read", total=total))
        memory_stream = stack.enter_context(BytesIO())
        shutil.copyfileobj(response, memory_stream)
        yield memory_stream
    finally:
        stack.close()
Пример #27
0
class Triggers:
    def __init__(self, triggers_config: TriggerConfig, report: Report) -> None:
        self.panic_trigger = PanicTrigger(
            global_commands=triggers_config.global_commands.panic,
            temp_commands={
                temp_name: actions.panic
                for temp_name, actions in
                triggers_config.temp_commands.items()
            },
            report=report,
        )
        self.threshold_trigger = ThresholdTrigger(
            global_commands=triggers_config.global_commands.threshold,
            temp_commands={
                temp_name: actions.threshold
                for temp_name, actions in
                triggers_config.temp_commands.items()
            },
            report=report,
        )
        self._stack: Optional[ExitStack] = None

    def __enter__(self):  # reusable
        self._stack = ExitStack()
        try:
            self._stack.enter_context(self.panic_trigger)
            self._stack.enter_context(self.threshold_trigger)
        except Exception:
            self._stack.close()
            raise
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        assert self._stack is not None
        self._stack.close()
        return None

    @property
    def is_alerting(self) -> bool:
        return self.panic_trigger.is_alerting or self.threshold_trigger.is_alerting

    def check(self, temps: Mapping[TempName, Optional[TempStatus]]) -> None:
        self.panic_trigger.check(temps)
        self.threshold_trigger.check(temps)
Пример #28
0
class DeviceScope(object):
    def __init__(self,
                 builder,
                 execution_mode=ExecutionMode.DEFAULT,
                 virtualGraph=None,
                 pipelineStage=None,
                 executionPhase=None,
                 nameScope=None,
                 additional_scopes=None):
        self.builder = builder
        self.execution_mode = execution_mode
        self.virtualGraph = virtualGraph
        self.pipelineStage = pipelineStage
        self.executionPhase = executionPhase
        self.nameScope = nameScope
        self.additional_scopes = additional_scopes or []

    def __enter__(self):
        self.stack = ExitStack()
        # ExecutionPhase will automatically set the virtualGraph attributes based on execution phase
        if self.execution_mode != ExecutionMode.PHASED \
                and self.virtualGraph is not None:
            self.stack.enter_context(
                self.builder.virtualGraph(self.virtualGraph))

        if self.execution_mode == ExecutionMode.PIPELINE\
                and self.pipelineStage is not None:
            self.stack.enter_context(
                self.builder.pipelineStage(self.pipelineStage))

        if self.execution_mode == ExecutionMode.PHASED\
                and self.executionPhase is not None:
            self.stack.enter_context(
                self.builder.executionPhase(self.executionPhase))

        if self.nameScope is not None:
            self.stack.enter_context(self.builder.nameScope(self.nameScope))
        for scope in self.additional_scopes:
            self.stack.enter_context(scope)
        return self

    def __exit__(self, *exp):
        self.stack.close()
        return False
Пример #29
0
class ResolveSource:
    """Context manager that handles all the different input sources and provides a path to a local file/directory."""
    def __init__(self, src: Union[str, Path]) -> None:
        """Performs the needed combination of downloading/extracting source files.

        :param src: HTTP(S) URL or local path to the OVPN profiles to process.
        """
        self.exit_stack = ExitStack()

        src_as_str = str(src)
        if re.search(r"^[a-zA-Z0-9]+://", src_as_str):
            if src_as_str.startswith("http://") or src_as_str.startswith(
                    "https://"):
                logger.debug("Determined source was remote file, downloading")
                self.path = self.exit_stack.enter_context(
                    HandleDownload(src_as_str))
            else:
                raise ValueError("Only HTTP(S) supported as remote protocol")
        else:
            self.path = Path(src)

        if self.path.is_file():
            try:
                self.path = self.exit_stack.enter_context(HandleZip(self.path))
            except zipfile.BadZipFile:
                # Make an assumption that our thing is a non-zip file
                pass
        elif not self.path.is_dir():
            raise ValueError("Path does not exist")

    def __enter__(self) -> Path:
        """Context manager __enter__ function.

        :return: Resolved path to either a local file or directory containing OVPN profile(s).
        """
        return self.path

    def __exit__(self, exc_type: Optional[Type[BaseException]],
                 exc_val: Optional[BaseException],
                 exc_tb: Optional[TracebackType]) -> Literal[False]:
        """Cleans up all the temporary files created by calling child context managers."""
        self.exit_stack.close()
        return False
Пример #30
0
 def updateParamFromPen(param, pen):
     """
     Applies settings from a pen to either a Parameter or dict. The Parameter or dict must already
     be populated with the relevant keys that can be found in `PenSelectorDialog.mkParam`.
     """
     stack = ExitStack()
     if isinstance(param, Parameter):
         names = param.names
         # Block changes until all are finalized
         stack.enter_context(param.treeChangeBlocker())
     else:
         names = param
     for opt in names:
         # Booleans have different naming convention
         if isinstance(param[opt], bool):
             attrName = f'is{opt.title()}'
         else:
             attrName = opt
         param[opt] = getattr(pen, attrName)()
     stack.close()
Пример #31
0
class hosts(ContextDecorator):
    def __init__(self, hosts, only=None):
        self.hosts = hosts
        if only is None:
            self.patchs = ["gethostbyname", "gethostbyname_ex", "getaddrinfo"]
        else:
            self.patchs = only
        self.exit_stack = ExitStack()

    def __enter__(self):
        if "gethostbyname" in self.patchs:
            self.exit_stack.enter_context(patch_gethostbyname(self.hosts))
        if "gethostbyname_ex" in self.patchs:
            self.exit_stack.enter_context(patch_gethostbyname_ex(self.hosts))
        if "getaddrinfo" in self.patchs:
            self.exit_stack.enter_context(patch_getaddrinfo(self.hosts))
        return self

    def __exit__(self, *exc):
        self.exit_stack.close()
        return False
Пример #32
0
class ScienceBeamParserBaseSession:
    def __init__(
        self,
        parser: 'ScienceBeamParser',
        temp_dir: Optional[str] = None,
        fulltext_processor_config: Optional[FullTextProcessorConfig] = None,
        document_request_parameters: Optional[DocumentRequestParameters] = None
    ):
        self.parser = parser
        self.exit_stack = ExitStack()
        self._temp_dir: Optional[str] = temp_dir
        if fulltext_processor_config is None:
            fulltext_processor_config = parser.fulltext_processor_config
        self.fulltext_processor_config = fulltext_processor_config
        if document_request_parameters is None:
            document_request_parameters = DocumentRequestParameters()
        self.document_request_parameters = document_request_parameters

    def __enter__(self) -> 'ScienceBeamParserBaseSession':
        return self

    def close(self):
        self.exit_stack.close()

    def __exit__(self, exc, value, tb):
        self.close()

    @property
    def temp_dir(self) -> str:
        if not self._temp_dir:
            temp_dir_context = TemporaryDirectory(  # pylint: disable=consider-using-with
                suffix='-sb-parser')
            self.exit_stack.push(temp_dir_context)
            self._temp_dir = temp_dir_context.__enter__()
        return self._temp_dir

    @property
    def temp_path(self) -> Path:
        return Path(self.temp_dir)
Пример #33
0
class PulpReader(object):
    """
    db.idx 
      db.idx(stuff) == KeyQuery or vQuery
    
    db.stream = StreamQuery
        db.stream(stuff) == StreamQuery or vQuery
    
    db.vQuery


    # All three main query objects support the same api.

    KeyQuery
    StreamQuery
    vQuery

    Thus they can be merged.

    """
    def __init__(self, db_name, msg_dumper=None, msg_loader=None, idx_dumpers=None, idx_loaders=None):
        self.dir_path = os.path.abspath(db_name)
        self.keys_path = os.path.join(self.dir_path, 'keys')
        if not all(os.path.exists(p) for p in [self.dir_path, self.keys_path]):
            print("Missing directory: one of {}".format([self.dir_path, 
                                                         self.keys_path]))
        self.table_stack = None
        self.master_table = None
        self.key_tables = {}
        self.msg_dumper = msg_dumper
        self.msg_loader = msg_loader
        self.idx_dumpers = idx_dumpers
        self.idx_loaders = idx_loaders
    
    @property
    def stream(self):
        return StreamQuery(self.master_table)

    @property
    def idx(self):
        return FieldQueryDispatch(self.master_table, self.key_tables)

    def __enter__(self):
        self.table_stack = ExitStack().__enter__()
        table = MasterTable(self.dir_path, 'r', dumper=self.msg_dumper ,loader=self.msg_loader)
        self._add_table(None, table)
        
        keys = [os.path.splitext(f)[0] for f in os.listdir(self.keys_path) if f.endswith('.meta')] ## More needed here.  More than just .meta files required.
        for key in keys:
            if self.idx_dumpers is not None:
                dumper = self.idx_dumpers.get(key, None)
            else:
                dumper = None
            if self.idx_loaders is not None:
                loader = self.idx_loaders.get(key, None)
            else:
                loader = None
            table = KeyTable(self.keys_path, key, 'r', dumper=dumper, loader=loader)
            self._add_table(key, table)
        return self

    def _add_table(self, key, table):
        if key == None:
            self.master_table = table
        else:
            self.key_tables[key] = table
        self.table_stack.enter_context(table)

    def __exit__(self, exc_type, exc_value, traceback):
        if self.table_stack is not None:
            self.table_stack.close()

    def __len__(self):
        """Dispatches to stream"""
        return len(self.stream)

    def __iter__(self):
        return iter(self.stream)

    def __getitem__(self, index):
        """Dispatches to stream"""
        if isinstance(index, int):
            return self.stream[index]
        elif isinstance(index, slice):
            return self.stream[index]
        else:
            raise NotImplementedError("index")
Пример #34
0
class FilesystemMockingTestCase(ResourceUsingTestCase):
    def setUp(self):
        super(FilesystemMockingTestCase, self).setUp()
        self.patched_funcs = ExitStack()

    def tearDown(self):
        self.patched_funcs.close()
        ResourceUsingTestCase.tearDown(self)

    def replicateTestRoot(self, example_root, target_root):
        real_root = self.resourceLocation()
        real_root = os.path.join(real_root, 'roots', example_root)
        for (dir_path, _dirnames, filenames) in os.walk(real_root):
            real_path = dir_path
            make_path = rebase_path(real_path[len(real_root):], target_root)
            util.ensure_dir(make_path)
            for f in filenames:
                real_path = util.abs_join(real_path, f)
                make_path = util.abs_join(make_path, f)
                shutil.copy(real_path, make_path)

    def patchUtils(self, new_root):
        patch_funcs = {
            util: [('write_file', 1),
                   ('append_file', 1),
                   ('load_file', 1),
                   ('ensure_dir', 1),
                   ('chmod', 1),
                   ('delete_dir_contents', 1),
                   ('del_file', 1),
                   ('sym_link', -1),
                   ('copy', -1)],
        }
        for (mod, funcs) in patch_funcs.items():
            for (f, am) in funcs:
                func = getattr(mod, f)
                trap_func = retarget_many_wrapper(new_root, am, func)
                self.patched_funcs.enter_context(
                    mock.patch.object(mod, f, trap_func))

        # Handle subprocess calls
        func = getattr(util, 'subp')

        def nsubp(*_args, **_kwargs):
            return ('', '')

        self.patched_funcs.enter_context(
            mock.patch.object(util, 'subp', nsubp))

        def null_func(*_args, **_kwargs):
            return None

        for f in ['chownbyid', 'chownbyname']:
            self.patched_funcs.enter_context(
                mock.patch.object(util, f, null_func))

    def patchOS(self, new_root):
        patch_funcs = {
            os.path: [('isfile', 1), ('exists', 1),
                      ('islink', 1), ('isdir', 1)],
            os: [('listdir', 1), ('mkdir', 1),
                 ('lstat', 1), ('symlink', 2)],
        }
        for (mod, funcs) in patch_funcs.items():
            for f, nargs in funcs:
                func = getattr(mod, f)
                trap_func = retarget_many_wrapper(new_root, nargs, func)
                self.patched_funcs.enter_context(
                    mock.patch.object(mod, f, trap_func))

    def patchOpen(self, new_root):
        trap_func = retarget_many_wrapper(new_root, 1, open)
        name = 'builtins.open' if PY3 else '__builtin__.open'
        self.patched_funcs.enter_context(mock.patch(name, trap_func))

    def patchStdoutAndStderr(self, stdout=None, stderr=None):
        if stdout is not None:
            self.patched_funcs.enter_context(
                mock.patch.object(sys, 'stdout', stdout))
        if stderr is not None:
            self.patched_funcs.enter_context(
                mock.patch.object(sys, 'stderr', stderr))
Пример #35
0
    class TestCase(unittest.TestCase):
        def setUp(self):
            super(TestCase, self).setUp()
            self.__all_cleanups = ExitStack()

        def tearDown(self):
            self.__all_cleanups.close()
            unittest.TestCase.tearDown(self)

        def addCleanup(self, function, *args, **kws):
            self.__all_cleanups.callback(function, *args, **kws)

        def assertIs(self, expr1, expr2, msg=None):
            if expr1 is not expr2:
                standardMsg = '%r is not %r' % (expr1, expr2)
                self.fail(self._formatMessage(msg, standardMsg))

        def assertIn(self, member, container, msg=None):
            if member not in container:
                standardMsg = '%r not found in %r' % (member, container)
                self.fail(self._formatMessage(msg, standardMsg))

        def assertNotIn(self, member, container, msg=None):
            if member in container:
                standardMsg = '%r unexpectedly found in %r'
                standardMsg = standardMsg % (member, container)
                self.fail(self._formatMessage(msg, standardMsg))

        def assertIsNone(self, value, msg=None):
            if value is not None:
                standardMsg = '%r is not None'
                standardMsg = standardMsg % (value)
                self.fail(self._formatMessage(msg, standardMsg))

        def assertIsInstance(self, obj, cls, msg=None):
            """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
            default message."""
            if not isinstance(obj, cls):
                standardMsg = '%s is not an instance of %r' % (repr(obj), cls)
                self.fail(self._formatMessage(msg, standardMsg))

        def assertDictContainsSubset(self, expected, actual, msg=None):
            missing = []
            mismatched = []
            for k, v in expected.items():
                if k not in actual:
                    missing.append(k)
                elif actual[k] != v:
                    mismatched.append('%r, expected: %r, actual: %r'
                                      % (k, v, actual[k]))

            if len(missing) == 0 and len(mismatched) == 0:
                return

            standardMsg = ''
            if missing:
                standardMsg = 'Missing: %r' % ','.join(m for m in missing)
            if mismatched:
                if standardMsg:
                    standardMsg += '; '
                standardMsg += 'Mismatched values: %s' % ','.join(mismatched)

            self.fail(self._formatMessage(msg, standardMsg))
Пример #36
0
class NosePlugin(Plugin):
    configSection = 'ubuntu-image'
    snap_mocker = None

    def __init__(self):
        super().__init__()
        self.patterns = []
        self.addArgument(self.patterns, 'P', 'pattern',
                         'Add a test matching pattern')

    def getTestCaseNames(self, event):
        if len(self.patterns) == 0:
            # No filter patterns, so everything should be tested.
            return
        # Does the pattern match the fully qualified class name?
        for pattern in self.patterns:
            full_class_name = '{}.{}'.format(
                event.testCase.__module__, event.testCase.__name__)
            if re.search(pattern, full_class_name):
                # Don't suppress this test class.
                return
        names = filter(event.isTestMethod, dir(event.testCase))
        for name in names:
            full_test_name = '{}.{}.{}'.format(
                event.testCase.__module__,
                event.testCase.__name__,
                name)
            for pattern in self.patterns:
                if re.search(pattern, full_test_name):
                    break
            else:
                event.excludedNames.append(name)

    def handleFile(self, event):
        path = event.path[len(TOPDIR)+1:]
        if len(self.patterns) > 0:
            for pattern in self.patterns:
                if re.search(pattern, path):
                    break
            else:
                # Skip this doctest.
                return
        base, ext = os.path.splitext(path)
        if ext != '.rst':
            return
        test = doctest.DocFileTest(
            path, package='ubuntu_image',
            optionflags=FLAGS,
            setUp=setup,
            tearDown=teardown)
        # Suppress the extra "Doctest: ..." line.
        test.shortDescription = lambda: None
        event.extraTests.append(test)

    def startTestRun(self, event):
        # Create a mock for the `sudo snap prepare-image` command.  This is an
        # expensive command which hits the actual snap store.  We want this to
        # run at least once so we know our tests are valid.  We can cache the
        # results in a test-suite-wide temporary directory and simulate future
        # calls by just recursively copying the contents to the specified
        # directories.
        #
        # It's a bit more complicated than that though, because it's possible
        # that the channel and model.assertion will be different, so we need
        # to make the cache dependent on those values.
        #
        # Finally, to enable full end-to-end tests, check an environment
        # variable to see if the mocking should even be done.  This way, we
        # can make our Travis-CI job do at least one real end-to-end test.
        self.resources = ExitStack()
        # How should we mock `snap prepare-image`?  If set to 'always' (case
        # insensitive), then use the sample data in the .zip file.  Any other
        # truthy value says to use a second-and-onward mock.
        should_we_mock = os.environ.get('UBUNTU_IMAGE_MOCK_SNAP', 'yes')
        if should_we_mock.lower() == 'always':
            mock_class = AlwaysMock
        elif as_bool(should_we_mock):
            mock_class = SecondAndOnwardMock
        else:
            mock_class = None
        if mock_class is not None:
            tmpdir = self.resources.enter_context(TemporaryDirectory())
            # Record the actual snap mocker on the class so that other tests
            # can temporarily disable it.  Some tests need to run the actual
            # snap() helper function.
            self.__class__.snap_mocker = self.resources.enter_context(
                mock_class(tmpdir))

    def stopTestRun(self, event):
        self.resources.close()
Пример #37
0
class RateLimiterTestCase(unittest.TestCase):

    def setUp(self):
        self._stack = ExitStack()
        self.mock_clock = self._stack.enter_context(
            patch.object(RateLimiter, '_clock'))
        self.mock_sleep = self._stack.enter_context(
            patch.object(RateLimiter, '_sleep'))
        self.mock_func = MagicMock()

    def tearDown(self):
        self._stack.close()

    def test_min_delay(self):
        min_delay = 3.5

        self.mock_clock.side_effect = [1]
        rl = RateLimiter(self.mock_func, min_delay_seconds=min_delay)

        # First call -- no delay
        clock_first = 10
        self.mock_clock.side_effect = [clock_first, clock_first]  # no delay here
        rl(sentinel.arg, kwa=sentinel.kwa)
        self.mock_sleep.assert_not_called()
        self.mock_func.assert_called_once_with(sentinel.arg, kwa=sentinel.kwa)

        # Second call after min_delay/3 seconds -- should be delayed
        clock_second = clock_first + (min_delay / 3)
        self.mock_clock.side_effect = [clock_second, clock_first + min_delay]
        rl(sentinel.arg, kwa=sentinel.kwa)
        self.mock_sleep.assert_called_with(min_delay - (clock_second - clock_first))
        self.mock_sleep.reset_mock()

        # Third call after min_delay*2 seconds -- no delay again
        clock_third = clock_first + min_delay + min_delay * 2
        self.mock_clock.side_effect = [clock_third, clock_third]
        rl(sentinel.arg, kwa=sentinel.kwa)
        self.mock_sleep.assert_not_called()

    def test_max_retries(self):
        self.mock_clock.return_value = 1
        rl = RateLimiter(self.mock_func, max_retries=3,
                         return_value_on_exception=sentinel.return_value)

        # Non-geopy errors must not be swallowed
        self.mock_func.side_effect = ValueError
        with self.assertRaises(ValueError):
            rl(sentinel.arg)
        self.assertEqual(1, self.mock_func.call_count)
        self.mock_func.reset_mock()

        # geopy errors must be swallowed and retried
        self.mock_func.side_effect = GeocoderServiceError
        self.assertEqual(sentinel.return_value, rl(sentinel.arg))
        self.assertEqual(4, self.mock_func.call_count)
        self.mock_func.reset_mock()

        # Successful value must be returned
        self.mock_func.side_effect = [
            GeocoderServiceError, GeocoderServiceError, sentinel.good
        ]
        self.assertEqual(sentinel.good, rl(sentinel.arg))
        self.assertEqual(3, self.mock_func.call_count)
        self.mock_func.reset_mock()

        # When swallowing is disabled, the exception must be raised
        rl.swallow_exceptions = False
        self.mock_func.side_effect = GeocoderQuotaExceeded
        with self.assertRaises(GeocoderQuotaExceeded):
            rl(sentinel.arg)
        self.assertEqual(4, self.mock_func.call_count)
        self.mock_func.reset_mock()

    def test_error_wait_seconds(self):
        error_wait = 3.3

        self.mock_clock.return_value = 1
        rl = RateLimiter(self.mock_func, max_retries=3,
                         error_wait_seconds=error_wait,
                         return_value_on_exception=sentinel.return_value)

        self.mock_func.side_effect = GeocoderServiceError
        self.assertEqual(sentinel.return_value, rl(sentinel.arg))
        self.assertEqual(4, self.mock_func.call_count)
        self.assertEqual(3, self.mock_sleep.call_count)
        self.mock_sleep.assert_called_with(error_wait)
        self.mock_func.reset_mock()
Пример #38
0
class TestAzureBounce(TestCase):

    def mock_out_azure_moving_parts(self):
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'invoke_agent'))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'wait_for_files'))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs',
                              mock.MagicMock(return_value=[])))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure,
                              'find_fabric_formatted_ephemeral_disk',
                              mock.MagicMock(return_value=None)))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure,
                              'find_fabric_formatted_ephemeral_part',
                              mock.MagicMock(return_value=None)))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric',
                              mock.MagicMock(return_value={})))
        self.patches.enter_context(
            mock.patch.object(DataSourceAzure.util, 'read_dmi_data',
                              mock.MagicMock(return_value='test-instance-id')))

    def setUp(self):
        super(TestAzureBounce, self).setUp()
        self.tmp = tempfile.mkdtemp()
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
        self.addCleanup(shutil.rmtree, self.tmp)
        DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
        self.patches = ExitStack()
        self.mock_out_azure_moving_parts()
        self.get_hostname = self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'get_hostname'))
        self.set_hostname = self.patches.enter_context(
            mock.patch.object(DataSourceAzure, 'set_hostname'))
        self.subp = self.patches.enter_context(
            mock.patch('cloudinit.sources.DataSourceAzure.util.subp'))

    def tearDown(self):
        self.patches.close()

    def _get_ds(self, ovfcontent=None):
        if ovfcontent is not None:
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
                         {'ovf-env.xml': ovfcontent})
        return DataSourceAzure.DataSourceAzureNet(
            {}, distro=None, paths=self.paths)

    def get_ovf_env_with_dscfg(self, hostname, cfg):
        odata = {
            'HostName': hostname,
            'dscfg': {
                'text': b64e(yaml.dump(cfg)),
                'encoding': 'base64'
            }
        }
        return construct_valid_ovf_env(data=odata)

    def test_disabled_bounce_does_not_change_hostname(self):
        cfg = {'hostname_bounce': {'policy': 'off'}}
        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
        self.assertEqual(0, self.set_hostname.call_count)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_disabled_bounce_does_not_perform_bounce(
            self, perform_hostname_bounce):
        cfg = {'hostname_bounce': {'policy': 'off'}}
        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
        self.assertEqual(0, perform_hostname_bounce.call_count)

    def test_same_hostname_does_not_change_hostname(self):
        host_name = 'unchanged-host-name'
        self.get_hostname.return_value = host_name
        cfg = {'hostname_bounce': {'policy': 'yes'}}
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
        self.assertEqual(0, self.set_hostname.call_count)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_unchanged_hostname_does_not_perform_bounce(
            self, perform_hostname_bounce):
        host_name = 'unchanged-host-name'
        self.get_hostname.return_value = host_name
        cfg = {'hostname_bounce': {'policy': 'yes'}}
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
        self.assertEqual(0, perform_hostname_bounce.call_count)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_force_performs_bounce_regardless(self, perform_hostname_bounce):
        host_name = 'unchanged-host-name'
        self.get_hostname.return_value = host_name
        cfg = {'hostname_bounce': {'policy': 'force'}}
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
        self.assertEqual(1, perform_hostname_bounce.call_count)

    def test_different_hostnames_sets_hostname(self):
        expected_hostname = 'azure-expected-host-name'
        self.get_hostname.return_value = 'default-host-name'
        self._get_ds(
            self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
        self.assertEqual(expected_hostname,
                         self.set_hostname.call_args_list[0][0][0])

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_different_hostnames_performs_bounce(
            self, perform_hostname_bounce):
        expected_hostname = 'azure-expected-host-name'
        self.get_hostname.return_value = 'default-host-name'
        self._get_ds(
            self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
        self.assertEqual(1, perform_hostname_bounce.call_count)

    def test_different_hostnames_sets_hostname_back(self):
        initial_host_name = 'default-host-name'
        self.get_hostname.return_value = initial_host_name
        self._get_ds(
            self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
        self.assertEqual(initial_host_name,
                         self.set_hostname.call_args_list[-1][0][0])

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_failure_in_bounce_still_resets_host_name(
            self, perform_hostname_bounce):
        perform_hostname_bounce.side_effect = Exception
        initial_host_name = 'default-host-name'
        self.get_hostname.return_value = initial_host_name
        self._get_ds(
            self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
        self.assertEqual(initial_host_name,
                         self.set_hostname.call_args_list[-1][0][0])

    def test_environment_correct_for_bounce_command(self):
        interface = 'int0'
        hostname = 'my-new-host'
        old_hostname = 'my-old-host'
        self.get_hostname.return_value = old_hostname
        cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg(hostname, cfg)
        self._get_ds(data).get_data()
        self.assertEqual(1, self.subp.call_count)
        bounce_env = self.subp.call_args[1]['env']
        self.assertEqual(interface, bounce_env['interface'])
        self.assertEqual(hostname, bounce_env['hostname'])
        self.assertEqual(old_hostname, bounce_env['old_hostname'])

    def test_default_bounce_command_used_by_default(self):
        cmd = 'default-bounce-command'
        DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd
        cfg = {'hostname_bounce': {'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
        self._get_ds(data).get_data()
        self.assertEqual(1, self.subp.call_count)
        bounce_args = self.subp.call_args[1]['args']
        self.assertEqual(cmd, bounce_args)

    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
    def test_set_hostname_option_can_disable_bounce(
            self, perform_hostname_bounce):
        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
        self._get_ds(data).get_data()

        self.assertEqual(0, perform_hostname_bounce.call_count)

    def test_set_hostname_option_can_disable_hostname_set(self):
        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
        self._get_ds(data).get_data()

        self.assertEqual(0, self.set_hostname.call_count)