Esempio n. 1
0
class ServerConfig(TLSCMDConfig, MultiCommCMDConfig):
    port: int = field(
        "Port to bind to", default=8080,
    )
    addr: str = field(
        "Address to bind to", default="127.0.0.1",
    )
    upload_dir: str = field(
        "Directory to store uploaded files in", default=None,
    )
    static: str = field(
        "Directory to serve static content from", default=None,
    )
    js: bool = field(
        "Serve JavaScript API file at /api.js",
        default=False,
        action="store_true",
    )
    insecure: bool = field(
        "Start without TLS encryption", action="store_true", default=False,
    )
    cors_domains: List[str] = field(
        "Domains to allow CORS for (see keys in defaults dict for aiohttp_cors.setup)",
        default_factory=lambda: [],
    )
    allow_caching: bool = field(
        "Allow caching of HTTP responses", action="store_true", default=False,
    )
    models: Model = field(
        "Models configured on start",
        default_factory=lambda: AsyncContextManagerList(),
        action=list_action(AsyncContextManagerList),
        labeled=True,
    )
    sources: Sources = field(
        "Sources configured on start",
        default_factory=lambda: Sources(),
        action=list_action(Sources),
        labeled=True,
    )
    scorers: AccuracyScorer = field(
        "Scorers configured on start",
        default_factory=lambda: AsyncContextManagerList(),
        action=list_action(AsyncContextManagerList),
        labeled=True,
    )
    redirect: List[str] = field(
        "list of METHOD SOURCE_PATH DESTINATION_PATH pairs, number of elements must be divisible by 3",
        action=ParseRedirectsAction,
        default_factory=lambda: [],
    )
    portfile: pathlib.Path = field(
        "File to write bound port to when starting. Helpful when port 0 was requeseted to bind to any free port",
        default=None,
    )
Esempio n. 2
0
    def args(cls, args, *above) -> Dict[str, Arg]:
        cls.config_set(
            args,
            above,
            "directory",
            Arg(
                default=os.path.join(os.path.expanduser("~"), ".cache",
                                     "dffml", "scratch"),
                help="Directory where state should be saved",
            ),
        )

        cls.config_set(
            args,
            above,
            "predict",
            Arg(type=str, help="Label or the value to be predicted"),
        )

        cls.config_set(
            args,
            above,
            "features",
            Arg(
                nargs="+",
                required=True,
                type=Feature.load,
                action=list_action(Features),
                help="Features to train on",
            ),
        )
        return args
Esempio n. 3
0
class ServerConfig(TLSCMDConfig, MultiCommCMDConfig):
    port: int = field(
        "Port to bind to",
        default=8080,
    )
    addr: str = field(
        "Address to bind to",
        default="127.0.0.1",
    )
    upload_dir: str = field(
        "Directory to store uploaded files in",
        default=None,
    )
    static: str = field(
        "Directory to serve static content from",
        default=None,
    )
    js: bool = field(
        "Serve JavaScript API file at /api.js",
        default=False,
        action="store_true",
    )
    insecure: bool = field(
        "Start without TLS encryption",
        action="store_true",
        default=False,
    )
    cors_domains: List[str] = field(
        "Domains to allow CORS for (see keys in defaults dict for aiohttp_cors.setup)",
        default_factory=lambda: [],
    )
    models: Model = field(
        "Models configured on start",
        default_factory=lambda: AsyncContextManagerList(),
        action=list_action(AsyncContextManagerList),
        labeled=True,
    )
    sources: Sources = field(
        "Sources configured on start",
        default_factory=lambda: Sources(),
        action=list_action(Sources),
        labeled=True,
    )
Esempio n. 4
0
 def test_list_action(self):
     dest, cls, parser = ("features", Features, list_action(Features))
     namespace = Namespace(**{dest: False})
     with self.subTest(single=dest):
         action = parser(dest=dest, option_strings="")
         action(None, namespace, "feed")
         self.assertEqual(getattr(namespace, dest, False), Features("feed"))
     with self.subTest(multiple=dest):
         action = parser(dest=dest, option_strings="")
         action(None, namespace, ["feed", "face"])
         self.assertEqual(getattr(namespace, dest, False),
                          Features("feed", "face"))
Esempio n. 5
0
 def args(cls, args, *above) -> Dict[str, Arg]:
     cls.config_set(args, above, "directory", Arg())
     cls.config_set(
         args,
         above,
         "features",
         Arg(
             nargs="+",
             required=True,
             type=Feature.load,
             action=list_action(Features),
         ),
     )
     return args
Esempio n. 6
0
    def args(cls, args, *above) -> Dict[str, Arg]:
        cls.config_set(
            args,
            above,
            "directory",
            Arg(
                default=os.path.join(
                    os.path.expanduser("~"),
                    ".cache",
                    "dffml",
                    f"scikit-{entry_point_name}",
                ),
                help="Directory where state should be saved",
            ),
        )
        cls.config_set(
            args,
            above,
            "predict",
            Arg(type=str, help="Label or the value to be predicted"),
        )

        cls.config_set(
            args,
            above,
            "features",
            Arg(
                nargs="+",
                required=True,
                type=Feature.load,
                action=list_action(Features),
                help="Features to train on",
            ),
        )

        for param in inspect.signature(cls.SCIKIT_MODEL).parameters.values():
            # TODO if param.default is an array then Args needs to get a
            # nargs="+"
            cls.config_set(
                args,
                above,
                param.name,
                Arg(
                    type=cls.type_for(param),
                    default=NoDefaultValue
                    if param.default == inspect._empty else param.default,
                ),
            )
        return args
Esempio n. 7
0
 def args(cls, args, *above) -> Dict[str, Arg]:
     cls.config_set(
         args,
         above,
         "directory",
         Arg(
             default=os.path.join(
                 os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
             ),
             help="Directory where state should be saved",
         ),
     )
     cls.config_set(
         args,
         above,
         "steps",
         Arg(
             type=int,
             default=3000,
             help="Number of steps to train the model",
         ),
     )
     cls.config_set(
         args,
         above,
         "epochs",
         Arg(
             type=int,
             default=30,
             help="Number of iterations to pass over all repos in a source",
         ),
     )
     cls.config_set(
         args,
         above,
         "hidden",
         Arg(
             type=int,
             nargs="+",
             default=[12, 40, 15],
             help="List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
         ),
     )
     cls.config_set(
         args,
         above,
         "predict",
         Arg(help="Feature name holding truth value"),
     )
     cls.config_set(
         args,
         above,
         "features",
         Arg(
             nargs="+",
             required=True,
             type=Feature.load,
             action=list_action(Features),
             help="Features to train on",
         ),
     )
     return args
Esempio n. 8
0
class Server(TLSCMD, MultiCommCMD, Routes):
    """
    HTTP server providing access to DFFML APIs
    """

    # Used for testing
    RUN_YIELD_START = False
    RUN_YIELD_FINISH = False
    INSECURE_NO_TLS = False

    arg_port = Arg("-port", help="Port to bind to", type=int, default=8080)
    arg_addr = Arg("-addr", help="Address to bind to", default="127.0.0.1")
    arg_upload_dir = Arg(
        "-upload-dir",
        help="Directory to store uploaded files in",
        default=None,
    )
    arg_static = Arg("-static",
                     help="Directory to serve static content from",
                     default=None)
    arg_js = Arg(
        "-js",
        help="Serve JavaScript API file at /api.js",
        default=False,
        action="store_true",
    )
    arg_insecure = Arg(
        "-insecure",
        help="Start without TLS encryption",
        action="store_true",
        default=False,
    )
    arg_cors_domains = Arg(
        "-cors-domains",
        help=
        "Domains to allow CORS for (see keys in defaults dict for aiohttp_cors.setup)",
        nargs="+",
        default=[],
    )
    arg_models = Arg(
        "-models",
        help="Models configured on start",
        nargs="+",
        default=AsyncContextManagerList(),
        type=Model.load_labeled,
        action=list_action(AsyncContextManagerList),
    )
    arg_sources = Arg(
        "-sources",
        help="Sources configured on start",
        nargs="+",
        default=Sources(),
        type=BaseSource.load_labeled,
        action=list_action(Sources),
    )

    async def start(self):
        if self.insecure:
            self.site = web.TCPSite(self.runner,
                                    host=self.addr,
                                    port=self.port)
        else:
            ssl_context = ssl.create_default_context(
                purpose=ssl.Purpose.SERVER_AUTH, cafile=self.cert)
            ssl_context.load_cert_chain(self.cert, self.key)
            self.site = web.TCPSite(
                self.runner,
                host=self.addr,
                port=self.port,
                ssl_context=ssl_context,
            )
        await self.site.start()
        self.port = self.site._server.sockets[0].getsockname()[1]
        self.logger.info(f"Serving on {self.addr}:{self.port}")

    async def run(self):
        """
        Binds to port and starts HTTP server
        """
        # Create dictionaries to hold configured sources and models
        await self.setup()
        await self.start()
        # Load
        if self.mc_config is not None:
            # Restore atomic after config is set, allow setting for now
            atomic = self.mc_atomic
            self.mc_atomic = False
            await self.register_directory(self.mc_config)
            self.mc_atomic = atomic
        try:
            # If we are testing then RUN_YIELD will be an asyncio.Event
            if self.RUN_YIELD_START is not False:
                await self.RUN_YIELD_START.put(self)
                await self.RUN_YIELD_FINISH.wait()
            else:  # pragma: no cov
                # Wait for ctrl-c
                while True:
                    await asyncio.sleep(60)
        finally:
            await self.app.cleanup()
            await self.site.stop()