Ejemplo n.º 1
0
 def test_tags_from_environment_and_constant(self):
     with preserve_envvars("STATSD_TAGS"):
         os.environ["STATSD_TAGS"] = "country:china,age:45"
         statsd = Statsd(constant_tags={"country": "canada"})
     statsd._socket = FakeSocket()
     statsd.gauge("gt", 123.4)
     self.assertEqual("gt:123.4|g|#age:45,country:canada", statsd.socket.recv())
Ejemplo n.º 2
0
 def setUp(self):
     """
     Set up a default Statsd instance and mock the socket.
     """
     #
     self.statsd = Statsd()
     self.statsd._socket = FakeSocket()
Ejemplo n.º 3
0
    def __init__(
        self,
        storage: StorageInterface,
        origin_url: str,
        logging_class: Optional[str] = None,
        save_data_path: Optional[str] = None,
        max_content_size: Optional[int] = None,
        lister_name: Optional[str] = None,
        lister_instance_name: Optional[str] = None,
        metadata_fetcher_credentials: CredentialsType = None,
    ):
        if lister_name == "":
            raise ValueError("lister_name must not be the empty string")
        if lister_name is None and lister_instance_name is not None:
            raise ValueError(
                f"lister_name is None but lister_instance_name is {lister_instance_name!r}"
            )
        if lister_name is not None and lister_instance_name is None:
            raise ValueError(
                f"lister_instance_name is None but lister_name is {lister_name!r}"
            )

        self.storage = storage
        self.origin = Origin(url=origin_url)
        self.max_content_size = int(
            max_content_size) if max_content_size else None
        self.lister_name = lister_name
        self.lister_instance_name = lister_instance_name
        self.metadata_fetcher_credentials = metadata_fetcher_credentials or {}

        if logging_class is None:
            logging_class = "%s.%s" % (
                self.__class__.__module__,
                self.__class__.__name__,
            )
        self.log = logging.getLogger(logging_class)

        _log = logging.getLogger("requests.packages.urllib3.connectionpool")
        _log.setLevel(logging.WARN)

        # possibly overridden in self.prepare method
        self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc)

        self.loaded_snapshot_id = None

        if save_data_path:
            path = save_data_path
            os.stat(path)
            if not os.access(path, os.R_OK | os.W_OK):
                raise PermissionError("Permission denied: %r" % path)

        self.save_data_path = save_data_path

        self.parent_origins = None

        self.statsd = Statsd(namespace="swh_loader",
                             constant_tags={"visit_type": self.visit_type})
Ejemplo n.º 4
0
def statsd():
    """Simple fixture giving a Statsd instance suitable for tests

    The Statsd instance uses a FakeSocket as `.socket` attribute in which one
    can get the accumulated statsd messages in a deque in `.socket.payloads`.
    """

    from swh.core.statsd import Statsd

    statsd = Statsd()
    statsd._socket = FakeSocket()
    yield statsd
Ejemplo n.º 5
0
def test_tags_from_environment_with_substitution(monkeypatch):
    monkeypatch.setenv("HOSTNAME", "sweethome")
    monkeypatch.setenv("PORT", "42")
    monkeypatch.setenv(
        "STATSD_TAGS", "country:china,age:45,host:$HOSTNAME,port:${PORT}"
    )
    statsd = Statsd()
    statsd._socket = FakeSocket()
    statsd.gauge("gt", 123.4)
    assert (
        statsd.socket.recv()
        == "gt:123.4|g|#age:45,country:china,host:sweethome,port:42"
    )
Ejemplo n.º 6
0
 def statsd(self):
     if self._statsd:
         return self._statsd
     worker_name = current_app.conf.get("worker_name")
     if worker_name:
         self._statsd = Statsd(constant_tags={
             "task": self.name,
             "worker": worker_name,
         })
         return self._statsd
     else:
         statsd = Statsd(constant_tags={
             "task": self.name,
             "worker": "unknown worker",
         })
         return statsd
Ejemplo n.º 7
0
def test_envvar_port(monkeypatch):
    monkeypatch.setenv("STATSD_HOST", "")
    monkeypatch.setenv("STATSD_PORT", "12345")
    local_statsd = Statsd()

    assert local_statsd.host == "localhost"
    assert local_statsd.port == 12345
Ejemplo n.º 8
0
def test_param_host(monkeypatch):
    monkeypatch.setenv("STATSD_HOST", "test-value")
    monkeypatch.setenv("STATSD_PORT", "")
    local_statsd = Statsd(host="actual-test-value")

    assert local_statsd.host == "actual-test-value"
    assert local_statsd.port == 8125
Ejemplo n.º 9
0
    def test_envvar_host(self):
        with preserve_envvars("STATSD_HOST", "STATSD_PORT"):
            os.environ["STATSD_HOST"] = "test-value"
            os.environ["STATSD_PORT"] = ""
            local_statsd = Statsd()

        self.assertEqual(local_statsd.host, "test-value")
        self.assertEqual(local_statsd.port, 8125)
Ejemplo n.º 10
0
    def test_envvar_port(self):
        with preserve_envvars("STATSD_HOST", "STATSD_PORT"):
            os.environ["STATSD_HOST"] = ""
            os.environ["STATSD_PORT"] = "12345"
            local_statsd = Statsd()

        self.assertEqual(local_statsd.host, "localhost")
        self.assertEqual(local_statsd.port, 12345)
Ejemplo n.º 11
0
def test_context_manager():
    fake_socket = FakeSocket()
    with Statsd() as statsd:
        statsd._socket = fake_socket
        statsd.gauge("page.views", 123)
        statsd.timing("timer", 123)

    assert fake_socket.recv() == "page.views:123|g\ntimer:123|ms"
Ejemplo n.º 12
0
def test_batched_buffer_autoflush():
    fake_socket = FakeSocket()
    with Statsd() as statsd:
        statsd._socket = fake_socket
        for i in range(51):
            statsd.increment("mycounter")
        assert "\n".join(["mycounter:1|c" for i in range(50)]) == fake_socket.recv()

    assert fake_socket.recv() == "mycounter:1|c"
Ejemplo n.º 13
0
def test_tags_from_environment_warning(monkeypatch):
    monkeypatch.setenv("STATSD_TAGS", "valid:tag,invalid_tag")
    with pytest.warns(UserWarning) as record:
        statsd = Statsd()

    assert len(record) == 1
    assert "invalid_tag" in record[0].message.args[0]
    assert "valid:tag" not in record[0].message.args[0]
    assert statsd.constant_tags == {"valid": "tag"}
Ejemplo n.º 14
0
    def test_tags_from_environment_warning(self):
        with preserve_envvars("STATSD_TAGS"):
            os.environ["STATSD_TAGS"] = "valid:tag,invalid_tag"
            with pytest.warns(UserWarning) as record:
                statsd = Statsd()

        assert len(record) == 1
        assert "invalid_tag" in record[0].message.args[0]
        assert "valid:tag" not in record[0].message.args[0]
        assert statsd.constant_tags == {"valid": "tag"}
Ejemplo n.º 15
0
    def test_batched_buffer_autoflush(self):
        fake_socket = FakeSocket()
        with Statsd() as statsd:
            statsd._socket = fake_socket
            for i in range(51):
                statsd.increment("mycounter")
            self.assertEqual(
                "\n".join(["mycounter:1|c" for i in range(50)]), fake_socket.recv(),
            )

        self.assertEqual("mycounter:1|c", fake_socket.recv())
Ejemplo n.º 16
0
def test_tags_from_environment_and_constant(monkeypatch):
    monkeypatch.setenv("STATSD_TAGS", "country:china,age:45")
    statsd = Statsd(constant_tags={"country": "canada"})
    statsd._socket = FakeSocket()
    statsd.gauge("gt", 123.4)
    assert statsd.socket.recv() == "gt:123.4|g|#age:45,country:canada"
Ejemplo n.º 17
0
def test_accessing_socket_multiple_times_returns_same_socket():
    local_statsd = Statsd()
    fresh_socket = FakeSocket()
    local_statsd._socket = fresh_socket
    assert fresh_socket == local_statsd.socket
    assert FakeSocket() != local_statsd.socket
Ejemplo n.º 18
0
def test_accessing_socket_opens_socket():
    local_statsd = Statsd()
    try:
        assert local_statsd.socket is not None
    finally:
        local_statsd.close_socket()
Ejemplo n.º 19
0
def test_instantiating_does_not_connect():
    local_statsd = Statsd()
    assert local_statsd._socket is None
Ejemplo n.º 20
0
 def test_accessing_socket_opens_socket(self):
     local_statsd = Statsd()
     try:
         self.assertIsNotNone(local_statsd.socket)
     finally:
         local_statsd.close_socket()
Ejemplo n.º 21
0
def test_param_port(monkeypatch):
    monkeypatch.setenv("STATSD_HOST", "")
    monkeypatch.setenv("STATSD_PORT", "12345")
    local_statsd = Statsd(port=4321)
    assert local_statsd.host == "localhost"
    assert local_statsd.port == 4321
Ejemplo n.º 22
0
 def test_instantiating_does_not_connect(self):
     local_statsd = Statsd()
     self.assertEqual(None, local_statsd._socket)
Ejemplo n.º 23
0
class TestStatsd(unittest.TestCase):
    def setUp(self):
        """
        Set up a default Statsd instance and mock the socket.
        """
        #
        self.statsd = Statsd()
        self.statsd._socket = FakeSocket()

    def recv(self):
        return self.statsd.socket.recv()

    def test_set(self):
        self.statsd.set("set", 123)
        assert self.recv() == "set:123|s"

    def test_gauge(self):
        self.statsd.gauge("gauge", 123.4)
        assert self.recv() == "gauge:123.4|g"

    def test_counter(self):
        self.statsd.increment("page.views")
        self.assertEqual("page.views:1|c", self.recv())

        self.statsd.increment("page.views", 11)
        self.assertEqual("page.views:11|c", self.recv())

        self.statsd.decrement("page.views")
        self.assertEqual("page.views:-1|c", self.recv())

        self.statsd.decrement("page.views", 12)
        self.assertEqual("page.views:-12|c", self.recv())

    def test_histogram(self):
        self.statsd.histogram("histo", 123.4)
        self.assertEqual("histo:123.4|h", self.recv())

    def test_tagged_gauge(self):
        self.statsd.gauge("gt", 123.4, tags={"country": "china", "age": 45})
        self.assertEqual("gt:123.4|g|#age:45,country:china", self.recv())

    def test_tagged_counter(self):
        self.statsd.increment("ct", tags={"country": "españa"})
        self.assertEqual("ct:1|c|#country:españa", self.recv())

    def test_tagged_histogram(self):
        self.statsd.histogram("h", 1, tags={"test_tag": "tag_value"})
        self.assertEqual("h:1|h|#test_tag:tag_value", self.recv())

    def test_sample_rate(self):
        self.statsd.increment("c", sample_rate=0)
        assert not self.recv()
        for i in range(10000):
            self.statsd.increment("sampled_counter", sample_rate=0.3)
        self.assert_almost_equal(3000, len(self.statsd.socket.payloads), 150)
        self.assertEqual("sampled_counter:1|c|@0.3", self.recv())

    def test_tags_and_samples(self):
        for i in range(100):
            self.statsd.gauge("gst", 23, tags={"sampled": True}, sample_rate=0.9)

        self.assert_almost_equal(90, len(self.statsd.socket.payloads), 10)
        self.assertEqual("gst:23|g|@0.9|#sampled:True", self.recv())

    def test_timing(self):
        self.statsd.timing("t", 123)
        self.assertEqual("t:123|ms", self.recv())

    def test_metric_namespace(self):
        """
        Namespace prefixes all metric names.
        """
        self.statsd.namespace = "foo"
        self.statsd.gauge("gauge", 123.4)
        self.assertEqual("foo.gauge:123.4|g", self.recv())

    # Test Client level constant tags
    def test_gauge_constant_tags(self):
        self.statsd.constant_tags = {
            "bar": "baz",
        }
        self.statsd.gauge("gauge", 123.4)
        assert self.recv() == "gauge:123.4|g|#bar:baz"

    def test_counter_constant_tag_with_metric_level_tags(self):
        self.statsd.constant_tags = {
            "bar": "baz",
            "foo": True,
        }
        self.statsd.increment("page.views", tags={"extra": "extra"})
        self.assertEqual(
            "page.views:1|c|#bar:baz,extra:extra,foo:True", self.recv(),
        )

    def test_gauge_constant_tags_with_metric_level_tags_twice(self):
        metric_level_tag = {"foo": "bar"}
        self.statsd.constant_tags = {"bar": "baz"}
        self.statsd.gauge("gauge", 123.4, tags=metric_level_tag)
        assert self.recv() == "gauge:123.4|g|#bar:baz,foo:bar"

        # sending metrics multiple times with same metric-level tags
        # should not duplicate the tags being sent
        self.statsd.gauge("gauge", 123.4, tags=metric_level_tag)
        assert self.recv() == "gauge:123.4|g|#bar:baz,foo:bar"

    def assert_almost_equal(self, a, b, delta):
        self.assertTrue(
            0 <= abs(a - b) <= delta, "%s - %s not within %s" % (a, b, delta)
        )

    def test_socket_error(self):
        self.statsd._socket = BrokenSocket()
        self.statsd.gauge("no error", 1)
        assert True, "success"

    def test_socket_timeout(self):
        self.statsd._socket = SlowSocket()
        self.statsd.gauge("no error", 1)
        assert True, "success"

    def test_timed(self):
        """
        Measure the distribution of a function's run time.
        """

        @self.statsd.timed("timed.test")
        def func(a, b, c=1, d=1):
            """docstring"""
            time.sleep(0.5)
            return (a, b, c, d)

        self.assertEqual("func", func.__name__)
        self.assertEqual("docstring", func.__doc__)

        result = func(1, 2, d=3)
        # Assert it handles args and kwargs correctly.
        self.assertEqual(result, (1, 2, 1, 3))

        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("ms", type_)
        self.assertEqual("timed.test", name)
        self.assert_almost_equal(500, float(value), 100)

    def test_timed_exception(self):
        """
        Exception bubble out of the decorator and is reported
        to statsd as a dedicated counter.
        """

        @self.statsd.timed("timed.test")
        def func(a, b, c=1, d=1):
            """docstring"""
            time.sleep(0.5)
            return (a / b, c, d)

        self.assertEqual("func", func.__name__)
        self.assertEqual("docstring", func.__doc__)

        with self.assertRaises(ZeroDivisionError):
            func(1, 0)

        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("c", type_)
        self.assertEqual("timed.test_error_count", name)
        self.assertEqual(int(value), 1)

    def test_timed_no_metric(self,):
        """
        Test using a decorator without providing a metric.
        """

        @self.statsd.timed()
        def func(a, b, c=1, d=1):
            """docstring"""
            time.sleep(0.5)
            return (a, b, c, d)

        self.assertEqual("func", func.__name__)
        self.assertEqual("docstring", func.__doc__)

        result = func(1, 2, d=3)
        # Assert it handles args and kwargs correctly.
        self.assertEqual(result, (1, 2, 1, 3))

        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("ms", type_)
        self.assertEqual("swh.core.tests.test_statsd.func", name)
        self.assert_almost_equal(500, float(value), 100)

    def test_timed_coroutine(self):
        """
        Measure the distribution of a coroutine function's run time.

        Warning: Python >= 3.5 only.
        """
        import asyncio

        @self.statsd.timed("timed.test")
        @asyncio.coroutine
        def print_foo():
            """docstring"""
            time.sleep(0.5)
            print("foo")

        loop = asyncio.new_event_loop()
        loop.run_until_complete(print_foo())
        loop.close()

        # Assert
        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("ms", type_)
        self.assertEqual("timed.test", name)
        self.assert_almost_equal(500, float(value), 100)

    def test_timed_context(self):
        """
        Measure the distribution of a context's run time.
        """
        # In milliseconds
        with self.statsd.timed("timed_context.test") as timer:
            self.assertIsInstance(timer, TimedContextManagerDecorator)
            time.sleep(0.5)

        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("ms", type_)
        self.assertEqual("timed_context.test", name)
        self.assert_almost_equal(500, float(value), 100)
        self.assert_almost_equal(500, timer.elapsed, 100)

    def test_timed_context_exception(self):
        """
        Exception bubbles out of the `timed` context manager and is
        reported to statsd as a dedicated counter.
        """

        class ContextException(Exception):
            pass

        def func(self):
            with self.statsd.timed("timed_context.test"):
                time.sleep(0.5)
                raise ContextException()

        # Ensure the exception was raised.
        self.assertRaises(ContextException, func, self)

        # Ensure the timing was recorded.
        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("c", type_)
        self.assertEqual("timed_context.test_error_count", name)
        self.assertEqual(int(value), 1)

    def test_timed_context_no_metric_name_exception(self):
        """Test that an exception occurs if using a context manager without a
        metric name.
        """

        def func(self):
            with self.statsd.timed():
                time.sleep(0.5)

        # Ensure the exception was raised.
        self.assertRaises(TypeError, func, self)

        # Ensure the timing was recorded.
        packet = self.recv()
        self.assertEqual(packet, None)

    def test_timed_start_stop_calls(self):
        timer = self.statsd.timed("timed_context.test")
        timer.start()
        time.sleep(0.5)
        timer.stop()

        packet = self.recv()
        name_value, type_ = packet.split("|")
        name, value = name_value.split(":")

        self.assertEqual("ms", type_)
        self.assertEqual("timed_context.test", name)
        self.assert_almost_equal(500, float(value), 100)

    def test_batched(self):
        self.statsd.open_buffer()
        self.statsd.gauge("page.views", 123)
        self.statsd.timing("timer", 123)
        self.statsd.close_buffer()

        self.assertEqual("page.views:123|g\ntimer:123|ms", self.recv())

    def test_context_manager(self):
        fake_socket = FakeSocket()
        with Statsd() as statsd:
            statsd._socket = fake_socket
            statsd.gauge("page.views", 123)
            statsd.timing("timer", 123)

        self.assertEqual("page.views:123|g\ntimer:123|ms", fake_socket.recv())

    def test_batched_buffer_autoflush(self):
        fake_socket = FakeSocket()
        with Statsd() as statsd:
            statsd._socket = fake_socket
            for i in range(51):
                statsd.increment("mycounter")
            self.assertEqual(
                "\n".join(["mycounter:1|c" for i in range(50)]), fake_socket.recv(),
            )

        self.assertEqual("mycounter:1|c", fake_socket.recv())

    def test_module_level_instance(self):
        from swh.core.statsd import statsd

        self.assertTrue(isinstance(statsd, Statsd))

    def test_instantiating_does_not_connect(self):
        local_statsd = Statsd()
        self.assertEqual(None, local_statsd._socket)

    def test_accessing_socket_opens_socket(self):
        local_statsd = Statsd()
        try:
            self.assertIsNotNone(local_statsd.socket)
        finally:
            local_statsd.close_socket()

    def test_accessing_socket_multiple_times_returns_same_socket(self):
        local_statsd = Statsd()
        fresh_socket = FakeSocket()
        local_statsd._socket = fresh_socket
        self.assertEqual(fresh_socket, local_statsd.socket)
        self.assertNotEqual(FakeSocket(), local_statsd.socket)

    def test_tags_from_environment(self):
        with preserve_envvars("STATSD_TAGS"):
            os.environ["STATSD_TAGS"] = "country:china,age:45"
            statsd = Statsd()

        statsd._socket = FakeSocket()
        statsd.gauge("gt", 123.4)
        self.assertEqual("gt:123.4|g|#age:45,country:china", statsd.socket.recv())

    def test_tags_from_environment_and_constant(self):
        with preserve_envvars("STATSD_TAGS"):
            os.environ["STATSD_TAGS"] = "country:china,age:45"
            statsd = Statsd(constant_tags={"country": "canada"})
        statsd._socket = FakeSocket()
        statsd.gauge("gt", 123.4)
        self.assertEqual("gt:123.4|g|#age:45,country:canada", statsd.socket.recv())

    def test_tags_from_environment_warning(self):
        with preserve_envvars("STATSD_TAGS"):
            os.environ["STATSD_TAGS"] = "valid:tag,invalid_tag"
            with pytest.warns(UserWarning) as record:
                statsd = Statsd()

        assert len(record) == 1
        assert "invalid_tag" in record[0].message.args[0]
        assert "valid:tag" not in record[0].message.args[0]
        assert statsd.constant_tags == {"valid": "tag"}

    def test_gauge_doesnt_send_none(self):
        self.statsd.gauge("metric", None)
        assert self.recv() is None

    def test_increment_doesnt_send_none(self):
        self.statsd.increment("metric", None)
        assert self.recv() is None

    def test_decrement_doesnt_send_none(self):
        self.statsd.decrement("metric", None)
        assert self.recv() is None

    def test_timing_doesnt_send_none(self):
        self.statsd.timing("metric", None)
        assert self.recv() is None

    def test_histogram_doesnt_send_none(self):
        self.statsd.histogram("metric", None)
        assert self.recv() is None

    def test_param_host(self):
        with preserve_envvars("STATSD_HOST", "STATSD_PORT"):
            os.environ["STATSD_HOST"] = "test-value"
            os.environ["STATSD_PORT"] = ""
            local_statsd = Statsd(host="actual-test-value")

        self.assertEqual(local_statsd.host, "actual-test-value")
        self.assertEqual(local_statsd.port, 8125)

    def test_param_port(self):
        with preserve_envvars("STATSD_HOST", "STATSD_PORT"):
            os.environ["STATSD_HOST"] = ""
            os.environ["STATSD_PORT"] = "12345"
            local_statsd = Statsd(port=4321)

        self.assertEqual(local_statsd.host, "localhost")
        self.assertEqual(local_statsd.port, 4321)

    def test_envvar_host(self):
        with preserve_envvars("STATSD_HOST", "STATSD_PORT"):
            os.environ["STATSD_HOST"] = "test-value"
            os.environ["STATSD_PORT"] = ""
            local_statsd = Statsd()

        self.assertEqual(local_statsd.host, "test-value")
        self.assertEqual(local_statsd.port, 8125)

    def test_envvar_port(self):
        with preserve_envvars("STATSD_HOST", "STATSD_PORT"):
            os.environ["STATSD_HOST"] = ""
            os.environ["STATSD_PORT"] = "12345"
            local_statsd = Statsd()

        self.assertEqual(local_statsd.host, "localhost")
        self.assertEqual(local_statsd.port, 12345)

    def test_namespace_added(self):
        local_statsd = Statsd(namespace="test-namespace")
        local_statsd._socket = FakeSocket()

        local_statsd.gauge("gauge", 123.4)
        assert local_statsd.socket.recv() == "test-namespace.gauge:123.4|g"

    def test_contextmanager_empty(self):
        with self.statsd:
            assert True, "success"

    def test_contextmanager_buffering(self):
        with self.statsd as s:
            s.gauge("gauge", 123.4)
            s.gauge("gauge_other", 456.78)
            self.assertIsNone(s.socket.recv())

        self.assertEqual(self.recv(), "gauge:123.4|g\ngauge_other:456.78|g")

    def test_timed_elapsed(self):
        with self.statsd.timed("test_timer") as t:
            pass

        self.assertGreaterEqual(t.elapsed, 0)
        self.assertEqual(self.recv(), "test_timer:%s|ms" % t.elapsed)
Ejemplo n.º 24
0
def test_envvar_host(monkeypatch):
    monkeypatch.setenv("STATSD_HOST", "test-value")
    monkeypatch.setenv("STATSD_PORT", "")
    local_statsd = Statsd()
    assert local_statsd.host == "test-value"
    assert local_statsd.port == 8125
Ejemplo n.º 25
0
 def test_accessing_socket_multiple_times_returns_same_socket(self):
     local_statsd = Statsd()
     fresh_socket = FakeSocket()
     local_statsd._socket = fresh_socket
     self.assertEqual(fresh_socket, local_statsd.socket)
     self.assertNotEqual(FakeSocket(), local_statsd.socket)
Ejemplo n.º 26
0
def test_namespace_added():
    local_statsd = Statsd(namespace="test-namespace")
    local_statsd._socket = FakeSocket()

    local_statsd.gauge("gauge", 123.4)
    assert local_statsd.socket.recv() == "test-namespace.gauge:123.4|g"
Ejemplo n.º 27
0
class BaseLoader:
    """Base class for (D)VCS loaders (e.g Svn, Git, Mercurial, ...) or PackageLoader (e.g
    PyPI, Npm, CRAN, ...)

    A loader retrieves origin information (git/mercurial/svn repositories, pypi/npm/...
    package artifacts), ingests the contents/directories/revisions/releases/snapshot
    read from those artifacts and send them to the archive through the storage backend.

    The main entry point for the loader is the :func:`load` function.

    2 static methods (:func:`from_config`, :func:`from_configfile`) centralizes and
    eases the loader instantiation from either configuration dict or configuration file.

    Some class examples:

    - :class:`SvnLoader`
    - :class:`GitLoader`
    - :class:`PyPILoader`
    - :class:`NpmLoader`

    Args:
      lister_name: Name of the lister which triggered this load.
        If provided, the loader will try to use the forge's API to retrieve extrinsic
        metadata
      lister_instance_name: Name of the lister instance which triggered this load.
        Must be None iff lister_name is, but it may be the empty string for listers
        with a single instance.
    """

    visit_type: str
    origin: Origin
    loaded_snapshot_id: Optional[Sha1Git]

    parent_origins: Optional[List[Origin]]
    """If the given origin is a "forge fork" (ie. created with the "Fork" button
    of GitHub-like forges), :meth:`build_extrinsic_origin_metadata` sets this to
    a list of origins it was forked from; closest parent first."""
    def __init__(
        self,
        storage: StorageInterface,
        origin_url: str,
        logging_class: Optional[str] = None,
        save_data_path: Optional[str] = None,
        max_content_size: Optional[int] = None,
        lister_name: Optional[str] = None,
        lister_instance_name: Optional[str] = None,
        metadata_fetcher_credentials: CredentialsType = None,
    ):
        if lister_name == "":
            raise ValueError("lister_name must not be the empty string")
        if lister_name is None and lister_instance_name is not None:
            raise ValueError(
                f"lister_name is None but lister_instance_name is {lister_instance_name!r}"
            )
        if lister_name is not None and lister_instance_name is None:
            raise ValueError(
                f"lister_instance_name is None but lister_name is {lister_name!r}"
            )

        self.storage = storage
        self.origin = Origin(url=origin_url)
        self.max_content_size = int(
            max_content_size) if max_content_size else None
        self.lister_name = lister_name
        self.lister_instance_name = lister_instance_name
        self.metadata_fetcher_credentials = metadata_fetcher_credentials or {}

        if logging_class is None:
            logging_class = "%s.%s" % (
                self.__class__.__module__,
                self.__class__.__name__,
            )
        self.log = logging.getLogger(logging_class)

        _log = logging.getLogger("requests.packages.urllib3.connectionpool")
        _log.setLevel(logging.WARN)

        # possibly overridden in self.prepare method
        self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc)

        self.loaded_snapshot_id = None

        if save_data_path:
            path = save_data_path
            os.stat(path)
            if not os.access(path, os.R_OK | os.W_OK):
                raise PermissionError("Permission denied: %r" % path)

        self.save_data_path = save_data_path

        self.parent_origins = None

        self.statsd = Statsd(namespace="swh_loader",
                             constant_tags={"visit_type": self.visit_type})

    @classmethod
    def from_config(cls, storage: Dict[str, Any], **config: Any):
        """Instantiate a loader from a configuration dict.

        This is basically a backwards-compatibility shim for the CLI.

        Args:
          storage: instantiation config for the storage
          config: the configuration dict for the loader, with the following keys:
            - credentials (optional): credentials list for the scheduler
            - any other kwargs passed to the loader.

        Returns:
          the instantiated loader
        """
        # Drop the legacy config keys which aren't used for this generation of loader.
        for legacy_key in ("storage", "celery"):
            config.pop(legacy_key, None)

        # Instantiate the storage
        storage_instance = get_storage(**storage)
        return cls(storage=storage_instance, **config)

    @classmethod
    def from_configfile(cls, **kwargs: Any):
        """Instantiate a loader from the configuration loaded from the
        SWH_CONFIG_FILENAME envvar, with potential extra keyword arguments if their
        value is not None.

        Args:
            kwargs: kwargs passed to the loader instantiation

        """
        config = dict(load_from_envvar(DEFAULT_CONFIG))
        config.update({k: v for k, v in kwargs.items() if v is not None})
        return cls.from_config(**config)

    def save_data(self) -> None:
        """Save the data associated to the current load"""
        raise NotImplementedError

    def get_save_data_path(self) -> str:
        """The path to which we archive the loader's raw data"""
        if not hasattr(self, "__save_data_path"):
            year = str(self.visit_date.year)

            assert self.origin
            url = self.origin.url.encode("utf-8")
            origin_url_hash = hashlib.sha1(url).hexdigest()

            path = "%s/sha1:%s/%s/%s" % (
                self.save_data_path,
                origin_url_hash[0:2],
                origin_url_hash,
                year,
            )

            os.makedirs(path, exist_ok=True)
            self.__save_data_path = path

        return self.__save_data_path

    def flush(self) -> Dict[str, int]:
        """Flush any potential buffered data not sent to swh-storage.
        Returns the same value as :meth:`swh.storage.interface.StorageInterface.flush`.
        """
        return self.storage.flush()

    def cleanup(self) -> None:
        """Last step executed by the loader."""
        raise NotImplementedError

    def _store_origin_visit(self) -> None:
        """Store origin and visit references. Sets the self.visit references."""
        assert self.origin
        self.storage.origin_add([self.origin])

        assert isinstance(self.visit_type, str)
        self.visit = list(
            self.storage.origin_visit_add([
                OriginVisit(
                    origin=self.origin.url,
                    date=self.visit_date,
                    type=self.visit_type,
                )
            ]))[0]

    def prepare(self) -> None:
        """Second step executed by the loader to prepare some state needed by
           the loader.

        Raises
           NotFound exception if the origin to ingest is not found.

        """
        raise NotImplementedError

    def get_origin(self) -> Origin:
        """Get the origin that is currently being loaded.
        self.origin should be set in :func:`prepare_origin`

        Returns:
          dict: an origin ready to be sent to storage by
          :func:`origin_add`.
        """
        assert self.origin
        return self.origin

    def fetch_data(self) -> bool:
        """Fetch the data from the source the loader is currently loading
           (ex: git/hg/svn/... repository).

        Returns:
            a value that is interpreted as a boolean. If True, fetch_data needs
            to be called again to complete loading.

        """
        raise NotImplementedError

    def process_data(self) -> bool:
        """Run any additional processing between fetching and storing the data

        Returns:
            a value that is interpreted as a boolean. If True, fetch_data needs
            to be called again to complete loading.
            Ignored if ``fetch_data`` already returned :const:`False`.
        """
        return True

    def store_data(self):
        """Store fetched data in the database.

        Should call the :func:`maybe_load_xyz` methods, which handle the
        bundles sent to storage, rather than send directly.
        """
        raise NotImplementedError

    def load_status(self) -> Dict[str, str]:
        """Detailed loading status.

        Defaults to logging an eventful load.

        Returns: a dictionary that is eventually passed back as the task's
          result to the scheduler, allowing tuning of the task recurrence
          mechanism.
        """
        return {
            "status": "eventful",
        }

    def post_load(self, success: bool = True) -> None:
        """Permit the loader to do some additional actions according to status
        after the loading is done. The flag success indicates the
        loading's status.

        Defaults to doing nothing.

        This is up to the implementer of this method to make sure this
        does not break.

        Args:
            success (bool): the success status of the loading

        """
        pass

    def visit_status(self) -> str:
        """Detailed visit status.

        Defaults to logging a full visit.
        """
        return "full"

    def pre_cleanup(self) -> None:
        """As a first step, will try and check for dangling data to cleanup.
        This should do its best to avoid raising issues.

        """
        pass

    def load(self) -> Dict[str, str]:
        r"""Loading logic for the loader to follow:

        - Store the actual ``origin_visit`` to storage
        - Call :meth:`prepare` to prepare any eventual state
        - Call :meth:`get_origin` to get the origin we work with and store

        - while True:

          - Call :meth:`fetch_data` to fetch the data to store
          - Call :meth:`process_data` to optionally run processing between
            :meth:`fetch_data` and :meth:`store_data`
          - Call :meth:`store_data` to store the data

        - Call :meth:`cleanup` to clean up any eventual state put in place
             in :meth:`prepare` method.

        """
        try:
            with self.statsd_timed("pre_cleanup"):
                self.pre_cleanup()
        except Exception:
            msg = "Cleaning up dangling data failed! Continue loading."
            self.log.warning(msg)
            sentry_sdk.capture_exception()

        self._store_origin_visit()

        assert (
            self.visit.visit
        ), "The method `_store_origin_visit` should set the visit (OriginVisit)"
        self.log.info("Load origin '%s' with type '%s'", self.origin.url,
                      self.visit.type)

        try:
            with self.statsd_timed("build_extrinsic_origin_metadata"):
                metadata = self.build_extrinsic_origin_metadata()
            self.load_metadata_objects(metadata)
        except Exception as e:
            sentry_sdk.capture_exception(e)
            # Do not fail the whole task if this is the only failure
            self.log.exception(
                "Failure while loading extrinsic origin metadata.",
                extra={
                    "swh_task_args": [],
                    "swh_task_kwargs": {
                        "origin": self.origin.url,
                        "lister_name": self.lister_name,
                        "lister_instance_name": self.lister_instance_name,
                    },
                },
            )

        total_time_fetch_data = 0.0
        total_time_process_data = 0.0
        total_time_store_data = 0.0

        try:
            # Initially not a success, will be True when actually one
            success = False
            with self.statsd_timed("prepare"):
                self.prepare()

            while True:
                t1 = time.monotonic()
                more_data_to_fetch = self.fetch_data()
                t2 = time.monotonic()
                total_time_fetch_data += t2 - t1

                more_data_to_fetch = self.process_data() and more_data_to_fetch
                t3 = time.monotonic()
                total_time_process_data += t3 - t2

                self.store_data()
                t4 = time.monotonic()
                total_time_store_data += t4 - t3
                if not more_data_to_fetch:
                    break

            self.statsd_timing("fetch_data", total_time_fetch_data * 1000.0)
            self.statsd_timing("process_data",
                               total_time_process_data * 1000.0)
            self.statsd_timing("store_data", total_time_store_data * 1000.0)

            status = self.visit_status()
            visit_status = OriginVisitStatus(
                origin=self.origin.url,
                visit=self.visit.visit,
                type=self.visit_type,
                date=now(),
                status=status,
                snapshot=self.loaded_snapshot_id,
            )
            self.storage.origin_visit_status_add([visit_status])
            success = True
            with self.statsd_timed("post_load",
                                   tags={
                                       "success": success,
                                       "status": status
                                   }):
                self.post_load()
        except BaseException as e:
            success = False
            if isinstance(e, NotFound):
                status = "not_found"
                task_status = "uneventful"
            else:
                status = "partial" if self.loaded_snapshot_id else "failed"
                task_status = "failed"

            self.log.exception(
                "Loading failure, updating to `%s` status",
                status,
                extra={
                    "swh_task_args": [],
                    "swh_task_kwargs": {
                        "origin": self.origin.url,
                        "lister_name": self.lister_name,
                        "lister_instance_name": self.lister_instance_name,
                    },
                },
            )
            if not isinstance(e, (SystemExit, KeyboardInterrupt)):
                sentry_sdk.capture_exception()
            visit_status = OriginVisitStatus(
                origin=self.origin.url,
                visit=self.visit.visit,
                type=self.visit_type,
                date=now(),
                status=status,
                snapshot=self.loaded_snapshot_id,
            )
            self.storage.origin_visit_status_add([visit_status])
            with self.statsd_timed("post_load",
                                   tags={
                                       "success": success,
                                       "status": status
                                   }):
                self.post_load(success=success)
            if not isinstance(e, Exception):
                # e derives from BaseException but not Exception; this is most likely
                # SystemExit or KeyboardInterrupt, so we should re-raise it.
                raise
            return {"status": task_status}
        finally:
            with self.statsd_timed("flush",
                                   tags={
                                       "success": success,
                                       "status": status
                                   }):
                self.flush()
            with self.statsd_timed("cleanup",
                                   tags={
                                       "success": success,
                                       "status": status
                                   }):
                self.cleanup()

        return self.load_status()

    def load_metadata_objects(
            self, metadata_objects: List[RawExtrinsicMetadata]) -> None:
        if not metadata_objects:
            return

        authorities = {mo.authority for mo in metadata_objects}
        self.storage.metadata_authority_add(list(authorities))

        fetchers = {mo.fetcher for mo in metadata_objects}
        self.storage.metadata_fetcher_add(list(fetchers))

        self.storage.raw_extrinsic_metadata_add(metadata_objects)

    def build_extrinsic_origin_metadata(self) -> List[RawExtrinsicMetadata]:
        """Builds a list of full RawExtrinsicMetadata objects, using
        a metadata fetcher returned by :func:`get_fetcher_classes`."""
        if self.lister_name is None:
            self.log.debug(
                "lister_not provided, skipping extrinsic origin metadata")
            return []

        assert (self.lister_instance_name is not None
                ), "lister_instance_name is None, but lister_name is not"

        metadata = []

        fetcher_classes = get_fetchers_for_lister(self.lister_name)

        self.statsd_average("metadata_fetchers", len(fetcher_classes))

        for cls in fetcher_classes:
            metadata_fetcher = cls(
                origin=self.origin,
                lister_name=self.lister_name,
                lister_instance_name=self.lister_instance_name,
                credentials=self.metadata_fetcher_credentials,
            )
            with self.statsd_timed("fetch_one_metadata",
                                   tags={"fetcher": cls.FETCHER_NAME}):
                metadata.extend(metadata_fetcher.get_origin_metadata())
            if self.parent_origins is None:
                self.parent_origins = metadata_fetcher.get_parent_origins()
                self.statsd_average(
                    "metadata_parent_origins",
                    len(self.parent_origins),
                    tags={"fetcher": cls.FETCHER_NAME},
                )
        self.statsd_average("metadata_objects", len(metadata))

        return metadata

    def statsd_timed(self,
                     name: str,
                     tags: Dict[str, Any] = {}) -> ContextManager:
        """
        Wrapper for :meth:`swh.core.statsd.Statsd.timed`, which uses the standard
        metric name and tags for loaders.
        """
        return self.statsd.timed("operation_duration_seconds",
                                 tags={
                                     "operation": name,
                                     **tags
                                 })

    def statsd_timing(self,
                      name: str,
                      value: float,
                      tags: Dict[str, Any] = {}) -> None:
        """
        Wrapper for :meth:`swh.core.statsd.Statsd.timing`, which uses the standard
        metric name and tags for loaders.
        """
        self.statsd.timing("operation_duration_seconds",
                           value,
                           tags={
                               "operation": name,
                               **tags
                           })

    def statsd_average(self,
                       name: str,
                       value: Union[int, float],
                       tags: Dict[str, Any] = {}) -> None:
        """Increments both ``{name}_sum`` (by the ``value``) and ``{name}_count``
        (by ``1``), allowing to prometheus to compute the average ``value`` over
        time."""
        self.statsd.increment(f"{name}_sum", value, tags=tags)
        self.statsd.increment(f"{name}_count", tags=tags)