def __init__(self, clientLabel, kw=None):
        """
    clientLabel: this *relatively short* string will be used to construct the
      temporary database name. It shouldn't contain any characters that would
      make it inappropriate for a database name (no spaces, etc.)
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ManagedTempRepository. Ignored
      when this instance is used as context manager. Defaults to kw=None to
      avoid having it added to the keyword args.
    """
        self._kw = kw

        self._unaffiliatedEngine = collectorsdb.getUnaffiliatedEngine()

        dbNameFromConfig = CollectorsDbConfig().get(
            self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME)
        self.tempDatabaseName = "{original}_{label}_{uid}".format(
            original=dbNameFromConfig, label=clientLabel, uid=uuid.uuid1().hex)

        # Create a Config patch to override the Repository database name
        self._configPatch = ConfigAttributePatch(
            self.REPO_CONFIG_NAME,
            self.REPO_BASE_CONFIG_DIR,
            values=((self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME,
                     self.tempDatabaseName), ))
        self._configPatchApplied = False

        self._attemptedToCreateDatabase = False
コード例 #2
0
  def __init__(self, clientLabel, kw=None):
    """
    clientLabel: this *relatively short* string will be used to construct the
      temporary database name. It shouldn't contain any characters that would
      make it inappropriate for a database name (no spaces, etc.)
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ManagedTempRepository. Ignored
      when this instance is used as context manager. Defaults to kw=None to
      avoid having it added to the keyword args.
    """
    self._kw = kw

    self._unaffiliatedEngine = repository.getUnaffiliatedEngine(
      htmengine.APP_CONFIG)

    dbNameFromConfig = htmengine.APP_CONFIG.get(self.REPO_SECTION_NAME,
                                                self.REPO_DATABASE_ATTR_NAME)
    self.tempDatabaseName = "{original}_{label}_{uid}".format(
      original=dbNameFromConfig,
      label=clientLabel,
      uid=uuid.uuid1().hex)

    # Create a Config patch to override the Repository database name
    self._configPatch = ConfigAttributePatch(
      self.REPO_CONFIG_NAME,
      self.REPO_BASE_CONFIG_DIR,
      values=((self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME,
               self.tempDatabaseName),))
    self._configPatchApplied = False

    self._attemptedToCreateDatabase = False
コード例 #3
0
  def testTransientErrorRetryDecorator(self):
    # Setup proxy.  We'll patch config later, so we need to cache the values
    # so that the original proxy may be restarted with the original params
    config = monitorsdb.MonitorsDbConfig()

    originalHost = config.get("repository", "host")
    originalPort = config.getint("repository", "port")

    def _startProxy():
      p = startProxy(originalHost, originalPort, 6033)
      p.next()
      return p

    proxy = _startProxy()
    self.addCleanup(proxy.send, "kill")

    # Patch monitorsdb config with local proxy
    with ConfigAttributePatch(
          config.CONFIG_NAME,
          config.baseConfigDir,
          (("repository", "host", "127.0.0.1"),
           ("repository", "port", "6033"))):

      # Force refresh of engine singleton
      monitorsdb._EngineSingleton._pid = None
      engine = monitorsdb.engineFactory()

      # First, make sure valid query returns expected results
      res = engine.execute("select 1")
      self.assertEqual(res.scalar(), 1)

      @monitorsdb.retryOnTransientErrors
      def _killProxyTryRestartProxyAndTryAgain(n=[]):
        if not n:
          # Kill the proxy on first attempt
          proxy.send("kill")
          proxy.next()
          try:
            engine.execute("select 1")
            self.fail("Proxy did not terminate as expected...")
          except sqlalchemy.exc.OperationalError:
            pass
          n.append(None)
        elif len(n) == 1:
          # Restore proxy in second attempt
          newProxy = _startProxy()
          self.addCleanup(newProxy.send, "kill")
          n.append(None)

        res = engine.execute("select 2")

        return res

      # Try again w/ retry decorator
      result = _killProxyTryRestartProxyAndTryAgain()

      # Verify that the expected value is eventually returned
      self.assertEqual(result.scalar(), 2)
    def start(self):
        assert not self.active

        self._tempParentDir = tempfile.mkdtemp(prefix=self.__class__.__name__)

        self.tempModelCheckpointDir = os.path.join(self._tempParentDir,
                                                   "tempStorageRoot")
        os.mkdir(self.tempModelCheckpointDir)

        self._configPatch = ConfigAttributePatch(
            "model-checkpoint.conf", os.environ.get("APPLICATION_CONFIG_PATH"),
            (("storage", "root", self.tempModelCheckpointDir), ))

        self._configPatch.start()

        self.active = True
        self._logger.info("%s: redirected model checkpoint storage to %s",
                          self.__class__.__name__, self.tempModelCheckpointDir)
コード例 #5
0
  def start(self):
    assert not self.active

    self._tempParentDir = tempfile.mkdtemp(
      prefix=self.__class__.__name__)

    self.tempModelCheckpointDir = os.path.join(self._tempParentDir,
                                               "tempStorageRoot")
    os.mkdir(self.tempModelCheckpointDir)

    self._configPatch = ConfigAttributePatch(
      "model-checkpoint.conf",
      os.environ.get("APPLICATION_CONFIG_PATH"),
      (("storage", "root", self.tempModelCheckpointDir),))

    self._configPatch.start()

    self.active = True
    self._logger.info("%s: redirected model checkpoint storage to %s",
                      self.__class__.__name__, self.tempModelCheckpointDir)
コード例 #6
0
class ManagedTempRepositoryBase(object):
  """Base class for context manager and function decorator that on entry patches
  the respository database name with a unique temp name and creates a temp
  repository; then drops the repository database on exit.

  This effectively redirects repository object transactions to the
  temporary database while in scope of ManagedTempRepository.

  NOTE: this affects repository access in the currently-executing process and
  its descendant processes; it has no impact on processes started externally or
  processes started without inherititing the environment variables of the
  current process.

  Sorry, but there is no class decorator capability provided at this time.

  Context Manager Example::

      with ManagedTempRepository(clientLabel=self.__class__.__name__) as repoCM:
        print repoCM.tempDatabaseName
        <do test logic>

  Function Decorator Example::

      @ManagedTempRepository(clientLabel="testSomething", kw="tempRepoPatch")
      def testSomething(self, tempRepoPatch):
        print tempRepoPatch.tempDatabaseName
        <do test logic>

  """
  __metaclass__ = abc.ABCMeta


  REPO_CONFIG_NAME = htmengine.APP_CONFIG.configName
  REPO_BASE_CONFIG_DIR = htmengine.APP_CONFIG.baseConfigDir
  REPO_SECTION_NAME = "repository"
  REPO_DATABASE_ATTR_NAME = "db"


  def __init__(self, clientLabel, kw=None):
    """
    clientLabel: this *relatively short* string will be used to construct the
      temporary database name. It shouldn't contain any characters that would
      make it inappropriate for a database name (no spaces, etc.)
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ManagedTempRepository. Ignored
      when this instance is used as context manager. Defaults to kw=None to
      avoid having it added to the keyword args.
    """
    self._kw = kw

    self._unaffiliatedEngine = repository.getUnaffiliatedEngine(
      htmengine.APP_CONFIG)

    dbNameFromConfig = htmengine.APP_CONFIG.get(self.REPO_SECTION_NAME,
                                                self.REPO_DATABASE_ATTR_NAME)
    self.tempDatabaseName = "{original}_{label}_{uid}".format(
      original=dbNameFromConfig,
      label=clientLabel,
      uid=uuid.uuid1().hex)

    # Create a Config patch to override the Repository database name
    self._configPatch = ConfigAttributePatch(
      self.REPO_CONFIG_NAME,
      self.REPO_BASE_CONFIG_DIR,
      values=((self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME,
               self.tempDatabaseName),))
    self._configPatchApplied = False

    self._attemptedToCreateDatabase = False


  @abc.abstractmethod
  def initTempDatabase(self):
    """Initialize the temporary repository database with default schema and
    contents
    """
    raise NotImplementedError


  def __enter__(self):
    self.start()
    return self


  def __exit__(self, *args):
    self.stop()
    return False


  def __call__(self, f):
    """ Implement the function decorator """

    @functools.wraps(f)
    def applyTempRepositoryPatch(*args, **kwargs):
      self.start()
      try:
        if self._kw is not None:
          kwargs[self._kw] = self
        return f(*args, **kwargs)
      finally:
        self.stop()

    return applyTempRepositoryPatch


  def start(self):
    # Removes possible left over cached engine
    # (needed if non-patched engine is run prior)
    repository.engineFactory(config=htmengine.APP_CONFIG, reset=True)

    # Override the Repository database name
    try:
      self._configPatch.start()
      self._configPatchApplied = True

      # Now create the temporary repository database
      self._attemptedToCreateDatabase = True
      self.initTempDatabase()

      # Verify that the temporary repository database got created
      numDbFound = self._unaffiliatedEngine.execute(
        "SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE "
        "`SCHEMA_NAME` = '{db}'".format(db=self.tempDatabaseName)).scalar()
      assert numDbFound == 1, (
        "Temp repo db={db} not found (numFound={numFound})".format(
          db=self.tempDatabaseName,
          numFound=numDbFound))
    except:
      # Attempt to clean up
      self.stop()

      raise


  def stop(self):
    try:
      if self._attemptedToCreateDatabase:
        self._attemptedToCreateDatabase = False
        # Drop the temporary repository database, if any
        self._unaffiliatedEngine.execute(
          "DROP DATABASE IF EXISTS {db}".format(db=self.tempDatabaseName))
    finally:
      if self._configPatchApplied:
        self._configPatch.stop()

      repository.engineFactory(config=htmengine.APP_CONFIG, reset=True)

      # Dispose of the unaffiliated engine's connection pool
      self._unaffiliatedEngine.dispose()
コード例 #7
0
class ManagedTempRepository(object):
    """ Context manager that on entry patches the respository database name with
  a unique temp name and creates the repository; then deletes the repository on
  exit.

  This effectively redirects repository object transactions to the
  temporary database while in scope of ManagedTempRepository.

  It may be used as a context manager or as a function decorator (sorry, but
  no class decorator capability at this time)

  Context Manager Example::

      with ManagedTempRepository(clientLabel=self.__class__.__name__) as repoCM:
        print repoCM.tempDatabaseName
        <do test logic>

  Function Decorator Example::

      @ManagedTempRepository(clientLabel="testSomething", kw="tempRepoPatch")
      def testSomething(self, tempRepoPatch):
        print tempRepoPatch.tempDatabaseName
        <do test logic>

  """
    REPO_CONFIG_NAME = config.CONFIG_NAME
    REPO_BASE_CONFIG_DIR = config.baseConfigDir
    REPO_SECTION_NAME = "repository"
    REPO_DATABASE_ATTR_NAME = "db"

    def __init__(self, clientLabel, kw=None):
        """
    clientLabel: this *relatively short* string will be used to construct the
      temporary database name. It shouldn't contain any characters that would
      make it inappropriate for a database name (no spaces, etc.)
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ManagedTempRepository. Ignored
      when this instance is used as context manager. Defaults to kw=None to
      avoid having it added to the keyword args.
    """
        self._kw = kw

        self.tempDatabaseName = "%s_%s_%s" % (self.getDatabaseNameFromConfig(),
                                              clientLabel, uuid.uuid1().hex)

        # Create a Config patch to override the Repository database name
        self._configPatch = ConfigAttributePatch(
            self.REPO_CONFIG_NAME,
            self.REPO_BASE_CONFIG_DIR,
            values=((self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME,
                     self.tempDatabaseName), ))
        self._configPatchApplied = False

        self._attemptedToCreateRepository = False

    @classmethod
    def getDatabaseNameFromConfig(cls):
        return config.get(cls.REPO_SECTION_NAME, cls.REPO_DATABASE_ATTR_NAME)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *args):
        self.stop()
        return False

    def __call__(self, f):
        """ Implement the function decorator """
        @functools.wraps(f)
        def applyTempRepositoryPatch(*args, **kwargs):
            self.start()
            try:
                if self._kw is not None:
                    kwargs[self._kw] = self
                return f(*args, **kwargs)
            finally:
                self.stop()

        return applyTempRepositoryPatch

    def start(self):
        # Removes possible left over cached engine
        # (needed if non-patched engine is run prior)
        repository.engineFactory(reset=True)

        # Override the Repository database name
        try:
            self._configPatch.start()
            self._configPatchApplied = True

            # Verity that the database doesn't exist yet
            assert self.tempDatabaseName not in getAllDatabaseNames(), (
                "Temp repo db=%s already existed" % (self.tempDatabaseName, ))

            # Now create the temporary repository database
            self._attemptedToCreateRepository = True
            repository.reset()

            # Verify that the temporary repository database got created
            assert self.tempDatabaseName in getAllDatabaseNames(), (
                "Temp repo db=%s not found" % (self.tempDatabaseName, ))
        except:
            # Attempt to clean up
            self.stop()

            raise

    def stop(self):
        try:
            if self._attemptedToCreateRepository:
                self._attemptedToCreateRepository = False
                # Delete the temporary repository database, if any
                with ENGINE.connect() as connection:
                    connection.execute("DROP DATABASE IF EXISTS %s" %
                                       (self.tempDatabaseName, ))
        finally:
            if self._configPatchApplied:
                self._configPatch.stop()
            try:
                del repository.engineFactory.engine
            except AttributeError:
                pass
class ManagedTempRepository(object):
    """Context manager and function decorator that on entry patches the
  respository database name with a unique temp name and creates a temp
  repository; then drops the repository database on exit.

  This effectively redirects repository object transactions to the
  temporary database while in scope of ManagedTempRepository.

  NOTE: this affects repository access in the currently-executing process and
  its descendant processes; it has no impact on processes started externally or
  processes started without inherititing the environment variables of the
  current process.

  Sorry, but there is no class decorator capability provided at this time.

  Context Manager Example::

      with ManagedTempRepository(clientLabel=self.__class__.__name__) as repoCM:
        print repoCM.tempDatabaseName
        <do test logic>

  Function Decorator Example::

      @ManagedTempRepository(clientLabel="testSomething", kw="tempRepoPatch")
      def testSomething(self, tempRepoPatch):
        print tempRepoPatch.tempDatabaseName
        <do test logic>

  """
    REPO_CONFIG_NAME = CollectorsDbConfig.CONFIG_NAME
    REPO_BASE_CONFIG_DIR = CollectorsDbConfig.CONFIG_DIR
    REPO_SECTION_NAME = "repository"
    REPO_DATABASE_ATTR_NAME = "db"

    def __init__(self, clientLabel, kw=None):
        """
    clientLabel: this *relatively short* string will be used to construct the
      temporary database name. It shouldn't contain any characters that would
      make it inappropriate for a database name (no spaces, etc.)
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ManagedTempRepository. Ignored
      when this instance is used as context manager. Defaults to kw=None to
      avoid having it added to the keyword args.
    """
        self._kw = kw

        self._unaffiliatedEngine = collectorsdb.getUnaffiliatedEngine()

        dbNameFromConfig = CollectorsDbConfig().get(
            self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME)
        self.tempDatabaseName = "{original}_{label}_{uid}".format(
            original=dbNameFromConfig, label=clientLabel, uid=uuid.uuid1().hex)

        # Create a Config patch to override the Repository database name
        self._configPatch = ConfigAttributePatch(
            self.REPO_CONFIG_NAME,
            self.REPO_BASE_CONFIG_DIR,
            values=((self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME,
                     self.tempDatabaseName), ))
        self._configPatchApplied = False

        self._attemptedToCreateDatabase = False

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *args):
        self.stop()
        return False

    def __call__(self, f):
        """ Implement the function decorator """
        @functools.wraps(f)
        def applyTempRepositoryPatch(*args, **kwargs):
            self.start()
            try:
                if self._kw is not None:
                    kwargs[self._kw] = self
                return f(*args, **kwargs)
            finally:
                self.stop()

        return applyTempRepositoryPatch

    def start(self):
        # Removes possible left over cached engine
        # (needed if non-patched engine is run prior)
        collectorsdb.resetEngineSingleton()

        # Override the Repository database name
        try:
            self._configPatch.start()
            self._configPatchApplied = True

            # Now create the temporary repository database
            self._attemptedToCreateDatabase = True
            collectorsdb.reset(suppressPromptAndObliterateDatabase=True)

            # Verify that the temporary repository database got created
            numDbFound = self._unaffiliatedEngine.execute(
                "SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE "
                "`SCHEMA_NAME` = '{db}'".format(
                    db=self.tempDatabaseName)).scalar()
            assert numDbFound == 1, (
                "Temp repo db={db} not found (numFound={numFound})".format(
                    db=self.tempDatabaseName, numFound=numDbFound))
        except:
            # Attempt to clean up
            self.stop()

            raise

    def stop(self):
        try:
            if self._attemptedToCreateDatabase:
                self._attemptedToCreateDatabase = False
                # Drop the temporary repository database, if any
                self._unaffiliatedEngine.execute(
                    "DROP DATABASE IF EXISTS {db}".format(
                        db=self.tempDatabaseName))
        finally:
            if self._configPatchApplied:
                self._configPatch.stop()

            collectorsdb.resetEngineSingleton()

            # Dispose of the unaffiliated engine's connection pool
            self._unaffiliatedEngine.dispose()
コード例 #9
0
class RabbitmqVirtualHostPatch(object):
    """ An instance of this class may be used as a decorator, class decorator or
  Context Manager for overriding the default virtual host both in-proc and in
  subprocesses.

  On start: creates a temporary virtual host and patches the "virtual_host"
    attribute in the "connection" section of rabbitmq.conf configuration file

  On stop: deletes the temporary virtual host and unpatches the "virtual_host"
    rabbitmq configuraiton attribute.

  NOTE: the patch assumes that the code under test connects to RabbitMQ using
  the virtual_host from "virtual_host" attribute in the "connection" section of
  rabbitmq.conf configuration file.

  NOTE: this decorator will only decorate methods beginning with the
  `mock.patch.TEST_PREFIX` prefix (defaults to "test"). Please keep this in
  mind when decorating entire classes.
  """

    _RABBIT_MANAGEMENT_HEADERS = {"content-type": "application/json"}

    def __init__(self, clientLabel, kw=None, logger=logging):
        """
    clientLabel: this string will be used to construct the temporary endpoint
      names. The following characters are permitted, and it shouldn't be too
      long: [._a-zA-Z]. This may be helpful with diagnostics. A specific test
      class name (or similar) would make a reasonable clientLabel.
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of RabbitmqVirtualHostPatch.
      Ignored when this instance is used as context manager. Defaults to kw=None
      to avoid having it added to the keyword args.
    """
        self.active = False
        """ True when applied successfully; False after successfully removed or not
    applied """

        self._clientLabel = clientLabel
        self._kw = kw
        self._logger = logger

        self._cachedVirtualHost = None
        """ Name of override RabbotMQ virtual host """

        self._virtualHostCreated = False

        self._configPatch = None

    @property
    def _vhost(self):
        if self._cachedVirtualHost is None:
            self._cachedVirtualHost = "%s_%s" % (
                self._clientLabel,
                uuid.uuid1().hex,
            )
        return self._cachedVirtualHost

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *args):
        self.stop()
        return False

    def __call__(self, f):
        """ Implement the function or class decorator """
        if isinstance(f, types.TypeType):
            return self._decorateClass(f)

        @functools.wraps(f)
        def applyVirtualHostPatch(*args, **kwargs):
            self.start()
            try:
                if self._kw is not None:
                    kwargs[self._kw] = self
                return f(*args, **kwargs)
            finally:
                self.stop()

        return applyVirtualHostPatch

    def _decorateClass(self, targetClass):
        """ Decorate the test methods in the given class. Honors
    `mock.patch.TEST_PREFIX` for choosing which methods to wrap
    """
        for attrName in dir(targetClass):
            if attrName.startswith(patch.TEST_PREFIX):
                f = getattr(targetClass, attrName)
                if callable(f):
                    decoratedFunc = RabbitmqVirtualHostPatch(
                        self._clientLabel, self._kw, self._logger)(f)
                    setattr(targetClass, attrName, decoratedFunc)
        return targetClass

    def start(self):
        assert not self.active

        # Use RabbitMQ Management Plugin to create the new temporary vhost
        connectionParams = amqp.connection.RabbitmqManagementConnectionParams()

        url = "http://%s:%s/api/vhosts/%s" % (
            connectionParams.host, connectionParams.port, self._vhost)

        try:
            try:
                response = requests.put(
                    url,
                    headers=self._RABBIT_MANAGEMENT_HEADERS,
                    auth=(connectionParams.username,
                          connectionParams.password))

                response.raise_for_status()

                self._virtualHostCreated = True
                self._logger.info("%s: created temporary rabbitmq vhost=%s",
                                  self.__class__.__name__, self._vhost)
            except Exception:
                self._logger.exception(
                    "Attempt to create temporary vhost=%s failed. url=%r",
                    self._vhost, url)
                raise

            # Configure permissions on the new temporary vhost
            try:
                url = "http://%s:%s/api/permissions/%s/%s" % (
                    connectionParams.host, connectionParams.port, self._vhost,
                    connectionParams.username)

                response = requests.put(
                    url,
                    headers=self._RABBIT_MANAGEMENT_HEADERS,
                    data=json.dumps({
                        "configure": ".*",
                        "write": ".*",
                        "read": ".*"
                    }),
                    auth=(connectionParams.username,
                          connectionParams.password))

                response.raise_for_status()

                self._logger.info(
                    "%s: Configured persmissions on temporary rabbitmq vhost=%s",
                    self.__class__.__name__, self._vhost)
            except Exception:
                self._logger.exception(
                    "Attempt to configure premissions on vhost=%s failed. url=%r",
                    self._vhost, url)
                raise

            # Apply a config patch to override the rabbitmq virtual host to be
            # used by message_bus_connector and others
            rabbitmqConfig = amqp.connection.RabbitmqConfig()
            self._configPatch = ConfigAttributePatch(
                rabbitmqConfig.CONFIG_NAME, rabbitmqConfig.baseConfigDir,
                (("connection", "virtual_host", self._vhost), ))

            self._configPatch.start()

            self._logger.info("%s: overrode rabbitmq vhost=%s",
                              self.__class__.__name__, self._vhost)

            # Self-validation
            connectionParams = (
                amqp.connection.getRabbitmqConnectionParameters())
            actualVhost = connectionParams.vhost
            assert actualVhost == self._vhost, (
                "Expected vhost=%r, but got vhost=%r") % (self._vhost,
                                                          actualVhost)

        except Exception:
            self._logger.exception("patch failed, deleting vhost=%s",
                                   self._vhost)
            self._removePatches()
            raise

        self.active = True
        self._logger.info("%s: applied patch", self.__class__.__name__)

    def stop(self):
        assert self.active

        self._removePatches()
        self.active = False

    def _removePatches(self):
        """ NOTE: may be called intenrally to clean-up mid-application of patch
    """
        try:
            if self._configPatch is not None and self._configPatch.active:
                self._configPatch.stop()
        finally:
            if self._virtualHostCreated:
                self._deleteTemporaryVhost()
                self._virtualHostCreated = False

        self._logger.info("%s: removed patch", self.__class__.__name__)

    def _deleteTemporaryVhost(self):
        """ Delete a RabbitMQ virtual host """
        # Use RabbitMQ Management Plugin to delete the temporary vhost
        connectionParams = (
            amqp.connection.RabbitmqManagementConnectionParams())

        url = "http://%s:%s/api/vhosts/%s" % (
            connectionParams.host, connectionParams.port, self._vhost)

        try:
            response = requests.delete(url,
                                       headers=self._RABBIT_MANAGEMENT_HEADERS,
                                       auth=(connectionParams.username,
                                             connectionParams.password))

            response.raise_for_status()

            self._logger.info("%s: deleted temporary rabbitmq vhost=%s",
                              self.__class__.__name__, self._vhost)
        except Exception:
            self._logger.exception(
                "Attempt to delete temporary vhost=%s failed. url=%r",
                self._vhost, url)
            raise
コード例 #10
0
    def start(self):
        assert not self.active

        # Use RabbitMQ Management Plugin to create the new temporary vhost
        connectionParams = amqp.connection.RabbitmqManagementConnectionParams()

        url = "http://%s:%s/api/vhosts/%s" % (
            connectionParams.host, connectionParams.port, self._vhost)

        try:
            try:
                response = requests.put(
                    url,
                    headers=self._RABBIT_MANAGEMENT_HEADERS,
                    auth=(connectionParams.username,
                          connectionParams.password))

                response.raise_for_status()

                self._virtualHostCreated = True
                self._logger.info("%s: created temporary rabbitmq vhost=%s",
                                  self.__class__.__name__, self._vhost)
            except Exception:
                self._logger.exception(
                    "Attempt to create temporary vhost=%s failed. url=%r",
                    self._vhost, url)
                raise

            # Configure permissions on the new temporary vhost
            try:
                url = "http://%s:%s/api/permissions/%s/%s" % (
                    connectionParams.host, connectionParams.port, self._vhost,
                    connectionParams.username)

                response = requests.put(
                    url,
                    headers=self._RABBIT_MANAGEMENT_HEADERS,
                    data=json.dumps({
                        "configure": ".*",
                        "write": ".*",
                        "read": ".*"
                    }),
                    auth=(connectionParams.username,
                          connectionParams.password))

                response.raise_for_status()

                self._logger.info(
                    "%s: Configured persmissions on temporary rabbitmq vhost=%s",
                    self.__class__.__name__, self._vhost)
            except Exception:
                self._logger.exception(
                    "Attempt to configure premissions on vhost=%s failed. url=%r",
                    self._vhost, url)
                raise

            # Apply a config patch to override the rabbitmq virtual host to be
            # used by message_bus_connector and others
            rabbitmqConfig = amqp.connection.RabbitmqConfig()
            self._configPatch = ConfigAttributePatch(
                rabbitmqConfig.CONFIG_NAME, rabbitmqConfig.baseConfigDir,
                (("connection", "virtual_host", self._vhost), ))

            self._configPatch.start()

            self._logger.info("%s: overrode rabbitmq vhost=%s",
                              self.__class__.__name__, self._vhost)

            # Self-validation
            connectionParams = (
                amqp.connection.getRabbitmqConnectionParameters())
            actualVhost = connectionParams.vhost
            assert actualVhost == self._vhost, (
                "Expected vhost=%r, but got vhost=%r") % (self._vhost,
                                                          actualVhost)

        except Exception:
            self._logger.exception("patch failed, deleting vhost=%s",
                                   self._vhost)
            self._removePatches()
            raise

        self.active = True
        self._logger.info("%s: applied patch", self.__class__.__name__)
コード例 #11
0
class ManagedTempRepository(object):
  """ Context manager that on entry patches the respository database name with
  a unique temp name and creates the repository; then deletes the repository on
  exit.

  This effectively redirects repository object transactions to the
  temporary database while in scope of ManagedTempRepository.

  It may be used as a context manager or as a function decorator (sorry, but
  no class decorator capability at this time)

  Context Manager Example::

      with ManagedTempRepository(clientLabel=self.__class__.__name__) as repoCM:
        print repoCM.tempDatabaseName
        <do test logic>

  Function Decorator Example::

      @ManagedTempRepository(clientLabel="testSomething", kw="tempRepoPatch")
      def testSomething(self, tempRepoPatch):
        print tempRepoPatch.tempDatabaseName
        <do test logic>

  """
  REPO_CONFIG_NAME = config.CONFIG_NAME
  REPO_BASE_CONFIG_DIR = config.baseConfigDir
  REPO_SECTION_NAME = "repository"
  REPO_DATABASE_ATTR_NAME = "db"


  def __init__(self, clientLabel, kw=None):
    """
    clientLabel: this *relatively short* string will be used to construct the
      temporary database name. It shouldn't contain any characters that would
      make it inappropriate for a database name (no spaces, etc.)
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ManagedTempRepository. Ignored
      when this instance is used as context manager. Defaults to kw=None to
      avoid having it added to the keyword args.
    """
    self._kw = kw

    self.tempDatabaseName = "%s_%s_%s" % (self.getDatabaseNameFromConfig(),
                                          clientLabel, uuid.uuid1().hex)

    # Create a Config patch to override the Repository database name
    self._configPatch = ConfigAttributePatch(
      self.REPO_CONFIG_NAME,
      self.REPO_BASE_CONFIG_DIR,
      values=((self.REPO_SECTION_NAME, self.REPO_DATABASE_ATTR_NAME,
               self.tempDatabaseName),))
    self._configPatchApplied = False

    self._attemptedToCreateRepository = False


  @classmethod
  def getDatabaseNameFromConfig(cls):
    return config.get(cls.REPO_SECTION_NAME,
                      cls.REPO_DATABASE_ATTR_NAME)


  def __enter__(self):
    self.start()
    return self


  def __exit__(self, *args):
    self.stop()
    return False


  def __call__(self, f):
    """ Implement the function decorator """

    @functools.wraps(f)
    def applyTempRepositoryPatch(*args, **kwargs):
      self.start()
      try:
        if self._kw is not None:
          kwargs[self._kw] = self
        return f(*args, **kwargs)
      finally:
        self.stop()

    return applyTempRepositoryPatch


  def start(self):
    # Removes possible left over cached engine
    # (needed if non-patched engine is run prior)
    repository.engineFactory(reset=True)

    # Override the Repository database name
    try:
      self._configPatch.start()
      self._configPatchApplied = True

      # Verity that the database doesn't exist yet
      assert self.tempDatabaseName not in getAllDatabaseNames(), (
        "Temp repo db=%s already existed" % (self.tempDatabaseName,))

      # Now create the temporary repository database
      self._attemptedToCreateRepository = True
      repository.reset()

      # Verify that the temporary repository database got created
      assert self.tempDatabaseName in getAllDatabaseNames(), (
        "Temp repo db=%s not found" % (self.tempDatabaseName,))
    except:
      # Attempt to clean up
      self.stop()

      raise


  def stop(self):
    try:
      if self._attemptedToCreateRepository:
        self._attemptedToCreateRepository = False
        # Delete the temporary repository database, if any
        with ENGINE.connect() as connection:
          connection.execute(
            "DROP DATABASE IF EXISTS %s" % (self.tempDatabaseName,))
    finally:
      if self._configPatchApplied:
        self._configPatch.stop()
      try:
        del repository.engineFactory.engine
      except AttributeError:
        pass
コード例 #12
0
    def testMetricCollectorRun(self, createAdapterMock, repoMock,
                               metricStreamerMock, multiprocessingMock):
        metricsPerChunk = 4

        # Configure multiprocessing
        def mapAsync(fn, tasks):
            class _(object):
                def wait(self):
                    map(fn, tasks)

            return _()

        multiprocessingMock.Pool.return_value.map_async.side_effect = mapAsync
        multiprocessingMock.Pipe.side_effect = multiprocessing.Pipe
        multiprocessingMock.Manager = (Mock(return_value=(Mock(
            JoinableQueue=(Mock(side_effect=multiprocessing.JoinableQueue))))))

        metricPollInterval = 5

        now = datetime.datetime.today()

        resultsOfGetCloudwatchMetricsPendingDataCollection = [
            [], [_makeMetricMockInstance(metricPollInterval, now, 1)],
            [
                _makeMetricMockInstance(metricPollInterval, now, 2),
                _makeMetricMockInstance(metricPollInterval, now, 3)
            ],
            KeyboardInterrupt("Fake KeyboardInterrupt to interrupt run-loop")
        ]

        repoMock.getCloudwatchMetricsPendingDataCollection.side_effect = (
            resultsOfGetCloudwatchMetricsPendingDataCollection)
        repoMock.retryOnTransientErrors.side_effect = lambda f: f

        # Configure the metric_collector.adapters module mock
        mockResults = [([], now),
                       ([[now, 1]] * metricsPerChunk,
                        now + datetime.timedelta(seconds=metricPollInterval)),
                       ([[now, 2]] * (metricsPerChunk * 5 + 1),
                        now + datetime.timedelta(seconds=metricPollInterval))]

        adapterInstanceMock = Mock(spec_set=_CloudwatchDatasourceAdapter)
        adapterInstanceMock.getMetricData.side_effect = mockResults
        adapterInstanceMock.getMetricResourceStatus.return_value = "status"

        createAdapterMock.return_value = adapterInstanceMock

        # Now, run MetricCollector and check results
        resultOfRunCollector = dict()

        def runCollector():
            try:
                collector = metric_collector.MetricCollector()
                resultOfRunCollector["returnCode"] = collector.run()
            except:
                resultOfRunCollector["exception"] = sys.exc_info()[1]
                raise

        with ConfigAttributePatch(
                YOMP.app.config.CONFIG_NAME, YOMP.app.config.baseConfigDir,
            (("metric_streamer", "chunk_size", str(metricsPerChunk)), )):

            # We run it in a thread in order to detect if MetricCollector.run fails to
            # return and to make sure that the test script will finish (in case run
            # doesn't)
            thread = threading.Thread(target=runCollector)
            thread.setDaemon(True)
            thread.start()

            thread.join(60)
            self.assertFalse(thread.isAlive())

        self.assertIn("exception", resultOfRunCollector)
        self.assertIsInstance(resultOfRunCollector["exception"],
                              KeyboardInterrupt)
        self.assertNotIn("returnCode", resultOfRunCollector)

        self.assertEqual(adapterInstanceMock.getMetricData.call_count,
                         len(mockResults))

        # Validate that all expected data points were published

        # ... validate metricIDs
        metricIDs = [
            kwargs["metricID"] for (args, kwargs) in
            metricStreamerMock.return_value.streamMetricData.call_args_list
        ]
        expectedMetricIDs = []
        getDataIndex = 0
        for metrics in resultsOfGetCloudwatchMetricsPendingDataCollection:
            if not metrics or isinstance(metrics, BaseException):
                continue
            for m in metrics:
                results = mockResults[getDataIndex][0]
                if results:
                    expectedMetricIDs.append(m.uid)

                getDataIndex += 1

        self.assertEqual(metricIDs, expectedMetricIDs)

        # ... validate data points
        dataPoints = list(
            itertools.chain(*[
                args[0] for (args, kwargs) in
                metricStreamerMock.return_value.streamMetricData.call_args_list
            ]))
        expectedDataPoints = list(
            itertools.chain(
                *[copy.deepcopy(r[0]) for r in mockResults if r[0]]))
        self.assertEqual(dataPoints, expectedDataPoints)

        # Assert instance status collected
        self.assertTrue(adapterInstanceMock.getMetricResourceStatus.called)

        # saveMetricInstanceStatus uses a connection, not an engine
        mockConnection = (repoMock.engineFactory.return_value.begin.
                          return_value.__enter__.return_value)

        # Assert instance status recorded
        for metricObj in resultsOfGetCloudwatchMetricsPendingDataCollection[1]:
            repoMock.saveMetricInstanceStatus.assert_any_call(
                mockConnection, metricObj.server,
                adapterInstanceMock.getMetricResourceStatus.return_value)

        for metricObj in resultsOfGetCloudwatchMetricsPendingDataCollection[2]:
            repoMock.saveMetricInstanceStatus.assert_any_call(
                mockConnection, metricObj.server,
                adapterInstanceMock.getMetricResourceStatus.return_value)
class ModelCheckpointStoragePatch(object):
    """ An instance of this class may be used as a decorator, class decorator
  or Context Manager for redirecting ModelCheckpoint storage to a temporary
  directory in-proc and in child processes.
  """
    def __init__(self, kw=None, logger=logging):
        """
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ModelCheckpointStoragePatch.
      Ignored when this instance is used as context manager. Defaults to kw=None
      to avoid having it added to the keyword args.
    """
        # True when applied successfully; False after successfully removed or not
        # applied
        self.active = False

        self._kw = kw
        self._logger = logger
        self._tempParentDir = None
        self.tempModelCheckpointDir = None
        self._configPatch = None

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *args):
        self.stop()
        return False

    def __call__(self, f):
        """ Implement the function or class decorator """
        if isinstance(f, types.TypeType):
            return self._decorateClass(f)

        @functools.wraps(f)
        def applyModelCheckpointPatch(*args, **kwargs):
            self.start()
            try:
                if self._kw is not None:
                    kwargs[self._kw] = self
                return f(*args, **kwargs)
            finally:
                self.stop()

        return applyModelCheckpointPatch

    def _decorateClass(self, targetClass):
        """ Decorate the test methods in the given class. Honors
    `mock.patch.TEST_PREFIX` for choosing which methods to wrap
    """
        for attrName in dir(targetClass):
            if attrName.startswith(patch.TEST_PREFIX):
                f = getattr(targetClass, attrName)
                if callable(f):
                    decoratedFunc = ModelCheckpointStoragePatch(
                        self._kw, self._logger)(f)
                    setattr(targetClass, attrName, decoratedFunc)
        return targetClass

    def start(self):
        assert not self.active

        self._tempParentDir = tempfile.mkdtemp(prefix=self.__class__.__name__)

        self.tempModelCheckpointDir = os.path.join(self._tempParentDir,
                                                   "tempStorageRoot")
        os.mkdir(self.tempModelCheckpointDir)

        self._configPatch = ConfigAttributePatch(
            "model-checkpoint.conf", os.environ.get("APPLICATION_CONFIG_PATH"),
            (("storage", "root", self.tempModelCheckpointDir), ))

        self._configPatch.start()

        self.active = True
        self._logger.info("%s: redirected model checkpoint storage to %s",
                          self.__class__.__name__, self.tempModelCheckpointDir)

    def stop(self):
        self._configPatch.stop()

        shutil.rmtree(self._tempParentDir)

        self.active = False
        self._logger.info("%s: removed model checkpoint storage override %s",
                          self.__class__.__name__, self.tempModelCheckpointDir)
コード例 #14
0
class RabbitmqVirtualHostPatch(object):
  """ An instance of this class may be used as a decorator, class decorator or
  Context Manager for overriding the default virtual host both in-proc and in
  subprocesses.

  On start: creates a temporary virtual host and patches the "virtual_host"
    attribute in the "connection" section of grok's rabbitmq.conf configuration
    file

  On stop: deletes the temporary virtual host and unpatches the "virtual_host"
    rabbitmq configuraiton attribute.

  NOTE: the patch assumes that the code under test connects to RabbitMQ using
  the virtual_host from "virtual_host" attribute in the "connection" section of
  grok's rabbitmq.conf configuration file.

  NOTE: this decorator will only decorate methods beginning with the
  `mock.patch.TEST_PREFIX` prefix (defaults to "test"). Please keep this in
  mind when decorating entire classes.
  """

  _RABBIT_MANAGEMENT_HEADERS = {"content-type": "application/json"}

  def __init__(self, clientLabel, kw=None, logger=logging):
    """
    clientLabel: this string will be used to construct the temporary endpoint
      names. The following characters are permitted, and it shouldn't be too
      long: [._a-zA-Z]. This may be helpful with diagnostics. A specific test
      class name (or similar) would make a reasonable clientLabel.
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of RabbitmqVirtualHostPatch.
      Ignored when this instance is used as context manager. Defaults to kw=None
      to avoid having it added to the keyword args.
    """
    self.active = False
    """ True when applied successfully; False after successfully removed or not
    applied """

    self._clientLabel = clientLabel
    self._kw = kw
    self._logger = logger

    self._cachedVirtualHost = None
    """ Name of override RabbotMQ virtual host """

    self._virtualHostCreated = False

    self._configPatch = None


  @property
  def _vhost(self):
    if self._cachedVirtualHost is None:
      self._cachedVirtualHost = "%s_%s" % (self._clientLabel, uuid.uuid1().hex,)
    return self._cachedVirtualHost


  def __enter__(self):
    self.start()
    return self


  def __exit__(self, *args):
    self.stop()
    return False


  def __call__(self, f):
    """ Implement the function or class decorator """
    if isinstance(f, types.TypeType):
      return self._decorateClass(f)

    @functools.wraps(f)
    def applyVirtualHostPatch(*args, **kwargs):
      self.start()
      try:
        if self._kw is not None:
          kwargs[self._kw] = self
        return f(*args, **kwargs)
      finally:
        self.stop()

    return applyVirtualHostPatch


  def _decorateClass(self, targetClass):
    """ Decorate the test methods in the given class. Honors
    `mock.patch.TEST_PREFIX` for choosing which methods to wrap
    """
    for attrName in dir(targetClass):
      if attrName.startswith(patch.TEST_PREFIX):
        f = getattr(targetClass, attrName)
        if callable(f):
          decoratedFunc = RabbitmqVirtualHostPatch(
            self._clientLabel, self._kw, self._logger)(f)
          setattr(targetClass, attrName, decoratedFunc)
    return targetClass


  def start(self):
    assert not self.active

    # Use RabbitMQ Management Plugin to create the new temporary vhost
    connectionParams = amqp.connection.RabbitmqManagementConnectionParams()

    url = "http://%s:%s/api/vhosts/%s" % (
      connectionParams.host, connectionParams.port, self._vhost)

    try:
      try:
        response = requests.put(
          url,
          headers=self._RABBIT_MANAGEMENT_HEADERS,
          auth=(
            connectionParams.username,
            connectionParams.password))

        response.raise_for_status()

        self._virtualHostCreated = True
        self._logger.info("%s: created temporary rabbitmq vhost=%s",
                          self.__class__.__name__, self._vhost)
      except Exception:
        self._logger.exception(
          "Attempt to create temporary vhost=%s failed. url=%r",
          self._vhost, url)
        raise

      # Configure permissions on the new temporary vhost
      try:
        url = "http://%s:%s/api/permissions/%s/%s" % (
          connectionParams.host, connectionParams.port,
          self._vhost, connectionParams.username)

        response = requests.put(
          url,
          headers=self._RABBIT_MANAGEMENT_HEADERS,
          data=json.dumps({"configure": ".*", "write": ".*", "read": ".*"}),
          auth=(
            connectionParams.username,
            connectionParams.password))

        response.raise_for_status()

        self._logger.info(
          "%s: Configured persmissions on temporary rabbitmq vhost=%s",
          self.__class__.__name__, self._vhost)
      except Exception:
        self._logger.exception(
          "Attempt to configure premissions on vhost=%s failed. url=%r",
          self._vhost, url)
        raise

      # Apply a config patch to override the rabbitmq virtual host to be
      # used by message_bus_connector and others
      rabbitmqConfig = amqp.connection.RabbitmqConfig()
      self._configPatch = ConfigAttributePatch(
        rabbitmqConfig.CONFIG_NAME,
        rabbitmqConfig.baseConfigDir,
        (("connection", "virtual_host", self._vhost),))

      self._configPatch.start()

      self._logger.info("%s: overrode rabbitmq vhost=%s",
                        self.__class__.__name__, self._vhost)

      # Self-validation
      connectionParams = (
        amqp.connection.getRabbitmqConnectionParameters())
      actualVhost = connectionParams.vhost
      assert actualVhost == self._vhost, (
        "Expected vhost=%r, but got vhost=%r") % (self._vhost, actualVhost)

    except Exception:
      self._logger.exception("patch failed, deleting vhost=%s", self._vhost)
      self._removePatches()
      raise

    self.active = True
    self._logger.info("%s: applied patch", self.__class__.__name__)


  def stop(self):
    assert self.active

    self._removePatches()
    self.active = False


  def _removePatches(self):
    """ NOTE: may be called intenrally to clean-up mid-application of patch
    """
    try:
      if self._configPatch is not None and self._configPatch.active:
        self._configPatch.stop()
    finally:
      if self._virtualHostCreated:
        self._deleteTemporaryVhost()
        self._virtualHostCreated = False

    self._logger.info("%s: removed patch", self.__class__.__name__)


  def _deleteTemporaryVhost(self):
    """ Delete a RabbitMQ virtual host """
    # Use RabbitMQ Management Plugin to delete the temporary vhost
    connectionParams = (
        amqp.connection.RabbitmqManagementConnectionParams())

    url = "http://%s:%s/api/vhosts/%s" % (
      connectionParams.host, connectionParams.port, self._vhost)

    try:
      response = requests.delete(
        url,
        headers=self._RABBIT_MANAGEMENT_HEADERS,
        auth=(
          connectionParams.username,
          connectionParams.password))

      response.raise_for_status()

      self._logger.info("%s: deleted temporary rabbitmq vhost=%s",
                        self.__class__.__name__, self._vhost)
    except Exception:
      self._logger.exception(
        "Attempt to delete temporary vhost=%s failed. url=%r",
        self._vhost, url)
      raise
コード例 #15
0
  def start(self):
    assert not self.active

    # Use RabbitMQ Management Plugin to create the new temporary vhost
    connectionParams = amqp.connection.RabbitmqManagementConnectionParams()

    url = "http://%s:%s/api/vhosts/%s" % (
      connectionParams.host, connectionParams.port, self._vhost)

    try:
      try:
        response = requests.put(
          url,
          headers=self._RABBIT_MANAGEMENT_HEADERS,
          auth=(
            connectionParams.username,
            connectionParams.password))

        response.raise_for_status()

        self._virtualHostCreated = True
        self._logger.info("%s: created temporary rabbitmq vhost=%s",
                          self.__class__.__name__, self._vhost)
      except Exception:
        self._logger.exception(
          "Attempt to create temporary vhost=%s failed. url=%r",
          self._vhost, url)
        raise

      # Configure permissions on the new temporary vhost
      try:
        url = "http://%s:%s/api/permissions/%s/%s" % (
          connectionParams.host, connectionParams.port,
          self._vhost, connectionParams.username)

        response = requests.put(
          url,
          headers=self._RABBIT_MANAGEMENT_HEADERS,
          data=json.dumps({"configure": ".*", "write": ".*", "read": ".*"}),
          auth=(
            connectionParams.username,
            connectionParams.password))

        response.raise_for_status()

        self._logger.info(
          "%s: Configured persmissions on temporary rabbitmq vhost=%s",
          self.__class__.__name__, self._vhost)
      except Exception:
        self._logger.exception(
          "Attempt to configure premissions on vhost=%s failed. url=%r",
          self._vhost, url)
        raise

      # Apply a config patch to override the rabbitmq virtual host to be
      # used by message_bus_connector and others
      rabbitmqConfig = amqp.connection.RabbitmqConfig()
      self._configPatch = ConfigAttributePatch(
        rabbitmqConfig.CONFIG_NAME,
        rabbitmqConfig.baseConfigDir,
        (("connection", "virtual_host", self._vhost),))

      self._configPatch.start()

      self._logger.info("%s: overrode rabbitmq vhost=%s",
                        self.__class__.__name__, self._vhost)

      # Self-validation
      connectionParams = (
        amqp.connection.getRabbitmqConnectionParameters())
      actualVhost = connectionParams.vhost
      assert actualVhost == self._vhost, (
        "Expected vhost=%r, but got vhost=%r") % (self._vhost, actualVhost)

    except Exception:
      self._logger.exception("patch failed, deleting vhost=%s", self._vhost)
      self._removePatches()
      raise

    self.active = True
    self._logger.info("%s: applied patch", self.__class__.__name__)
コード例 #16
0
class ModelCheckpointStoragePatch(object):
  """ An instance of this class may be used as a decorator, class decorator
  or Context Manager for redirecting ModelCheckpoint storage to a temporary
  directory in-proc and in child processes.
  """

  def __init__(self, kw=None, logger=logging):
    """
    kw: name of keyword argument to add to the decorated function(s). Its value
      will be a reference to this instance of ModelCheckpointStoragePatch.
      Ignored when this instance is used as context manager. Defaults to kw=None
      to avoid having it added to the keyword args.
    """
    # True when applied successfully; False after successfully removed or not
    # applied
    self.active = False

    self._kw = kw
    self._logger = logger
    self._tempParentDir = None
    self.tempModelCheckpointDir = None
    self._configPatch = None


  def __enter__(self):
    self.start()
    return self


  def __exit__(self, *args):
    self.stop()
    return False


  def __call__(self, f):
    """ Implement the function or class decorator """
    if isinstance(f, types.TypeType):
      return self._decorateClass(f)

    @functools.wraps(f)
    def applyModelCheckpointPatch(*args, **kwargs):
      self.start()
      try:
        if self._kw is not None:
          kwargs[self._kw] = self
        return f(*args, **kwargs)
      finally:
        self.stop()

    return applyModelCheckpointPatch


  def _decorateClass(self, targetClass):
    """ Decorate the test methods in the given class. Honors
    `mock.patch.TEST_PREFIX` for choosing which methods to wrap
    """
    for attrName in dir(targetClass):
      if attrName.startswith(patch.TEST_PREFIX):
        f = getattr(targetClass, attrName)
        if callable(f):
          decoratedFunc = ModelCheckpointStoragePatch(
            self._kw, self._logger)(f)
          setattr(targetClass, attrName, decoratedFunc)
    return targetClass


  def start(self):
    assert not self.active

    self._tempParentDir = tempfile.mkdtemp(
      prefix=self.__class__.__name__)

    self.tempModelCheckpointDir = os.path.join(self._tempParentDir,
                                               "tempStorageRoot")
    os.mkdir(self.tempModelCheckpointDir)

    self._configPatch = ConfigAttributePatch(
      "model-checkpoint.conf",
      os.environ.get("APPLICATION_CONFIG_PATH"),
      (("storage", "root", self.tempModelCheckpointDir),))

    self._configPatch.start()

    self.active = True
    self._logger.info("%s: redirected model checkpoint storage to %s",
                      self.__class__.__name__, self.tempModelCheckpointDir)


  def stop(self):
    self._configPatch.stop()

    shutil.rmtree(self._tempParentDir)

    self.active = False
    self._logger.info("%s: removed model checkpoint storage override %s",
                      self.__class__.__name__, self.tempModelCheckpointDir)