def __init__(self, config): self._config = config self._api = BrainAPI( access_key=config.accesskey, username=config.username, api_url=config.url, timeout=config.network_timeout )
def __init__(self, config, name=None): """ Construct Brain object by passing in a Config object with an optional name argument. Arguments: config: A configuration used to connect to the BRAIN. name: The name of the BRAIN to connect to. """ self._config = config self._api = BrainAPI(config, config.network_timeout) self._timeout = self.config.network_timeout self.description = None self.name = name if name else self.config.brain self._status = None self._info = None self._state = None self._sims = None self.latest_version = None self._user_info = get_user_info() self.update()
class BrainController(): """ Brain Controller """ def __init__(self, config): self._config = config self._api = BrainAPI( access_key=config.accesskey, username=config.username, api_url=config.url, timeout=config.network_timeout ) @_handle_server_errors def create(self, name, ink_file=None, ink_str=None): """ Creates a BRAIN. A path to an inkling file or a raw inkling string can be passed in as arguments to the function. If neither are present, a blank BRAIN is created. The inkling file is prioritized over the string. param name: string name of brain param inkling_file: string path to inkling file param inkling_str: string raw inkling string """ response = self._api.create_brain(name, ink_file, ink_str) return response @_handle_server_errors def delete(self, name): """ Deletes a BRAIN. param name: string name of brain """ response = self._api.delete_brain(name) return response @_handle_server_errors def push_inkling(self, name, ink_file=None, ink_str=None): """ Pushes inkling to server. A path to an inkling file or a raw inkling string can be passed in as arguments to the function. If neither are present an error is raised to the caller. param name: string name of brain param inkling_file: string path to inkling file param inkling_str: string raw inkling string """ response = self._api.push_inkling(name, ink_file, ink_str) return response @_handle_server_errors def train_start(self, name): """ Starts training for a BRAIN param name: string name of brain """ response = self._api.start_training(name) return response @_handle_server_errors def train_stop(self, name): """ Stops training for a BRAIN param name: string name of brain """ response = self._api.stop_training(name) return response @_handle_server_errors def train_resume(self, name, version='latest'): """ Starts training for a BRAIN param name: string name of brain param version: string Version of BRAIN. Defaults to 'latest' """ response = self._api.resume_training(name, version) return response @_handle_server_errors def status(self, name): """ Retrieves status for a BRAIN param name: string name of brain """ response = self._api.get_brain_status(name) return response @_handle_server_errors def info(self, name): """ Retrieves BRAIN information. param name: string name of brain """ response = self._api.get_brain_info(name) return response @_handle_server_errors def sample_rate(self, name): """ Retrieves sample rate for a given BRAIN. param name: string name of brain """ try: response = self._api.get_brain_status(name) rate = sum(sims['sample_rate'] for sims in response['simulators']) return rate except(TypeError, KeyError) as err: log.error(err) log.error('Unable to determine sample rate from response') return 0 @_handle_server_errors def simulator_info(self, name): """ Retrieves simulator information for a given BRAIN. param name: string name of brain """ response = self._api.get_simulator_info(name) return response @_handle_server_errors def training_episode_metrics(self, name, version='latest'): """ Retrieves training episode metrics for a given BRAIN. param name: string name of brain param version: string verion of BRAIN. defaults to 'latest' """ response = self._api.training_episode_metrics(name, version) return response @_handle_server_errors def test_episode_metrics(self, name, version='latest'): """ Retrieves test episode metrics for a given BRAIN. param name: string name of brain param version: string verion of BRAIN. defaults to 'latest' """ response = self._api.test_episode_metrics(name, version) return response @_handle_server_errors def iteration_metrics(self, name, version='latest'): """ Retrieves iteration metrics for a given BRAIN. param name: string name of brain param version: string verion of BRAIN. defaults to 'latest' """ response = self._api.iteration_metrics(name, version) return response
def config(self, config): self._config = config self._api = BrainAPI(config.accesskey, config.username, config.url, config.network_timeout)
class Brain(object): """ Manages communication with the BRAIN on the server. This class can be used to introspect information about a BRAIN on the server and is used to query status and other properties. Attributes: config: The configuration object used to connect to this BRAIN. description: A user generated description of this BRAIN. exists: Whether this BRAIN exists on the server. name: The name of this BRAIN. ready: Whether this BRAIN is ready for training. state: Current state of this BRAIN on the server. version: The currently selected version of the BRAIN. latest_version: The latest version of the BRAIN. Example: import sys, bonsai_ai config = bonsai_ai.Config(sys.argv) brain = bonsai_ai.Brain(config) print(brain) """ def __init__(self, config, name=None): """ Construct Brain object by passing in a Config object with an optional name argument. Arguments: config: A configuration used to connect to the BRAIN. name: The name of the BRAIN to connect to. """ self._config = config self._api = BrainAPI(config.accesskey, config.username, config.url, config.network_timeout) self._timeout = self.config.network_timeout self.description = None self.name = name if name else self.config.brain self._status = None self._info = None self._state = None self._sims = None self.latest_version = None self._user_info = get_user_info() self.update() def __repr__(self): return '{{'\ 'name: {self.name!r}, ' \ 'description: {self.description!r}, ' \ 'latest_version: {self.latest_version!r}, ' \ 'config: {self.config!r}' \ '}}'.format(self=self) @property def config(self): return self._config @config.setter def config(self, config): self._config = config self._api = BrainAPI(config.accesskey, config.username, config.url, config.network_timeout) def _brain_url(self): """ Utility function to obtain brain url from config. Example: http://localhost:5000/v1/nav/brain3 """ url_base = self.config.url url_path = '/v1/{user}/{brain}'.format(user=self.config.username, brain=self.name) return urljoin(url_base, url_path) def _websocket_url(self): # Grab api url and split it api_url = self._brain_url() api_url = urlparse(api_url) # Replace the scheme with ws or wss depending on protocol if api_url.scheme == 'http': split_ws_url = api_url._replace(scheme='ws') elif api_url.scheme == 'https': split_ws_url = api_url._replace(scheme='wss') else: split_ws_url = api_url ws_url = urlunparse(split_ws_url) return ws_url def _prediction_url(self): """ Utility function to obtain prediction url from config """ return '{ws_url}/{version}/predictions/ws'.format( ws_url=self._websocket_url(), version=self.version) def _simulation_url(self): """ Returns simulation url """ return '{}/sims/ws'.format(self._websocket_url()) def _request_header(self): """ Utility function to obtain header that is sent with requests """ return { 'Authorization': self.config.accesskey, 'User-Agent': self._user_info } def _proxy_header(self): """ Utility function to obtain proxy that is sent with requests """ if self.config.proxy: url_components = urlparse(self.config.proxy) proxy_dict = {url_components.scheme: self.config.proxy} return proxy_dict else: return None def update(self): """ Refreshes description, status, and other information with the current state of the BRAIN on the server. Called by default when constructing a new Brain object. """ try: log.brain('Getting {} info...'.format(self.name)) self._info = self._api.get_brain_info(self.name) if self._info['versions']: self.latest_version = self._info['versions'][0]['version'] else: self.latest_version = 0 log.brain('Getting {} info...'.format(self.name)) self._status = self._api.get_brain_status(self.name) log.brain('Getting {} sims...'.format(self.name)) self._sims = self._api.get_simulator_info(self.name) self._state = self._status['state'] except requests.exceptions.Timeout as e: log.error('Request timeout in bonsai_ai.Brain: ' + repr(e)) except Exception as e: print('WARNING: ignoring failed update in Brain init.') @property def ready(self): """ Returns True when the BRAIN is ready for training. """ self.update() if self.config.predict: return self._state == STOPPED or self._state == COMPLETED return self._state == IN_PROGRESS @property def exists(self): self.update() """ Returns True when the BRAIN exists (i.e. update succeeded) """ if self.config.predict: return self._state is not None and self._version_exists() return self._state is not None def sim_exists(self, sim_name): self.update() if not self._sims: return False return sim_name in self._sims def _version_exists(self): for v in self._info['versions']: if v['version'] == self.version: return True return False @property def state(self): """ Returns the current state of the target BRAIN """ self.update() return self._state @property def status(self): """ Returns the current status of the target BRAIN """ self.update() return self._status @property def sample_rate(self): """ Returns the sample rate in iterations/second for all simulators connected to the brain """ self.update() try: rate = sum(sims['sample_rate'] for sims in self._status['simulators']) return rate except (TypeError, KeyError): log.info('Unable to retrieve sample rate from BRAIN ') return 0 @property def version(self): """ Returns the current BRAIN version number. """ if self.config.brain_version == 0: return self.latest_version else: return self.config.brain_version def training_episode_metrics(self, version=None): """ Returns data about each training episode for a given version of a BRAIN. Defaults to configured version if none is given. :param version: Version of your brain. Defaults to configured version. """ self.update() if version is None: version = self.version return self._api.training_episode_metrics(self.name, version) def iteration_metrics(self, version=None): """ Returns iteration data for a given version of a BRAIN. Defaults to configured version if none is given. Iterations contain data for the number of iterations that have occured in a simulation and at what timestamp. This data gets logged about once every 100 iterations. This can be useful for long episodes when other metrics may not be getting data. :param version: Version of your brain. Defaults to configured version. """ self.update() if version is None: version = self.version return self._api.iteration_metrics(self.name, version) def test_episode_metrics(self, version=None): """ Returns test pass data for a given version of a BRAIN. Defaults to configured version if none is given. Test pass episodes occur once every 20 training episodes during training for a given version of a BRAIN. The value is representative of the AI's performance at a regular interval of training :param version: Version of your brain. Defaults to configured version. """ self.update() if version is None: version = self.version return self._api.test_episode_metrics(self.name, version)
def brain_api(train_config): api = BrainAPI(train_config) return api
def __init__(self, config): self._config = config self._api = BrainAPI(config=config, timeout=config.network_timeout)
def config(self, config): self._config = config self._api = BrainAPI(config, config.network_timeout)