Exemplo n.º 1
0
async def test_bind_address_ipv6(async_finalizer, client):
    @protocol.method(path="/test", operation="POST", client_types=["api"])
    async def test_endpoint():
        pass

    class TestSlice(ServerSlice):
        @protocol.handle(test_endpoint)
        async def test_endpoint_handle(self):
            return 200

    # Get free port on all interfaces
    sock = netutil.bind_sockets(0, "::", family=socket.AF_INET6)[0]
    (_addr, free_port, _flowinfo, _scopeid) = sock.getsockname()
    sock.close()

    # Configure server
    Config.load_config()
    Config.set("server", "bind-port", str(free_port))
    Config.set("server", "bind-address", "::1")
    Config.set("client_rest_transport", "port", str(free_port))
    Config.set("client_rest_transport", "host", "::1")

    # Start server
    rs = Server()
    rs.add_slice(TestSlice("test"))
    await rs.start()
    async_finalizer(rs.stop)

    # Check if server is reachable on loopback interface
    result = await client.test_endpoint()
    assert result.code == 200
Exemplo n.º 2
0
 async def prestart(self, server: protocol.Server) -> None:
     await super().prestart(server)
     self.autostarted_agent_manager = cast(
         AutostartedAgentManager,
         server.get_slice(SLICE_AUTOSTARTED_AGENT_MANAGER))
     self.resource_service = cast(ResourceService,
                                  server.get_slice(SLICE_RESOURCE))
Exemplo n.º 3
0
    async def assert_port_bound():
        # Start server
        rs = Server()
        rs.add_slice(TestSlice("test"))
        await rs.start()
        async_finalizer(rs.stop)

        # Check if server is reachable on loopback interface
        result = await client.test_endpoint()
        assert result.code == 200
        await rs.stop()
Exemplo n.º 4
0
 async def prestart(self, server: protocol.Server) -> None:
     await super().prestart(server)
     self.server_slice = cast(Server, server.get_slice(SLICE_SERVER))
     self.agent_manager = cast(AgentManager,
                               server.get_slice(SLICE_AGENT_MANAGER))
     self.autostarted_agent_manager = cast(
         AutostartedAgentManager,
         server.get_slice(SLICE_AUTOSTARTED_AGENT_MANAGER))
     self.orchestration_service = cast(
         OrchestrationService, server.get_slice(SLICE_ORCHESTRATION))
     self.resource_service = cast(ResourceService,
                                  server.get_slice(SLICE_RESOURCE))
Exemplo n.º 5
0
async def compilerservice(server_config, init_dataclasses_and_load_schema):
    server = Server()
    cs = CompilerService()
    await cs.prestart(server)
    await cs.start()
    server.add_slice(cs)
    notification_service = NotificationService()
    await notification_service.prestart(server)
    await notification_service.start()
    yield cs
    await notification_service.prestop()
    await notification_service.stop()
    await cs.prestop()
    await cs.stop()
Exemplo n.º 6
0
async def test_bind_address_ipv4(async_finalizer, client):
    """This test case check if the Inmanta server doesn't bind on another interface than 127.0.0.1 when bind-address is equal
    to 127.0.0.1. Procedure:
        1) Get free port on all interfaces.
        2) Bind that port on a non-loopback interface, so it's not available for the inmanta server anymore.
        3) Start the Inmanta server with bind-address 127.0.0.1. and execute an API call
    """
    @protocol.method(path="/test", operation="POST", client_types=["api"])
    async def test_endpoint():
        pass

    class TestSlice(ServerSlice):
        @protocol.handle(test_endpoint)
        async def test_endpoint_handle(self):
            return 200

    # Select a bind address which is not on the loopback interface
    non_loopback_interfaces = [
        i for i in netifaces.interfaces()
        if i != "lo" and socket.AF_INET in netifaces.ifaddresses(i)
    ]
    bind_iface = "eth0" if "eth0" in non_loopback_interfaces else random.choice(
        non_loopback_interfaces)
    bind_addr = netifaces.ifaddresses(bind_iface)[socket.AF_INET][0]["addr"]

    # Get free port on all interfaces
    sock = netutil.bind_sockets(0, "0.0.0.0", family=socket.AF_INET)[0]
    _addr, free_port = sock.getsockname()
    sock.close()

    # Bind port on non-loopback interface
    sock = netutil.bind_sockets(free_port, bind_addr, family=socket.AF_INET)[0]
    try:
        # Configure server
        Config.load_config()
        Config.set("server", "bind-port", str(free_port))
        Config.set("server", "bind-address", "127.0.0.1")
        Config.set("client_rest_transport", "port", str(free_port))

        # Start server
        rs = Server()
        rs.add_slice(TestSlice("test"))
        await rs.start()
        async_finalizer(rs.stop)

        # Check if server is reachable on loopback interface
        result = await client.test_endpoint()
        assert result.code == 200
    finally:
        sock.close()
Exemplo n.º 7
0
async def test_generate_openapi_definition(server: Server, feature_manager: FeatureManager, patch_openapi_spec_validator: None):
    global_url_map = server._transport.get_global_url_map(server.get_slices().values())
    openapi = OpenApiConverter(global_url_map, feature_manager)
    openapi_json = openapi.generate_openapi_json()
    assert openapi_json
    openapi_parsed = json.loads(openapi_json)
    openapi_v3_spec_validator.validate(openapi_parsed)
Exemplo n.º 8
0
    async def prestart(self, server: protocol.Server) -> None:
        self._server = server
        self._server_storage: Dict[str, str] = self.check_storage()
        self.compiler: "CompilerService" = cast(
            "CompilerService", server.get_slice(SLICE_COMPILER))

        self.setup_dashboard()
async def compilerservice(server_config, init_dataclasses_and_load_schema):
    server = Server()
    cs = CompilerService()
    await cs.prestart(server)
    await cs.start()
    yield cs
    await cs.prestop()
    await cs.stop()
Exemplo n.º 10
0
async def test_agent_timeout(unused_tcp_port, no_tid_check, async_finalizer,
                             postgres_db, database_name):
    from inmanta.config import Config

    configure(unused_tcp_port, database_name, postgres_db.port)

    Config.set("server", "agent-timeout", "1")

    rs = Server()
    server = SessionSpy()
    rs.get_slice(SLICE_SESSION_MANAGER).add_listener(server)
    rs.add_slice(server)
    await rs.start()
    async_finalizer(rs.stop)

    env = uuid.uuid4()

    # agent 1
    agent = Agent("agent")
    await agent.add_end_point_name("agent")
    agent.set_environment(env)
    await agent.start()
    async_finalizer(agent.stop)

    # wait till up
    await retry_limited(lambda: len(server.get_sessions()) == 1, timeout=10)
    assert len(server.get_sessions()) == 1
    await assert_agent_counter(agent, 1, 0)

    # agent 2
    agent2 = Agent("agent")
    await agent2.add_end_point_name("agent")
    agent2.set_environment(env)
    await agent2.start()
    async_finalizer(agent2.stop)

    # wait till up
    await retry_limited(lambda: len(server.get_sessions()) == 2, timeout=10)
    assert len(server.get_sessions()) == 2
    await assert_agent_counter(agent, 1, 0)
    await assert_agent_counter(agent2, 1, 0)

    # see if it stays up
    await check_sessions(server.get_sessions())
    await sleep(1.1)
    assert len(server.get_sessions()) == 2
    await check_sessions(server.get_sessions())

    # take it down
    await agent2.stop()

    # Timeout=2
    # -> 1sec: Wait for agent-timeout
    # -> 1sec: Wait until session bookkeeping is updated
    await retry_limited(lambda: len(server.get_sessions()) == 1, timeout=2)
    print(server.get_sessions())
    await check_sessions(server.get_sessions())
    assert server.expires == 1
    await assert_agent_counter(agent, 1, 0)
    await assert_agent_counter(agent2, 1, 0)
Exemplo n.º 11
0
async def test_2way_protocol(unused_tcp_port, no_tid_check, postgres_db,
                             database_name):
    configure(unused_tcp_port, database_name, postgres_db.port)

    rs = Server()
    server = SessionSpy()
    rs.get_slice(SLICE_SESSION_MANAGER).add_listener(server)
    rs.add_slice(server)
    await rs.start()

    agent = Agent("agent")
    await agent.add_end_point_name("agent")
    agent.set_environment(uuid.uuid4())
    await agent.start()

    await retry_limited(lambda: len(server.get_sessions()) == 1, 10)
    assert len(server.get_sessions()) == 1
    await assert_agent_counter(agent, 1, 0)

    client = protocol.Client("client")
    status = await client.get_status_x(str(agent.environment))
    assert status.code == 200
    assert "agents" in status.result
    assert len(status.result["agents"]) == 1
    assert status.result["agents"][0]["status"], "ok"
    await server.stop()

    await rs.stop()
    await agent.stop()
    await assert_agent_counter(agent, 1, 0)
Exemplo n.º 12
0
 async def start_server():
     rs = Server()
     server = SessionSpy()
     rs.get_slice(SLICE_SESSION_MANAGER).add_listener(server)
     rs.add_slice(server)
     await rs.start()
     async_finalizer(rs.stop)
     return server, rs
Exemplo n.º 13
0
def test_phase_3():
    with splice_extension_in("test_module_path"):
        from inmanta_ext.testplugin.extension import XTestSlice

        server = Server()
        server.add_slice(XTestSlice())
        server.add_slice(inmanta.server.server.Server())
        server.add_slice(AgentManager())
        server.add_slice(AutostartedAgentManager())

        order = server._get_slice_sequence()
        print([s.name for s in order])
        assert [s.name for s in order] == [
            SLICE_SESSION_MANAGER,
            SLICE_AGENT_MANAGER,
            SLICE_SERVER,
            SLICE_AUTOSTARTED_AGENT_MANAGER,
            SLICE_TRANSPORT,
            "testplugin.testslice",
        ]
Exemplo n.º 14
0
async def test_scheduler(server_config, init_dataclasses_and_load_schema,
                         caplog):
    """Test the scheduler part in isolation, mock out compile runner and listen to state updates"""
    class Collector(CompileStateListener):
        """
        Collect all state updates, optionally hang the processing of listeners
        """
        def __init__(self):
            self.seen = []
            self.preseen = []
            self.lock = Semaphore(1)

        def reset(self):
            self.seen = []
            self.preseen = []

        async def compile_done(self, compile: data.Compile):
            self.preseen.append(compile)
            print("Got compile done for ", compile.remote_id)
            async with self.lock:
                self.seen.append(compile)

        async def hang(self):
            await self.lock.acquire()

        def release(self):
            self.lock.release()

        def verify(self, envs: uuid.UUID):
            assert sorted([x.remote_id for x in self.seen]) == sorted(envs)
            self.reset()

    class HangRunner(object):
        """
        compile runner mock, hang until released
        """
        def __init__(self):
            self.lock = Semaphore(0)
            self.started = False
            self.done = False
            self.version = None

        async def run(self, force_update: Optional[bool] = False):
            self.started = True
            await self.lock.acquire()
            self.done = True
            return True, None

        def release(self):
            self.lock.release()

    class HookedCompilerService(CompilerService):
        """
        hook in the hangrunner
        """
        def __init__(self):
            super(HookedCompilerService, self).__init__()
            self.locks = {}

        def _get_compile_runner(self, compile: data.Compile, project_dir: str):
            print("Get Run: ", compile.remote_id, compile.id)
            runner = HangRunner()
            self.locks[compile.remote_id] = runner
            return runner

        def get_runner(self, remote_id: uuid.UUID) -> HangRunner:
            return self.locks.get(remote_id)

    # manual setup of server
    server = Server()
    cs = HookedCompilerService()
    await cs.prestart(server)
    await cs.start()
    server.add_slice(cs)
    notification_service = NotificationService()
    await notification_service.prestart(server)
    await notification_service.start()
    collector = Collector()
    cs.add_listener(collector)

    async def request_compile(env: data.Environment) -> uuid.UUID:
        """Request compile for given env, return remote_id"""
        u1 = uuid.uuid4()
        # add unique environment variables to prevent merging in request_recompile
        await cs.request_recompile(env,
                                   False,
                                   False,
                                   u1,
                                   env_vars={"uuid": str(u1)})
        results = await data.Compile.get_by_remote_id(env.id, u1)
        assert len(results) == 1
        assert results[0].remote_id == u1
        print("request: ", u1, results[0].id)
        return u1

    # setup projects in the database
    project = data.Project(name="test")
    await project.insert()
    env1 = data.Environment(name="dev",
                            project=project.id,
                            repo_url="",
                            repo_branch="")
    await env1.insert()
    env2 = data.Environment(name="dev2",
                            project=project.id,
                            repo_url="",
                            repo_branch="")
    await env2.insert()

    # setup series of compiles for two envs
    # e1 is for a plain run
    # e2 is for server restart
    e1 = [await request_compile(env1) for i in range(3)]
    e2 = [await request_compile(env2) for i in range(4)]
    print("env 1:", e1)

    async def check_compile_in_sequence(env: data.Environment,
                                        remote_ids: List[uuid.UUID], idx: int):
        """
        Check integrity of a compile sequence and progress the hangrunner.
        """
        before = remote_ids[:idx]

        for rid in before:
            prevrunner = cs.get_runner(rid)
            assert prevrunner.done

        if idx < len(remote_ids):
            current = remote_ids[idx]
            after = remote_ids[idx + 1:]

            assert await cs.is_compiling(env.id) == 200

            await retry_limited(lambda: cs.get_runner(current) is not None, 1)
            await retry_limited(lambda: cs.get_runner(current).started, 1)

            for rid in after:
                nextrunner = cs.get_runner(rid)
                assert nextrunner is None

            cs.get_runner(current).release()
            await asyncio.sleep(0)
            await retry_limited(lambda: cs.get_runner(current).done, 1)

        else:

            async def isdone():
                return await cs.is_compiling(env.id) == 204

            await retry_limited(isdone, 1)

    # run through env1, entire sequence
    for i in range(4):
        await check_compile_in_sequence(env1, e1, i)
    collector.verify(e1)
    print("env1 done")

    print("env2 ", e2)
    # make event collector hang
    await collector.hang()
    # progress two steps into env2
    for i in range(2):
        await check_compile_in_sequence(env2, e2, i)

    assert not collector.seen
    print(collector.preseen)
    await retry_limited(lambda: len(collector.preseen) == 2, 1)

    # test server restart
    await notification_service.prestop()
    await notification_service.stop()
    await cs.prestop()
    await cs.stop()

    # in the log, find cancel of compile(hangs) and handler(hangs)
    LogSequence(caplog, allow_errors=False).contains(
        "inmanta.util", logging.WARNING,
        "was cancelled").contains("inmanta.util", logging.WARNING,
                                  "was cancelled").no_more_errors()

    print("restarting")

    # restart new server
    cs = HookedCompilerService()
    await cs.prestart(server)
    await cs.start()
    collector = Collector()
    cs.add_listener(collector)

    # complete the sequence, expect re-run of third compile
    for i in range(3):
        print(i)
        await check_compile_in_sequence(env2, e2[2:], i)

    # all are re-run, entire sequence present
    collector.verify(e2)

    await report_db_index_usage()
Exemplo n.º 15
0
 async def prestart(self, server: protocol.Server) -> None:
     await super().prestart(server)
     self._compiler_service = cast(CompilerService, server.get_slice(SLICE_COMPILER))
     self._compiler_service.add_listener(self)
Exemplo n.º 16
0
 async def prestart(self, server: protocol.Server) -> None:
     await super().prestart(server)
     self.file_slice = cast(FileService, server.get_slice(SLICE_FILE))
Exemplo n.º 17
0
 async def prestart(self, server: protocol.Server) -> None:
     await super().prestart(server)
     self.server_slice = cast(Server, server.get_slice(SLICE_SERVER))
     self.agentmanager = cast(AgentManager,
                              server.get_slice(SLICE_AGENT_MANAGER))
Exemplo n.º 18
0
 async def prestart(self, server: protocol.Server) -> None:
     await super().prestart(server)
     self.server_slice = cast(Server, server.get_slice(SLICE_SERVER))
Exemplo n.º 19
0
class InmantaBootloader(object):
    """The inmanta bootloader is responsible for:
    - discovering extensions
    - loading extensions
    - loading core and extension slices
    - starting the server and its slices in the correct order
    """
    def __init__(self) -> None:
        self.restserver = Server()
        self.started = False
        self.feature_manager: Optional[FeatureManager] = None

    async def start(self) -> None:
        ctx = self.load_slices()
        self.feature_manager = ctx.get_feature_manager()
        for mypart in ctx.get_slices():
            self.restserver.add_slice(mypart)
            ctx.get_feature_manager().add_slice(mypart)
        await self.restserver.start()
        self.started = True

    async def stop(self, timeout: Optional[int] = None) -> None:
        """
        :param timeout: Raises TimeoutError when the server hasn't finished stopping after
                        this amount of seconds. This argument should only be used by test
                        cases.
        """
        if not timeout:
            await self._stop()
        else:
            await asyncio.wait_for(self._stop(), timeout=timeout)

    async def _stop(self) -> None:
        await self.restserver.stop()
        if self.feature_manager is not None:
            self.feature_manager.stop()

    # Extension loading Phase I: from start to setup functions collected
    def _discover_plugin_packages(self) -> List[str]:
        """Discover all packages that are defined in the inmanta_ext namespace package. Filter available extensions based on
        enabled_extensions and disabled_extensions config in the server configuration.

        :return: A list of all subpackages defined in inmanta_ext
        """
        inmanta_ext = importlib.import_module(EXTENSION_NAMESPACE)
        available = {
            name[len(EXTENSION_NAMESPACE) + 1:]: name
            for finder, name, ispkg in iter_namespace(inmanta_ext)
        }

        LOGGER.info("Discovered extensions: %s", ", ".join(available.keys()))

        extensions = []
        enabled = [x for x in config.server_enabled_extensions.get() if len(x)]

        if enabled:
            for ext in enabled:
                if ext not in available:
                    raise PluginLoadFailed(
                        f"Extension {ext} in config option {config.server_enabled_extensions.name} in section "
                        f"{config.server_enabled_extensions.section} is not available."
                    )

                extensions.append(available[ext])
        elif len(available) > 1:
            # More than core is available
            LOGGER.info(
                f"Load extensions by setting configuration option {config.server_enabled_extensions.name} in section "
                f"{config.server_enabled_extensions.section}. {len(available) - 1} extensions available but none are enabled."
            )

        if "core" not in extensions:
            extensions.append(available["core"])

        return extensions

    def _load_extension(self,
                        name: str) -> Callable[[ApplicationContext], None]:
        """Import the extension defined in the package in name and return the setup function that needs to be called for the
        extension to register its slices in the application context.
        """
        try:
            importlib.import_module(name)
        except Exception as e:
            raise PluginLoadFailed(f"Could not load module {name}") from e

        try:
            mod = importlib.import_module(f"{name}.{EXTENSION_MODULE}")
            return mod.setup
        except Exception as e:
            raise PluginLoadFailed(
                f"Could not load module {name}.{EXTENSION_MODULE}") from e

    def _load_extensions(
            self) -> Dict[str, Callable[[ApplicationContext], None]]:
        """Discover all extensions, validate correct naming and load its setup function"""
        plugins: Dict[str, Callable[[ApplicationContext], None]] = {}
        for name in self._discover_plugin_packages():
            try:
                plugin = self._load_extension(name)
                assert name.startswith(f"{EXTENSION_NAMESPACE}.")
                name = name[len(EXTENSION_NAMESPACE) + 1:]
                plugins[name] = plugin
            except PluginLoadFailed:
                LOGGER.warning("Could not load extension %s",
                               name,
                               exc_info=True)
        return plugins

    # Extension loading Phase II: collect slices
    def _collect_slices(
        self, extensions: Dict[str, Callable[[ApplicationContext], None]]
    ) -> ApplicationContext:
        """
        Call the setup function on all extensions and let them register their slices in the ApplicationContext.
        """
        ctx = ApplicationContext()
        for name, setup in extensions.items():
            myctx = ConstrainedApplicationContext(ctx, name)
            setup(myctx)
        return ctx

    def load_slices(self) -> ApplicationContext:
        """
        Load all slices in the server
        """
        exts = self._load_extensions()
        return self._collect_slices(exts)
Exemplo n.º 20
0
 def __init__(self) -> None:
     self.restserver = Server()
     self.started = False
     self.feature_manager: Optional[FeatureManager] = None
Exemplo n.º 21
0
def feature_manager(server: Server) -> FeatureManager:
    return server.get_slice(SLICE_SERVER).feature_manager
Exemplo n.º 22
0
async def test_notification_cleanup_on_start(init_dataclasses_and_load_schema,
                                             async_finalizer,
                                             server_config) -> None:
    project = data.Project(name="test")
    await project.insert()

    env_with_default_retention = data.Environment(name="testenv",
                                                  project=project.id)
    await env_with_default_retention.insert()
    env_with_short_retention = data.Environment(name="testenv2",
                                                project=project.id)
    await env_with_short_retention.insert()
    await env_with_short_retention.set(data.NOTIFICATION_RETENTION, 30)

    timestamps = [
        datetime.datetime.now().astimezone() - datetime.timedelta(days=366),
        datetime.datetime.now().astimezone() - datetime.timedelta(days=35),
        datetime.datetime.now().astimezone(),
    ]

    async def insert_notifications_with_timestamps(
            timestamps: Sequence[datetime.datetime],
            environment_ids: Sequence[uuid.UUID]) -> None:
        for env_id in environment_ids:
            for created in timestamps:
                await data.Notification(
                    title="Notification",
                    message="Something happened",
                    environment=env_id,
                    severity=const.NotificationSeverity.message,
                    uri="/api/v2/notification",
                    created=created,
                    read=False,
                    cleared=False,
                ).insert()

    await insert_notifications_with_timestamps(
        timestamps,
        [env_with_default_retention.id, env_with_short_retention.id])

    server = Server()
    notification_service = NotificationService()
    compiler_service = CompilerService()
    server.add_slice(compiler_service)
    server.add_slice(notification_service)
    await server.start()
    async_finalizer.add(server.stop)

    async def notification_cleaned_up(
            env_id: uuid.UUID, expected_length_after_cleanup: int) -> bool:
        default_env_notifications = await data.Notification.get_list(
            environment=env_id)
        return len(default_env_notifications) == expected_length_after_cleanup

    await retry_limited(partial(notification_cleaned_up,
                                env_with_default_retention.id, 2),
                        timeout=10)

    default_env_notifications = await data.Notification.get_list(
        environment=env_with_default_retention.id,
        order="DESC",
        order_by_column="created")
    # Only the oldest one is deleted
    assert len(default_env_notifications) == 2
    assert default_env_notifications[0].created == timestamps[2]
    assert default_env_notifications[1].created == timestamps[1]

    await retry_limited(partial(notification_cleaned_up,
                                env_with_short_retention.id, 1),
                        timeout=10)
    short_retention_notifications = await data.Notification.get_list(
        environment=env_with_short_retention.id,
        order="DESC",
        order_by_column="created")
    # Only the latest one is kept
    assert len(short_retention_notifications) == 1
    assert short_retention_notifications[0].created == timestamps[2]