class ServerConfig(TLSCMDConfig, MultiCommCMDConfig): port: int = field( "Port to bind to", default=8080, ) addr: str = field( "Address to bind to", default="127.0.0.1", ) upload_dir: str = field( "Directory to store uploaded files in", default=None, ) static: str = field( "Directory to serve static content from", default=None, ) js: bool = field( "Serve JavaScript API file at /api.js", default=False, action="store_true", ) insecure: bool = field( "Start without TLS encryption", action="store_true", default=False, ) cors_domains: List[str] = field( "Domains to allow CORS for (see keys in defaults dict for aiohttp_cors.setup)", default_factory=lambda: [], ) allow_caching: bool = field( "Allow caching of HTTP responses", action="store_true", default=False, ) models: Model = field( "Models configured on start", default_factory=lambda: AsyncContextManagerList(), action=list_action(AsyncContextManagerList), labeled=True, ) sources: Sources = field( "Sources configured on start", default_factory=lambda: Sources(), action=list_action(Sources), labeled=True, ) scorers: AccuracyScorer = field( "Scorers configured on start", default_factory=lambda: AsyncContextManagerList(), action=list_action(AsyncContextManagerList), labeled=True, ) redirect: List[str] = field( "list of METHOD SOURCE_PATH DESTINATION_PATH pairs, number of elements must be divisible by 3", action=ParseRedirectsAction, default_factory=lambda: [], ) portfile: pathlib.Path = field( "File to write bound port to when starting. Helpful when port 0 was requeseted to bind to any free port", default=None, )
def args(cls, args, *above) -> Dict[str, Arg]: cls.config_set( args, above, "directory", Arg( default=os.path.join(os.path.expanduser("~"), ".cache", "dffml", "scratch"), help="Directory where state should be saved", ), ) cls.config_set( args, above, "predict", Arg(type=str, help="Label or the value to be predicted"), ) cls.config_set( args, above, "features", Arg( nargs="+", required=True, type=Feature.load, action=list_action(Features), help="Features to train on", ), ) return args
class ServerConfig(TLSCMDConfig, MultiCommCMDConfig): port: int = field( "Port to bind to", default=8080, ) addr: str = field( "Address to bind to", default="127.0.0.1", ) upload_dir: str = field( "Directory to store uploaded files in", default=None, ) static: str = field( "Directory to serve static content from", default=None, ) js: bool = field( "Serve JavaScript API file at /api.js", default=False, action="store_true", ) insecure: bool = field( "Start without TLS encryption", action="store_true", default=False, ) cors_domains: List[str] = field( "Domains to allow CORS for (see keys in defaults dict for aiohttp_cors.setup)", default_factory=lambda: [], ) models: Model = field( "Models configured on start", default_factory=lambda: AsyncContextManagerList(), action=list_action(AsyncContextManagerList), labeled=True, ) sources: Sources = field( "Sources configured on start", default_factory=lambda: Sources(), action=list_action(Sources), labeled=True, )
def test_list_action(self): dest, cls, parser = ("features", Features, list_action(Features)) namespace = Namespace(**{dest: False}) with self.subTest(single=dest): action = parser(dest=dest, option_strings="") action(None, namespace, "feed") self.assertEqual(getattr(namespace, dest, False), Features("feed")) with self.subTest(multiple=dest): action = parser(dest=dest, option_strings="") action(None, namespace, ["feed", "face"]) self.assertEqual(getattr(namespace, dest, False), Features("feed", "face"))
def args(cls, args, *above) -> Dict[str, Arg]: cls.config_set(args, above, "directory", Arg()) cls.config_set( args, above, "features", Arg( nargs="+", required=True, type=Feature.load, action=list_action(Features), ), ) return args
def args(cls, args, *above) -> Dict[str, Arg]: cls.config_set( args, above, "directory", Arg( default=os.path.join( os.path.expanduser("~"), ".cache", "dffml", f"scikit-{entry_point_name}", ), help="Directory where state should be saved", ), ) cls.config_set( args, above, "predict", Arg(type=str, help="Label or the value to be predicted"), ) cls.config_set( args, above, "features", Arg( nargs="+", required=True, type=Feature.load, action=list_action(Features), help="Features to train on", ), ) for param in inspect.signature(cls.SCIKIT_MODEL).parameters.values(): # TODO if param.default is an array then Args needs to get a # nargs="+" cls.config_set( args, above, param.name, Arg( type=cls.type_for(param), default=NoDefaultValue if param.default == inspect._empty else param.default, ), ) return args
def args(cls, args, *above) -> Dict[str, Arg]: cls.config_set( args, above, "directory", Arg( default=os.path.join( os.path.expanduser("~"), ".cache", "dffml", "tensorflow" ), help="Directory where state should be saved", ), ) cls.config_set( args, above, "steps", Arg( type=int, default=3000, help="Number of steps to train the model", ), ) cls.config_set( args, above, "epochs", Arg( type=int, default=30, help="Number of iterations to pass over all repos in a source", ), ) cls.config_set( args, above, "hidden", Arg( type=int, nargs="+", default=[12, 40, 15], help="List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer", ), ) cls.config_set( args, above, "predict", Arg(help="Feature name holding truth value"), ) cls.config_set( args, above, "features", Arg( nargs="+", required=True, type=Feature.load, action=list_action(Features), help="Features to train on", ), ) return args
class Server(TLSCMD, MultiCommCMD, Routes): """ HTTP server providing access to DFFML APIs """ # Used for testing RUN_YIELD_START = False RUN_YIELD_FINISH = False INSECURE_NO_TLS = False arg_port = Arg("-port", help="Port to bind to", type=int, default=8080) arg_addr = Arg("-addr", help="Address to bind to", default="127.0.0.1") arg_upload_dir = Arg( "-upload-dir", help="Directory to store uploaded files in", default=None, ) arg_static = Arg("-static", help="Directory to serve static content from", default=None) arg_js = Arg( "-js", help="Serve JavaScript API file at /api.js", default=False, action="store_true", ) arg_insecure = Arg( "-insecure", help="Start without TLS encryption", action="store_true", default=False, ) arg_cors_domains = Arg( "-cors-domains", help= "Domains to allow CORS for (see keys in defaults dict for aiohttp_cors.setup)", nargs="+", default=[], ) arg_models = Arg( "-models", help="Models configured on start", nargs="+", default=AsyncContextManagerList(), type=Model.load_labeled, action=list_action(AsyncContextManagerList), ) arg_sources = Arg( "-sources", help="Sources configured on start", nargs="+", default=Sources(), type=BaseSource.load_labeled, action=list_action(Sources), ) async def start(self): if self.insecure: self.site = web.TCPSite(self.runner, host=self.addr, port=self.port) else: ssl_context = ssl.create_default_context( purpose=ssl.Purpose.SERVER_AUTH, cafile=self.cert) ssl_context.load_cert_chain(self.cert, self.key) self.site = web.TCPSite( self.runner, host=self.addr, port=self.port, ssl_context=ssl_context, ) await self.site.start() self.port = self.site._server.sockets[0].getsockname()[1] self.logger.info(f"Serving on {self.addr}:{self.port}") async def run(self): """ Binds to port and starts HTTP server """ # Create dictionaries to hold configured sources and models await self.setup() await self.start() # Load if self.mc_config is not None: # Restore atomic after config is set, allow setting for now atomic = self.mc_atomic self.mc_atomic = False await self.register_directory(self.mc_config) self.mc_atomic = atomic try: # If we are testing then RUN_YIELD will be an asyncio.Event if self.RUN_YIELD_START is not False: await self.RUN_YIELD_START.put(self) await self.RUN_YIELD_FINISH.wait() else: # pragma: no cov # Wait for ctrl-c while True: await asyncio.sleep(60) finally: await self.app.cleanup() await self.site.stop()