Exemple #1
0
    def test_initial_interval(self, random, mean):
        """
        When constructed without a value for ``last_run``,
        ``lease_maintenance_service`` schedules its first run to take place
        after an interval that falls uniformly in range centered on ``mean``
        with a size of ``range``.
        """
        clock = Clock()
        # Construct a range that fits in with the mean
        range_ = timedelta(
            seconds=random.uniform(0, mean.total_seconds()),
        )

        service = lease_maintenance_service(
            dummy_maintain_leases,
            clock,
            FilePath(self.useFixture(TempDir()).join(u"last-run")),
            random,
            mean,
            range_,
        )
        service.startService()
        [maintenance_call] = clock.getDelayedCalls()

        datetime_now = datetime.utcfromtimestamp(clock.seconds())
        low = datetime_now + mean - (range_ / 2)
        high = datetime_now + mean + (range_ / 2)
        self.assertThat(
            datetime.utcfromtimestamp(maintenance_call.getTime()),
            between(low, high),
        )
Exemple #2
0
    def test_download_is_initiated_in_new_thread(self):
        clock = Clock()
        maas_meta_last_modified = self.patch(
            tftppath, 'maas_meta_last_modified')
        one_week = timedelta(minutes=15).total_seconds()
        maas_meta_last_modified.return_value = clock.seconds() - one_week
        http_proxy = factory.make_simple_http_url()
        https_proxy = factory.make_simple_http_url()
        rpc_client = Mock()
        client_call = Mock()
        client_call.side_effect = [
            defer.succeed(dict(sources=sentinel.sources)),
            defer.succeed(dict(
                http=urlparse(http_proxy),
                https=urlparse(https_proxy))),
            ]
        rpc_client.getClientNow.return_value = defer.succeed(client_call)
        rpc_client.maas_url = factory.make_simple_http_url()

        # We could patch out 'import_boot_images' instead here but I
        # don't do that for 2 reasons:
        # 1. It requires spinning the reactor again before being able to
        # test the result.
        # 2. It means there's no thread to clean up after the test.
        deferToThread = self.patch(boot_images, 'deferToThread')
        deferToThread.return_value = defer.succeed(None)
        service = ImageDownloadService(
            rpc_client, sentinel.tftp_root, clock)
        service.startService()
        self.assertThat(
            deferToThread, MockCalledOnceWith(
                _run_import, sentinel.sources, rpc_client.maas_url,
                http_proxy=http_proxy, https_proxy=https_proxy))
Exemple #3
0
	def test_cssRewriterFixesUrls(self):
		"""
		The CSS rewriter appends ?cachebreakers to the url(...)s inside
		the .css file.  If a file mentioned by a url(...) is modified, the
		processed .css is updated.
		"""
		clock = Clock()
		fc = FileCache(lambda: clock.seconds(), 1)
		parent, t = self._makeTree()
		root = BetterFile(parent.path, fileCache=fc, rewriteCss=True)
		site = server.Site(root)

		def requestStyleCss():
			return self._requestPostpathAndRender(
				root, ['sub', 'style.css'], path='/sub/style.css', site=site)

		d = requestStyleCss()

		expect = """\
/* CSSResource processed %(md5original)s */
div { background-image: url(http://127.0.0.1/not-modified.png); }
td { background-image: url(https://127.0.0.1/not-modified.png); }
p { background-image: url(../one.png?cb=%(md5one)s); }
q { background-image: url(two.png?cb=%(md5two)s); }
b { background-image: url(sub%%20sub/three.png?cb=%(md5three)s); }
i { background-image: url(/sub/sub%%20sub/three.png?cb=%(md5three)s); }
"""

		def assertCacheBrokenLinks((request, child)):
			out = "".join(request.written)
			self.assertEqual(expect % t, out,
				"\nExpected:\n\n%s\n\nGot:\n\n%s" % (expect % t, out))
			expectedBreaker = hashlib.md5(expect % t).hexdigest()
			self.assertEqual(expectedBreaker, child.getCacheBreaker())
		d.addCallback(assertCacheBrokenLinks)

		def modifyThreePngAndMakeRequest(_):
			parent.child('sub').child('sub sub').child('three.png').setContent("replacement")
			return requestStyleCss()
		d.addCallback(modifyThreePngAndMakeRequest)

		def assertNotUpdatedLinks((request, child)):
			out = "".join(request.written)
			# Still the same links, because we didn't advance the clock.
			self.assertEqual(expect % t, out)
		d.addCallback(assertNotUpdatedLinks)

		def advanceClockAndMakeRequest(_):
			clock.advance(1)
			return requestStyleCss()
		d.addCallback(advanceClockAndMakeRequest)

		def assertUpdatedLinks((request, child)):
			out = "".join(request.written)
			t2 = t.copy()
			t2['md5three'] = t['md5replacement']
			self.assertEqual(expect % t2, out)
		d.addCallback(assertUpdatedLinks)

		return d
Exemple #4
0
	def test_rewriteCss(self):
		"""
		Test that CSS processing works, and verify the header.
		"""
		clock = Clock()
		fc = FileCache(lambda: clock.seconds(), 1)
		temp = FilePath(self.mktemp() + '.css')
		with temp.open('wb') as f:
			f.write("p { color: red; }\n")

		# BetterFile(temp.path) would not work because the processing happens
		# in getChild.  So, create a BetterFile for the .css file's parent dir.
		bf = BetterFile(temp.parent().path, fileCache=fc, rewriteCss=True)
		d = self._requestPostpathAndRender(bf, [temp.basename()])

		headerRe = re.compile(r"/\* CSSResource processed ([0-9a-f]{32}?) \*/")
		def assertProcessedContent((request, child)):
			out = "".join(request.written)
			lines = out.split("\n")
			self.assertTrue(re.match(headerRe, lines[0]), lines[0])
			self.assertEqual("p { color: red; }", lines[1])
			self.assertEqual("", lines[2])
			self.assertEqual(3, len(lines))
		d.addCallback(assertProcessedContent)
		return d
Exemple #5
0
 def test_initiates_download_if_15_minutes_has_passed(self):
     clock = Clock()
     service = ImageDownloadService(sentinel.service, sentinel.tftp_root,
                                    clock)
     _start_download = self.patch_download(service, None)
     one_week_ago = clock.seconds() - timedelta(minutes=15).total_seconds()
     self.patch(tftppath,
                "maas_meta_last_modified").return_value = one_week_ago
     service.startService()
     self.assertThat(_start_download, MockCalledOnceWith())
Exemple #6
0
 def test_no_download_if_15_minutes_has_not_passed(self):
     clock = Clock()
     service = ImageDownloadService(sentinel.service, sentinel.tftp_root,
                                    clock)
     _start_download = self.patch_download(service, None)
     one_week = timedelta(minutes=15).total_seconds()
     self.patch(tftppath,
                "maas_meta_last_modified").return_value = clock.seconds()
     clock.advance(one_week - 1)
     service.startService()
     self.assertThat(_start_download, MockNotCalled())
Exemple #7
0
    def test_periodic_client_enquire_link(self):
        request_body_a = (
            "<ENQRequest>"
            "<requestId>0</requestId>"
            "<enqCmd>ENQUIRELINK</enqCmd>"
            "</ENQRequest>")
        expected_request_packet_a = utils.mk_packet('0', request_body_a)

        response_body_a = (
            "<ENQResponse>"
            "<requestId>0</requestId>"
            "<enqCmd>ENQUIRELINKRSP</enqCmd>"
            "</ENQResponse>")
        response_packet_a = utils.mk_packet('0', response_body_a)
        self.server.responses[expected_request_packet_a] = response_packet_a

        request_body_b = (
            "<ENQRequest>"
            "<requestId>1</requestId>"
            "<enqCmd>ENQUIRELINK</enqCmd>"
            "</ENQRequest>")
        expected_request_packet_b = utils.mk_packet('1', request_body_b)

        response_body_b = (
            "<ENQResponse>"
            "<requestId>1</requestId>"
            "<enqCmd>ENQUIRELINKRSP</enqCmd>"
            "</ENQResponse>")
        response_packet_b = utils.mk_packet('1', response_body_b)
        self.server.responses[expected_request_packet_b] = response_packet_b

        clock = Clock()
        t0 = clock.seconds()
        self.client.clock = clock
        self.client.enquire_link_interval = 120
        self.client.timeout_period = 20
        self.client.authenticated = True
        self.client.start_periodic_enquire_link()

        # advance to just after the first enquire link request
        clock.advance(0.01)
        self.assert_next_timeout(t0 + 20)

        # wait for the first enquire link response
        yield self.client.wait_for_data()
        self.assert_timeout_cancelled()

        # advance to just after the second enquire link request
        clock.advance(120.01)
        self.assert_next_timeout(t0 + 140)

        # wait for the second enquire link response
        yield self.client.wait_for_data()
        self.assert_timeout_cancelled()
Exemple #8
0
	def test_httpRequest(self):
		clock = Clock()
		rco = ResponseCacheOptions(
			cacheTime=3600, httpCachePublic=False, httpsCachePublic=True)
		request = DummyRequest([])

		setCachingHeadersOnRequest(request, rco, getTime=lambda: clock.seconds())
		self.assertEqual({
			'Cache-Control': ['max-age=3600, private'],
			'Date': ['Thu, 01 Jan 1970 00:00:00 GMT'],
			'Expires': ['Thu, 01 Jan 1970 01:00:00 GMT']},
		dict(request.responseHeaders.getAllRawHeaders()))
Exemple #9
0
    def is_locked(self, reactor: Clock) -> bool:
        """ Returns if the unspent tx is locked or available to be spent

            :param reactor: reactor to get the current time
            :type reactor: :py:class:`twisted.internet.Reactor`

            :return: if the unspent tx is locked
            :rtype: bool
        """
        if self.timelock is None or self.timelock < int(reactor.seconds()):
            return False
        else:
            return True
Exemple #10
0
	def test_noCache(self):
		"""
		If C{cacheTime} is 0, appropriate headers are set.
		"""
		clock = Clock()
		# Even though these are both public=True, it correctly sets ", private".
		rco = ResponseCacheOptions(
			cacheTime=0, httpCachePublic=True, httpsCachePublic=True)
		request = DummyRequest([])
		setCachingHeadersOnRequest(request, rco, getTime=lambda: clock.seconds())

		self.assertEqual({
			'Cache-Control': ['max-age=0, private'],
			# A Date header is set even in this case.
			'Date': ['Thu, 01 Jan 1970 00:00:00 GMT'],
			'Expires': ['-1']},
		dict(request.responseHeaders.getAllRawHeaders()))
Exemple #11
0
	def test_cssCached(self):
		"""
		The processed CSS file is cached, and updated when the underlying
		file changes.
		"""
		clock = Clock()
		fc = FileCache(lambda: clock.seconds(), 1)
		temp = FilePath(self.mktemp() + '.css')
		temp.setContent("p { color: red; }\n")

		bf = BetterFile(temp.parent().path, fileCache=fc, rewriteCss=True)
		d = self._requestPostpathAndRender(bf, [temp.basename()])

		def assertColorRed((request, child)):
			lines = "".join(request.written).split("\n")
			self.assertEqual(["p { color: red; }", ""], lines[1:])
		d.addCallback(assertColorRed)

		def modifyUnderlyingAndMakeRequest(_):
			with temp.open('wb') as f:
				f.write("p { color: green; }\n")
			d = self._requestPostpathAndRender(bf, [temp.basename()])
			return d
		d.addCallback(modifyUnderlyingAndMakeRequest)

		def assertStillColorRed((request, child)):
			lines = "".join(request.written).split("\n")
			self.assertEqual(["p { color: red; }", ""], lines[1:])
		d.addCallback(assertStillColorRed)

		def advanceClockAndMakeRequest(_):
			clock.advance(1)
			d = self._requestPostpathAndRender(bf, [temp.basename()])
			return d
		d.addCallback(advanceClockAndMakeRequest)

		def assertColorGreen((request, child)):
			lines = "".join(request.written).split("\n")
			self.assertEqual(["p { color: green; }", ""], lines[1:])
		d.addCallback(assertColorGreen)

		return d
class _SharedMixin(SystemTestMixin):
    """Base class for Foolscap and HTTP mixins."""

    SKIP_TESTS = set()  # type: Set[str]

    def _get_istorage_server(self):
        raise NotImplementedError("implement in subclass")

    @inlineCallbacks
    def setUp(self):
        if self._testMethodName in self.SKIP_TESTS:
            raise SkipTest("Test {} is still not supported".format(
                self._testMethodName))

        AsyncTestCase.setUp(self)

        self.basedir = "test_istorageserver/" + self.id()
        yield SystemTestMixin.setUp(self)
        yield self.set_up_nodes(1)
        self.server = None
        for s in self.clients[0].services:
            if isinstance(s, StorageServer):
                self.server = s
                break
        assert self.server is not None, "Couldn't find StorageServer"
        self._clock = Clock()
        self._clock.advance(123456)
        self.server._clock = self._clock
        self.storage_client = yield self._get_istorage_server()

    def fake_time(self):
        """Return the current fake, test-controlled, time."""
        return self._clock.seconds()

    def fake_sleep(self, seconds):
        """Advance the fake time by the given number of seconds."""
        self._clock.advance(seconds)

    @inlineCallbacks
    def tearDown(self):
        AsyncTestCase.tearDown(self)
        yield SystemTestMixin.tearDown(self)
Exemple #13
0
    def ret():
        error = [None]

        clock = Clock()

        d = f(clock)

        @d.addErrback
        def on_error(f):
            error[0] = f

        while True:
            time_to_wait = max([0] + [call.getTime() - clock.seconds() for call in clock.getDelayedCalls()])
            if time_to_wait == 0:
                break
            else:
                clock.advance(time_to_wait)

        if error[0]:
            error[0].raiseException()
Exemple #14
0
	def test_requestAlreadyHasHeaders(self):
		"""
		If the request passed to L{setCachingHeadersOnRequest} already has headers,
		existing Date/Expires/Cache-Control headers are replaced, and
		irrelevant ones are kept.
		"""
		clock = Clock()
		rco = ResponseCacheOptions(
			cacheTime=3600, httpCachePublic=False, httpsCachePublic=True)
		request = DummyRequest([])
		request.responseHeaders.setRawHeaders('cache-control', ['X', 'Y'])
		request.responseHeaders.setRawHeaders('date', ['whenever'])
		request.responseHeaders.setRawHeaders('expires', ['sometime'])
		request.responseHeaders.setRawHeaders('extra', ['one', 'two'])

		setCachingHeadersOnRequest(request, rco, getTime=lambda: clock.seconds())
		self.assertEqual({
			'Cache-Control': ['max-age=3600, private'],
			'Date': ['Thu, 01 Jan 1970 00:00:00 GMT'],
			'Expires': ['Thu, 01 Jan 1970 01:00:00 GMT'],
			'Extra': ['one', 'two']},
		dict(request.responseHeaders.getAllRawHeaders()))
Exemple #15
0
class APITestsMixin(APIAssertionsMixin):
    """
    Helpers for writing tests for the Docker Volume Plugin API.
    """
    NODE_A = uuid4()
    NODE_B = uuid4()

    def initialize(self):
        """
        Create initial objects for the ``VolumePlugin``.
        """
        self.volume_plugin_reactor = Clock()
        self.flocker_client = SimpleCountingProxy(FakeFlockerClient())

    def test_pluginactivate(self):
        """
        ``/Plugins.Activate`` indicates the plugin is a volume driver.
        """
        # Docker 1.8, at least, sends "null" as the body. Our test
        # infrastructure has the opposite bug so just going to send some
        # other garbage as the body (12345) to demonstrate that it's
        # ignored as per the spec which declares no body.
        return self.assertResult(b"POST", b"/Plugin.Activate", 12345, OK,
                                 {u"Implements": [u"VolumeDriver"]})

    def test_remove(self):
        """
        ``/VolumeDriver.Remove`` returns a successful result.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Remove",
                                 {u"Name": u"vol"}, OK, {u"Err": None})

    def test_unmount(self):
        """
        ``/VolumeDriver.Unmount`` returns a successful result.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Unmount",
                                 {u"Name": u"vol"}, OK, {u"Err": None})

    def test_create_with_opts(self):
        """
        Calling the ``/VolumerDriver.Create`` API with an ``Opts`` value
        in the request body JSON ignores this parameter and creates
        a volume with the given name.
        """
        name = u"testvolume"
        d = self.assertResult(b"POST", b"/VolumeDriver.Create",
                              {u"Name": name, 'Opts': {'ignored': 'ignored'}},
                              OK, {u"Err": None})
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(self.assertItemsEqual, [
            Dataset(dataset_id=UUID(dataset_id_from_name(name)),
                    primary=self.NODE_A,
                    maximum_size=DEFAULT_SIZE,
                    metadata={u"name": name})])
        return d

    def create(self, name):
        """
        Call the ``/VolumeDriver.Create`` API to create a volume with the
        given name.

        :param unicode name: The name of the volume to create.

        :return: ``Deferred`` that fires when the volume that was created.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Create",
                                 {u"Name": name}, OK, {u"Err": None})

    def test_create_creates(self):
        """
        ``/VolumeDriver.Create`` creates a new dataset in the configuration.
        """
        name = u"myvol"
        d = self.create(name)
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(self.assertItemsEqual, [
            Dataset(dataset_id=UUID(dataset_id_from_name(name)),
                    primary=self.NODE_A,
                    maximum_size=DEFAULT_SIZE,
                    metadata={u"name": name})])
        return d

    def test_create_duplicate_name(self):
        """
        If a dataset with the given name already exists,
        ``/VolumeDriver.Create`` succeeds without create a new volume.
        """
        name = u"thename"
        # Create a dataset out-of-band with matching name but non-matching
        # dataset ID:
        d = self.flocker_client.create_dataset(
            self.NODE_A, DEFAULT_SIZE, metadata={u"name": name})
        d.addCallback(lambda _: self.create(name))
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(lambda results: self.assertEqual(len(results), 1))
        return d

    def test_create_duplicate_name_race_condition(self):
        """
        If a dataset with the given name is created while the
        ``/VolumeDriver.Create`` call is in flight, the call does not
        result in an error.
        """
        name = u"thename"

        # Create a dataset out-of-band with matching dataset ID and name
        # which the docker plugin won't be able to see.
        def create_after_list():
            # Clean up the patched version:
            del self.flocker_client.list_datasets_configuration
            # But first time we're called, we create dataset and lie about
            # its existence:
            d = self.flocker_client.create_dataset(
                self.NODE_A, DEFAULT_SIZE,
                metadata={u"name": name},
                dataset_id=UUID(dataset_id_from_name(name)))
            d.addCallback(lambda _: [])
            return d
        self.flocker_client.list_datasets_configuration = create_after_list

        return self.create(name)

    def _flush_volume_plugin_reactor_on_endpoint_render(self):
        """
        This method patches ``self.app`` so that after any endpoint is
        rendered, the reactor used by the volume plugin is advanced repeatedly
        until there are no more ``delayedCalls`` pending on the reactor.
        """
        real_execute_endpoint = self.app.execute_endpoint

        def patched_execute_endpoint(*args, **kwargs):
            val = real_execute_endpoint(*args, **kwargs)
            while self.volume_plugin_reactor.getDelayedCalls():
                pending_calls = self.volume_plugin_reactor.getDelayedCalls()
                next_expiration = min(t.getTime() for t in pending_calls)
                now = self.volume_plugin_reactor.seconds()
                self.volume_plugin_reactor.advance(
                    max(0.0, next_expiration - now))
            return val
        self.patch(self.app, 'execute_endpoint', patched_execute_endpoint)

    def test_mount(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive.
        """
        name = u"myvol"
        dataset_id = UUID(dataset_id_from_name(name))
        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(
            self.NODE_B, DEFAULT_SIZE, metadata={u"name": name},
            dataset_id=dataset_id)

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 5 seconds for the dataset to get established on
        # Node A.
        self.volume_plugin_reactor.callLater(
            5.0, self.flocker_client.synchronize_state)

        d.addCallback(lambda _:
                      self.assertResult(
                          b"POST", b"/VolumeDriver.Mount",
                          {u"Name": name}, OK,
                          {u"Err": None,
                           u"Mountpoint": u"/flocker/{}".format(dataset_id)}))
        d.addCallback(lambda _: self.flocker_client.list_datasets_state())

        def final_assertions(datasets):
            self.assertEqual([self.NODE_A],
                             [d.primary for d in datasets
                              if d.dataset_id == dataset_id])
            # There should be less than 20 calls to list_datasets_state over
            # the course of 5 seconds.
            self.assertLess(
                self.flocker_client.num_calls('list_datasets_state'), 20)
        d.addCallback(final_assertions)

        return d

    def test_mount_timeout(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive. If it does not arrive within 120 seconds, then it
        returns an error up to docker.
        """
        name = u"myvol"
        dataset_id = UUID(dataset_id_from_name(name))
        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(
            self.NODE_B, DEFAULT_SIZE, metadata={u"name": name},
            dataset_id=dataset_id)

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 500 seconds for the dataset to get established
        # on Node A. This should be longer than the timeout.
        self.volume_plugin_reactor.callLater(
            500.0, self.flocker_client.synchronize_state)

        d.addCallback(lambda _:
                      self.assertResult(
                          b"POST", b"/VolumeDriver.Mount",
                          {u"Name": name}, OK,
                          {u"Err": u"Timed out waiting for dataset to mount.",
                           u"Mountpoint": u""}))
        return d

    def test_mount_already_exists(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive when used by the volumes that already exist and
        don't have a special dataset ID.
        """
        name = u"myvol"

        d = self.flocker_client.create_dataset(
            self.NODE_A, DEFAULT_SIZE, metadata={u"name": name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            result = self.assertResult(
                b"POST", b"/VolumeDriver.Mount",
                {u"Name": name}, OK,
                {u"Err": None,
                 u"Mountpoint": u"/flocker/{}".format(
                     dataset.dataset_id)})
            result.addCallback(lambda _:
                               self.flocker_client.list_datasets_state())
            result.addCallback(lambda ds: self.assertEqual(
                [self.NODE_A], [d.primary for d in ds
                                if d.dataset_id == dataset.dataset_id]))
            return result
        d.addCallback(created)
        return d

    def test_unknown_mount(self):
        """
        ``/VolumeDriver.Mount`` returns an error when asked to mount a
        non-existent volume.
        """
        name = u"myvol"
        return self.assertResult(
            b"POST", b"/VolumeDriver.Mount",
            {u"Name": name}, OK,
            {u"Err": u"Could not find volume with given name."})

    def test_path(self):
        """
        ``/VolumeDriver.Path`` returns the mount path of the given volume if
        it is currently known.
        """
        name = u"myvol"
        dataset_id = UUID(dataset_id_from_name(name))

        d = self.create(name)
        # The dataset arrives as state:
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        d.addCallback(lambda _: self.assertResponseCode(
            b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK))

        d.addCallback(lambda _:
                      self.assertResult(
                          b"POST", b"/VolumeDriver.Path",
                          {u"Name": name}, OK,
                          {u"Err": None,
                           u"Mountpoint": u"/flocker/{}".format(dataset_id)}))
        return d

    def test_path_existing(self):
        """
        ``/VolumeDriver.Path`` returns the mount path of the given volume if
        it is currently known, including for a dataset that was created
        not by the plugin.
        """
        name = u"myvol"

        d = self.flocker_client.create_dataset(
            self.NODE_A, DEFAULT_SIZE, metadata={u"name": name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            return self.assertResult(
                b"POST", b"/VolumeDriver.Path",
                {u"Name": name}, OK,
                {u"Err": None,
                 u"Mountpoint": u"/flocker/{}".format(dataset.dataset_id)})
        d.addCallback(created)
        return d

    def test_unknown_path(self):
        """
        ``/VolumeDriver.Path`` returns an error when asked for the mount path
        of a non-existent volume.
        """
        name = u"myvol"
        return self.assertResult(
            b"POST", b"/VolumeDriver.Path",
            {u"Name": name}, OK,
            {u"Err": u"Could not find volume with given name."})

    def test_non_local_path(self):
        """
        ``/VolumeDriver.Path`` returns an error when asked for the mount path
        of a volume that is not mounted locally.

        This can happen as a result of ``docker inspect`` on a container
        that has been created but is still waiting for its volume to
        arrive from another node. It seems like Docker may also call this
        after ``/VolumeDriver.Create``, so again while waiting for a
        volume to arrive.
        """
        name = u"myvol"
        dataset_id = UUID(dataset_id_from_name(name))

        # Create dataset on node B:
        d = self.flocker_client.create_dataset(
            self.NODE_B, DEFAULT_SIZE, metadata={u"name": name},
            dataset_id=dataset_id)
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        # Ask for path on node A:
        d.addCallback(lambda _:
                      self.assertResult(
                          b"POST", b"/VolumeDriver.Path",
                          {u"Name": name}, OK,
                          {u"Err": "Volume not available.",
                           u"Mountpoint": u""}))
        return d

    @capture_logging(lambda self, logger:
                     self.assertEqual(
                         len(logger.flushTracebacks(CustomException)), 1))
    def test_unexpected_error_reporting(self, logger):
        """
        If an unexpected error occurs Docker gets back a useful error message.
        """
        def error():
            raise CustomException("I've made a terrible mistake")
        self.patch(self.flocker_client, "list_datasets_configuration",
                   error)
        return self.assertResult(
            b"POST", b"/VolumeDriver.Path",
            {u"Name": u"whatever"}, OK,
            {u"Err": "CustomException: I've made a terrible mistake"})

    @capture_logging(None)
    def test_bad_request(self, logger):
        """
        If a ``BadRequest`` exception is raised it is converted to appropriate
        JSON.
        """
        def error():
            raise make_bad_request(code=423, Err=u"no good")
        self.patch(self.flocker_client, "list_datasets_configuration",
                   error)
        return self.assertResult(
            b"POST", b"/VolumeDriver.Path",
            {u"Name": u"whatever"}, 423,
            {u"Err": "no good"})
class MemCacheTestCase(CommandMixin, TestCase):
    """
    Test client protocol class L{MemCacheProtocol}.
    """

    def setUp(self):
        """
        Create a memcache client, connect it to a string protocol, and make it
        use a deterministic clock.
        """
        self.proto = MemCacheProtocol()
        self.clock = Clock()
        self.proto.callLater = self.clock.callLater
        self.transport = StringTransportWithDisconnection()
        self.transport.protocol = self.proto
        self.proto.makeConnection(self.transport)


    def _test(self, d, send, recv, result):
        """
        Implementation of C{_test} which checks that the command sends C{send}
        data, and that upon reception of C{recv} the result is C{result}.

        @param d: the resulting deferred from the memcache command.
        @type d: C{Deferred}

        @param send: the expected data to be sent.
        @type send: C{str}

        @param recv: the data to simulate as reception.
        @type recv: C{str}

        @param result: the expected result.
        @type result: C{any}
        """
        def cb(res):
            self.assertEquals(res, result)
        self.assertEquals(self.transport.value(), send)
        d.addCallback(cb)
        self.proto.dataReceived(recv)
        return d


    def test_invalidGetResponse(self):
        """
        If the value returned doesn't match the expected key of the current
        C{get} command, an error is raised in L{MemCacheProtocol.dataReceived}.
        """
        self.proto.get("foo")
        s = "spamegg"
        self.assertRaises(RuntimeError,
            self.proto.dataReceived,
            "VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))


    def test_invalidMultipleGetResponse(self):
        """
        If the value returned doesn't match one the expected keys of the
        current multiple C{get} command, an error is raised error in
        L{MemCacheProtocol.dataReceived}.
        """
        self.proto.getMultiple(["foo", "bar"])
        s = "spamegg"
        self.assertRaises(RuntimeError,
            self.proto.dataReceived,
            "VALUE egg 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))


    def test_timeOut(self):
        """
        Test the timeout on outgoing requests: when timeout is detected, all
        current commands fail with a L{TimeoutError}, and the connection is
        closed.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)
        def checkMessage(error):
            self.assertEquals(str(error), "Connection timeout")
        d1.addCallback(checkMessage)
        return gatherResults([d1, d2, d3])


    def test_timeoutRemoved(self):
        """
        When a request gets a response, no pending timeout call remains around.
        """
        d = self.proto.get("foo")

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEquals(result, (0, "bar"))
            self.assertEquals(len(self.clock.calls), 0)
        d.addCallback(check)
        return d


    def test_timeOutRaw(self):
        """
        Test the timeout when raw mode was started: the timeout is not reset
        until all the data has been received, so we can have a L{TimeoutError}
        when waiting for raw data.
        """
        d1 = self.proto.get("foo")
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived("VALUE foo 0 10\r\n12345")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        return gatherResults([d1, d2])


    def test_timeOutStat(self):
        """
        Test the timeout when stat command has started: the timeout is not
        reset until the final B{END} is received.
        """
        d1 = self.proto.stats()
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived("STAT foo bar\r\n")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        return gatherResults([d1, d2])


    def test_timeoutPipelining(self):
        """
        When two requests are sent, a timeout call remains around for the
        second request, and its timeout time is correct.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEquals(result, (0, "bar"))
            self.assertEquals(len(self.clock.calls), 1)
            for i in range(self.proto.persistentTimeOut):
                self.clock.advance(1)
            return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
        def checkTime(ignored):
            # Check that the timeout happened C{self.proto.persistentTimeOut}
            # after the last response
            self.assertEquals(
                self.clock.seconds(), 2 * self.proto.persistentTimeOut - 1)
        d1.addCallback(check)
        return d1


    def test_timeoutNotReset(self):
        """
        Check that timeout is not resetted for every command, but keep the
        timeout from the first command without response.
        """
        d1 = self.proto.get("foo")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        d2 = self.proto.get("bar")
        self.clock.advance(1)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)
        return gatherResults([d1, d2, d3])


    def test_timeoutCleanDeferreds(self):
        """
        C{timeoutConnection} cleans the list of commands that it fires with
        C{TimeoutError}: C{connectionLost} doesn't try to fire them again, but
        sets the disconnected state so that future commands fail with a
        C{RuntimeError}.
        """
        d1 = self.proto.get("foo")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        d2 = self.proto.get("bar")
        self.assertFailure(d2, RuntimeError)
        return gatherResults([d1, d2])


    def test_connectionLost(self):
        """
        When disconnection occurs while commands are still outstanding, the
        commands fail.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        self.transport.loseConnection()
        done = DeferredList([d1, d2], consumeErrors=True)
        def checkFailures(results):
            for success, result in results:
                self.assertFalse(success)
                result.trap(ConnectionDone)
        return done.addCallback(checkFailures)


    def test_tooLongKey(self):
        """
        An error is raised when trying to use a too long key: the called
        command returns a L{Deferred} which fails with a L{ClientError}.
        """
        d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError)
        d3 = self.assertFailure(self.proto.get("a" * 500), ClientError)
        d4 = self.assertFailure(
            self.proto.append("a" * 500, "bar"), ClientError)
        d5 = self.assertFailure(
            self.proto.prepend("a" * 500, "bar"), ClientError)
        d6 = self.assertFailure(
            self.proto.getMultiple(["foo", "a" * 500]), ClientError)
        return gatherResults([d1, d2, d3, d4, d5, d6])


    def test_invalidCommand(self):
        """
        When an unknown command is sent directly (not through public API), the
        server answers with an B{ERROR} token, and the command fails with
        L{NoSuchCommand}.
        """
        d = self.proto._set("egg", "foo", "bar", 0, 0, "")
        self.assertEquals(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n")
        self.assertFailure(d, NoSuchCommand)
        self.proto.dataReceived("ERROR\r\n")
        return d


    def test_clientError(self):
        """
        Test the L{ClientError} error: when the server sends a B{CLIENT_ERROR}
        token, the originating command fails with L{ClientError}, and the error
        contains the text sent by the server.
        """
        a = "eggspamm"
        d = self.proto.set("foo", a)
        self.assertEquals(self.transport.value(),
                          "set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ClientError)
        def check(err):
            self.assertEquals(str(err), "We don't like egg and spam")
        d.addCallback(check)
        self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n")
        return d


    def test_serverError(self):
        """
        Test the L{ServerError} error: when the server sends a B{SERVER_ERROR}
        token, the originating command fails with L{ServerError}, and the error
        contains the text sent by the server.
        """
        a = "eggspamm"
        d = self.proto.set("foo", a)
        self.assertEquals(self.transport.value(),
                          "set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ServerError)
        def check(err):
            self.assertEquals(str(err), "zomg")
        d.addCallback(check)
        self.proto.dataReceived("SERVER_ERROR zomg\r\n")
        return d


    def test_unicodeKey(self):
        """
        Using a non-string key as argument to commands raises an error.
        """
        d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError)
        d3 = self.assertFailure(self.proto.get(1), ClientError)
        d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError)
        d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError)
        d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError)
        d7 = self.assertFailure(
            self.proto.getMultiple(["egg", 1]), ClientError)
        return gatherResults([d1, d2, d3, d4, d5, d6, d7])


    def test_unicodeValue(self):
        """
        Using a non-string value raises an error.
        """
        return self.assertFailure(self.proto.set("foo", u"bar"), ClientError)


    def test_pipelining(self):
        """
        Multiple requests can be sent subsequently to the server, and the
        protocol orders the responses correctly and dispatch to the
        corresponding client command.
        """
        d1 = self.proto.get("foo")
        d1.addCallback(self.assertEquals, (0, "bar"))
        d2 = self.proto.set("bar", "spamspamspam")
        d2.addCallback(self.assertEquals, True)
        d3 = self.proto.get("egg")
        d3.addCallback(self.assertEquals, (0, "spam"))
        self.assertEquals(self.transport.value(),
            "get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n")
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n"
                                "STORED\r\n"
                                "VALUE egg 0 4\r\nspam\r\nEND\r\n")
        return gatherResults([d1, d2, d3])


    def test_getInChunks(self):
        """
        If the value retrieved by a C{get} arrive in chunks, the protocol
        is able to reconstruct it and to produce the good value.
        """
        d = self.proto.get("foo")
        d.addCallback(self.assertEquals, (0, "0123456789"))
        self.assertEquals(self.transport.value(), "get foo\r\n")
        self.proto.dataReceived("VALUE foo 0 10\r\n0123456")
        self.proto.dataReceived("789")
        self.proto.dataReceived("\r\nEND")
        self.proto.dataReceived("\r\n")
        return d


    def test_append(self):
        """
        L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
        method: it returns a L{Deferred} which is called back with C{True} when
        the operation succeeds.
        """
        return self._test(self.proto.append("foo", "bar"),
            "append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_prepend(self):
        """
        L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
        method: it returns a L{Deferred} which is called back with C{True} when
        the operation succeeds.
        """
        return self._test(self.proto.prepend("foo", "bar"),
            "prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_gets(self):
        """
        L{MemCacheProtocol.get} handles an additional cas result when
        C{withIdentifier} is C{True} and forward it in the resulting
        L{Deferred}.
        """
        return self._test(self.proto.get("foo", True), "gets foo\r\n",
            "VALUE foo 0 3 1234\r\nbar\r\nEND\r\n", (0, "1234", "bar"))


    def test_emptyGets(self):
        """
        Test getting a non-available key with gets: it succeeds but return
        C{None} as value, C{0} as flag and an empty cas value.
        """
        return self._test(self.proto.get("foo", True), "gets foo\r\n",
            "END\r\n", (0, "", None))


    def test_getsMultiple(self):
        """
        L{MemCacheProtocol.getMultiple} handles an additional cas field in the
        returned tuples if C{withIdentifier} is C{True}.
        """
        return self._test(self.proto.getMultiple(["foo", "bar"], True),
            "gets foo bar\r\n",
            "VALUE foo 0 3 1234\r\negg\r\nVALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
            {'bar': (0, '2345', 'spam'), 'foo': (0, '1234', 'egg')})


    def test_getsMultipleWithEmpty(self):
        """
        When getting a non-available key with L{MemCacheProtocol.getMultiple}
        when C{withIdentifier} is C{True}, the other keys are retrieved
        correctly, and the non-available key gets a tuple of C{0} as flag,
        C{None} as value, and an empty cas value.
        """
        return self._test(self.proto.getMultiple(["foo", "bar"], True),
            "gets foo bar\r\n",
            "VALUE foo 0 3 1234\r\negg\r\nEND\r\n",
            {'bar': (0, '', None), 'foo': (0, '1234', 'egg')})


    def test_checkAndSet(self):
        """
        L{MemCacheProtocol.checkAndSet} passes an additional cas identifier
        that the server handles to check if the data has to be updated.
        """
        return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
            "cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True)


    def test_casUnknowKey(self):
        """
        When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the
        resulting L{Deferred} fires with C{False}.
        """
        return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
            "cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False)
Exemple #17
0
class ClusterStateServiceTests(TestCase):
    """
    Tests for ``ClusterStateService``.
    """
    WITH_APPS = NodeState(
        hostname=u"192.0.2.56", uuid=uuid4(),
        applications=[APP1, APP2],
    )
    WITH_MANIFESTATION = NodeState(
        hostname=u"host2",
        manifestations={MANIFESTATION.dataset_id: MANIFESTATION},
        devices={}, paths={},
    )

    def setUp(self):
        super(ClusterStateServiceTests, self).setUp()
        self.clock = Clock()

    def service(self):
        service = ClusterStateService(self.clock)
        service.startService()
        self.addCleanup(service.stopService)
        return service

    def test_applications(self):
        """
        ``ClusterStateService.as_deployment`` copies applications from the
        given node state.
        """
        service = self.service()
        service.apply_changes([self.WITH_APPS])
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes=[self.WITH_APPS])
        )

    def test_other_manifestations(self):
        """
        ``ClusterStateService.as_deployment`` copies over other manifestations
        to the ``Node`` instances it creates.
        """
        service = self.service()
        service.apply_changes([self.WITH_MANIFESTATION])
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes={self.WITH_MANIFESTATION})
        )

    def test_partial_update(self):
        """
        An update that is ignorant about certain parts of a node's state only
        updates the information it knows about.
        """
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", applications=[APP1]),
            NodeState(hostname=u"host1", applications=None,
                      manifestations={
                          MANIFESTATION.dataset_id:
                          MANIFESTATION},
                      devices={}, paths={})
        ])
        self.assertEqual(service.as_deployment(),
                         DeploymentState(nodes=[NodeState(
                             hostname=u"host1",
                             manifestations={
                                 MANIFESTATION.dataset_id: MANIFESTATION},
                             devices={}, paths={},
                             applications=[APP1])]))

    def test_update(self):
        """
        An update for previously given hostname overrides the previous state
        of that hostname.
        """
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", applications=[APP1]),
            NodeState(hostname=u"host1", applications=[APP2]),
        ])
        self.assertEqual(service.as_deployment(),
                         DeploymentState(nodes=[NodeState(
                             hostname=u"host1",
                             applications=frozenset([APP2]))]))

    def test_multiple_hosts(self):
        """
        The information from multiple hosts is combined by
        ``ClusterStateService.as_deployment``.
        """
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", applications=[APP1]),
            NodeState(hostname=u"host2", applications=[APP2]),
        ])
        self.assertEqual(service.as_deployment(),
                         DeploymentState(nodes=[
                             NodeState(
                                 hostname=u"host1",
                                 applications=frozenset([APP1])),
                             NodeState(
                                 hostname=u"host2",
                                 applications=frozenset([APP2])),
                         ]))

    def test_manifestation_path(self):
        """
        ``manifestation_path`` returns the path on the filesystem where the
        given dataset exists.
        """
        identifier = uuid4()
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", uuid=identifier,
                      manifestations={
                          MANIFESTATION.dataset_id:
                          MANIFESTATION},
                      paths={MANIFESTATION.dataset_id:
                             FilePath(b"/xxx/yyy")},
                      devices={})
        ])
        self.assertEqual(
            service.manifestation_path(identifier, MANIFESTATION.dataset_id),
            FilePath(b"/xxx/yyy"))

    def test_expiration(self):
        """
        Information updates that are more than the hard-coded expiration period
        (in seconds) old are wiped.
        """
        service = self.service()
        service.apply_changes([self.WITH_APPS])
        advance_rest(self.clock)
        before_wipe_state = service.as_deployment()
        advance_some(self.clock)
        after_wipe_state = service.as_deployment()
        self.assertEqual(
            [before_wipe_state, after_wipe_state],
            [DeploymentState(nodes=[self.WITH_APPS]), DeploymentState()],
        )

    def test_expiration_from_inactivity(self):
        """
        Information updates from a source with no activity for more than the
        hard-coded expiration period are wiped.
        """
        service = self.service()
        source = ChangeSource()

        # Apply some changes at T1
        source.set_last_activity(self.clock.seconds())
        service.apply_changes_from_source(source, [self.WITH_APPS])

        # A little bit of time passes (T2) and there is some activity.
        advance_some(self.clock)
        source.set_last_activity(self.clock.seconds())

        # Enough more time passes (T3) to reach EXPIRATION_TIME from T1
        advance_rest(self.clock)
        before_wipe_state = service.as_deployment()

        # Enough more time passes (T4) to reach EXPIRATION_TIME from T2
        advance_some(self.clock)
        after_wipe_state = service.as_deployment()

        # The activity at T2 prevents the state from being wiped at T3 but then
        # it is wiped at T4.
        self.assertEqual(
            [before_wipe_state, after_wipe_state],
            [DeploymentState(nodes=[self.WITH_APPS]), DeploymentState()],
        )

    def test_updates_different_key(self):
        """
        A wipe created by a ``IClusterStateChange`` with a given wipe key is
        not overwritten by a later ``IClusterStateChange`` with a different
        key.
        """
        service = self.service()
        app_node = self.WITH_APPS
        app_node_2 = NodeState(hostname=app_node.hostname, uuid=app_node.uuid,
                               manifestations={
                                   MANIFESTATION.dataset_id: MANIFESTATION},
                               devices={}, paths={})

        # Some changes are applied at T1
        service.apply_changes([app_node])

        # A little time passes (T2) and some unrelated changes are applied.
        advance_some(self.clock)
        service.apply_changes([app_node_2])

        # Enough additional time passes (T3) to reach EXPIRATION_TIME from T1
        advance_rest(self.clock)
        before_wipe_state = service.as_deployment()

        # Enough additional time passes (T4) to reach EXPIRATION_TIME from T2
        advance_some(self.clock)
        after_wipe_state = service.as_deployment()

        # The state applied at T1 is wiped at T3
        # Separately, the state applied at T2 is wiped at T4
        self.assertEqual(
            [before_wipe_state, after_wipe_state],
            [DeploymentState(nodes=[app_node_2]), DeploymentState()],
        )

    def test_update_with_same_key(self):
        """
        An update with the same key as a previous one delays wiping.
        """
        service = self.service()
        # Some changes are applied at T1
        service.apply_changes([self.WITH_APPS])

        # Some time passes (T2) and the same changes are re-applied
        advance_some(self.clock)
        service.apply_changes([self.WITH_APPS])

        # Enough time passes (T3) to reach EXPIRATION_TIME from T1 but not T2
        advance_rest(self.clock)

        # The state applied at T1 and refreshed at T2 is not wiped at T3.
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes=[self.WITH_APPS]),
        )
Exemple #18
0
class TestCase(unittest.TestCase):
    def setUp(self):
        self.tmpdirs = []
        self.clock = Clock()
        self.clock.advance(time.time())
        self.log = logger.new()

    def tearDown(self):
        self.clean_tmpdirs()

    def _create_test_wallet(self):
        """ Generate a Wallet with a number of keypairs for testing
            :rtype: Wallet
        """
        tmpdir = tempfile.mkdtemp()
        self.tmpdirs.append(tmpdir)

        wallet = Wallet(directory=tmpdir)
        wallet.unlock(b'MYPASS')
        wallet.generate_keys(count=20)
        wallet.lock()
        return wallet

    def create_peer(self,
                    network,
                    peer_id=None,
                    wallet=None,
                    tx_storage=None,
                    unlock_wallet=True,
                    wallet_index=False,
                    capabilities=None):
        if peer_id is None:
            peer_id = PeerId()
        if not wallet:
            wallet = self._create_test_wallet()
            if unlock_wallet:
                wallet.unlock(b'MYPASS')
        manager = HathorManager(
            self.clock,
            peer_id=peer_id,
            network=network,
            wallet=wallet,
            tx_storage=tx_storage,
            wallet_index=wallet_index,
            capabilities=capabilities,
        )
        manager.avg_time_between_blocks = 0.0001
        manager.test_mode = TestMode.TEST_ALL_WEIGHT
        manager._full_verification = True
        manager.start()
        self.run_to_completion()
        return manager

    def run_to_completion(self):
        """ This will advance the test's clock until all calls scheduled are done.
        """
        for call in self.clock.getDelayedCalls():
            amount = call.getTime() - self.clock.seconds()
            self.clock.advance(amount)

    def set_random_seed(self, seed=None):
        if seed is None:
            seed = numpy.random.randint(2**32)
        self.random_seed = seed
        random.seed(self.random_seed)
        numpy.random.seed(self.random_seed)

    def assertTipsEqual(self, manager1, manager2):
        s1 = set(manager1.tx_storage.get_all_tips())
        s2 = set(manager2.tx_storage.get_all_tips())
        self.assertEqual(s1, s2)

        s1 = set(manager1.tx_storage.get_tx_tips())
        s2 = set(manager2.tx_storage.get_tx_tips())
        self.assertEqual(s1, s2)

    def assertTipsNotEqual(self, manager1, manager2):
        s1 = set(manager1.tx_storage.get_all_tips())
        s2 = set(manager2.tx_storage.get_all_tips())
        self.assertNotEqual(s1, s2)

    def assertConsensusEqual(self, manager1, manager2):
        self.assertEqual(manager1.tx_storage.get_count_tx_blocks(),
                         manager2.tx_storage.get_count_tx_blocks())
        for tx1 in manager1.tx_storage.get_all_transactions():
            tx2 = manager2.tx_storage.get_transaction(tx1.hash)
            tx1_meta = tx1.get_metadata()
            tx2_meta = tx2.get_metadata()
            self.assertEqual(tx1_meta.conflict_with, tx2_meta.conflict_with)
            # Soft verification
            if tx1_meta.voided_by is None:
                # If tx1 is not voided, then tx2 must be not voided.
                self.assertIsNone(tx2_meta.voided_by)
            else:
                # If tx1 is voided, then tx2 must be voided.
                self.assertGreaterEqual(len(tx1_meta.voided_by), 1)
                self.assertGreaterEqual(len(tx2_meta.voided_by), 1)
            # Hard verification
            # self.assertEqual(tx1_meta.voided_by, tx2_meta.voided_by)

    def assertConsensusValid(self, manager):
        for tx in manager.tx_storage.get_all_transactions():
            if tx.is_block:
                self.assertBlockConsensusValid(tx)
            else:
                self.assertTransactionConsensusValid(tx)

    def assertBlockConsensusValid(self, block):
        self.assertTrue(block.is_block)
        if not block.parents:
            # Genesis
            return
        meta = block.get_metadata()
        if meta.voided_by is None:
            parent = block.get_block_parent()
            parent_meta = parent.get_metadata()
            self.assertIsNone(parent_meta.voided_by)

    def assertTransactionConsensusValid(self, tx):
        self.assertFalse(tx.is_block)
        meta = tx.get_metadata()
        if meta.voided_by and tx.hash in meta.voided_by:
            # If a transaction voids itself, then it must have at
            # least one conflict.
            self.assertTrue(meta.conflict_with)

        for txin in tx.inputs:
            spent_tx = tx.get_spent_tx(txin)
            spent_meta = spent_tx.get_metadata()

            if spent_meta.voided_by is not None:
                self.assertIsNotNone(meta.voided_by)
                self.assertTrue(spent_meta.voided_by.issubset(meta.voided_by))

        for parent in tx.get_parents():
            parent_meta = parent.get_metadata()
            if parent_meta.voided_by is not None:
                self.assertIsNotNone(meta.voided_by)
                self.assertTrue(parent_meta.voided_by.issubset(meta.voided_by))

    def clean_tmpdirs(self):
        for tmpdir in self.tmpdirs:
            shutil.rmtree(tmpdir)

    def clean_pending(self, required_to_quiesce=True):
        """
        This handy method cleans all pending tasks from the reactor.

        When writing a unit test, consider the following question:

            Is the code that you are testing required to release control once it
            has done its job, so that it is impossible for it to later come around
            (with a delayed reactor task) and do anything further?

        If so, then trial will usefully test that for you -- if the code under
        test leaves any pending tasks on the reactor then trial will fail it.

        On the other hand, some code is *not* required to release control -- some
        code is allowed to continuously maintain control by rescheduling reactor
        tasks in order to do ongoing work.  Trial will incorrectly require that
        code to clean up all its tasks from the reactor.

        Most people think that such code should be amended to have an optional
        "shutdown" operation that releases all control, but on the contrary it is
        good design for some code to *not* have a shutdown operation, but instead
        to have a "crash-only" design in which it recovers from crash on startup.

        If the code under test is of the "long-running" kind, which is *not*
        required to shutdown cleanly in order to pass tests, then you can simply
        call testutil.clean_pending() at the end of the unit test, and trial will
        be satisfied.

        Copy from: https://github.com/zooko/pyutil/blob/master/pyutil/testutil.py#L68
        """
        pending = reactor.getDelayedCalls()
        active = bool(pending)
        for p in pending:
            if p.active():
                p.cancel()
            else:
                print('WEIRDNESS! pending timed call not active!')
        if required_to_quiesce and active:
            self.fail(
                'Reactor was still active when it was required to be quiescent.'
            )

    def get_address(self, index: int) -> Optional[str]:
        """ Generate a fixed HD Wallet and return an address
        """
        from hathor.wallet import HDWallet
        words = (
            'bind daring above film health blush during tiny neck slight clown salmon '
            'wine brown good setup later omit jaguar tourist rescue flip pet salute'
        )

        hd = HDWallet(words=words)
        hd._manually_initialize()

        if index >= hd.gap_limit:
            return None

        return list(hd.keys.keys())[index]
Exemple #19
0
class EINVALTestCase(TestCase):
    """
    Sometimes, L{os.listdir} will raise C{EINVAL}.  This is a transient error,
    and L{CachingFilePath.listdir} should work around it by retrying the
    C{listdir} operation until it succeeds.
    """

    def setUp(self):
        """
        Create a L{CachingFilePath} for the test to use.
        """
        self.cfp = CachingFilePath(self.mktemp())
        self.clock = Clock()
        self.cfp._sleep = self.clock.advance

    def test_testValidity(self):
        """
        If C{listdir} is replaced on a L{CachingFilePath}, we should be able to
        observe exceptions raised by the replacement.  This verifies that the
        test patching done here is actually testing something.
        """
        class CustomException(Exception):
            "Just for testing."
        def blowUp(dirname):
            raise CustomException()
        self.cfp._listdir = blowUp
        self.assertRaises(CustomException, self.cfp.listdir)
        self.assertRaises(CustomException, self.cfp.children)

    def test_retryLoop(self):
        """
        L{CachingFilePath} should catch C{EINVAL} and respond by retrying the
        C{listdir} operation until it succeeds.
        """
        calls = []

        def raiseEINVAL(dirname):
            calls.append(dirname)
            if len(calls) < 5:
                raise OSError(EINVAL, "This should be caught by the test.")
            return ['a', 'b', 'c']
        self.cfp._listdir = raiseEINVAL
        self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c'])
        self.assertEquals(self.cfp.children(), [
            CachingFilePath(pathjoin(self.cfp.path, 'a')),
            CachingFilePath(pathjoin(self.cfp.path, 'b')),
            CachingFilePath(pathjoin(self.cfp.path, 'c')),
        ])

    def requireTimePassed(self, filenames):
        """
        Create a replacement for listdir() which only fires after a certain
        amount of time.
        """
        self.calls = []

        def thunk(dirname):
            now = self.clock.seconds()
            if now < 20.0:
                self.calls.append(now)
                raise OSError(EINVAL, "Not enough time has passed yet.")
            else:
                return filenames
        self.cfp._listdir = thunk

    def assertRequiredTimePassed(self):
        """
        Assert that calls to the simulated time.sleep() installed by
        C{requireTimePassed} have been invoked the required number of times.
        """
        # Waiting should be growing by *2 each time until the additional wait
        # exceeds BACKOFF_MAX (5), at which point we should wait for 5s each
        # time.
        def cumulative(values):
            current = 0.0
            for value in values:
                current += value
                yield current

        self.assertEquals(
            self.calls,
            list(cumulative(
                [0.0, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 5.0, 5.0]
            ))
        )

    def test_backoff(self):
        """
        L{CachingFilePath} will wait for an increasing interval up to
        C{BACKOFF_MAX} between calls to listdir().
        """
        self.requireTimePassed(['a', 'b', 'c'])
        self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c'])
Exemple #20
0
class TestCase(unittest.TestCase):
    _enable_sync_v1: bool
    _enable_sync_v2: bool
    use_memory_storage: bool = USE_MEMORY_STORAGE

    def setUp(self):
        _set_test_mode(TestMode.TEST_ALL_WEIGHT)
        self.tmpdirs = []
        self.clock = Clock()
        self.clock.advance(time.time())
        self.log = logger.new()
        self.reset_peer_id_pool()
        self.rng = Random()

    def tearDown(self):
        self.clean_tmpdirs()

    def reset_peer_id_pool(self) -> None:
        self._free_peer_id_pool = self.new_peer_id_pool()

    def new_peer_id_pool(self) -> List[PeerId]:
        return PEER_ID_POOL.copy()

    def get_random_peer_id_from_pool(self,
                                     pool: Optional[List[PeerId]] = None,
                                     rng: Optional[Random] = None) -> PeerId:
        if pool is None:
            pool = self._free_peer_id_pool
        if not pool:
            raise RuntimeError('no more peer ids on the pool')
        if rng is None:
            rng = self.rng
        peer_id = self.rng.choice(pool)
        pool.remove(peer_id)
        return peer_id

    def _create_test_wallet(self):
        """ Generate a Wallet with a number of keypairs for testing
            :rtype: Wallet
        """
        tmpdir = tempfile.mkdtemp()
        self.tmpdirs.append(tmpdir)

        wallet = Wallet(directory=tmpdir)
        wallet.unlock(b'MYPASS')
        wallet.generate_keys(count=20)
        wallet.lock()
        return wallet

    def create_peer(self,
                    network,
                    peer_id=None,
                    wallet=None,
                    tx_storage=None,
                    unlock_wallet=True,
                    wallet_index=False,
                    capabilities=None,
                    full_verification=True,
                    enable_sync_v1=None,
                    enable_sync_v2=None,
                    checkpoints=None):
        if enable_sync_v1 is None:
            assert hasattr(self, '_enable_sync_v1'), (
                '`_enable_sync_v1` has no default by design, either set one on '
                'the test class or pass `enable_sync_v1` by argument')
            enable_sync_v1 = self._enable_sync_v1
        if enable_sync_v2 is None:
            assert hasattr(self, '_enable_sync_v2'), (
                '`_enable_sync_v2` has no default by design, either set one on '
                'the test class or pass `enable_sync_v2` by argument')
            enable_sync_v2 = self._enable_sync_v2
        assert enable_sync_v1 or enable_sync_v2, 'enable at least one sync version'

        if peer_id is None:
            peer_id = PeerId()
        if not wallet:
            wallet = self._create_test_wallet()
            if unlock_wallet:
                wallet.unlock(b'MYPASS')
        if tx_storage is None:
            if self.use_memory_storage:
                from hathor.transaction.storage.memory_storage import TransactionMemoryStorage
                tx_storage = TransactionMemoryStorage()
            else:
                from hathor.transaction.storage.rocksdb_storage import TransactionRocksDBStorage
                directory = tempfile.mkdtemp()
                self.tmpdirs.append(directory)
                tx_storage = TransactionRocksDBStorage(directory)
        manager = HathorManager(
            self.clock,
            peer_id=peer_id,
            network=network,
            wallet=wallet,
            tx_storage=tx_storage,
            wallet_index=wallet_index,
            capabilities=capabilities,
            rng=self.rng,
            enable_sync_v1=enable_sync_v1,
            enable_sync_v2=enable_sync_v2,
            checkpoints=checkpoints,
        )

        # XXX: just making sure that tests set this up correctly
        if enable_sync_v2:
            assert SyncVersion.V2 in manager.connections._sync_factories
        else:
            assert SyncVersion.V2 not in manager.connections._sync_factories
        if enable_sync_v1:
            assert SyncVersion.V1 in manager.connections._sync_factories
        else:
            assert SyncVersion.V1 not in manager.connections._sync_factories

        manager.avg_time_between_blocks = 0.0001
        manager._full_verification = full_verification
        manager.start()
        self.run_to_completion()
        return manager

    def run_to_completion(self):
        """ This will advance the test's clock until all calls scheduled are done.
        """
        for call in self.clock.getDelayedCalls():
            amount = call.getTime() - self.clock.seconds()
            self.clock.advance(amount)

    def assertTipsEqual(self, manager1, manager2):
        s1 = set(manager1.tx_storage.get_all_tips())
        s2 = set(manager2.tx_storage.get_all_tips())
        self.assertEqual(s1, s2)

        s1 = set(manager1.tx_storage.get_tx_tips())
        s2 = set(manager2.tx_storage.get_tx_tips())
        self.assertEqual(s1, s2)

    def assertTipsNotEqual(self, manager1, manager2):
        s1 = set(manager1.tx_storage.get_all_tips())
        s2 = set(manager2.tx_storage.get_all_tips())
        self.assertNotEqual(s1, s2)

    def assertConsensusEqual(self, manager1, manager2):
        self.assertEqual(manager1.tx_storage.get_count_tx_blocks(),
                         manager2.tx_storage.get_count_tx_blocks())
        for tx1 in manager1.tx_storage.get_all_transactions():
            tx2 = manager2.tx_storage.get_transaction(tx1.hash)
            tx1_meta = tx1.get_metadata()
            tx2_meta = tx2.get_metadata()
            # conflict_with's type is Optional[List[bytes]], so we convert to a set because order does not matter.
            self.assertEqual(set(tx1_meta.conflict_with or []),
                             set(tx2_meta.conflict_with or []))
            # Soft verification
            if tx1_meta.voided_by is None:
                # If tx1 is not voided, then tx2 must be not voided.
                self.assertIsNone(tx2_meta.voided_by)
            else:
                # If tx1 is voided, then tx2 must be voided.
                self.assertGreaterEqual(len(tx1_meta.voided_by), 1)
                self.assertGreaterEqual(len(tx2_meta.voided_by), 1)
            # Hard verification
            # self.assertEqual(tx1_meta.voided_by, tx2_meta.voided_by)

    def assertConsensusValid(self, manager):
        for tx in manager.tx_storage.get_all_transactions():
            if tx.is_block:
                self.assertBlockConsensusValid(tx)
            else:
                self.assertTransactionConsensusValid(tx)

    def assertBlockConsensusValid(self, block):
        self.assertTrue(block.is_block)
        if not block.parents:
            # Genesis
            return
        meta = block.get_metadata()
        if meta.voided_by is None:
            parent = block.get_block_parent()
            parent_meta = parent.get_metadata()
            self.assertIsNone(parent_meta.voided_by)

    def assertTransactionConsensusValid(self, tx):
        self.assertFalse(tx.is_block)
        meta = tx.get_metadata()
        if meta.voided_by and tx.hash in meta.voided_by:
            # If a transaction voids itself, then it must have at
            # least one conflict.
            self.assertTrue(meta.conflict_with)

        is_tx_executed = bool(not meta.voided_by)
        for h in meta.conflict_with or []:
            tx2 = tx.storage.get_transaction(h)
            meta2 = tx2.get_metadata()
            is_tx2_executed = bool(not meta2.voided_by)
            self.assertFalse(is_tx_executed and is_tx2_executed)

        for txin in tx.inputs:
            spent_tx = tx.get_spent_tx(txin)
            spent_meta = spent_tx.get_metadata()

            if spent_meta.voided_by is not None:
                self.assertIsNotNone(meta.voided_by)
                self.assertTrue(spent_meta.voided_by)
                self.assertTrue(meta.voided_by)
                self.assertTrue(spent_meta.voided_by.issubset(meta.voided_by))

        for parent in tx.get_parents():
            parent_meta = parent.get_metadata()
            if parent_meta.voided_by is not None:
                self.assertIsNotNone(meta.voided_by)
                self.assertTrue(parent_meta.voided_by)
                self.assertTrue(meta.voided_by)
                self.assertTrue(parent_meta.voided_by.issubset(meta.voided_by))

    def clean_tmpdirs(self):
        for tmpdir in self.tmpdirs:
            shutil.rmtree(tmpdir)

    def clean_pending(self, required_to_quiesce=True):
        """
        This handy method cleans all pending tasks from the reactor.

        When writing a unit test, consider the following question:

            Is the code that you are testing required to release control once it
            has done its job, so that it is impossible for it to later come around
            (with a delayed reactor task) and do anything further?

        If so, then trial will usefully test that for you -- if the code under
        test leaves any pending tasks on the reactor then trial will fail it.

        On the other hand, some code is *not* required to release control -- some
        code is allowed to continuously maintain control by rescheduling reactor
        tasks in order to do ongoing work.  Trial will incorrectly require that
        code to clean up all its tasks from the reactor.

        Most people think that such code should be amended to have an optional
        "shutdown" operation that releases all control, but on the contrary it is
        good design for some code to *not* have a shutdown operation, but instead
        to have a "crash-only" design in which it recovers from crash on startup.

        If the code under test is of the "long-running" kind, which is *not*
        required to shutdown cleanly in order to pass tests, then you can simply
        call testutil.clean_pending() at the end of the unit test, and trial will
        be satisfied.

        Copy from: https://github.com/zooko/pyutil/blob/master/pyutil/testutil.py#L68
        """
        pending = reactor.getDelayedCalls()
        active = bool(pending)
        for p in pending:
            if p.active():
                p.cancel()
            else:
                print('WEIRDNESS! pending timed call not active!')
        if required_to_quiesce and active:
            self.fail(
                'Reactor was still active when it was required to be quiescent.'
            )

    def get_address(self, index: int) -> Optional[str]:
        """ Generate a fixed HD Wallet and return an address
        """
        from hathor.wallet import HDWallet
        words = (
            'bind daring above film health blush during tiny neck slight clown salmon '
            'wine brown good setup later omit jaguar tourist rescue flip pet salute'
        )

        hd = HDWallet(words=words)
        hd._manually_initialize()

        if index >= hd.gap_limit:
            return None

        return list(hd.keys.keys())[index]
Exemple #21
0
 def patch_message_rate_clock(self):
     '''Patches the message rate clock, and returns the clock'''
     clock = Clock()
     self.patch(MessageRateStore, 'get_seconds', lambda _: clock.seconds())
     return clock
class TransportTestCase(object):
    """PT client and server connect over a string transport.

    We bypass the communication between client and server and intercept the
    messages sent over the string transport.
    """

    def setUp(self):
        """Set the reactor's callLater to our clock's callLater function
        and build the protocols.
        """
        self.clock = Clock()
        reactor.callLater = self.clock.callLater
        self.dump = []
        self.proto_client = self._build_protocol(const.CLIENT)
        self.proto_server = self._build_protocol(const.SERVER)
        self.pt_client = self.proto_client.circuit.transport
        self.pt_server = self.proto_server.circuit.transport
        self._proxy(self.proto_client, self.proto_server)
        self._bypass_connection(self.proto_client, self.proto_server)

    def _proxy(self, client, server):
        """Proxy the communication between client and server and dump
        intercepted data into a dictionary.
        """
        def decorate_intercept(end):
            old_rcv_f = end.circuit.transport.receivedDownstream
            old_snd_f = end.circuit.transport.sendDownstream
            def intercept(old_f, direction):
                def new_f(data):
                    msgs = old_f(data)
                    end.history[direction].append((self.clock.seconds(), msgs))
                return new_f
            end.circuit.transport.receivedDownstream = intercept(old_rcv_f, 'rcv')
            end.circuit.transport.sendDownstream = intercept(old_snd_f, 'snd')
        decorate_intercept(client)
        decorate_intercept(server)

    def _bypass_connection(self, client, server):
        """Instead of requiring TCP connections between client and server
        transports, we directly pass the data written from one end to the
        received function at the other.
        """
        def curry_bypass_connection(up, down, direction):
            old_write = up.circuit.downstream.write
            def write(data):
                old_write(data)
                down.dataReceived(data)
                self.dump.append((self.clock.seconds(), direction * len(data)))
            return write
        client.circuit.downstream.write = curry_bypass_connection(client, server, const.OUT)
        server.circuit.downstream.write = curry_bypass_connection(server, client, const.IN)

    def _build_protocol(self, mode):
        """Build client and server protocols for an end point."""
        addr_tuple = (HOST, str(PORT))
        address = IPv4Address('TCP', HOST, PORT)
        pt_config = self._build_transport_configuration(mode)
        transport_class = self._configure_transport_class(mode, pt_config)
        f_server = net.StaticDestinationServerFactory(addr_tuple, mode, transport_class, pt_config)
        protocol_server = self._set_protocol(f_server, address)
        f_client = net.StaticDestinationClientFactory(protocol_server.circuit, const.CLIENT)
        protocol_client = self._set_protocol(f_client, address)
        if mode == const.CLIENT:
            return protocol_client
        elif mode == const.SERVER:
            return protocol_server
        else:
            raise ValueError("Transport mode '%s' not recognized." % mode)

    def _set_protocol(self, factory, address):
        """Make protocol connection with a Twisted string transport."""
        protocol = factory.buildProtocol(address)
        protocol.makeConnection(proto_helpers.StringTransport())
        protocol.history = {'rcv': [], 'snd': []}
        return protocol

    def _build_transport_configuration(self, mode):
        """Configure transport as a managed transport."""
        pt_config = transport_config.TransportConfig()
        pt_config.setStateLocation(const.TEMP_DIR)
        pt_config.setObfsproxyMode("managed")
        pt_config.setListenerMode(mode)
        return pt_config

    def _configure_transport_class(self, mode, pt_config):
        """Use the global arguments to configure the trasnport."""
        transport_args = [mode, ADDR, "--dest=%s" % ADDR] + self.args
        sys.argv = [sys.argv[0],
                "--log-file", join(const.TEMP_DIR, "%s.log" % mode),
                "--log-min-severity", "debug"]
        sys.argv.append("wfpad")  # use wfpad transport
        sys.argv += transport_args
        parser = set_up_cli_parsing()
        consider_cli_args(parser.parse_args())
        transport_class = get_transport_class(self.transport, mode)
        transport_class.setup(pt_config)
        p = ArgumentParser()
        transport_class.register_external_mode_cli(p)
        args = p.parse_args(transport_args)
        transport_class.validate_external_mode_cli(args)
        return transport_class

    def _lose_protocol_connection(self, protocol):
        """Disconnect client and server transports."""
        protocol.circuit.upstream.transport.loseConnection()
        protocol.circuit.downstream.transport.loseConnection()

    def advance_next_delayed_call(self):
        """Advance clock to first delayed call in reactor."""
        first_delayed_call = self.clock.getDelayedCalls()[0]
        self.clock.advance(first_delayed_call.getTime() - self.clock.seconds())

    def is_timeout(self, call):
        """Check if the call has actually timed out."""
        return isinstance(call.args[0], TimeoutError)

    def advance_delayed_calls(self, max_dcalls=NUM_DCALLS, no_timeout=True):
        """Advance clock to the point all delayed calls up to that moment have
        been called.
        """
        i, timeouts = 0, []
        while len(self.clock.getDelayedCalls()) > 0 and i < max_dcalls:
            i += 1
            dcall = self.clock.getDelayedCalls()[0]
            if no_timeout:
                if len(dcall.args) > 0 and self.is_timeout(dcall):
                    if dcall in timeouts:
                        break
                    self._queue_first_call()
                    timeouts.append(dcall)
                    continue
            self.advance_next_delayed_call()

    def _queue_first_call(self, delay=10000.0):
        """Put the first delayed call to the last position."""
        timeout = self.clock.calls.pop(0)
        timeout.time = delay
        self.clock.calls.append(timeout)

    def tearDown(self):
        """Close connections and advance all delayed calls."""
        # Need to wait a bit beacuse obfsproxy network.Circuit.circuitCompleted
        # defers 0.02s a dummy call to dataReceived to flush connection.
        self._lose_protocol_connection(self.proto_client)
        self._lose_protocol_connection(self.proto_server)
        self.advance_delayed_calls()
Exemple #23
0
class APITestsMixin(APIAssertionsMixin):
    """
    Helpers for writing tests for the Docker Volume Plugin API.
    """
    NODE_A = uuid4()
    NODE_B = uuid4()

    def initialize(self):
        """
        Create initial objects for the ``VolumePlugin``.
        """
        self.volume_plugin_reactor = Clock()
        self.flocker_client = SimpleCountingProxy(FakeFlockerClient())
        # The conditional_create operation used by the plugin relies on
        # the passage of time... so make sure time passes! We still use a
        # fake clock since some tests want to skip ahead.
        self.looping = LoopingCall(
            lambda: self.volume_plugin_reactor.advance(0.001))
        self.looping.start(0.001)
        self.addCleanup(self.looping.stop)

    def test_pluginactivate(self):
        """
        ``/Plugins.Activate`` indicates the plugin is a volume driver.
        """
        # Docker 1.8, at least, sends "null" as the body. Our test
        # infrastructure has the opposite bug so just going to send some
        # other garbage as the body (12345) to demonstrate that it's
        # ignored as per the spec which declares no body.
        return self.assertResult(b"POST", b"/Plugin.Activate", 12345, OK,
                                 {u"Implements": [u"VolumeDriver"]})

    def test_remove(self):
        """
        ``/VolumeDriver.Remove`` returns a successful result.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Remove",
                                 {u"Name": u"vol"}, OK, {u"Err": None})

    def test_unmount(self):
        """
        ``/VolumeDriver.Unmount`` returns a successful result.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Unmount",
                                 {u"Name": u"vol"}, OK, {u"Err": None})

    def test_create_with_profile(self):
        """
        Calling the ``/VolumerDriver.Create`` API with an ``Opts`` value
        of "profile=[gold,silver,bronze] in the request body JSON create a
        volume with a given name with [gold,silver,bronze] profile.
        """
        profile = sampled_from(["gold", "silver", "bronze"]).example()
        name = random_name(self)
        d = self.assertResult(b"POST", b"/VolumeDriver.Create", {
            u"Name": name,
            'Opts': {
                u"profile": profile
            }
        }, OK, {u"Err": None})
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(list)
        d.addCallback(
            lambda result: self.assertItemsEqual(result, [
                Dataset(dataset_id=result[0].dataset_id,
                        primary=self.NODE_A,
                        maximum_size=int(DEFAULT_SIZE.to_Byte()),
                        metadata={
                            u"name": name,
                            u"clusterhq:flocker:profile": unicode(profile)
                        })
            ]))
        return d

    def test_create_with_size(self):
        """
        Calling the ``/VolumerDriver.Create`` API with an ``Opts`` value
        of "size=<somesize> in the request body JSON create a volume
        with a given name and random size between 1-100G
        """
        name = random_name(self)
        size = integers(min_value=1, max_value=75).example()
        expression = volume_expression.example()
        size_opt = "".join(str(size)) + expression
        d = self.assertResult(b"POST", b"/VolumeDriver.Create", {
            u"Name": name,
            'Opts': {
                u"size": size_opt
            }
        }, OK, {u"Err": None})

        real_size = int(parse_num(size_opt).to_Byte())
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(list)
        d.addCallback(
            lambda result: self.assertItemsEqual(result, [
                Dataset(dataset_id=result[0].dataset_id,
                        primary=self.NODE_A,
                        maximum_size=real_size,
                        metadata={
                            u"name": name,
                            u"maximum_size": unicode(real_size)
                        })
            ]))
        return d

    @given(expr=volume_expression, size=integers(min_value=75, max_value=100))
    def test_parsenum_size(self, expr, size):
        """
        Send different forms of size expressions
        to ``parse_num``, we expect G(Gigabyte) size results.

        :param expr str: A string representing the size expression
        :param size int: A string representing the volume size
        """
        expected_size = int(GiB(size).to_Byte())
        return self.assertEqual(expected_size,
                                int(parse_num(str(size) + expr).to_Byte()))

    @given(expr=sampled_from(["KB", "MB", "GB", "TB", ""]),
           size=integers(min_value=1, max_value=100))
    def test_parsenum_all_sizes(self, expr, size):
        """
        Send standard size expressions to ``parse_num`` in
        many sizes, we expect to get correct size results.

        :param expr str: A string representing the size expression
        :param size int: A string representing the volume size
        """
        if expr is "KB":
            expected_size = int(KiB(size).to_Byte())
        elif expr is "MB":
            expected_size = int(MiB(size).to_Byte())
        elif expr is "GB":
            expected_size = int(GiB(size).to_Byte())
        elif expr is "TB":
            expected_size = int(TiB(size).to_Byte())
        else:
            expected_size = int(Byte(size).to_Byte())
        return self.assertEqual(expected_size,
                                int(parse_num(str(size) + expr).to_Byte()))

    @given(size=sampled_from(
        [u"foo10Gb", u"10bar10", "10foogib", "10Gfoo", "GIB", "bar10foo"]))
    def test_parsenum_bad_size(self, size):
        """
        Send unacceptable size expressions, upon error
        users should expect to receive Flocker's ``DEFAULT_SIZE``

        :param size str: A string representing the bad volume size
        """
        return self.assertEqual(int(DEFAULT_SIZE.to_Byte()),
                                int(parse_num(size).to_Byte()))

    def create(self, name):
        """
        Call the ``/VolumeDriver.Create`` API to create a volume with the
        given name.

        :param unicode name: The name of the volume to create.

        :return: ``Deferred`` that fires when the volume that was created.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Create",
                                 {u"Name": name}, OK, {u"Err": None})

    def test_create_creates(self):
        """
        ``/VolumeDriver.Create`` creates a new dataset in the configuration.
        """
        name = u"myvol"
        d = self.create(name)
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(list)
        d.addCallback(
            lambda result: self.assertItemsEqual(result, [
                Dataset(dataset_id=result[0].dataset_id,
                        primary=self.NODE_A,
                        maximum_size=int(DEFAULT_SIZE.to_Byte()),
                        metadata={u"name": name})
            ]))
        return d

    def test_create_duplicate_name(self):
        """
        If a dataset with the given name already exists,
        ``/VolumeDriver.Create`` succeeds without create a new volume.
        """
        name = u"thename"
        # Create a dataset out-of-band with matching name but non-matching
        # dataset ID:
        d = self.flocker_client.create_dataset(self.NODE_A,
                                               int(DEFAULT_SIZE.to_Byte()),
                                               metadata={u"name": name})
        d.addCallback(lambda _: self.create(name))
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(lambda results: self.assertEqual(len(list(results)), 1))
        return d

    def test_create_duplicate_name_race_condition(self):
        """
        If a dataset with the given name is created while the
        ``/VolumeDriver.Create`` call is in flight, the call does not
        result in an error.
        """
        name = u"thename"

        # Create a dataset out-of-band with matching dataset ID and name
        # which the docker plugin won't be able to see.
        def create_after_list():
            # Clean up the patched version:
            del self.flocker_client.list_datasets_configuration
            # But first time we're called, we create dataset and lie about
            # its existence:
            d = self.flocker_client.create_dataset(self.NODE_A,
                                                   int(DEFAULT_SIZE.to_Byte()),
                                                   metadata={u"name": name})
            d.addCallback(
                lambda _: DatasetsConfiguration(tag=u"1234", datasets={}))
            return d

        self.flocker_client.list_datasets_configuration = create_after_list

        return self.create(name)

    def _flush_volume_plugin_reactor_on_endpoint_render(self):
        """
        This method patches ``self.app`` so that after any endpoint is
        rendered, the reactor used by the volume plugin is advanced repeatedly
        until there are no more ``delayedCalls`` pending on the reactor.
        """
        real_execute_endpoint = self.app.execute_endpoint

        def patched_execute_endpoint(*args, **kwargs):
            val = real_execute_endpoint(*args, **kwargs)
            while self.volume_plugin_reactor.getDelayedCalls():
                pending_calls = self.volume_plugin_reactor.getDelayedCalls()
                next_expiration = min(t.getTime() for t in pending_calls)
                now = self.volume_plugin_reactor.seconds()
                self.volume_plugin_reactor.advance(
                    max(0.0, next_expiration - now))
            return val

        self.patch(self.app, 'execute_endpoint', patched_execute_endpoint)

    def test_mount(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive.
        """
        name = u"myvol"
        dataset_id = uuid4()

        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(self.NODE_B,
                                               int(DEFAULT_SIZE.to_Byte()),
                                               metadata={u"name": name},
                                               dataset_id=dataset_id)

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 5 seconds for the dataset to get established on
        # Node A.
        self.volume_plugin_reactor.callLater(
            5.0, self.flocker_client.synchronize_state)

        d.addCallback(lambda _: self.assertResult(
            b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK, {
                u"Err": None,
                u"Mountpoint": u"/flocker/{}".format(dataset_id)
            }))
        d.addCallback(lambda _: self.flocker_client.list_datasets_state())

        def final_assertions(datasets):
            self.assertEqual(
                [self.NODE_A],
                [d.primary for d in datasets if d.dataset_id == dataset_id])
            # There should be less than 20 calls to list_datasets_state over
            # the course of 5 seconds.
            self.assertLess(
                self.flocker_client.num_calls('list_datasets_state'), 20)

        d.addCallback(final_assertions)

        return d

    def test_mount_timeout(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive. If it does not arrive within 120 seconds, then it
        returns an error up to docker.
        """
        name = u"myvol"
        dataset_id = uuid4()
        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(self.NODE_B,
                                               int(DEFAULT_SIZE.to_Byte()),
                                               metadata={u"name": name},
                                               dataset_id=dataset_id)

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 500 seconds for the dataset to get established
        # on Node A. This should be longer than the timeout.
        self.volume_plugin_reactor.callLater(
            500.0, self.flocker_client.synchronize_state)

        d.addCallback(lambda _: self.assertResult(
            b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK, {
                u"Err": u"Timed out waiting for dataset to mount.",
                u"Mountpoint": u""
            }))
        return d

    def test_mount_already_exists(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive when used by the volumes that already exist and
        don't have a special dataset ID.
        """
        name = u"myvol"

        d = self.flocker_client.create_dataset(self.NODE_A,
                                               int(DEFAULT_SIZE.to_Byte()),
                                               metadata={u"name": name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            result = self.assertResult(
                b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK, {
                    u"Err": None,
                    u"Mountpoint": u"/flocker/{}".format(dataset.dataset_id)
                })
            result.addCallback(
                lambda _: self.flocker_client.list_datasets_state())
            result.addCallback(lambda ds: self.assertEqual([self.NODE_A], [
                d.primary for d in ds if d.dataset_id == dataset.dataset_id
            ]))
            return result

        d.addCallback(created)
        return d

    def test_unknown_mount(self):
        """
        ``/VolumeDriver.Mount`` returns an error when asked to mount a
        non-existent volume.
        """
        name = u"myvol"
        return self.assertResult(
            b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK,
            {u"Err": u"Could not find volume with given name."})

    def test_path(self):
        """
        ``/VolumeDriver.Path`` returns the mount path of the given volume if
        it is currently known.
        """
        name = u"myvol"

        d = self.create(name)
        # The dataset arrives as state:
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        d.addCallback(lambda _: self.assertResponseCode(
            b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK))
        d.addCallback(
            lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(lambda datasets_config: self.assertResult(
            b"POST", b"/VolumeDriver.Path", {u"Name": name}, OK, {
                u"Err":
                None,
                u"Mountpoint":
                u"/flocker/{}".format(datasets_config.datasets.keys()[0])
            }))
        return d

    def test_path_existing(self):
        """
        ``/VolumeDriver.Path`` returns the mount path of the given volume if
        it is currently known, including for a dataset that was created
        not by the plugin.
        """
        name = u"myvol"

        d = self.flocker_client.create_dataset(self.NODE_A,
                                               int(DEFAULT_SIZE.to_Byte()),
                                               metadata={u"name": name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            return self.assertResult(
                b"POST", b"/VolumeDriver.Path", {u"Name": name}, OK, {
                    u"Err": None,
                    u"Mountpoint": u"/flocker/{}".format(dataset.dataset_id)
                })

        d.addCallback(created)
        return d

    def test_unknown_path(self):
        """
        ``/VolumeDriver.Path`` returns an error when asked for the mount path
        of a non-existent volume.
        """
        name = u"myvol"
        return self.assertResult(
            b"POST", b"/VolumeDriver.Path", {u"Name": name}, OK,
            {u"Err": u"Could not find volume with given name."})

    def test_non_local_path(self):
        """
        ``/VolumeDriver.Path`` returns an error when asked for the mount path
        of a volume that is not mounted locally.

        This can happen as a result of ``docker inspect`` on a container
        that has been created but is still waiting for its volume to
        arrive from another node. It seems like Docker may also call this
        after ``/VolumeDriver.Create``, so again while waiting for a
        volume to arrive.
        """
        name = u"myvol"
        dataset_id = uuid4()

        # Create dataset on node B:
        d = self.flocker_client.create_dataset(self.NODE_B,
                                               int(DEFAULT_SIZE.to_Byte()),
                                               metadata={u"name": name},
                                               dataset_id=dataset_id)
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        # Ask for path on node A:
        d.addCallback(lambda _: self.assertResult(
            b"POST", b"/VolumeDriver.Path", {u"Name": name}, OK, {
                u"Err": "Volume not available.",
                u"Mountpoint": u""
            }))
        return d

    @capture_logging(lambda self, logger: self.assertEqual(
        len(logger.flushTracebacks(CustomException)), 1))
    def test_unexpected_error_reporting(self, logger):
        """
        If an unexpected error occurs Docker gets back a useful error message.
        """
        def error():
            raise CustomException("I've made a terrible mistake")

        self.patch(self.flocker_client, "list_datasets_configuration", error)
        return self.assertResult(
            b"POST", b"/VolumeDriver.Path", {u"Name": u"whatever"}, OK,
            {u"Err": "CustomException: I've made a terrible mistake"})

    @capture_logging(None)
    def test_bad_request(self, logger):
        """
        If a ``BadRequest`` exception is raised it is converted to appropriate
        JSON.
        """
        def error():
            raise make_bad_request(code=423, Err=u"no good")

        self.patch(self.flocker_client, "list_datasets_configuration", error)
        return self.assertResult(b"POST", b"/VolumeDriver.Path",
                                 {u"Name": u"whatever"}, 423,
                                 {u"Err": "no good"})

    def test_unsupported_method(self):
        """
        If an unsupported method is requested the 405 Not Allowed response
        code is returned.
        """
        return self.assertResponseCode(b"BAD_METHOD", b"/VolumeDriver.Path",
                                       None, NOT_ALLOWED)

    def test_unknown_uri(self):
        """
        If an unknown URI path is requested the 404 Not Found response code is
        returned.
        """
        return self.assertResponseCode(b"BAD_METHOD", b"/xxxnotthere", None,
                                       NOT_FOUND)

    def test_empty_host(self):
        """
        If an empty host header is sent to the Docker plugin it does not blow
        up, instead operating normally. E.g. for ``Plugin.Activate`` call
        returns the ``Implements`` response.
        """
        return self.assertResult(b"POST",
                                 b"/Plugin.Activate",
                                 12345,
                                 OK, {u"Implements": [u"VolumeDriver"]},
                                 additional_headers={b"Host": [""]})
class MemCacheTests(CommandMixin, TestCase):
    """
    Test client protocol class L{MemCacheProtocol}.
    """
    def setUp(self):
        """
        Create a memcache client, connect it to a string protocol, and make it
        use a deterministic clock.
        """
        self.proto = MemCacheProtocol()
        self.clock = Clock()
        self.proto.callLater = self.clock.callLater
        self.transport = StringTransportWithDisconnection()
        self.transport.protocol = self.proto
        self.proto.makeConnection(self.transport)

    def _test(self, d, send, recv, result):
        """
        Implementation of C{_test} which checks that the command sends C{send}
        data, and that upon reception of C{recv} the result is C{result}.

        @param d: the resulting deferred from the memcache command.
        @type d: C{Deferred}

        @param send: the expected data to be sent.
        @type send: C{bytes}

        @param recv: the data to simulate as reception.
        @type recv: C{bytes}

        @param result: the expected result.
        @type result: C{any}
        """
        def cb(res):
            self.assertEqual(res, result)

        self.assertEqual(self.transport.value(), send)
        d.addCallback(cb)
        self.proto.dataReceived(recv)
        return d

    def test_invalidGetResponse(self):
        """
        If the value returned doesn't match the expected key of the current
        C{get} command, an error is raised in L{MemCacheProtocol.dataReceived}.
        """
        self.proto.get(b"foo")
        self.assertRaises(
            RuntimeError,
            self.proto.dataReceived,
            b"VALUE bar 0 7\r\nspamegg\r\nEND\r\n",
        )

    def test_invalidMultipleGetResponse(self):
        """
        If the value returned doesn't match one the expected keys of the
        current multiple C{get} command, an error is raised error in
        L{MemCacheProtocol.dataReceived}.
        """
        self.proto.getMultiple([b"foo", b"bar"])
        self.assertRaises(
            RuntimeError,
            self.proto.dataReceived,
            b"VALUE egg 0 7\r\nspamegg\r\nEND\r\n",
        )

    def test_invalidEndResponse(self):
        """
        If an END is received in response to an operation that isn't C{get},
        C{gets}, or C{stats}, an error is raised in
        L{MemCacheProtocol.dataReceived}.
        """
        self.proto.set(b"key", b"value")
        self.assertRaises(RuntimeError, self.proto.dataReceived, b"END\r\n")

    def test_timeOut(self):
        """
        Test the timeout on outgoing requests: when timeout is detected, all
        current commands fail with a L{TimeoutError}, and the connection is
        closed.
        """
        d1 = self.proto.get(b"foo")
        d2 = self.proto.get(b"bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)

        def checkMessage(error):
            self.assertEqual(str(error), "Connection timeout")

        d1.addCallback(checkMessage)
        self.assertFailure(d3, ConnectionDone)
        return gatherResults([d1, d2, d3])

    def test_timeoutRemoved(self):
        """
        When a request gets a response, no pending timeout call remains around.
        """
        d = self.proto.get(b"foo")

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEqual(result, (0, b"bar"))
            self.assertEqual(len(self.clock.calls), 0)

        d.addCallback(check)
        return d

    def test_timeOutRaw(self):
        """
        Test the timeout when raw mode was started: the timeout is not reset
        until all the data has been received, so we can have a L{TimeoutError}
        when waiting for raw data.
        """
        d1 = self.proto.get(b"foo")
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived(b"VALUE foo 0 10\r\n12345")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, ConnectionDone)
        return gatherResults([d1, d2])

    def test_timeOutStat(self):
        """
        Test the timeout when stat command has started: the timeout is not
        reset until the final B{END} is received.
        """
        d1 = self.proto.stats()
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived(b"STAT foo bar\r\n")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, ConnectionDone)
        return gatherResults([d1, d2])

    def test_timeoutPipelining(self):
        """
        When two requests are sent, a timeout call remains around for the
        second request, and its timeout time is correct.
        """
        d1 = self.proto.get(b"foo")
        d2 = self.proto.get(b"bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEqual(result, (0, b"bar"))
            self.assertEqual(len(self.clock.calls), 1)
            for i in range(self.proto.persistentTimeOut):
                self.clock.advance(1)
            return self.assertFailure(d2, TimeoutError).addCallback(checkTime)

        def checkTime(ignored):
            # Check that the timeout happened C{self.proto.persistentTimeOut}
            # after the last response
            self.assertEqual(self.clock.seconds(),
                             2 * self.proto.persistentTimeOut - 1)

        d1.addCallback(check)
        self.assertFailure(d3, ConnectionDone)
        return d1

    def test_timeoutNotReset(self):
        """
        Check that timeout is not resetted for every command, but keep the
        timeout from the first command without response.
        """
        d1 = self.proto.get(b"foo")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        d2 = self.proto.get(b"bar")
        self.clock.advance(1)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)
        self.assertFailure(d3, ConnectionDone)
        return gatherResults([d1, d2, d3])

    def test_timeoutCleanDeferreds(self):
        """
        C{timeoutConnection} cleans the list of commands that it fires with
        C{TimeoutError}: C{connectionLost} doesn't try to fire them again, but
        sets the disconnected state so that future commands fail with a
        C{RuntimeError}.
        """
        d1 = self.proto.get(b"foo")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        d2 = self.proto.get(b"bar")
        self.assertFailure(d2, RuntimeError)
        return gatherResults([d1, d2])

    def test_connectionLost(self):
        """
        When disconnection occurs while commands are still outstanding, the
        commands fail.
        """
        d1 = self.proto.get(b"foo")
        d2 = self.proto.get(b"bar")
        self.transport.loseConnection()
        done = DeferredList([d1, d2], consumeErrors=True)

        def checkFailures(results):
            for success, result in results:
                self.assertFalse(success)
                result.trap(ConnectionDone)

        return done.addCallback(checkFailures)

    def test_tooLongKey(self):
        """
        An error is raised when trying to use a too long key: the called
        command returns a L{Deferred} which fails with a L{ClientError}.
        """
        d1 = self.assertFailure(self.proto.set(b"a" * 500, b"bar"),
                                ClientError)
        d2 = self.assertFailure(self.proto.increment(b"a" * 500), ClientError)
        d3 = self.assertFailure(self.proto.get(b"a" * 500), ClientError)
        d4 = self.assertFailure(self.proto.append(b"a" * 500, b"bar"),
                                ClientError)
        d5 = self.assertFailure(self.proto.prepend(b"a" * 500, b"bar"),
                                ClientError)
        d6 = self.assertFailure(self.proto.getMultiple([b"foo", b"a" * 500]),
                                ClientError)
        return gatherResults([d1, d2, d3, d4, d5, d6])

    def test_invalidCommand(self):
        """
        When an unknown command is sent directly (not through public API), the
        server answers with an B{ERROR} token, and the command fails with
        L{NoSuchCommand}.
        """
        d = self.proto._set(b"egg", b"foo", b"bar", 0, 0, b"")
        self.assertEqual(self.transport.value(), b"egg foo 0 0 3\r\nbar\r\n")
        self.assertFailure(d, NoSuchCommand)
        self.proto.dataReceived(b"ERROR\r\n")
        return d

    def test_clientError(self):
        """
        Test the L{ClientError} error: when the server sends a B{CLIENT_ERROR}
        token, the originating command fails with L{ClientError}, and the error
        contains the text sent by the server.
        """
        a = b"eggspamm"
        d = self.proto.set(b"foo", a)
        self.assertEqual(self.transport.value(),
                         b"set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ClientError)

        def check(err):
            self.assertEqual(str(err), repr(b"We don't like egg and spam"))

        d.addCallback(check)
        self.proto.dataReceived(b"CLIENT_ERROR We don't like egg and spam\r\n")
        return d

    def test_serverError(self):
        """
        Test the L{ServerError} error: when the server sends a B{SERVER_ERROR}
        token, the originating command fails with L{ServerError}, and the error
        contains the text sent by the server.
        """
        a = b"eggspamm"
        d = self.proto.set(b"foo", a)
        self.assertEqual(self.transport.value(),
                         b"set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ServerError)

        def check(err):
            self.assertEqual(str(err), repr(b"zomg"))

        d.addCallback(check)
        self.proto.dataReceived(b"SERVER_ERROR zomg\r\n")
        return d

    def test_unicodeKey(self):
        """
        Using a non-string key as argument to commands raises an error.
        """
        d1 = self.assertFailure(self.proto.set("foo", b"bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment("egg"), ClientError)
        d3 = self.assertFailure(self.proto.get(1), ClientError)
        d4 = self.assertFailure(self.proto.delete("bar"), ClientError)
        d5 = self.assertFailure(self.proto.append("foo", b"bar"), ClientError)
        d6 = self.assertFailure(self.proto.prepend("foo", b"bar"), ClientError)
        d7 = self.assertFailure(self.proto.getMultiple([b"egg", 1]),
                                ClientError)
        return gatherResults([d1, d2, d3, d4, d5, d6, d7])

    def test_unicodeValue(self):
        """
        Using a non-string value raises an error.
        """
        return self.assertFailure(self.proto.set(b"foo", "bar"), ClientError)

    def test_pipelining(self):
        """
        Multiple requests can be sent subsequently to the server, and the
        protocol orders the responses correctly and dispatch to the
        corresponding client command.
        """
        d1 = self.proto.get(b"foo")
        d1.addCallback(self.assertEqual, (0, b"bar"))
        d2 = self.proto.set(b"bar", b"spamspamspam")
        d2.addCallback(self.assertEqual, True)
        d3 = self.proto.get(b"egg")
        d3.addCallback(self.assertEqual, (0, b"spam"))
        self.assertEqual(
            self.transport.value(),
            b"get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n",
        )
        self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n"
                                b"STORED\r\n"
                                b"VALUE egg 0 4\r\nspam\r\nEND\r\n")
        return gatherResults([d1, d2, d3])

    def test_getInChunks(self):
        """
        If the value retrieved by a C{get} arrive in chunks, the protocol
        is able to reconstruct it and to produce the good value.
        """
        d = self.proto.get(b"foo")
        d.addCallback(self.assertEqual, (0, b"0123456789"))
        self.assertEqual(self.transport.value(), b"get foo\r\n")
        self.proto.dataReceived(b"VALUE foo 0 10\r\n0123456")
        self.proto.dataReceived(b"789")
        self.proto.dataReceived(b"\r\nEND")
        self.proto.dataReceived(b"\r\n")
        return d

    def test_append(self):
        """
        L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
        method: it returns a L{Deferred} which is called back with C{True} when
        the operation succeeds.
        """
        return self._test(
            self.proto.append(b"foo", b"bar"),
            b"append foo 0 0 3\r\nbar\r\n",
            b"STORED\r\n",
            True,
        )

    def test_prepend(self):
        """
        L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
        method: it returns a L{Deferred} which is called back with C{True} when
        the operation succeeds.
        """
        return self._test(
            self.proto.prepend(b"foo", b"bar"),
            b"prepend foo 0 0 3\r\nbar\r\n",
            b"STORED\r\n",
            True,
        )

    def test_gets(self):
        """
        L{MemCacheProtocol.get} handles an additional cas result when
        C{withIdentifier} is C{True} and forward it in the resulting
        L{Deferred}.
        """
        return self._test(
            self.proto.get(b"foo", True),
            b"gets foo\r\n",
            b"VALUE foo 0 3 1234\r\nbar\r\nEND\r\n",
            (0, b"1234", b"bar"),
        )

    def test_emptyGets(self):
        """
        Test getting a non-available key with gets: it succeeds but return
        L{None} as value, C{0} as flag and an empty cas value.
        """
        return self._test(self.proto.get(b"foo", True), b"gets foo\r\n",
                          b"END\r\n", (0, b"", None))

    def test_getsMultiple(self):
        """
        L{MemCacheProtocol.getMultiple} handles an additional cas field in the
        returned tuples if C{withIdentifier} is C{True}.
        """
        return self._test(
            self.proto.getMultiple([b"foo", b"bar"], True),
            b"gets foo bar\r\n",
            b"VALUE foo 0 3 1234\r\negg\r\n"
            b"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
            {
                b"bar": (0, b"2345", b"spam"),
                b"foo": (0, b"1234", b"egg")
            },
        )

    def test_getsMultipleIterableKeys(self):
        """
        L{MemCacheProtocol.getMultiple} accepts any iterable of keys.
        """
        return self._test(
            self.proto.getMultiple(iter([b"foo", b"bar"]), True),
            b"gets foo bar\r\n",
            b"VALUE foo 0 3 1234\r\negg\r\n"
            b"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
            {
                b"bar": (0, b"2345", b"spam"),
                b"foo": (0, b"1234", b"egg")
            },
        )

    def test_getsMultipleWithEmpty(self):
        """
        When getting a non-available key with L{MemCacheProtocol.getMultiple}
        when C{withIdentifier} is C{True}, the other keys are retrieved
        correctly, and the non-available key gets a tuple of C{0} as flag,
        L{None} as value, and an empty cas value.
        """
        return self._test(
            self.proto.getMultiple([b"foo", b"bar"], True),
            b"gets foo bar\r\n",
            b"VALUE foo 0 3 1234\r\negg\r\nEND\r\n",
            {
                b"bar": (0, b"", None),
                b"foo": (0, b"1234", b"egg")
            },
        )

    def test_checkAndSet(self):
        """
        L{MemCacheProtocol.checkAndSet} passes an additional cas identifier
        that the server handles to check if the data has to be updated.
        """
        return self._test(
            self.proto.checkAndSet(b"foo", b"bar", cas=b"1234"),
            b"cas foo 0 0 3 1234\r\nbar\r\n",
            b"STORED\r\n",
            True,
        )

    def test_casUnknowKey(self):
        """
        When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the
        resulting L{Deferred} fires with C{False}.
        """
        return self._test(
            self.proto.checkAndSet(b"foo", b"bar", cas=b"1234"),
            b"cas foo 0 0 3 1234\r\nbar\r\n",
            b"EXISTS\r\n",
            False,
        )
class SchedulerStartupTests(TestCase):
    """
    Tests for behavior relating to L{Scheduler} service startup.
    """
    def setUp(self):
        self.clock = Clock()
        self.store = Store()


    def tearDown(self):
        return self.stopStoreService()


    def now(self):
        return Time.fromPOSIXTimestamp(self.clock.seconds())


    def time(self, offset):
        return self.now() + timedelta(seconds=offset)


    def makeScheduler(self):
        """
        Create, install, and return a Scheduler with a fake callLater.
        """
        scheduler = IScheduler(self.store)
        scheduler.callLater = self.clock.callLater
        scheduler.now = self.now
        return scheduler


    def startStoreService(self):
        """
        Start the Store Service.
        """
        service = IService(self.store)
        service.startService()


    def stopStoreService(self):
        service = IService(self.store)
        if service.running:
            return service.stopService()


    def test_schedulerStartsWhenServiceStarts(self):
        """
        Test that IScheduler(store).startService() gets called whenever
        IService(store).startService() is called.
        """
        service = IService(self.store)
        service.startService()
        scheduler = service.getServiceNamed(SITE_SCHEDULER)
        self.assertTrue(scheduler.running)


    def test_scheduleWhileStopped(self):
        """
        Test that a schedule call on a L{Scheduler} which has not been started
        does not result in the creation of a transient timed event.
        """
        scheduler = self.makeScheduler()
        scheduler.schedule(TestEvent(store=self.store), self.time(1))
        self.assertEqual(self.clock.calls, [])


    def test_scheduleWithRunningService(self):
        """
        Test that if a scheduler is created and installed on a store which has
        a started service, a transient timed event is created when the scheduler
        is used.
        """
        self.startStoreService()
        scheduler = self.makeScheduler()
        scheduler.schedule(TestEvent(store=self.store), self.time(1))
        self.assertEqual(len(self.clock.calls), 1)


    def test_schedulerStartedWithPastEvent(self):
        """
        Test that an existing Scheduler with a TimedEvent in the past is
        started immediately (but does not run the TimedEvent synchronously)
        when the Store Service is started.
        """
        scheduler = self.makeScheduler()
        scheduler.schedule(TestEvent(store=self.store), self.time(-1))
        self.assertEqual(self.clock.calls, [])
        self.startStoreService()
        self.assertEqual(len(self.clock.calls), 1)


    def test_schedulerStartedWithFutureEvent(self):
        """
        Test that an existing Scheduler with a TimedEvent in the future is
        started immediately when the Store Service is started.
        """
        scheduler = self.makeScheduler()
        scheduler.schedule(TestEvent(store=self.store), self.time(1))
        self.assertEqual(self.clock.calls, [])
        self.startStoreService()
        self.assertEqual(len(self.clock.calls), 1)


    def test_schedulerStopped(self):
        """
        Test that when the Store Service is stopped, the Scheduler's transient
        timed event is cleaned up.
        """
        self.test_scheduleWithRunningService()
        d = self.stopStoreService()
        def cbStopped(ignored):
            self.assertEqual(self.clock.calls, [])
        d.addCallback(cbStopped)
        return d
Exemple #26
0
class LeasesTests(AsyncTestCase):
    """
    Tests for ``LeaseService`` and ``update_leases``.
    """
    def setUp(self):
        super(LeasesTests, self).setUp()
        self.clock = Clock()
        self.persistence_service = ConfigurationPersistenceService(
            self.clock, FilePath(self.mktemp()))
        self.persistence_service.startService()
        self.addCleanup(self.persistence_service.stopService)

    def test_update_leases_saves_changed_leases(self):
        """
        ``update_leases`` only changes the leases stored in the configuration.
        """
        node_id = uuid4()
        dataset_id = uuid4()

        original_leases = Leases().acquire(
            datetime.fromtimestamp(0, UTC), uuid4(), node_id)

        def update(leases):
            return leases.acquire(
                datetime.fromtimestamp(1000, UTC), dataset_id, node_id)

        d = self.persistence_service.save(
            LATEST_TEST_DEPLOYMENT.set(leases=original_leases))
        d.addCallback(
            lambda _: update_leases(update, self.persistence_service))

        def updated(_):
            self.assertEqual(
                self.persistence_service.get(),
                LATEST_TEST_DEPLOYMENT.set(leases=update(original_leases)))
        d.addCallback(updated)
        return d

    def test_update_leases_result(self):
        """
        ``update_leases`` returns a ``Deferred`` firing with the updated
        ``Leases`` instance.
        """
        node_id = uuid4()
        dataset_id = uuid4()
        original_leases = Leases()

        def update(leases):
            return leases.acquire(
                datetime.fromtimestamp(1000, UTC), dataset_id, node_id)
        d = update_leases(update, self.persistence_service)

        def updated(updated_leases):
            self.assertEqual(updated_leases, update(original_leases))
        d.addCallback(updated)
        return d

    def test_expired_lease_removed(self):
        """
        A lease that has expired is removed from the persisted
        configuration.
        """
        timestep = 100
        node_id = uuid4()
        ids = uuid4(), uuid4()
        # First dataset lease expires at timestep:
        now = self.clock.seconds()
        leases = Leases().acquire(
            datetime.fromtimestamp(now, UTC), ids[0], node_id, timestep)
        # Second dataset lease expires at timestep * 2:
        leases = leases.acquire(
            datetime.fromtimestamp(now, UTC), ids[1], node_id, timestep * 2)
        new_config = Deployment(leases=leases)
        d = self.persistence_service.save(new_config)

        def saved(_):
            self.clock.advance(timestep - 1)  # 99
            before_first_expire = self.persistence_service.get().leases
            self.clock.advance(2)  # 101
            after_first_expire = self.persistence_service.get().leases
            self.clock.advance(timestep - 2)  # 199
            before_second_expire = self.persistence_service.get().leases
            self.clock.advance(2)  # 201
            after_second_expire = self.persistence_service.get().leases

            self.assertTupleEqual(
                (before_first_expire, after_first_expire,
                 before_second_expire, after_second_expire),
                (leases, leases.remove(ids[0]), leases.remove(ids[0]),
                 leases.remove(ids[0]).remove(ids[1])))
        d.addCallback(saved)
        return d

    @capture_logging(None)
    def test_expire_lease_logging(self, logger):
        """
        An expired lease is logged.
        """
        node_id = uuid4()
        dataset_id = uuid4()
        leases = Leases().acquire(
            datetime.fromtimestamp(self.clock.seconds(), UTC),
            dataset_id, node_id, 1)

        d = self.persistence_service.save(Deployment(leases=leases))

        def saved(_):
            logger.reset()
            self.clock.advance(1000)
            assertHasMessage(self, logger, _LOG_EXPIRE, {
                u"dataset_id": dataset_id, u"node_id": node_id})
        d.addCallback(saved)
        return d
class MemCacheTestCase(TestCase):
    """
    Test client protocol class L{MemCacheProtocol}.
    """

    def setUp(self):
        """
        Create a memcache client, connect it to a string protocol, and make it
        use a deterministic clock.
        """
        self.proto = MemCacheProtocol()
        self.clock = Clock()
        self.proto.callLater = self.clock.callLater
        self.transport = StringTransportWithDisconnection()
        self.transport.protocol = self.proto
        self.proto.makeConnection(self.transport)


    def _test(self, d, send, recv, result):
        """
        Shortcut method for classic tests.

        @param d: the resulting deferred from the memcache command.
        @type d: C{Deferred}

        @param send: the expected data to be sent.
        @type send: C{str}

        @param recv: the data to simulate as reception.
        @type recv: C{str}

        @param result: the expected result.
        @type result: C{any}
        """
        def cb(res):
            self.assertEquals(res, result)
        self.assertEquals(self.transport.value(), send)
        d.addCallback(cb)
        self.proto.dataReceived(recv)
        return d


    def test_get(self):
        """
        L{MemCacheProtocol.get} should return a L{Deferred} which is
        called back with the value and the flag associated with the given key
        if the server returns a successful result.
        """
        return self._test(self.proto.get("foo"), "get foo\r\n",
            "VALUE foo 0 3\r\nbar\r\nEND\r\n", (0, "bar"))


    def test_emptyGet(self):
        """
        Test getting a non-available key: it should succeed but return C{None}
        as value and C{0} as flag.
        """
        return self._test(self.proto.get("foo"), "get foo\r\n",
            "END\r\n", (0, None))


    def test_set(self):
        """
        L{MemCacheProtocol.set} should return a L{Deferred} which is
        called back with C{True} when the operation succeeds.
        """
        return self._test(self.proto.set("foo", "bar"),
            "set foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_add(self):
        """
        L{MemCacheProtocol.add} should return a L{Deferred} which is
        called back with C{True} when the operation succeeds.
        """
        return self._test(self.proto.add("foo", "bar"),
            "add foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_replace(self):
        """
        L{MemCacheProtocol.replace} should return a L{Deferred} which
        is called back with C{True} when the operation succeeds.
        """
        return self._test(self.proto.replace("foo", "bar"),
            "replace foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_errorAdd(self):
        """
        Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
        key already exists on the server, it returns a B{NOT STORED} answer,
        which should callback the resulting L{Deferred} with C{False}.
        """
        return self._test(self.proto.add("foo", "bar"),
            "add foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)


    def test_errorReplace(self):
        """
        Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
        but the key doesn't exist on the server, it returns a B{NOT STORED}
        answer, which should callback the resulting L{Deferred} with C{False}.
        """
        return self._test(self.proto.replace("foo", "bar"),
            "replace foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)


    def test_delete(self):
        """
        L{MemCacheProtocol.delete} should return a L{Deferred} which is
        called back with C{True} when the server notifies a success.
        """
        return self._test(self.proto.delete("bar"), "delete bar\r\n",
            "DELETED\r\n", True)


    def test_errorDelete(self):
        """
        Test a error during a delete: if key doesn't exist on the server, it
        returns a B{NOT FOUND} answer which should callback the resulting
        L{Deferred} with C{False}.
        """
        return self._test(self.proto.delete("bar"), "delete bar\r\n",
            "NOT FOUND\r\n", False)


    def test_increment(self):
        """
        Test incrementing a variable: L{MemCacheProtocol.increment} should
        return a L{Deferred} which is called back with the incremented value of
        the given key.
        """
        return self._test(self.proto.increment("foo"), "incr foo 1\r\n",
            "4\r\n", 4)


    def test_decrement(self):
        """
        Test decrementing a variable: L{MemCacheProtocol.decrement} should
        return a L{Deferred} which is called back with the decremented value of
        the given key.
        """
        return self._test(
            self.proto.decrement("foo"), "decr foo 1\r\n", "5\r\n", 5)


    def test_incrementVal(self):
        """
        L{MemCacheProtocol.increment} takes an optional argument C{value} which
        should replace the default value of 1 when specified.
        """
        return self._test(self.proto.increment("foo", 8), "incr foo 8\r\n",
            "4\r\n", 4)


    def test_decrementVal(self):
        """
        L{MemCacheProtocol.decrement} takes an optional argument C{value} which
        should replace the default value of 1 when specified.
        """
        return self._test(self.proto.decrement("foo", 3), "decr foo 3\r\n",
            "5\r\n", 5)


    def test_stats(self):
        """
        Test retrieving server statistics via the L{MemCacheProtocol.stats}
        command: it should parse the data sent by the server and call back the
        resulting L{Deferred} with a dictionary of the received statistics.
        """
        return self._test(self.proto.stats(), "stats\r\n",
            "STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
            {"foo": "bar", "egg": "spam"})


    def test_statsWithArgument(self):
        """

        L{MemCacheProtocol.stats} takes an optional C{str} argument which,
        if specified, is sent along with the I{STAT} command.  The I{STAT}
        responses from the server are parsed as key/value pairs and returned
        as a C{dict} (as in the case where the argument is not specified).
        """
        return self._test(self.proto.stats("blah"), "stats blah\r\n",
            "STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
            {"foo": "bar", "egg": "spam"})


    def test_version(self):
        """
        Test version retrieval via the L{MemCacheProtocol.version} command: it
        should return a L{Deferred} which is called back with the version sent
        by the server.
        """
        return self._test(self.proto.version(), "version\r\n",
            "VERSION 1.1\r\n", "1.1")


    def test_flushAll(self):
        """
        L{MemCacheProtocol.flushAll} should return a L{Deferred} which is
        called back with C{True} if the server acknowledges success.
        """
        return self._test(self.proto.flushAll(), "flush_all\r\n",
            "OK\r\n", True)


    def test_invalidGetResponse(self):
        """
        If the value returned doesn't match the expected key of the current, we
        should get an error in L{MemCacheProtocol.dataReceived}.
        """
        self.proto.get("foo")
        s = "spamegg"
        self.assertRaises(RuntimeError,
            self.proto.dataReceived,
            "VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))


    def test_timeOut(self):
        """
        Test the timeout on outgoing requests: when timeout is detected, all
        current commands should fail with a L{TimeoutError}, and the
        connection should be closed.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)
        def checkMessage(error):
            self.assertEquals(str(error), "Connection timeout")
        d1.addCallback(checkMessage)
        return gatherResults([d1, d2, d3])


    def test_timeoutRemoved(self):
        """
        When a request gets a response, no pending timeout call should remain
        around.
        """
        d = self.proto.get("foo")

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEquals(result, (0, "bar"))
            self.assertEquals(len(self.clock.calls), 0)
        d.addCallback(check)
        return d


    def test_timeOutRaw(self):
        """
        Test the timeout when raw mode was started: the timeout should not be
        reset until all the data has been received, so we can have a
        L{TimeoutError} when waiting for raw data.
        """
        d1 = self.proto.get("foo")
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived("VALUE foo 0 10\r\n12345")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        return gatherResults([d1, d2])


    def test_timeOutStat(self):
        """
        Test the timeout when stat command has started: the timeout should not
        be reset until the final B{END} is received.
        """
        d1 = self.proto.stats()
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived("STAT foo bar\r\n")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        return gatherResults([d1, d2])


    def test_timeoutPipelining(self):
        """
        When two requests are sent, a timeout call should remain around for the
        second request, and its timeout time should be correct.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEquals(result, (0, "bar"))
            self.assertEquals(len(self.clock.calls), 1)
            for i in range(self.proto.persistentTimeOut):
                self.clock.advance(1)
            return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
        def checkTime(ignored):
            # Check that the timeout happened C{self.proto.persistentTimeOut}
            # after the last response
            self.assertEquals(self.clock.seconds(),
                    2 * self.proto.persistentTimeOut - 1)
        d1.addCallback(check)
        return d1


    def test_timeoutNotReset(self):
        """
        Check that timeout is not resetted for every command, but keep the
        timeout from the first command without response.
        """
        d1 = self.proto.get("foo")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        d2 = self.proto.get("bar")
        self.clock.advance(1)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)
        return gatherResults([d1, d2, d3])


    def test_tooLongKey(self):
        """
        Test that an error is raised when trying to use a too long key: the
        called command should return a L{Deferred} which fail with a
        L{ClientError}.
        """
        d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError)
        d3 = self.assertFailure(self.proto.get("a" * 500), ClientError)
        d4 = self.assertFailure(self.proto.append("a" * 500, "bar"), ClientError)
        d5 = self.assertFailure(self.proto.prepend("a" * 500, "bar"), ClientError)
        return gatherResults([d1, d2, d3, d4, d5])


    def test_invalidCommand(self):
        """
        When an unknown command is sent directly (not through public API), the
        server answers with an B{ERROR} token, and the command should fail with
        L{NoSuchCommand}.
        """
        d = self.proto._set("egg", "foo", "bar", 0, 0, "")
        self.assertEquals(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n")
        self.assertFailure(d, NoSuchCommand)
        self.proto.dataReceived("ERROR\r\n")
        return d


    def test_clientError(self):
        """
        Test the L{ClientError} error: when the server send a B{CLIENT_ERROR}
        token, the originating command should fail with L{ClientError}, and the
        error should contain the text sent by the server.
        """
        a = "eggspamm"
        d = self.proto.set("foo", a)
        self.assertEquals(self.transport.value(),
                          "set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ClientError)
        def check(err):
            self.assertEquals(str(err), "We don't like egg and spam")
        d.addCallback(check)
        self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n")
        return d


    def test_serverError(self):
        """
        Test the L{ServerError} error: when the server send a B{SERVER_ERROR}
        token, the originating command should fail with L{ServerError}, and the
        error should contain the text sent by the server.
        """
        a = "eggspamm"
        d = self.proto.set("foo", a)
        self.assertEquals(self.transport.value(),
                          "set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ServerError)
        def check(err):
            self.assertEquals(str(err), "zomg")
        d.addCallback(check)
        self.proto.dataReceived("SERVER_ERROR zomg\r\n")
        return d


    def test_unicodeKey(self):
        """
        Using a non-string key as argument to commands should raise an error.
        """
        d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError)
        d3 = self.assertFailure(self.proto.get(1), ClientError)
        d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError)
        d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError)
        d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError)
        return gatherResults([d1, d2, d3, d4, d5, d6])


    def test_unicodeValue(self):
        """
        Using a non-string value should raise an error.
        """
        return self.assertFailure(self.proto.set("foo", u"bar"), ClientError)


    def test_pipelining(self):
        """
        Test that multiple requests can be sent subsequently to the server, and
        that the protocol order the responses correctly and dispatch to the
        corresponding client command.
        """
        d1 = self.proto.get("foo")
        d1.addCallback(self.assertEquals, (0, "bar"))
        d2 = self.proto.set("bar", "spamspamspam")
        d2.addCallback(self.assertEquals, True)
        d3 = self.proto.get("egg")
        d3.addCallback(self.assertEquals, (0, "spam"))
        self.assertEquals(self.transport.value(),
            "get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n")
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n"
                                "STORED\r\n"
                                "VALUE egg 0 4\r\nspam\r\nEND\r\n")
        return gatherResults([d1, d2, d3])


    def test_getInChunks(self):
        """
        If the value retrieved by a C{get} arrive in chunks, the protocol
        should be able to reconstruct it and to produce the good value.
        """
        d = self.proto.get("foo")
        d.addCallback(self.assertEquals, (0, "0123456789"))
        self.assertEquals(self.transport.value(), "get foo\r\n")
        self.proto.dataReceived("VALUE foo 0 10\r\n0123456")
        self.proto.dataReceived("789")
        self.proto.dataReceived("\r\nEND")
        self.proto.dataReceived("\r\n")
        return d


    def test_append(self):
        """
        L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
        method: it should return a L{Deferred} which is called back with
        C{True} when the operation succeeds.
        """
        return self._test(self.proto.append("foo", "bar"),
            "append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_prepend(self):
        """
        L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
        method: it should return a L{Deferred} which is called back with
        C{True} when the operation succeeds.
        """
        return self._test(self.proto.prepend("foo", "bar"),
            "prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)


    def test_gets(self):
        """
        L{MemCacheProtocol.get} should handle an additional cas result when
        C{withIdentifier} is C{True} and forward it in the resulting
        L{Deferred}.
        """
        return self._test(self.proto.get("foo", True), "gets foo\r\n",
            "VALUE foo 0 3 1234\r\nbar\r\nEND\r\n", (0, "1234", "bar"))


    def test_emptyGets(self):
        """
        Test getting a non-available key with gets: it should succeed but
        return C{None} as value, C{0} as flag and an empty cas value.
        """
        return self._test(self.proto.get("foo", True), "gets foo\r\n",
            "END\r\n", (0, "", None))


    def test_checkAndSet(self):
        """
        L{MemCacheProtocol.checkAndSet} passes an additional cas identifier that the
        server should handle to check if the data has to be updated.
        """
        return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
            "cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True)


    def test_casUnknowKey(self):
        """
        When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the resulting
        L{Deferred} should fire with C{False}.
        """
        return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
            "cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False)
Exemple #28
0
 def patch_message_rate_clock(self):
     '''Patches the message rate clock, and returns the clock'''
     clock = Clock()
     self.patch(MessageRateStore, 'get_seconds', lambda _: clock.seconds())
     return clock
class TransportTestCase(object):
    """PT client and server connect over a string transport.

    We bypass the communication between client and server and intercept the
    messages sent over the string transport.
    """
    def setUp(self):
        """Set the reactor's callLater to our clock's callLater function
        and build the protocols.
        """
        self.clock = Clock()
        reactor.callLater = self.clock.callLater
        self.dump = []
        self.proto_client = self._build_protocol(const.CLIENT)
        self.proto_server = self._build_protocol(const.SERVER)
        self.pt_client = self.proto_client.circuit.transport
        self.pt_server = self.proto_server.circuit.transport
        self._proxy(self.proto_client, self.proto_server)
        self._bypass_connection(self.proto_client, self.proto_server)

    def _proxy(self, client, server):
        """Proxy the communication between client and server and dump
        intercepted data into a dictionary.
        """
        def decorate_intercept(end):
            old_rcv_f = end.circuit.transport.receivedDownstream
            old_snd_f = end.circuit.transport.sendDownstream

            def intercept(old_f, direction):
                def new_f(data):
                    msgs = old_f(data)
                    end.history[direction].append((self.clock.seconds(), msgs))

                return new_f

            end.circuit.transport.receivedDownstream = intercept(
                old_rcv_f, 'rcv')
            end.circuit.transport.sendDownstream = intercept(old_snd_f, 'snd')

        decorate_intercept(client)
        decorate_intercept(server)

    def _bypass_connection(self, client, server):
        """Instead of requiring TCP connections between client and server
        transports, we directly pass the data written from one end to the
        received function at the other.
        """
        def curry_bypass_connection(up, down, direction):
            old_write = up.circuit.downstream.write

            def write(data):
                old_write(data)
                down.dataReceived(data)
                self.dump.append((self.clock.seconds(), direction * len(data)))

            return write

        client.circuit.downstream.write = curry_bypass_connection(
            client, server, const.OUT)
        server.circuit.downstream.write = curry_bypass_connection(
            server, client, const.IN)

    def _build_protocol(self, mode):
        """Build client and server protocols for an end point."""
        addr_tuple = (HOST, str(PORT))
        address = IPv4Address('TCP', HOST, PORT)
        pt_config = self._build_transport_configuration(mode)
        transport_class = self._configure_transport_class(mode, pt_config)
        f_server = net.StaticDestinationServerFactory(addr_tuple, mode,
                                                      transport_class,
                                                      pt_config)
        protocol_server = self._set_protocol(f_server, address)
        f_client = net.StaticDestinationClientFactory(protocol_server.circuit,
                                                      const.CLIENT)
        protocol_client = self._set_protocol(f_client, address)
        if mode == const.CLIENT:
            return protocol_client
        elif mode == const.SERVER:
            return protocol_server
        else:
            raise ValueError("Transport mode '%s' not recognized." % mode)

    def _set_protocol(self, factory, address):
        """Make protocol connection with a Twisted string transport."""
        protocol = factory.buildProtocol(address)
        protocol.makeConnection(proto_helpers.StringTransport())
        protocol.history = {'rcv': [], 'snd': []}
        return protocol

    def _build_transport_configuration(self, mode):
        """Configure transport as a managed transport."""
        pt_config = transport_config.TransportConfig()
        pt_config.setStateLocation(const.TEMP_DIR)
        pt_config.setObfsproxyMode("managed")
        pt_config.setListenerMode(mode)
        return pt_config

    def _configure_transport_class(self, mode, pt_config):
        """Use the global arguments to configure the trasnport."""
        transport_args = [mode, ADDR, "--dest=%s" % ADDR] + self.args
        sys.argv = [
            sys.argv[0], "--log-file",
            join(const.TEMP_DIR, "%s.log" % mode), "--log-min-severity",
            "debug"
        ]
        sys.argv.append("wfpad")  # use wfpad transport
        sys.argv += transport_args
        parser = set_up_cli_parsing()
        consider_cli_args(parser.parse_args())
        transport_class = get_transport_class(self.transport, mode)
        transport_class.setup(pt_config)
        p = ArgumentParser()
        transport_class.register_external_mode_cli(p)
        args = p.parse_args(transport_args)
        transport_class.validate_external_mode_cli(args)
        return transport_class

    def _lose_protocol_connection(self, protocol):
        """Disconnect client and server transports."""
        protocol.circuit.upstream.transport.loseConnection()
        protocol.circuit.downstream.transport.loseConnection()

    def advance_next_delayed_call(self):
        """Advance clock to first delayed call in reactor."""
        first_delayed_call = self.clock.getDelayedCalls()[0]
        self.clock.advance(first_delayed_call.getTime() - self.clock.seconds())

    def is_timeout(self, call):
        """Check if the call has actually timed out."""
        return isinstance(call.args[0], TimeoutError)

    def advance_delayed_calls(self, max_dcalls=NUM_DCALLS, no_timeout=True):
        """Advance clock to the point all delayed calls up to that moment have
        been called.
        """
        i, timeouts = 0, []
        while len(self.clock.getDelayedCalls()) > 0 and i < max_dcalls:
            i += 1
            dcall = self.clock.getDelayedCalls()[0]
            if no_timeout:
                if len(dcall.args) > 0 and self.is_timeout(dcall):
                    if dcall in timeouts:
                        break
                    self._queue_first_call()
                    timeouts.append(dcall)
                    continue
            self.advance_next_delayed_call()

    def _queue_first_call(self, delay=10000.0):
        """Put the first delayed call to the last position."""
        timeout = self.clock.calls.pop(0)
        timeout.time = delay
        self.clock.calls.append(timeout)

    def tearDown(self):
        """Close connections and advance all delayed calls."""
        # Need to wait a bit beacuse obfsproxy network.Circuit.circuitCompleted
        # defers 0.02s a dummy call to dataReceived to flush connection.
        self._lose_protocol_connection(self.proto_client)
        self._lose_protocol_connection(self.proto_server)
        self.advance_delayed_calls()
class SchedTest:
    def tearDown(self):
        return IService(self.siteStore).stopService()


    def setUp(self):
        self.clock = Clock()

        scheduler = IScheduler(self.siteStore)
        self.stubTime(scheduler)
        IService(self.siteStore).startService()


    def now(self):
        return Time.fromPOSIXTimestamp(self.clock.seconds())


    def stubTime(self, scheduler):
        scheduler.callLater = self.clock.callLater
        scheduler.now = self.now


    def test_implementsSchedulerInterface(self):
        """
        Verify that IScheduler is declared as implemented.
        """
        self.failUnless(IScheduler.providedBy(IScheduler(self.store)))


    def test_scheduler(self):
        """
        Test that the ordering and timing of scheduled calls is correct.
        """
        # create 3 timed events.  the first one fires.  the second one fires,
        # then reschedules itself.  the third one should never fire because the
        # reactor is shut down first.  assert that the first and second fire
        # only once, and that the third never fires.
        s = self.store

        t1 = TestEvent(testCase=self,
                       name=u't1',
                       store=s, runAgain=None)
        t2 = TestEvent(testCase=self,
                       name=u't2',
                       store=s, runAgain=2)
        t3 = TestEvent(testCase=self,
                       name=u't3',
                       store=s, runAgain=None)

        now = self.now()
        self.ts = [t1, t2, t3]

        S = IScheduler(s)

        # Schedule them out of order to make sure behavior doesn't
        # depend on tasks arriving in soonest-to-latest order.
        S.schedule(t2, now + timedelta(seconds=3))
        S.schedule(t1, now + timedelta(seconds=1))
        S.schedule(t3, now + timedelta(seconds=100))

        self.clock.pump([2, 2, 2])
        self.assertEqual(t1.runCount, 1)
        self.assertEqual(t2.runCount, 2)
        self.assertEqual(t3.runCount, 0)


    def test_unscheduling(self):
        """
        Test the unscheduleFirst method of the scheduler.
        """
        sch = IScheduler(self.store)
        t1 = TestEvent(testCase=self, name=u't1', store=self.store)
        t2 = TestEvent(testCase=self, name=u't2', store=self.store, runAgain=None)

        sch.schedule(t1, self.now() + timedelta(seconds=1))
        sch.schedule(t2, self.now() + timedelta(seconds=2))
        sch.unscheduleFirst(t1)
        self.clock.advance(3)
        self.assertEquals(t1.runCount, 0)
        self.assertEquals(t2.runCount, 1)


    def test_inspection(self):
        """
        Test that the L{scheduledTimes} method returns an iterable of all the
        times at which a particular item is scheduled to run.
        """
        now = self.now() + timedelta(seconds=1)
        off = timedelta(seconds=3)
        sch = IScheduler(self.store)
        runnable = TestEvent(store=self.store, name=u'Only event')
        sch.schedule(runnable, now)
        sch.schedule(runnable, now + off)
        sch.schedule(runnable, now + off + off)

        self.assertEquals(
            list(sch.scheduledTimes(runnable)),
            [now, now + off, now + off + off])


    def test_scheduledTimesDuringRun(self):
        """
        L{Scheduler.scheduledTimes} should not include scheduled times that have
        already triggered.
        """
        futureTimes = []
        scheduler = IScheduler(self.store)
        runner = HookRunner(
            store=self.store,
            hook=lambda self: futureTimes.append(
                list(scheduler.scheduledTimes(self))))

        then = self.now() + timedelta(seconds=1)
        scheduler.schedule(runner, self.now())
        scheduler.schedule(runner, then)
        self.clock.advance(1)
        self.assertEquals(futureTimes, [[then], []])


    def test_deletedRunnable(self):
        """
        Verify that if a scheduled item is deleted,
        L{TimedEvent.invokeRunnable} just deletes the L{TimedEvent} without
        raising an exception.
        """
        now = self.now()
        scheduler = IScheduler(self.store)
        runnable = TestEvent(store=self.store, name=u'Only event')
        scheduler.schedule(runnable, now)

        runnable.deleteFromStore()

        # Invoke it manually to avoid timing complexity.
        timedEvent = self.store.findUnique(
            TimedEvent, TimedEvent.runnable == runnable)
        timedEvent.invokeRunnable()

        self.assertEqual(
            self.store.findUnique(
                TimedEvent,
                TimedEvent.runnable == runnable,
                default=None),
            None)
Exemple #31
0
class ClusterStateServiceTests(TestCase):
    """
    Tests for ``ClusterStateService``.
    """
    WITH_APPS = NodeState(
        hostname=u"192.0.2.56",
        uuid=uuid4(),
        applications={a.name: a
                      for a in [APP1, APP2]},
    )
    WITH_MANIFESTATION = NodeState(
        hostname=u"host2",
        manifestations={MANIFESTATION.dataset_id: MANIFESTATION},
        devices={},
        paths={},
    )

    def setUp(self):
        super(ClusterStateServiceTests, self).setUp()
        self.clock = Clock()

    def service(self):
        service = ClusterStateService(self.clock)
        service.startService()
        self.addCleanup(service.stopService)
        return service

    def test_applications(self):
        """
        ``ClusterStateService.as_deployment`` copies applications from the
        given node state.
        """
        service = self.service()
        service.apply_changes([self.WITH_APPS])
        self.assertEqual(service.as_deployment(),
                         DeploymentState(nodes=[self.WITH_APPS]))

    def test_other_manifestations(self):
        """
        ``ClusterStateService.as_deployment`` copies over other manifestations
        to the ``Node`` instances it creates.
        """
        service = self.service()
        service.apply_changes([self.WITH_MANIFESTATION])
        self.assertEqual(service.as_deployment(),
                         DeploymentState(nodes={self.WITH_MANIFESTATION}))

    def test_partial_update(self):
        """
        An update that is ignorant about certain parts of a node's state only
        updates the information it knows about.
        """
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", applications={APP1.name: APP1}),
            NodeState(hostname=u"host1",
                      applications=None,
                      manifestations={MANIFESTATION.dataset_id: MANIFESTATION},
                      devices={},
                      paths={})
        ])
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes=[
                NodeState(
                    hostname=u"host1",
                    manifestations={MANIFESTATION.dataset_id: MANIFESTATION},
                    devices={},
                    paths={},
                    applications={APP1.name: APP1})
            ]))

    def test_update(self):
        """
        An update for previously given hostname overrides the previous state
        of that hostname.
        """
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", applications={APP1.name: APP1}),
            NodeState(hostname=u"host1", applications={APP2.name: APP2}),
        ])
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes=[
                NodeState(hostname=u"host1", applications={APP2.name: APP2})
            ]))

    def test_multiple_hosts(self):
        """
        The information from multiple hosts is combined by
        ``ClusterStateService.as_deployment``.
        """
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1", applications={APP1.name: APP1}),
            NodeState(hostname=u"host2", applications={APP2.name: APP2}),
        ])
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes=[
                NodeState(hostname=u"host1", applications={APP1.name: APP1}),
                NodeState(hostname=u"host2", applications={APP2.name: APP2}),
            ]))

    def test_manifestation_path(self):
        """
        ``manifestation_path`` returns the path on the filesystem where the
        given dataset exists.
        """
        identifier = uuid4()
        service = self.service()
        service.apply_changes([
            NodeState(hostname=u"host1",
                      uuid=identifier,
                      manifestations={MANIFESTATION.dataset_id: MANIFESTATION},
                      paths={MANIFESTATION.dataset_id: FilePath(b"/xxx/yyy")},
                      devices={})
        ])
        self.assertEqual(
            service.manifestation_path(identifier, MANIFESTATION.dataset_id),
            FilePath(b"/xxx/yyy"))

    def test_expiration(self):
        """
        Information updates that are more than the hard-coded expiration period
        (in seconds) old are wiped.
        """
        service = self.service()
        service.apply_changes([self.WITH_APPS])
        advance_rest(self.clock)
        before_wipe_state = service.as_deployment()
        advance_some(self.clock)
        after_wipe_state = service.as_deployment()
        self.assertEqual(
            [before_wipe_state, after_wipe_state],
            [DeploymentState(nodes=[self.WITH_APPS]),
             DeploymentState()],
        )

    def test_expiration_from_inactivity(self):
        """
        Information updates from a source with no activity for more than the
        hard-coded expiration period are wiped.
        """
        service = self.service()
        source = ChangeSource()

        # Apply some changes at T1
        source.set_last_activity(self.clock.seconds())
        service.apply_changes_from_source(source, [self.WITH_APPS])

        # A little bit of time passes (T2) and there is some activity.
        advance_some(self.clock)
        source.set_last_activity(self.clock.seconds())

        # Enough more time passes (T3) to reach EXPIRATION_TIME from T1
        advance_rest(self.clock)
        before_wipe_state = service.as_deployment()

        # Enough more time passes (T4) to reach EXPIRATION_TIME from T2
        advance_some(self.clock)
        after_wipe_state = service.as_deployment()

        # The activity at T2 prevents the state from being wiped at T3 but then
        # it is wiped at T4.
        self.assertEqual(
            [before_wipe_state, after_wipe_state],
            [DeploymentState(nodes=[self.WITH_APPS]),
             DeploymentState()],
        )

    def test_updates_different_key(self):
        """
        A wipe created by a ``IClusterStateChange`` with a given wipe key is
        not overwritten by a later ``IClusterStateChange`` with a different
        key.
        """
        service = self.service()
        app_node = self.WITH_APPS
        app_node_2 = NodeState(
            hostname=app_node.hostname,
            uuid=app_node.uuid,
            manifestations={MANIFESTATION.dataset_id: MANIFESTATION},
            devices={},
            paths={})

        # Some changes are applied at T1
        service.apply_changes([app_node])

        # A little time passes (T2) and some unrelated changes are applied.
        advance_some(self.clock)
        service.apply_changes([app_node_2])

        # Enough additional time passes (T3) to reach EXPIRATION_TIME from T1
        advance_rest(self.clock)
        before_wipe_state = service.as_deployment()

        # Enough additional time passes (T4) to reach EXPIRATION_TIME from T2
        advance_some(self.clock)
        after_wipe_state = service.as_deployment()

        # The state applied at T1 is wiped at T3
        # Separately, the state applied at T2 is wiped at T4
        self.assertEqual(
            [before_wipe_state, after_wipe_state],
            [DeploymentState(nodes=[app_node_2]),
             DeploymentState()],
        )

    def test_update_with_same_key(self):
        """
        An update with the same key as a previous one delays wiping.
        """
        service = self.service()
        # Some changes are applied at T1
        service.apply_changes([self.WITH_APPS])

        # Some time passes (T2) and the same changes are re-applied
        advance_some(self.clock)
        service.apply_changes([self.WITH_APPS])

        # Enough time passes (T3) to reach EXPIRATION_TIME from T1 but not T2
        advance_rest(self.clock)

        # The state applied at T1 and refreshed at T2 is not wiped at T3.
        self.assertEqual(
            service.as_deployment(),
            DeploymentState(nodes=[self.WITH_APPS]),
        )
class SubStoreSchedulerReentrancy(TestCase):
    """
    Test re-entrant scheduling calls on an item run by a SubScheduler.
    """
    def setUp(self):
        self.clock = Clock()

        self.dbdir = filepath.FilePath(self.mktemp())
        self.store = Store(self.dbdir)
        self.substoreItem = SubStore.createNew(self.store, ['sub'])
        self.substore = self.substoreItem.open()

        self.scheduler = IScheduler(self.store)
        self.subscheduler = IScheduler(self.substore)

        self.scheduler.callLater = self.clock.callLater
        self.scheduler.now = lambda: Time.fromPOSIXTimestamp(self.clock.seconds())
        self.subscheduler.now = lambda: Time.fromPOSIXTimestamp(self.clock.seconds())

        IService(self.store).startService()


    def tearDown(self):
        return IService(self.store).stopService()


    def _scheduleRunner(self, now, offset):
        scheduledAt = Time.fromPOSIXTimestamp(now + offset)
        rescheduleFor = Time.fromPOSIXTimestamp(now + offset + 10)
        runnable = ScheduleCallingItem(store=self.substore, rescheduleFor=rescheduleFor)
        self.subscheduler.schedule(runnable, scheduledAt)
        return runnable


    def testSchedule(self):
        """
        Test the schedule method, as invoked from the run method of an item
        being run by the subscheduler.
        """
        now = self.clock.seconds()
        runnable = self._scheduleRunner(now, 10)

        self.clock.advance(11)

        self.assertEqual(
            list(self.subscheduler.scheduledTimes(runnable)),
            [Time.fromPOSIXTimestamp(now + 20)])

        hook = self.store.findUnique(
            _SubSchedulerParentHook,
            _SubSchedulerParentHook.subStore == self.substoreItem)

        self.assertEqual(
            list(self.scheduler.scheduledTimes(hook)),
            [Time.fromPOSIXTimestamp(now + 20)])


    def testScheduleWithLaterTimedEvents(self):
        """
        Like L{testSchedule}, but use a SubScheduler which has pre-existing
        TimedEvents which are beyond the new runnable's scheduled time (to
        trigger the reschedule-using code-path in
        _SubSchedulerParentHook._schedule).
        """
        now = self.clock.seconds()
        when = Time.fromPOSIXTimestamp(now + 30)
        null = NullRunnable(store=self.substore)
        self.subscheduler.schedule(null, when)
        runnable = self._scheduleRunner(now, 10)

        self.clock.advance(11)

        self.assertEqual(
            list(self.subscheduler.scheduledTimes(runnable)),
            [Time.fromPOSIXTimestamp(now + 20)])

        self.assertEqual(
            list(self.subscheduler.scheduledTimes(null)),
            [Time.fromPOSIXTimestamp(now + 30)])

        hook = self.store.findUnique(
            _SubSchedulerParentHook,
            _SubSchedulerParentHook.subStore == self.substoreItem)

        self.assertEqual(
            list(self.scheduler.scheduledTimes(hook)),
            [Time.fromPOSIXTimestamp(20)])


    def testScheduleWithEarlierTimedEvents(self):
        """
        Like L{testSchedule}, but use a SubScheduler which has pre-existing
        TimedEvents which are before the new runnable's scheduled time.
        """
        now = self.clock.seconds()
        when = Time.fromPOSIXTimestamp(now + 15)
        null = NullRunnable(store=self.substore)
        self.subscheduler.schedule(null, when)
        runnable = self._scheduleRunner(now, 10)

        self.clock.advance(11)

        self.assertEqual(
            list(self.subscheduler.scheduledTimes(runnable)),
            [Time.fromPOSIXTimestamp(now + 20)])

        self.assertEqual(
            list(self.subscheduler.scheduledTimes(null)),
            [Time.fromPOSIXTimestamp(now + 15)])

        hook = self.store.findUnique(
            _SubSchedulerParentHook,
            _SubSchedulerParentHook.subStore == self.substoreItem)

        self.assertEqual(
            list(self.scheduler.scheduledTimes(hook)),
            [Time.fromPOSIXTimestamp(now + 15)])


    def testMultipleEventsPerTick(self):
        """
        Test running several runnables in a single tick of the subscheduler.
        """
        now = self.clock.seconds()
        runnables = [
            self._scheduleRunner(now, 10),
            self._scheduleRunner(now, 11),
            self._scheduleRunner(now, 12)]

        self.clock.advance(13)

        for n, runnable in enumerate(runnables):
            self.assertEqual(
                list(self.subscheduler.scheduledTimes(runnable)),
                [Time.fromPOSIXTimestamp(now + n + 20)])

        hook = self.store.findUnique(
            _SubSchedulerParentHook,
            _SubSchedulerParentHook.subStore == self.substoreItem)

        self.assertEqual(
            list(self.scheduler.scheduledTimes(hook)),
            [Time.fromPOSIXTimestamp(now + 20)])
Exemple #33
0
class MemCacheTestCase(TestCase):
    """
    Test client protocol class L{MemCacheProtocol}.
    """
    def setUp(self):
        """
        Create a memcache client, connect it to a string protocol, and make it
        use a deterministic clock.
        """
        self.proto = MemCacheProtocol()
        self.clock = Clock()
        self.proto.callLater = self.clock.callLater
        self.transport = StringTransportWithDisconnection()
        self.transport.protocol = self.proto
        self.proto.makeConnection(self.transport)

    def _test(self, d, send, recv, result):
        """
        Shortcut method for classic tests.

        @param d: the resulting deferred from the memcache command.
        @type d: C{Deferred}

        @param send: the expected data to be sent.
        @type send: C{str}

        @param recv: the data to simulate as reception.
        @type recv: C{str}

        @param result: the expected result.
        @type result: C{any}
        """
        def cb(res):
            self.assertEquals(res, result)

        self.assertEquals(self.transport.value(), send)
        d.addCallback(cb)
        self.proto.dataReceived(recv)
        return d

    def test_get(self):
        """
        L{MemCacheProtocol.get} should return a L{Deferred} which is
        called back with the value and the flag associated with the given key
        if the server returns a successful result.
        """
        return self._test(self.proto.get("foo"), "get foo\r\n",
                          "VALUE foo 0 3\r\nbar\r\nEND\r\n", (0, "bar"))

    def test_emptyGet(self):
        """
        Test getting a non-available key: it should succeed but return C{None}
        as value and C{0} as flag.
        """
        return self._test(self.proto.get("foo"), "get foo\r\n", "END\r\n",
                          (0, None))

    def test_set(self):
        """
        L{MemCacheProtocol.set} should return a L{Deferred} which is
        called back with C{True} when the operation succeeds.
        """
        return self._test(self.proto.set("foo", "bar"),
                          "set foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)

    def test_add(self):
        """
        L{MemCacheProtocol.add} should return a L{Deferred} which is
        called back with C{True} when the operation succeeds.
        """
        return self._test(self.proto.add("foo", "bar"),
                          "add foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)

    def test_replace(self):
        """
        L{MemCacheProtocol.replace} should return a L{Deferred} which
        is called back with C{True} when the operation succeeds.
        """
        return self._test(self.proto.replace("foo", "bar"),
                          "replace foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)

    def test_errorAdd(self):
        """
        Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
        key already exists on the server, it returns a B{NOT STORED} answer,
        which should callback the resulting L{Deferred} with C{False}.
        """
        return self._test(self.proto.add("foo", "bar"),
                          "add foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)

    def test_errorReplace(self):
        """
        Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
        but the key doesn't exist on the server, it returns a B{NOT STORED}
        answer, which should callback the resulting L{Deferred} with C{False}.
        """
        return self._test(self.proto.replace("foo", "bar"),
                          "replace foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n",
                          False)

    def test_delete(self):
        """
        L{MemCacheProtocol.delete} should return a L{Deferred} which is
        called back with C{True} when the server notifies a success.
        """
        return self._test(self.proto.delete("bar"), "delete bar\r\n",
                          "DELETED\r\n", True)

    def test_errorDelete(self):
        """
        Test a error during a delete: if key doesn't exist on the server, it
        returns a B{NOT FOUND} answer which should callback the resulting
        L{Deferred} with C{False}.
        """
        return self._test(self.proto.delete("bar"), "delete bar\r\n",
                          "NOT FOUND\r\n", False)

    def test_increment(self):
        """
        Test incrementing a variable: L{MemCacheProtocol.increment} should
        return a L{Deferred} which is called back with the incremented value of
        the given key.
        """
        return self._test(self.proto.increment("foo"), "incr foo 1\r\n",
                          "4\r\n", 4)

    def test_decrement(self):
        """
        Test decrementing a variable: L{MemCacheProtocol.decrement} should
        return a L{Deferred} which is called back with the decremented value of
        the given key.
        """
        return self._test(self.proto.decrement("foo"), "decr foo 1\r\n",
                          "5\r\n", 5)

    def test_incrementVal(self):
        """
        L{MemCacheProtocol.increment} takes an optional argument C{value} which
        should replace the default value of 1 when specified.
        """
        return self._test(self.proto.increment("foo", 8), "incr foo 8\r\n",
                          "4\r\n", 4)

    def test_decrementVal(self):
        """
        L{MemCacheProtocol.decrement} takes an optional argument C{value} which
        should replace the default value of 1 when specified.
        """
        return self._test(self.proto.decrement("foo", 3), "decr foo 3\r\n",
                          "5\r\n", 5)

    def test_stats(self):
        """
        Test retrieving server statistics via the L{MemCacheProtocol.stats}
        command: it should parse the data sent by the server and call back the
        resulting L{Deferred} with a dictionary of the received statistics.
        """
        return self._test(self.proto.stats(), "stats\r\n",
                          "STAT foo bar\r\nSTAT egg spam\r\nEND\r\n", {
                              "foo": "bar",
                              "egg": "spam"
                          })

    def test_version(self):
        """
        Test version retrieval via the L{MemCacheProtocol.version} command: it
        should return a L{Deferred} which is called back with the version sent
        by the server.
        """
        return self._test(self.proto.version(), "version\r\n",
                          "VERSION 1.1\r\n", "1.1")

    def test_flushAll(self):
        """
        L{MemCacheProtocol.flushAll} should return a L{Deferred} which is
        called back with C{True} if the server acknowledges success.
        """
        return self._test(self.proto.flushAll(), "flush_all\r\n", "OK\r\n",
                          True)

    def test_invalidGetResponse(self):
        """
        If the value returned doesn't match the expected key of the current, we
        should get an error in L{MemCacheProtocol.dataReceived}.
        """
        self.proto.get("foo")
        s = "spamegg"
        self.assertRaises(RuntimeError, self.proto.dataReceived,
                          "VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))

    def test_timeOut(self):
        """
        Test the timeout on outgoing requests: when timeout is detected, all
        current commands should fail with a L{TimeoutError}, and the
        connection should be closed.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)

        def checkMessage(error):
            self.assertEquals(str(error), "Connection timeout")

        d1.addCallback(checkMessage)
        return gatherResults([d1, d2, d3])

    def test_timeoutRemoved(self):
        """
        When a request gets a response, no pending timeout call should remain
        around.
        """
        d = self.proto.get("foo")

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEquals(result, (0, "bar"))
            self.assertEquals(len(self.clock.calls), 0)

        d.addCallback(check)
        return d

    def test_timeOutRaw(self):
        """
        Test the timeout when raw mode was started: the timeout should not be
        reset until all the data has been received, so we can have a
        L{TimeoutError} when waiting for raw data.
        """
        d1 = self.proto.get("foo")
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived("VALUE foo 0 10\r\n12345")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        return gatherResults([d1, d2])

    def test_timeOutStat(self):
        """
        Test the timeout when stat command has started: the timeout should not
        be reset until the final B{END} is received.
        """
        d1 = self.proto.stats()
        d2 = Deferred()
        self.proto.connectionLost = d2.callback

        self.proto.dataReceived("STAT foo bar\r\n")
        self.clock.advance(self.proto.persistentTimeOut)
        self.assertFailure(d1, TimeoutError)
        return gatherResults([d1, d2])

    def test_timeoutPipelining(self):
        """
        When two requests are sent, a timeout call should remain around for the
        second request, and its timeout time should be correct.
        """
        d1 = self.proto.get("foo")
        d2 = self.proto.get("bar")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")

        def check(result):
            self.assertEquals(result, (0, "bar"))
            self.assertEquals(len(self.clock.calls), 1)
            for i in range(self.proto.persistentTimeOut):
                self.clock.advance(1)
            return self.assertFailure(d2, TimeoutError).addCallback(checkTime)

        def checkTime(ignored):
            # Check that the timeout happened C{self.proto.persistentTimeOut}
            # after the last response
            self.assertEquals(self.clock.seconds(),
                              2 * self.proto.persistentTimeOut - 1)

        d1.addCallback(check)
        return d1

    def test_timeoutNotReset(self):
        """
        Check that timeout is not resetted for every command, but keep the
        timeout from the first command without response.
        """
        d1 = self.proto.get("foo")
        d3 = Deferred()
        self.proto.connectionLost = d3.callback

        self.clock.advance(self.proto.persistentTimeOut - 1)
        d2 = self.proto.get("bar")
        self.clock.advance(1)
        self.assertFailure(d1, TimeoutError)
        self.assertFailure(d2, TimeoutError)
        return gatherResults([d1, d2, d3])

    def test_tooLongKey(self):
        """
        Test that an error is raised when trying to use a too long key: the
        called command should return a L{Deferred} which fail with a
        L{ClientError}.
        """
        d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError)
        d3 = self.assertFailure(self.proto.get("a" * 500), ClientError)
        d4 = self.assertFailure(self.proto.append("a" * 500, "bar"),
                                ClientError)
        d5 = self.assertFailure(self.proto.prepend("a" * 500, "bar"),
                                ClientError)
        return gatherResults([d1, d2, d3, d4, d5])

    def test_invalidCommand(self):
        """
        When an unknown command is sent directly (not through public API), the
        server answers with an B{ERROR} token, and the command should fail with
        L{NoSuchCommand}.
        """
        d = self.proto._set("egg", "foo", "bar", 0, 0, "")
        self.assertEquals(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n")
        self.assertFailure(d, NoSuchCommand)
        self.proto.dataReceived("ERROR\r\n")
        return d

    def test_clientError(self):
        """
        Test the L{ClientError} error: when the server send a B{CLIENT_ERROR}
        token, the originating command should fail with L{ClientError}, and the
        error should contain the text sent by the server.
        """
        a = "eggspamm"
        d = self.proto.set("foo", a)
        self.assertEquals(self.transport.value(),
                          "set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ClientError)

        def check(err):
            self.assertEquals(str(err), "We don't like egg and spam")

        d.addCallback(check)
        self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n")
        return d

    def test_serverError(self):
        """
        Test the L{ServerError} error: when the server send a B{SERVER_ERROR}
        token, the originating command should fail with L{ServerError}, and the
        error should contain the text sent by the server.
        """
        a = "eggspamm"
        d = self.proto.set("foo", a)
        self.assertEquals(self.transport.value(),
                          "set foo 0 0 8\r\neggspamm\r\n")
        self.assertFailure(d, ServerError)

        def check(err):
            self.assertEquals(str(err), "zomg")

        d.addCallback(check)
        self.proto.dataReceived("SERVER_ERROR zomg\r\n")
        return d

    def test_unicodeKey(self):
        """
        Using a non-string key as argument to commands should raise an error.
        """
        d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError)
        d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError)
        d3 = self.assertFailure(self.proto.get(1), ClientError)
        d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError)
        d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError)
        d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError)
        return gatherResults([d1, d2, d3, d4, d5, d6])

    def test_unicodeValue(self):
        """
        Using a non-string value should raise an error.
        """
        return self.assertFailure(self.proto.set("foo", u"bar"), ClientError)

    def test_pipelining(self):
        """
        Test that multiple requests can be sent subsequently to the server, and
        that the protocol order the responses correctly and dispatch to the
        corresponding client command.
        """
        d1 = self.proto.get("foo")
        d1.addCallback(self.assertEquals, (0, "bar"))
        d2 = self.proto.set("bar", "spamspamspam")
        d2.addCallback(self.assertEquals, True)
        d3 = self.proto.get("egg")
        d3.addCallback(self.assertEquals, (0, "spam"))
        self.assertEquals(
            self.transport.value(),
            "get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n")
        self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n"
                                "STORED\r\n"
                                "VALUE egg 0 4\r\nspam\r\nEND\r\n")
        return gatherResults([d1, d2, d3])

    def test_getInChunks(self):
        """
        If the value retrieved by a C{get} arrive in chunks, the protocol
        should be able to reconstruct it and to produce the good value.
        """
        d = self.proto.get("foo")
        d.addCallback(self.assertEquals, (0, "0123456789"))
        self.assertEquals(self.transport.value(), "get foo\r\n")
        self.proto.dataReceived("VALUE foo 0 10\r\n0123456")
        self.proto.dataReceived("789")
        self.proto.dataReceived("\r\nEND")
        self.proto.dataReceived("\r\n")
        return d

    def test_append(self):
        """
        L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
        method: it should return a L{Deferred} which is called back with
        C{True} when the operation succeeds.
        """
        return self._test(self.proto.append("foo", "bar"),
                          "append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)

    def test_prepend(self):
        """
        L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
        method: it should return a L{Deferred} which is called back with
        C{True} when the operation succeeds.
        """
        return self._test(self.proto.prepend("foo", "bar"),
                          "prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)

    def test_gets(self):
        """
        L{MemCacheProtocol.get} should handle an additional cas result when
        C{withIdentifier} is C{True} and forward it in the resulting
        L{Deferred}.
        """
        return self._test(self.proto.get("foo", True), "gets foo\r\n",
                          "VALUE foo 0 3 1234\r\nbar\r\nEND\r\n",
                          (0, "1234", "bar"))

    def test_emptyGets(self):
        """
        Test getting a non-available key with gets: it should succeed but
        return C{None} as value, C{0} as flag and an empty cas value.
        """
        return self._test(self.proto.get("foo", True), "gets foo\r\n",
                          "END\r\n", (0, "", None))

    def test_checkAndSet(self):
        """
        L{MemCacheProtocol.checkAndSet} passes an additional cas identifier that the
        server should handle to check if the data has to be updated.
        """
        return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
                          "cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True)

    def test_casUnknowKey(self):
        """
        When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the resulting
        L{Deferred} should fire with C{False}.
        """
        return self._test(self.proto.checkAndSet("foo", "bar", cas="1234"),
                          "cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False)
Exemple #34
0
class LeasesTests(TestCase):
    """
    Tests for ``LeaseService`` and ``update_leases``.
    """
    def setUp(self):
        self.clock = Clock()
        self.persistence_service = ConfigurationPersistenceService(
            self.clock, FilePath(self.mktemp()))
        self.persistence_service.startService()
        self.addCleanup(self.persistence_service.stopService)

    def test_update_leases_saves_changed_leases(self):
        """
        ``update_leases`` only changes the leases stored in the configuration.
        """
        node_id = uuid4()
        dataset_id = uuid4()

        original_leases = Leases().acquire(datetime.fromtimestamp(0, UTC),
                                           uuid4(), node_id)

        def update(leases):
            return leases.acquire(datetime.fromtimestamp(1000, UTC),
                                  dataset_id, node_id)

        d = self.persistence_service.save(
            TEST_DEPLOYMENT.set(leases=original_leases))
        d.addCallback(
            lambda _: update_leases(update, self.persistence_service))

        def updated(_):
            self.assertEqual(
                self.persistence_service.get(),
                TEST_DEPLOYMENT.set(leases=update(original_leases)))

        d.addCallback(updated)
        return d

    def test_update_leases_result(self):
        """
        ``update_leases`` returns a ``Deferred`` firing with the updated
        ``Leases`` instance.
        """
        node_id = uuid4()
        dataset_id = uuid4()
        original_leases = Leases()

        def update(leases):
            return leases.acquire(datetime.fromtimestamp(1000, UTC),
                                  dataset_id, node_id)

        d = update_leases(update, self.persistence_service)

        def updated(updated_leases):
            self.assertEqual(updated_leases, update(original_leases))

        d.addCallback(updated)
        return d

    def test_expired_lease_removed(self):
        """
        A lease that has expired is removed from the persisted
        configuration.
        """
        timestep = 100
        node_id = uuid4()
        ids = uuid4(), uuid4()
        # First dataset lease expires at timestep:
        now = self.clock.seconds()
        leases = Leases().acquire(datetime.fromtimestamp(now, UTC), ids[0],
                                  node_id, timestep)
        # Second dataset lease expires at timestep * 2:
        leases = leases.acquire(datetime.fromtimestamp(now, UTC), ids[1],
                                node_id, timestep * 2)
        new_config = Deployment(leases=leases)
        d = self.persistence_service.save(new_config)

        def saved(_):
            self.clock.advance(timestep - 1)  # 99
            before_first_expire = self.persistence_service.get().leases
            self.clock.advance(2)  # 101
            after_first_expire = self.persistence_service.get().leases
            self.clock.advance(timestep - 2)  # 199
            before_second_expire = self.persistence_service.get().leases
            self.clock.advance(2)  # 201
            after_second_expire = self.persistence_service.get().leases

            self.assertTupleEqual(
                (before_first_expire, after_first_expire, before_second_expire,
                 after_second_expire),
                (leases, leases.remove(ids[0]), leases.remove(
                    ids[0]), leases.remove(ids[0]).remove(ids[1])))

        d.addCallback(saved)
        return d

    @capture_logging(None)
    def test_expire_lease_logging(self, logger):
        """
        An expired lease is logged.
        """
        node_id = uuid4()
        dataset_id = uuid4()
        leases = Leases().acquire(
            datetime.fromtimestamp(self.clock.seconds(), UTC), dataset_id,
            node_id, 1)

        d = self.persistence_service.save(Deployment(leases=leases))

        def saved(_):
            logger.reset()
            self.clock.advance(1000)
            assertHasMessage(self, logger, _LOG_EXPIRE, {
                u"dataset_id": dataset_id,
                u"node_id": node_id
            })

        d.addCallback(saved)
        return d
Exemple #35
0
class EINVALTestCase(TestCase):
    """
    Sometimes, L{os.listdir} will raise C{EINVAL}.  This is a transient error,
    and L{CachingFilePath.listdir} should work around it by retrying the
    C{listdir} operation until it succeeds.
    """

    def setUp(self):
        """
        Create a L{CachingFilePath} for the test to use.
        """
        self.cfp = CachingFilePath(self.mktemp())
        self.clock = Clock()
        self.cfp._sleep = self.clock.advance


    def test_testValidity(self):
        """
        If C{listdir} is replaced on a L{CachingFilePath}, we should be able to
        observe exceptions raised by the replacement.  This verifies that the
        test patching done here is actually testing something.
        """
        class CustomException(Exception): "Just for testing."
        def blowUp(dirname):
            raise CustomException()
        self.cfp._listdir = blowUp
        self.assertRaises(CustomException, self.cfp.listdir)
        self.assertRaises(CustomException, self.cfp.children)


    def test_retryLoop(self):
        """
        L{CachingFilePath} should catch C{EINVAL} and respond by retrying the
        C{listdir} operation until it succeeds.
        """
        calls = []
        def raiseEINVAL(dirname):
            calls.append(dirname)
            if len(calls) < 5:
                raise OSError(EINVAL, "This should be caught by the test.")
            return ['a', 'b', 'c']
        self.cfp._listdir = raiseEINVAL
        self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c'])
        self.assertEquals(self.cfp.children(), [
                CachingFilePath(pathjoin(self.cfp.path, 'a')),
                CachingFilePath(pathjoin(self.cfp.path, 'b')),
                CachingFilePath(pathjoin(self.cfp.path, 'c')),])


    def requireTimePassed(self, filenames):
        """
        Create a replacement for listdir() which only fires after a certain
        amount of time.
        """
        self.calls = []
        def thunk(dirname):
            now = self.clock.seconds()
            if now < 20.0:
                self.calls.append(now)
                raise OSError(EINVAL, "Not enough time has passed yet.")
            else:
                return filenames
        self.cfp._listdir = thunk


    def assertRequiredTimePassed(self):
        """
        Assert that calls to the simulated time.sleep() installed by
        C{requireTimePassed} have been invoked the required number of times.
        """
        # Waiting should be growing by *2 each time until the additional wait
        # exceeds BACKOFF_MAX (5), at which point we should wait for 5s each
        # time.
        def cumulative(values):
            current = 0.0
            for value in values:
                current += value
                yield current

        self.assertEquals(self.calls,
                          list(cumulative(
                    [0.0, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 5.0, 5.0])))


    def test_backoff(self):
        """
        L{CachingFilePath} will wait for an increasing interval up to
        C{BACKOFF_MAX} between calls to listdir().
        """
        self.requireTimePassed(['a', 'b', 'c'])
        self.assertEquals(self.cfp.listdir(), ['a', 'b', 'c'])


    def test_siblingExtensionSearch(self):
        """
        L{FilePath.siblingExtensionSearch} is unfortunately not implemented in
        terms of L{FilePath.listdir}, so we need to verify that it will also
        retry.
        """
        filenames = [self.cfp.basename()+'.a',
                     self.cfp.basename() + '.b',
                     self.cfp.basename() + '.c']
        siblings = map(self.cfp.sibling, filenames)
        for sibling in siblings:
            sibling.touch()
        self.requireTimePassed(filenames)
        self.assertEquals(self.cfp.siblingExtensionSearch("*"),
                          siblings[0])
        self.assertRequiredTimePassed()
Exemple #36
0
class APITestsMixin(APIAssertionsMixin):
    """
    Helpers for writing tests for the Docker Volume Plugin API.
    """

    NODE_A = uuid4()
    NODE_B = uuid4()

    def initialize(self):
        """
        Create initial objects for the ``VolumePlugin``.
        """
        self.volume_plugin_reactor = Clock()
        self.flocker_client = SimpleCountingProxy(FakeFlockerClient())
        # The conditional_create operation used by the plugin relies on
        # the passage of time... so make sure time passes! We still use a
        # fake clock since some tests want to skip ahead.
        self.looping = LoopingCall(lambda: self.volume_plugin_reactor.advance(0.001))
        self.looping.start(0.001)
        self.addCleanup(self.looping.stop)

    def test_pluginactivate(self):
        """
        ``/Plugins.Activate`` indicates the plugin is a volume driver.
        """
        # Docker 1.8, at least, sends "null" as the body. Our test
        # infrastructure has the opposite bug so just going to send some
        # other garbage as the body (12345) to demonstrate that it's
        # ignored as per the spec which declares no body.
        return self.assertResult(b"POST", b"/Plugin.Activate", 12345, OK, {u"Implements": [u"VolumeDriver"]})

    def test_remove(self):
        """
        ``/VolumeDriver.Remove`` returns a successful result.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Remove", {u"Name": u"vol"}, OK, {u"Err": u""})

    def test_unmount(self):
        """
        ``/VolumeDriver.Unmount`` returns a successful result.
        """
        unmount_id = "".join(random.choice("0123456789abcdef") for n in xrange(64))
        return self.assertResult(
            b"POST", b"/VolumeDriver.Unmount", {u"Name": u"vol", u"ID": unicode(unmount_id)}, OK, {u"Err": u""}
        )

    def test_unmount_no_id(self):
        """
        ``/VolumeDriver.Unmount`` returns a successful result.

        No ID for backward compatability with Docker < 1.12
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Unmount", {u"Name": u"vol"}, OK, {u"Err": u""})

    def test_create_with_profile(self):
        """
        Calling the ``/VolumerDriver.Create`` API with an ``Opts`` value
        of "profile=[gold,silver,bronze] in the request body JSON create a
        volume with a given name with [gold,silver,bronze] profile.
        """
        profile = sampled_from(["gold", "silver", "bronze"]).example()
        name = random_name(self)
        d = self.assertResult(
            b"POST", b"/VolumeDriver.Create", {u"Name": name, "Opts": {u"profile": profile}}, OK, {u"Err": u""}
        )
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(list)
        d.addCallback(
            lambda result: self.assertItemsEqual(
                result,
                [
                    Dataset(
                        dataset_id=result[0].dataset_id,
                        primary=self.NODE_A,
                        maximum_size=int(DEFAULT_SIZE.to_Byte()),
                        metadata={NAME_FIELD: name, u"clusterhq:flocker:profile": unicode(profile)},
                    )
                ],
            )
        )
        return d

    def test_create_with_size(self):
        """
        Calling the ``/VolumerDriver.Create`` API with an ``Opts`` value
        of "size=<somesize> in the request body JSON create a volume
        with a given name and random size between 1-100G
        """
        name = random_name(self)
        size = integers(min_value=1, max_value=75).example()
        expression = volume_expression.example()
        size_opt = "".join(str(size)) + expression
        d = self.assertResult(
            b"POST", b"/VolumeDriver.Create", {u"Name": name, "Opts": {u"size": size_opt}}, OK, {u"Err": u""}
        )

        real_size = int(parse_num(size_opt).to_Byte())
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(list)
        d.addCallback(
            lambda result: self.assertItemsEqual(
                result,
                [
                    Dataset(
                        dataset_id=result[0].dataset_id,
                        primary=self.NODE_A,
                        maximum_size=real_size,
                        metadata={NAME_FIELD: name, u"maximum_size": unicode(real_size)},
                    )
                ],
            )
        )
        return d

    @given(expr=volume_expression, size=integers(min_value=75, max_value=100))
    def test_parsenum_size(self, expr, size):
        """
        Send different forms of size expressions
        to ``parse_num``, we expect G(Gigabyte) size results.

        :param expr str: A string representing the size expression
        :param size int: A string representing the volume size
        """
        expected_size = int(GiB(size).to_Byte())
        return self.assertEqual(expected_size, int(parse_num(str(size) + expr).to_Byte()))

    @given(expr=sampled_from(["KB", "MB", "GB", "TB", ""]), size=integers(min_value=1, max_value=100))
    def test_parsenum_all_sizes(self, expr, size):
        """
        Send standard size expressions to ``parse_num`` in
        many sizes, we expect to get correct size results.

        :param expr str: A string representing the size expression
        :param size int: A string representing the volume size
        """
        if expr is "KB":
            expected_size = int(KiB(size).to_Byte())
        elif expr is "MB":
            expected_size = int(MiB(size).to_Byte())
        elif expr is "GB":
            expected_size = int(GiB(size).to_Byte())
        elif expr is "TB":
            expected_size = int(TiB(size).to_Byte())
        else:
            expected_size = int(Byte(size).to_Byte())
        return self.assertEqual(expected_size, int(parse_num(str(size) + expr).to_Byte()))

    @given(size=sampled_from([u"foo10Gb", u"10bar10", "10foogib", "10Gfoo", "GIB", "bar10foo"]))
    def test_parsenum_bad_size(self, size):
        """
        Send unacceptable size expressions, upon error
        users should expect to receive Flocker's ``DEFAULT_SIZE``

        :param size str: A string representing the bad volume size
        """
        return self.assertEqual(int(DEFAULT_SIZE.to_Byte()), int(parse_num(size).to_Byte()))

    def create(self, name):
        """
        Call the ``/VolumeDriver.Create`` API to create a volume with the
        given name.

        :param unicode name: The name of the volume to create.

        :return: ``Deferred`` that fires when the volume that was created.
        """
        return self.assertResult(b"POST", b"/VolumeDriver.Create", {u"Name": name}, OK, {u"Err": u""})

    def test_create_creates(self):
        """
        ``/VolumeDriver.Create`` creates a new dataset in the configuration.
        """
        name = u"myvol"
        d = self.create(name)
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(list)
        d.addCallback(
            lambda result: self.assertItemsEqual(
                result,
                [
                    Dataset(
                        dataset_id=result[0].dataset_id,
                        primary=self.NODE_A,
                        maximum_size=int(DEFAULT_SIZE.to_Byte()),
                        metadata={NAME_FIELD: name},
                    )
                ],
            )
        )
        return d

    def test_create_duplicate_name(self):
        """
        If a dataset with the given name already exists,
        ``/VolumeDriver.Create`` succeeds without create a new volume.
        """
        name = u"thename"
        # Create a dataset out-of-band with matching name but non-matching
        # dataset ID:
        d = self.flocker_client.create_dataset(self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name})
        d.addCallback(lambda _: self.create(name))
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(lambda results: self.assertEqual(len(list(results)), 1))
        return d

    def test_create_duplicate_name_race_condition(self):
        """
        If a dataset with the given name is created while the
        ``/VolumeDriver.Create`` call is in flight, the call does not
        result in an error.
        """
        name = u"thename"

        # Create a dataset out-of-band with matching dataset ID and name
        # which the docker plugin won't be able to see.
        def create_after_list():
            # Clean up the patched version:
            del self.flocker_client.list_datasets_configuration
            # But first time we're called, we create dataset and lie about
            # its existence:
            d = self.flocker_client.create_dataset(
                self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}
            )
            d.addCallback(lambda _: DatasetsConfiguration(tag=u"1234", datasets={}))
            return d

        self.flocker_client.list_datasets_configuration = create_after_list

        return self.create(name)

    def _flush_volume_plugin_reactor_on_endpoint_render(self):
        """
        This method patches ``self.app`` so that after any endpoint is
        rendered, the reactor used by the volume plugin is advanced repeatedly
        until there are no more ``delayedCalls`` pending on the reactor.
        """
        real_execute_endpoint = self.app.execute_endpoint

        def patched_execute_endpoint(*args, **kwargs):
            val = real_execute_endpoint(*args, **kwargs)
            while self.volume_plugin_reactor.getDelayedCalls():
                pending_calls = self.volume_plugin_reactor.getDelayedCalls()
                next_expiration = min(t.getTime() for t in pending_calls)
                now = self.volume_plugin_reactor.seconds()
                self.volume_plugin_reactor.advance(max(0.0, next_expiration - now))
            return val

        self.patch(self.app, "execute_endpoint", patched_execute_endpoint)

    def test_mount(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive.
        """
        name = u"myvol"
        dataset_id = uuid4()
        mount_id = "".join(random.choice("0123456789abcdef") for n in xrange(64))

        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(
            self.NODE_B, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}, dataset_id=dataset_id
        )

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 5 seconds for the dataset to get established on
        # Node A.
        self.volume_plugin_reactor.callLater(5.0, self.flocker_client.synchronize_state)

        d.addCallback(
            lambda _: self.assertResult(
                b"POST",
                b"/VolumeDriver.Mount",
                {u"Name": name, u"ID": unicode(mount_id)},
                OK,
                {u"Err": u"", u"Mountpoint": u"/flocker/{}".format(dataset_id)},
            )
        )
        d.addCallback(lambda _: self.flocker_client.list_datasets_state())

        def final_assertions(datasets):
            self.assertEqual([self.NODE_A], [d.primary for d in datasets if d.dataset_id == dataset_id])
            # There should be less than 20 calls to list_datasets_state over
            # the course of 5 seconds.
            self.assertLess(self.flocker_client.num_calls("list_datasets_state"), 20)

        d.addCallback(final_assertions)

        return d

    def test_mount_no_id(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive.

        No ID for backward compatability with Docker < 1.12
        """
        name = u"myvol"
        dataset_id = uuid4()

        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(
            self.NODE_B, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}, dataset_id=dataset_id
        )

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 5 seconds for the dataset to get established on
        # Node A.
        self.volume_plugin_reactor.callLater(5.0, self.flocker_client.synchronize_state)

        d.addCallback(
            lambda _: self.assertResult(
                b"POST",
                b"/VolumeDriver.Mount",
                {u"Name": name},
                OK,
                {u"Err": u"", u"Mountpoint": u"/flocker/{}".format(dataset_id)},
            )
        )
        d.addCallback(lambda _: self.flocker_client.list_datasets_state())

        def final_assertions(datasets):
            self.assertEqual([self.NODE_A], [d.primary for d in datasets if d.dataset_id == dataset_id])
            # There should be less than 20 calls to list_datasets_state over
            # the course of 5 seconds.
            self.assertLess(self.flocker_client.num_calls("list_datasets_state"), 20)

        d.addCallback(final_assertions)

        return d

    def test_mount_timeout(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive. If it does not arrive within 120 seconds, then it
        returns an error up to docker.
        """
        name = u"myvol"
        dataset_id = uuid4()
        mount_id = "".join(random.choice("0123456789abcdef") for n in xrange(64))
        # Create dataset on a different node:
        d = self.flocker_client.create_dataset(
            self.NODE_B, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}, dataset_id=dataset_id
        )

        self._flush_volume_plugin_reactor_on_endpoint_render()

        # Pretend that it takes 500 seconds for the dataset to get established
        # on Node A. This should be longer than the timeout.
        self.volume_plugin_reactor.callLater(500.0, self.flocker_client.synchronize_state)

        d.addCallback(
            lambda _: self.assertResult(
                b"POST",
                b"/VolumeDriver.Mount",
                {u"Name": name, u"ID": unicode(mount_id)},
                OK,
                {u"Err": u"Timed out waiting for dataset to mount.", u"Mountpoint": u""},
            )
        )
        return d

    def test_mount_already_exists(self):
        """
        ``/VolumeDriver.Mount`` sets the primary of the dataset with matching
        name to the current node and then waits for the dataset to
        actually arrive when used by the volumes that already exist and
        don't have a special dataset ID.
        """
        name = u"myvol"
        mount_id = "".join(random.choice("0123456789abcdef") for n in xrange(64))

        d = self.flocker_client.create_dataset(self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            result = self.assertResult(
                b"POST",
                b"/VolumeDriver.Mount",
                {u"Name": name, u"ID": unicode(mount_id)},
                OK,
                {u"Err": u"", u"Mountpoint": u"/flocker/{}".format(dataset.dataset_id)},
            )
            result.addCallback(lambda _: self.flocker_client.list_datasets_state())
            result.addCallback(
                lambda ds: self.assertEqual(
                    [self.NODE_A], [d.primary for d in ds if d.dataset_id == dataset.dataset_id]
                )
            )
            return result

        d.addCallback(created)
        return d

    def test_unknown_mount(self):
        """
        ``/VolumeDriver.Mount`` returns an error when asked to mount a
        non-existent volume.
        """
        name = u"myvol"
        mount_id = "".join(random.choice("0123456789abcdef") for n in xrange(64))
        return self.assertResult(
            b"POST",
            b"/VolumeDriver.Mount",
            {u"Name": name, u"ID": unicode(mount_id)},
            OK,
            {u"Err": u"Could not find volume with given name."},
        )

    def test_path(self):
        """
        ``/VolumeDriver.Path`` returns the mount path of the given volume if
        it is currently known.
        """
        name = u"myvol"

        d = self.create(name)
        # The dataset arrives as state:
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        d.addCallback(lambda _: self.assertResponseCode(b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK))
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(
            lambda datasets_config: self.assertResult(
                b"POST",
                b"/VolumeDriver.Path",
                {u"Name": name},
                OK,
                {u"Err": u"", u"Mountpoint": u"/flocker/{}".format(datasets_config.datasets.keys()[0])},
            )
        )
        return d

    def test_path_existing(self):
        """
        ``/VolumeDriver.Path`` returns the mount path of the given volume if
        it is currently known, including for a dataset that was created
        not by the plugin.
        """
        name = u"myvol"

        d = self.flocker_client.create_dataset(self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            return self.assertResult(
                b"POST",
                b"/VolumeDriver.Path",
                {u"Name": name},
                OK,
                {u"Err": u"", u"Mountpoint": u"/flocker/{}".format(dataset.dataset_id)},
            )

        d.addCallback(created)
        return d

    def test_unknown_path(self):
        """
        ``/VolumeDriver.Path`` returns an error when asked for the mount path
        of a non-existent volume.
        """
        name = u"myvol"
        return self.assertResult(
            b"POST", b"/VolumeDriver.Path", {u"Name": name}, OK, {u"Err": u"Could not find volume with given name."}
        )

    def test_non_local_path(self):
        """
        ``/VolumeDriver.Path`` returns an error when asked for the mount path
        of a volume that is not mounted locally.

        This can happen as a result of ``docker inspect`` on a container
        that has been created but is still waiting for its volume to
        arrive from another node. It seems like Docker may also call this
        after ``/VolumeDriver.Create``, so again while waiting for a
        volume to arrive.
        """
        name = u"myvol"
        dataset_id = uuid4()

        # Create dataset on node B:
        d = self.flocker_client.create_dataset(
            self.NODE_B, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}, dataset_id=dataset_id
        )
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        # Ask for path on node A:
        d.addCallback(
            lambda _: self.assertResult(
                b"POST",
                b"/VolumeDriver.Path",
                {u"Name": name},
                OK,
                {u"Err": "Volume not available.", u"Mountpoint": u""},
            )
        )
        return d

    @capture_logging(lambda self, logger: self.assertEqual(len(logger.flushTracebacks(CustomException)), 1))
    def test_unexpected_error_reporting(self, logger):
        """
        If an unexpected error occurs Docker gets back a useful error message.
        """

        def error():
            raise CustomException("I've made a terrible mistake")

        self.patch(self.flocker_client, "list_datasets_configuration", error)
        return self.assertResult(
            b"POST",
            b"/VolumeDriver.Path",
            {u"Name": u"whatever"},
            OK,
            {u"Err": "CustomException: I've made a terrible mistake"},
        )

    @capture_logging(None)
    def test_bad_request(self, logger):
        """
        If a ``BadRequest`` exception is raised it is converted to appropriate
        JSON.
        """

        def error():
            raise make_bad_request(code=423, Err=u"no good")

        self.patch(self.flocker_client, "list_datasets_configuration", error)
        return self.assertResult(b"POST", b"/VolumeDriver.Path", {u"Name": u"whatever"}, 423, {u"Err": "no good"})

    def test_unsupported_method(self):
        """
        If an unsupported method is requested the 405 Not Allowed response
        code is returned.
        """
        return self.assertResponseCode(b"BAD_METHOD", b"/VolumeDriver.Path", None, NOT_ALLOWED)

    def test_unknown_uri(self):
        """
        If an unknown URI path is requested the 404 Not Found response code is
        returned.
        """
        return self.assertResponseCode(b"BAD_METHOD", b"/xxxnotthere", None, NOT_FOUND)

    def test_empty_host(self):
        """
        If an empty host header is sent to the Docker plugin it does not blow
        up, instead operating normally. E.g. for ``Plugin.Activate`` call
        returns the ``Implements`` response.
        """
        return self.assertResult(
            b"POST",
            b"/Plugin.Activate",
            12345,
            OK,
            {u"Implements": [u"VolumeDriver"]},
            additional_headers={b"Host": [""]},
        )

    def test_get(self):
        """
        ``/VolumeDriver.Get`` returns the mount path of the given volume if
        it is currently known.
        """
        name = u"myvol"

        d = self.create(name)
        # The dataset arrives as state:
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        d.addCallback(lambda _: self.assertResponseCode(b"POST", b"/VolumeDriver.Mount", {u"Name": name}, OK))
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(
            lambda datasets_config: self.assertResult(
                b"POST",
                b"/VolumeDriver.Get",
                {u"Name": name},
                OK,
                {
                    u"Err": u"",
                    u"Volume": {
                        u"Name": name,
                        u"Mountpoint": u"/flocker/{}".format(datasets_config.datasets.keys()[0]),
                    },
                },
            )
        )
        return d

    def test_get_existing(self):
        """
        ``/VolumeDriver.Get`` returns the mount path of the given volume if
        it is currently known, including for a dataset that was created
        not by the plugin.
        """
        name = u"myvol"

        d = self.flocker_client.create_dataset(self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name})

        def created(dataset):
            self.flocker_client.synchronize_state()
            return self.assertResult(
                b"POST",
                b"/VolumeDriver.Get",
                {u"Name": name},
                OK,
                {u"Err": u"", u"Volume": {u"Name": name, u"Mountpoint": u"/flocker/{}".format(dataset.dataset_id)}},
            )

        d.addCallback(created)
        return d

    def test_unknown_get(self):
        """
        ``/VolumeDriver.Get`` returns an error when asked for the mount path
        of a non-existent volume.
        """
        name = u"myvol"
        return self.assertResult(
            b"POST", b"/VolumeDriver.Get", {u"Name": name}, OK, {u"Err": u"Could not find volume with given name."}
        )

    def test_non_local_get(self):
        """
        ``/VolumeDriver.Get`` returns an empty mount point when asked about a
        volume that is not mounted locally.
        """
        name = u"myvol"
        dataset_id = uuid4()

        # Create dataset on node B:
        d = self.flocker_client.create_dataset(
            self.NODE_B, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}, dataset_id=dataset_id
        )
        d.addCallback(lambda _: self.flocker_client.synchronize_state())

        # Ask for path on node A:
        d.addCallback(
            lambda _: self.assertResult(
                b"POST",
                b"/VolumeDriver.Get",
                {u"Name": name},
                OK,
                {u"Err": u"", u"Volume": {u"Name": name, u"Mountpoint": u""}},
            )
        )
        return d

    def test_list(self):
        """
        ``/VolumeDriver.List`` returns the mount path of the given volume if
        it is currently known and an empty mount point for non-local
        volumes.
        """
        name = u"myvol"
        remote_name = u"myvol3"

        d = gatherResults(
            [
                self.flocker_client.create_dataset(
                    self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: name}
                ),
                self.flocker_client.create_dataset(
                    self.NODE_B, int(DEFAULT_SIZE.to_Byte()), metadata={NAME_FIELD: remote_name}
                ),
            ]
        )

        # The datasets arrive as state:
        d.addCallback(lambda _: self.flocker_client.synchronize_state())
        d.addCallback(lambda _: self.flocker_client.list_datasets_configuration())
        d.addCallback(
            lambda datasets_config: self.assertResult(
                b"POST",
                b"/VolumeDriver.List",
                {},
                OK,
                {
                    u"Err": u"",
                    u"Volumes": sorted(
                        [
                            {
                                u"Name": name,
                                u"Mountpoint": u"/flocker/{}".format(
                                    [
                                        key
                                        for (key, value) in datasets_config.datasets.items()
                                        if value.metadata["name"] == name
                                    ][0]
                                ),
                            },
                            {u"Name": remote_name, u"Mountpoint": u""},
                        ]
                    ),
                },
            )
        )
        return d

    def test_list_no_metadata_name(self):
        """
        ``/VolumeDriver.List`` omits volumes that don't have a metadata field
        for their name.
        """
        d = self.flocker_client.create_dataset(self.NODE_A, int(DEFAULT_SIZE.to_Byte()), metadata={})
        d.addCallback(
            lambda _: self.assertResult(b"POST", b"/VolumeDriver.List", {}, OK, {u"Err": u"", u"Volumes": []})
        )
        return d
Exemple #37
0
class TestWindowManager(VumiTestCase):

    @inlineCallbacks
    def setUp(self):
        self.persistence_helper = self.add_helper(PersistenceHelper())
        redis = yield self.persistence_helper.get_redis_manager()
        self.window_id = 'window_id'

        # Patch the clock so we can control time
        self.clock = Clock()
        self.patch(WindowManager, 'get_clock', lambda _: self.clock)

        self.wm = WindowManager(redis, window_size=10, flight_lifetime=10)
        self.add_cleanup(self.wm.stop)
        yield self.wm.create_window(self.window_id)
        self.redis = self.wm.redis

    @inlineCallbacks
    def test_windows(self):
        windows = yield self.wm.get_windows()
        self.assertTrue(self.window_id in windows)

    def test_strict_window_recreation(self):
        return self.assertFailure(
            self.wm.create_window(self.window_id, strict=True),
                                    WindowException)

    @inlineCallbacks
    def test_window_recreation(self):
        orig_clock_time = self.clock.seconds()
        clock_time = yield self.wm.create_window(self.window_id)
        self.assertEqual(clock_time, orig_clock_time)

    @inlineCallbacks
    def test_window_removal(self):
        yield self.wm.add(self.window_id, 1)
        yield self.assertFailure(self.wm.remove_window(self.window_id),
            WindowException)
        key = yield self.wm.get_next_key(self.window_id)
        item = yield self.wm.get_data(self.window_id, key)
        self.assertEqual(item, 1)
        self.assertEqual((yield self.wm.remove_window(self.window_id)), None)

    @inlineCallbacks
    def test_adding_to_window(self):
        for i in range(10):
            yield self.wm.add(self.window_id, i)
        window_key = self.wm.window_key(self.window_id)
        window_members = yield self.redis.llen(window_key)
        self.assertEqual(window_members, 10)

    @inlineCallbacks
    def test_fetching_from_window(self):
        for i in range(12):
            yield self.wm.add(self.window_id, i)

        flight_keys = []
        for i in range(10):
            flight_key = yield self.wm.get_next_key(self.window_id)
            self.assertTrue(flight_key)
            flight_keys.append(flight_key)

        out_of_window_flight = yield self.wm.get_next_key(self.window_id)
        self.assertEqual(out_of_window_flight, None)

        # We should get data out in the order we put it in
        for i, flight_key in enumerate(flight_keys):
            data = yield self.wm.get_data(self.window_id, flight_key)
            self.assertEqual(data, i)

        # Removing one should allow for space for the next to fill up
        yield self.wm.remove_key(self.window_id, flight_keys[0])
        next_flight_key = yield self.wm.get_next_key(self.window_id)
        self.assertTrue(next_flight_key)

    @inlineCallbacks
    def test_set_and_external_id(self):
        yield self.wm.set_external_id(self.window_id, "flight_key",
                                      "external_id")
        self.assertEqual(
            (yield self.wm.get_external_id(self.window_id, "flight_key")),
            "external_id")
        self.assertEqual(
            (yield self.wm.get_internal_id(self.window_id, "external_id")),
            "flight_key")

    @inlineCallbacks
    def test_remove_key_removes_external_and_internal_id(self):
        yield self.wm.set_external_id(self.window_id, "flight_key",
                                      "external_id")
        yield self.wm.remove_key(self.window_id, "flight_key")
        self.assertEqual(
            (yield self.wm.get_external_id(self.window_id, "flight_key")),
            None)
        self.assertEqual(
            (yield self.wm.get_internal_id(self.window_id, "external_id")),
            None)

    @inlineCallbacks
    def assert_count_waiting(self, window_id, amount):
        self.assertEqual((yield self.wm.count_waiting(window_id)), amount)

    @inlineCallbacks
    def assert_expired_keys(self, window_id, amount):
        # Stuff has taken too long and so we should get 10 expired keys
        expired_keys = yield self.wm.get_expired_flight_keys(window_id)
        self.assertEqual(len(expired_keys), amount)

    @inlineCallbacks
    def assert_in_flight(self, window_id, amount):
        self.assertEqual((yield self.wm.count_in_flight(window_id)),
            amount)

    @inlineCallbacks
    def slide_window(self, limit=10):
        for i in range(limit):
            yield self.wm.get_next_key(self.window_id)

    @inlineCallbacks
    def test_expiry_of_acks(self):

        def mock_clock_time(self):
            return self._clocktime

        self.patch(WindowManager, 'get_clocktime', mock_clock_time)
        self.wm._clocktime = 0

        for i in range(30):
            yield self.wm.add(self.window_id, i)

        # We're manually setting the clock instead of using clock.advance()
        # so we can wait for the deferreds to finish before continuing to the
        # next clear_expired_flight_keys run since LoopingCall() will only fire
        # again if the previous run has completed.
        yield self.slide_window()
        self.wm._clocktime = 10
        yield self.wm.clear_expired_flight_keys()
        self.assert_expired_keys(self.window_id, 10)

        yield self.slide_window()
        self.wm._clocktime = 20
        yield self.wm.clear_expired_flight_keys()
        self.assert_expired_keys(self.window_id, 20)

        yield self.slide_window()
        self.wm._clocktime = 30
        yield self.wm.clear_expired_flight_keys()
        self.assert_expired_keys(self.window_id, 30)

        self.assert_in_flight(self.window_id, 0)
        self.assert_count_waiting(self.window_id, 0)

    @inlineCallbacks
    def test_monitor_windows(self):
        yield self.wm.remove_window(self.window_id)

        window_ids = ['window_id_1', 'window_id_2']
        for window_id in window_ids:
            yield self.wm.create_window(window_id)
            for i in range(20):
                yield self.wm.add(window_id, i)

        key_callbacks = {}

        def callback(window_id, key):
            key_callbacks.setdefault(window_id, []).append(key)

        cleanup_callbacks = []

        def cleanup_callback(window_id):
            cleanup_callbacks.append(window_id)

        yield self.wm._monitor_windows(callback, False)

        self.assertEqual(set(key_callbacks.keys()), set(window_ids))
        self.assertEqual(len(key_callbacks.values()[0]), 10)
        self.assertEqual(len(key_callbacks.values()[1]), 10)

        yield self.wm._monitor_windows(callback, False)

        # Nothing should've changed since we haven't removed anything.
        self.assertEqual(len(key_callbacks.values()[0]), 10)
        self.assertEqual(len(key_callbacks.values()[1]), 10)

        for window_id, keys in key_callbacks.items():
            for key in keys:
                yield self.wm.remove_key(window_id, key)

        yield self.wm._monitor_windows(callback, False)
        # Everything should've been processed now
        self.assertEqual(len(key_callbacks.values()[0]), 20)
        self.assertEqual(len(key_callbacks.values()[1]), 20)

        # Now run again but cleanup the empty windows
        self.assertEqual(set((yield self.wm.get_windows())), set(window_ids))
        for window_id, keys in key_callbacks.items():
            for key in keys:
                yield self.wm.remove_key(window_id, key)

        yield self.wm._monitor_windows(callback, True, cleanup_callback)
        self.assertEqual(len(key_callbacks.values()[0]), 20)
        self.assertEqual(len(key_callbacks.values()[1]), 20)
        self.assertEqual((yield self.wm.get_windows()), [])
        self.assertEqual(set(cleanup_callbacks), set(window_ids))