Exemple #1
0
def create_simple(id, last_weights=False):
    job_backend = JobBackend()
    job_backend.load(id)
    job_model = job_backend.get_job_model()

    if last_weights:
        weights_path = job_model.get_weights_filepath_latest()
    else:
        weights_path = job_model.get_weights_filepath_best()

    if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
        weight_url = job_backend.get_best_weight_url(id)
        if not weight_url:
            raise Exception("No weights available for this job.")

        print(("Download weights %s to %s .." % (weight_url, weights_path)))
        ensure_dir(os.path.dirname(weights_path))

        f = open(weights_path, 'wb')
        f.write(urllib.urlopen(weight_url).read())
        f.close()

    model.job_prepare(job_model)

    general_logger = GeneralLogger(job_backend=job_backend)
    trainer = Trainer(job_backend, general_logger)

    job_model.set_input_shape(trainer)

    model = job_model.get_built_model(trainer)

    job_model.load_weights(model, weights_path)

    return job_model, model
Exemple #2
0
    def __init__(self, job_id=None, api_key=None):
        self.event_listener = EventListener()
        self.api_key = api_key if api_key else os.getenv('API_KEY')

        if job_id and '/' in job_id:
            raise Exception('job_id needs to be a job id, not a model name.')

        on_shutdown.started_jobs.append(self)

        self.job_id = job_id
        self.client = None
        self.job = None

        self.last_batch_time = time.time()
        self.start_time = time.time()
        self.current_epoch = 0
        self.current_batch = 0
        self.total_epochs = 0
        self.made_batches = 0
        self.batches_per_second = 0

        # done means: done, abort or crash method has been called.
        self.ended = False

        # running means: the syncer client is running.
        self.running = False
        self.monitoring_thread = None
        self.general_logger_stdout = GeneralLogger(job_backend=self)
        self.general_logger_error = GeneralLogger(job_backend=self, error=True)

        self.host = os.getenv('API_HOST')
        self.port = int(os.getenv('API_PORT') or 8051)
        if not self.host or self.host == 'false':
            self.host = 'trainer.aetros.com'

        self.last_progress_call = None
        self.job_ids = []
        self.in_request = False
        self.stop_requested = False
        self.event_listener.on('stop', self.external_stop)
        self.event_listener.on('aborted', self.external_aborted)
        self.event_listener.on('registration', self.on_registration)
        self.event_listener.on('registration_failed',
                               self.on_registration_failed)

        self.client = JobClient(self.host, self.port, self.event_listener)
Exemple #3
0
    def setup(self, x=None, nb_epoch=1, batch_size=16):
        graph = self.model_to_graph(self.model)

        from keras.preprocessing.image import Iterator

        if isinstance(x, Iterator):
            batch_size = x.batch_size

        settings = {
            'epochs':
            nb_epoch,
            'batchSize':
            batch_size,
            'optimizer':
            type(self.model.optimizer).__name__ if hasattr(
                self.model, 'optimizer') else ''
        }

        self.job_backend.ensure_model(self.id,
                                      self.model.to_json(),
                                      settings=settings,
                                      type=self.model_type,
                                      graph=graph)

        job_id = self.job_backend.create(self.id, insights=self.insights)
        self.job_backend.start()

        print(
            "AETROS job '%s' created and started. Open http://%s/trainer/app#/job=%s to monitor the training."
            % (job_id, self.job_backend.host, job_id))

        job = self.job_backend.load_light_job()
        general_logger = GeneralLogger(job, job_backend=self.job_backend)
        self.trainer = Trainer(self.job_backend, general_logger)

        self.monitoringThread = MonitoringThread(self.job_backend,
                                                 self.trainer)
        self.monitoringThread.daemon = True
        self.monitoringThread.start()

        self.trainer.model = self.model
        self.trainer.data_train = {'x': x}

        self.callback = KerasLogger(self.trainer, self.job_backend,
                                    general_logger)
        self.callback.log_epoch = False
        self.callback.model = self.model
        self.callback.confusion_matrix = self.confusion_matrix

        return self.callback
class ServerCommand:
    model = None
    job_model = None

    def __init__(self, logger):
        self.logger = logger
        self.last_utilization = None
        self.last_net = {}
        self.nets = []
        self.server = None
        self.ending = False
        self.active = True
        self.queue = {}
        self.queuedMap = {}

        self.general_logger_stdout = None
        self.general_logger_stderr = None

        self.executed_jobs = 0
        self.max_parallel_jobs = 2
        self.max_jobs = 0
        self.ssh_key_path = None

        self.job_processes = []
        self.registered = False
        self.show_stdout = False

    def main(self, args):
        import aetros.const

        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                         prog=aetros.const.__prog__ + ' server')
        parser.add_argument('name', nargs='?', help="Server name")
        parser.add_argument('--generate-ssh-key', help="Generates automatically a ssh key, register them in AETROS in your account, and delete them when the command exits.")
        parser.add_argument('--max-parallel', help="How many jobs should be run at the same time.")
        parser.add_argument('--max-jobs', help="How many jobs are allowed to run in total.")
        parser.add_argument('--host', help="Default trainer.aetros.com. Read from environment variable API_HOST.")
        parser.add_argument('--port', help="Default 8051. Read from environment variable API_PORT.")
        parser.add_argument('--show-stdout', action='store_true', help="Show all stdout of all jobs")

        parsed_args = parser.parse_args(args)

        if not parsed_args.name:
            parser.print_help()
            sys.exit()

        config = read_config()

        if parsed_args.max_parallel:
            self.max_parallel_jobs = int(parsed_args.max_parallel)
        if parsed_args.max_jobs:
            self.max_jobs = int(parsed_args.max_jobs)
        if parsed_args.show_stdout:
            self.show_stdout = True

        event_listener = EventListener()

        event_listener.on('registration', self.registration_complete)
        event_listener.on('failed', self.connection_failed)
        event_listener.on('queue-jobs', self.queue_jobs)
        event_listener.on('unqueue-jobs', self.unqueue_jobs)
        event_listener.on('queue-ok', self.queue_ok)
        event_listener.on('stop-job', self.stop_job)
        event_listener.on('close', self.on_client_close)

        signal.signal(signal.SIGUSR1, self.on_signusr1)

        ssh_key_registered = False
        if parsed_args.generate_ssh_key:

            self.ssh_key_path = os.path.expanduser('~/.ssh/id_' + parsed_args.name.replace('/', '__') + '_rsa')
            if not os.path.exists(self.ssh_key_path):
                self.logger.info('Generate SSH key')
                subprocess.check_output(['ssh-keygen', '-q', '-N', '', '-t', 'rsa', '-b', '4048', '-f', self.ssh_key_path])

            self.logger.info('Register SSH key at ' + config['host'])
            url = 'https://' + config['host'] + '/api/server/ssh-key'

            with open(self.ssh_key_path +'.pub', 'r') as f:
                data = {
                    'name': parsed_args.name,
                    'secure_key': parsed_args.generate_ssh_key,
                    'key': f.read(),
                }

                auth = None
                if 'auth_user' in config:
                    auth = HTTPBasicAuth(config['auth_user'], config['auth_pw'])

                response = requests.post(url, data, auth=auth, verify=config['ssl_verify'], headers={'Accept': 'application/json'})

                if response.status_code != 200:
                    raise_response_exception('Could not register SSH key in AETROS Trainer.', response)

                ssh_key_registered = response.content == 'true'

        def delete_ssh_key():
            with open(self.ssh_key_path +'.pub', 'r') as f:
                data = {
                    'secure_key': parsed_args.generate_ssh_key,
                    'key': f.read(),
                }
                self.logger.info('Delete SSH key at ' + config['host'])
                url = 'https://' + config['host'] + '/api/server/ssh-key/delete'

                auth = None
                if 'auth_user' in config:
                    auth = HTTPBasicAuth(config['auth_user'], config['auth_pw'])

                response = requests.post(url, data, auth=auth, verify=config['ssl_verify'], headers={'Accept': 'application/json'})

                if response.status_code != 200:
                    raise_response_exception('Could not delete SSH key in AETROS Trainer.', response)

                os.unlink(self.ssh_key_path)
                os.unlink(self.ssh_key_path +'.pub')

        if parsed_args.generate_ssh_key and ssh_key_registered:
            import atexit
            atexit.register(delete_ssh_key)

        if parsed_args.host:
            config['host'] = parsed_args.host

        if self.ssh_key_path:
            config['ssh_key'] = self.ssh_key_path

        self.server = ServerClient(config, event_listener, self.logger)

        self.general_logger_stdout = GeneralLogger(job_backend=self)
        self.general_logger_stderr = GeneralLogger(job_backend=self, error=True)

        sys.stdout = sys.__stdout__ = self.general_logger_stdout
        sys.stderr = sys.__stderr__ = self.general_logger_stderr

        self.server.configure(parsed_args.name)
        self.logger.info('Connecting to ' + config['host'])
        self.server.start()
        self.write_log("\n")

        try:
            while self.active:
                if self.registered:
                    self.server.send_message({'type': 'utilization', 'values': self.collect_system_utilization()})
                    self.process_queue()

                time.sleep(1)
        except KeyboardInterrupt:
            self.logger.warning('Aborted')
            self.stop()

    def on_signusr1(self, signal, frame):
        self.logger.info("%d queued, %d running, %d max" % (len(self.queue), len(self.job_processes), self.max_parallel_jobs))

    def on_client_close(self, params):
        self.active = False
        self.logger.warning('Closed')

    def write_log(self, message):
        self.server.send_message({'type': 'log', 'message': message})

    def stop(self):
        self.active = False

        for p in self.job_processes:
            p.kill()

        self.general_logger_stdout.flush()
        self.general_logger_stderr.flush()
        self.server.close()

    def end(self):
        self.ending = True

        for p in self.job_processes:
            p.wait()

        self.check_finished_jobs()
        self.stop()

    def connection_failed(self, params):
        self.active = False
        sys.exit(1)

    def stop_job(self, id):
        if id in self.queuedMap:
            job = self.queuedMap[id]
            self.logger.info("Queued job removed %s (priority: %s) " % (job['id'], job['priority']))

            # removing from the queue is enough, since the job process itself terminates it when job is aborted.
            if job in self.queue:
                self.queue[job['priority']].remove(job)

            del self.queuedMap[id]

    def unqueue_jobs(self, jobs):
        for id in jobs:
            if id in self.queuedMap:
                self.logger.info('Removed job %s from queue.' % (id, ))

                for priority in self.queue:
                    if self.queuedMap[id] in self.queue[priority]:
                        self.queue[priority].remove(self.queuedMap[id])

                del self.queuedMap[id]

    def queue_jobs(self, jobs):
        self.logger.debug('Got queue list with %d items.' % (len(jobs), ))

        for id in jobs.keys():
            self.check_finished_jobs()

            job = jobs[id]
            priority = job['priority']
            job['id'] = id

            if self.is_job_queued(id):
                self.logger.debug("Requested job %s is already known. Exclude that one." % (id, ))
                return

            self.logger.info("Queued job %s (priority:%d) in %s ..." % (
               job['id'], job['priority'],os.getcwd()
            ))

            self.queuedMap[job['id']] = job

            # add the job into the wait list
            if job['priority'] not in self.queue:
                self.queue[priority] = []

            self.queue[priority].append(job)

    def is_job_queued(self, id):
        return id in self.queuedMap

    def queued_count(self):
        i = 0
        for jobs in six.itervalues(self.queue):
            i += len(jobs)

        return i

    def is_job_running(self, id):
        for process in self.job_processes:
            job = getattr(process, 'job')
            if job['id'] == id:
                return True

        return False

    def queue_ok(self, id):
        """
        We queued the job, told the server so and server said we're ok to start the job now.

        :param id: 
        :return: 
        """
        job = self.queuedMap[id]
        priority = job['priority']

        self.logger.debug("Queued job confirmed %s (priority: %s) " % (job['id'], priority))

        # add the job into the wait list
        if job['priority'] not in self.queue:
            self.queue[priority] = []

        self.queue[priority].append(job)

    def check_finished_jobs(self):
        for process in self.job_processes:
            job = getattr(process, 'job')
            exit_code = process.poll()
            if exit_code is not None and exit_code > 0:
                reason = 'Failed job %s. Exit status: %s' % (job['id'], str(exit_code))
                self.logger.error(reason)
                self.server.send_message({'type': 'job-failed', 'id': job['id'], 'error': reason})
            elif exit_code is not None and exit_code == 0:
                self.logger.info('Finished job %s. Exit status: %s' % (job['id'], str(exit_code)))

            if exit_code is not None and job['id'] in self.queuedMap:
                del self.queuedMap[job['id']]

        # remove dead job processes
        self.job_processes = [x for x in self.job_processes if x.poll() is None]

    def process_queue(self):
        self.check_finished_jobs()

        if self.ending:
            return

        if len(self.job_processes) >= self.max_parallel_jobs:
            return

        if self.max_jobs and self.executed_jobs >= self.max_jobs:
            self.logger.warning('Limit of max jobs %d/%d reached. Waiting for active jobs to finish ...' % (self.executed_jobs, self.max_jobs))
            self.end()
            return

        # sort by priority: The higher the sooner the job starts
        for priority in sorted(self.queue, reverse=True):
            q = self.queue[priority]

            if len(q) > 0:
                # registered and free space for new jobs, so execute another one
                self.execute_job(q.pop(0))
                break

    def execute_job(self, job):
        self.logger.info("Execute job %s (priority=%s) in %s ..." % (job['id'], job['priority'], os.getcwd()))

        self.executed_jobs += 1

        with open(os.devnull, 'r+b', 0) as DEVNULL:
            my_env = os.environ.copy()

            if self.ssh_key_path is not None:
                my_env['AETROS_SSH_KEY'] = self.ssh_key_path

            args = [sys.executable, '-m', 'aetros', 'start', job['id']]
            self.logger.info('$ ' + ' '.join(args))
            self.server.send_message({'type': 'job-executed', 'id': job['id']})

            process = subprocess.Popen(args, bufsize=1,
                stdin=DEVNULL, stderr=subprocess.PIPE, stdout=subprocess.PIPE, env=my_env)

            self.general_logger_stdout.attach(process.stdout)
            self.general_logger_stderr.attach(process.stderr)

            setattr(process, 'job', job)
            self.job_processes.append(process)

    def registration_complete(self, params):
        self.registered = True

        # upon registration, we need to clear the queue, since the server sends us immediately
        # all to be enqueued jobs after registration/re-connection
        self.queue = {}
        self.queueMap = {}

        self.logger.info("As server %s under account %s registered." % (params['server'], params['username']))
        self.server.send_message({'type': 'system', 'values': self.collect_system_information()})

    def collect_system_information(self):
        values = {}
        mem = psutil.virtual_memory()
        values['memory_total'] = mem.total

        import cpuinfo
        cpu = cpuinfo.get_cpu_info()
        values['cpu_name'] = cpu['brand']
        values['cpu'] = [cpu['hz_actual_raw'][0], cpu['count']]
        values['nets'] = {}
        values['disks'] = {}
        values['boot_time'] = psutil.boot_time()

        for disk in psutil.disk_partitions():
            try:
                name = self.get_disk_name(disk[1])
                values['disks'][name] = psutil.disk_usage(disk[1]).total
            except:
                # suppress Operation not permitted
                pass

        try:
            for id, net in psutil.net_if_stats().items():
                if 0 != id.find('lo') and net.isup:
                    self.nets.append(id)
                    values['nets'][id] = net.speed or 1000
        except:
            # suppress Operation not permitted
            pass

        return values

    def get_disk_name(self, name):

        if 0 == name.find("/Volumes"):
            return os.path.basename(name)

        return name

    def collect_system_utilization(self):
        values = {}

        values['cpu'] = psutil.cpu_percent(interval=0.2, percpu=True)
        mem = psutil.virtual_memory()
        values['memory'] = mem.percent
        values['disks'] = {}
        values['jobs'] = {'parallel': self.max_parallel_jobs, 'enqueued': self.queued_count(), 'running': len(self.job_processes)}
        values['nets'] = {}
        values['processes'] = []

        for disk in psutil.disk_partitions():
            try:
                name = self.get_disk_name(disk[1])
                values['disks'][name] = psutil.disk_usage(disk[1]).used
            except:
                pass

        net_stats = psutil.net_io_counters(pernic=True)
        for id in self.nets:
            net = net_stats[id]
            values['nets'][id] = {
                'recv': net.bytes_recv,
                'sent': net.bytes_sent,
                'upload': 0,
                'download': 0
            }

            if id in self.last_net and self.last_utilization:
                values['nets'][id]['upload'] = (net.bytes_sent - self.last_net[id]['sent']) / (
                    time.time() - self.last_utilization)
                values['nets'][id]['download'] = (net.bytes_recv - self.last_net[id]['recv']) / (
                    time.time() - self.last_utilization)

            self.last_net[id] = dict(values['nets'][id])

        for p in psutil.process_iter():
            try:
                cpu = p.cpu_percent()
                if cpu > 1 or p.memory_percent() > 1:
                    values['processes'].append([
                        p.pid,
                        p.name(),
                        p.username(),
                        p.create_time(),
                        p.status(),
                        p.num_threads(),
                        p.memory_percent(),
                        cpu
                    ])
            except OSError:
                pass
            except psutil.Error:
                pass

        try:
            if hasattr(os, 'getloadavg'):
                values['loadavg'] = os.getloadavg()
            else:
                values['loadavg'] = ''
        except OSError:
            pass

        self.last_utilization = time.time()
        return values
    def main(self, args):
        import aetros.const

        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                         prog=aetros.const.__prog__ + ' server')
        parser.add_argument('name', nargs='?', help="Server name")
        parser.add_argument('--generate-ssh-key', help="Generates automatically a ssh key, register them in AETROS in your account, and delete them when the command exits.")
        parser.add_argument('--max-parallel', help="How many jobs should be run at the same time.")
        parser.add_argument('--max-jobs', help="How many jobs are allowed to run in total.")
        parser.add_argument('--host', help="Default trainer.aetros.com. Read from environment variable API_HOST.")
        parser.add_argument('--port', help="Default 8051. Read from environment variable API_PORT.")
        parser.add_argument('--show-stdout', action='store_true', help="Show all stdout of all jobs")

        parsed_args = parser.parse_args(args)

        if not parsed_args.name:
            parser.print_help()
            sys.exit()

        config = read_config()

        if parsed_args.max_parallel:
            self.max_parallel_jobs = int(parsed_args.max_parallel)
        if parsed_args.max_jobs:
            self.max_jobs = int(parsed_args.max_jobs)
        if parsed_args.show_stdout:
            self.show_stdout = True

        event_listener = EventListener()

        event_listener.on('registration', self.registration_complete)
        event_listener.on('failed', self.connection_failed)
        event_listener.on('queue-jobs', self.queue_jobs)
        event_listener.on('unqueue-jobs', self.unqueue_jobs)
        event_listener.on('queue-ok', self.queue_ok)
        event_listener.on('stop-job', self.stop_job)
        event_listener.on('close', self.on_client_close)

        signal.signal(signal.SIGUSR1, self.on_signusr1)

        ssh_key_registered = False
        if parsed_args.generate_ssh_key:

            self.ssh_key_path = os.path.expanduser('~/.ssh/id_' + parsed_args.name.replace('/', '__') + '_rsa')
            if not os.path.exists(self.ssh_key_path):
                self.logger.info('Generate SSH key')
                subprocess.check_output(['ssh-keygen', '-q', '-N', '', '-t', 'rsa', '-b', '4048', '-f', self.ssh_key_path])

            self.logger.info('Register SSH key at ' + config['host'])
            url = 'https://' + config['host'] + '/api/server/ssh-key'

            with open(self.ssh_key_path +'.pub', 'r') as f:
                data = {
                    'name': parsed_args.name,
                    'secure_key': parsed_args.generate_ssh_key,
                    'key': f.read(),
                }

                auth = None
                if 'auth_user' in config:
                    auth = HTTPBasicAuth(config['auth_user'], config['auth_pw'])

                response = requests.post(url, data, auth=auth, verify=config['ssl_verify'], headers={'Accept': 'application/json'})

                if response.status_code != 200:
                    raise_response_exception('Could not register SSH key in AETROS Trainer.', response)

                ssh_key_registered = response.content == 'true'

        def delete_ssh_key():
            with open(self.ssh_key_path +'.pub', 'r') as f:
                data = {
                    'secure_key': parsed_args.generate_ssh_key,
                    'key': f.read(),
                }
                self.logger.info('Delete SSH key at ' + config['host'])
                url = 'https://' + config['host'] + '/api/server/ssh-key/delete'

                auth = None
                if 'auth_user' in config:
                    auth = HTTPBasicAuth(config['auth_user'], config['auth_pw'])

                response = requests.post(url, data, auth=auth, verify=config['ssl_verify'], headers={'Accept': 'application/json'})

                if response.status_code != 200:
                    raise_response_exception('Could not delete SSH key in AETROS Trainer.', response)

                os.unlink(self.ssh_key_path)
                os.unlink(self.ssh_key_path +'.pub')

        if parsed_args.generate_ssh_key and ssh_key_registered:
            import atexit
            atexit.register(delete_ssh_key)

        if parsed_args.host:
            config['host'] = parsed_args.host

        if self.ssh_key_path:
            config['ssh_key'] = self.ssh_key_path

        self.server = ServerClient(config, event_listener, self.logger)

        self.general_logger_stdout = GeneralLogger(job_backend=self)
        self.general_logger_stderr = GeneralLogger(job_backend=self, error=True)

        sys.stdout = sys.__stdout__ = self.general_logger_stdout
        sys.stderr = sys.__stderr__ = self.general_logger_stderr

        self.server.configure(parsed_args.name)
        self.logger.info('Connecting to ' + config['host'])
        self.server.start()
        self.write_log("\n")

        try:
            while self.active:
                if self.registered:
                    self.server.send_message({'type': 'utilization', 'values': self.collect_system_utilization()})
                    self.process_queue()

                time.sleep(1)
        except KeyboardInterrupt:
            self.logger.warning('Aborted')
            self.stop()
Exemple #6
0
class ServerCommand:
    model = None
    job_model = None

    def __init__(self, logger):
        self.logger = logger
        self.last_utilization = None
        self.last_net = {}
        self.nets = []
        self.server = None
        self.ending = False
        self.active = True
        self.config = {}
        self.lock = Lock()

        self.general_logger_stdout = None
        self.general_logger_stderr = None

        self.executed_jobs = 0
        self.started_jobs = {}
        self.max_jobs = 0
        self.ssh_key_private = None
        self.ssh_key_public = None

        self.resources_limit = {}
        self.enabled_gpus = []

        self.job_processes = {}
        self.registered = False
        self.show_stdout = False

    def main(self, args):
        import aetros.const

        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                         prog=aetros.const.__prog__ + ' server')
        parser.add_argument('name', nargs='?', help="Server name")
        parser.add_argument('--generate-ssh-key', help="Generates automatically a ssh key, register them in AETROS in "
                                                       "your account, and delete them when the server exits. "
                                                       "You should prefer 'aetros authenticate' command as its safer.")

        parser.add_argument('--allow-host-execution', action='store_true', help="Whether a job can run on this server "
            "directly, without a virtual (docker) container.\nSecurity risk and makes resource limitation useless.")

        parser.add_argument('--max-memory',
            help="How many RAM is available. In gigabyte. Per default all available memory.")
        parser.add_argument('--max-cpus',
            help="How many cores are available. Per default all available CPU cores.")
        parser.add_argument('--max-gpus',
            help="How many GPUs are available. Comma separate list of device ids."
                 "Per default all available GPU cards. Use 'aetros gpu' too see the ids.")

        parser.add_argument('--no-gpus', action='store_true', help="Disable all GPUs")

        parser.add_argument('--max-jobs', help="How many jobs are allowed to run in total until the process exists automatically.")
        parser.add_argument('--host', help="Default trainer.aetros.com. Read from the global configuration ~/aetros.yml.")
        parser.add_argument('--show-stdout', action='store_true', help="Show all stdout of all jobs. Only for debugging necessary.")

        parsed_args = parser.parse_args(args)

        if not parsed_args.name:
            parser.print_help()
            sys.exit()

        self.config = read_home_config()

        if parsed_args.max_jobs:
            self.max_jobs = int(parsed_args.max_jobs)

        if parsed_args.max_memory:
            self.resources_limit['memory'] = int(parsed_args.max_memory)

        if parsed_args.max_cpus:
            self.resources_limit['cpus'] = int(parsed_args.max_cpus)

        self.resources_limit['host_execution'] = parsed_args.allow_host_execution

        gpus = []
        try:
            gpus = aetros.cuda_gpu.get_ordered_devices()
            for i in range(len(gpus)):
                self.enabled_gpus.append(i)
        except aetros.cuda_gpu.CudaNotImplementedException: pass

        if parsed_args.max_gpus:
            self.enabled_gpus = []

            for i in parsed_args.max_gpus.split(','):
                i = int(i)
                if i < 0 or i >= len(gpus):
                    raise Exception('--max-gpus ' + str(i) + ' not available on the system. GPUs ' + str([i for i in range(len(gpus))])+ ' detected.')

                self.enabled_gpus.append(i)

        elif parsed_args.no_gpus:
            self.enabled_gpus = []

        if parsed_args.show_stdout:
            self.show_stdout = True

        event_listener = EventListener()

        event_listener.on('registration', self.registration_complete)
        event_listener.on('failed', self.connection_failed)
        event_listener.on('jobs', self.sync_jobs)
        event_listener.on('close', self.on_client_close)

        if hasattr(signal, 'SIGUSR1'):
            signal.signal(signal.SIGUSR1, self.on_signusr1)

        ssh_key_registered = False
        if parsed_args.generate_ssh_key:
            self.logger.info('Generate SSH key')

            ssh_key = paramiko.RSAKey.generate(4096)
            self.ssh_key_private = ssh_key.key.private_bytes(
                serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()
            ).decode()
            self.ssh_key_public = 'rsa ' + ssh_key.get_base64() + ' ' + parsed_args.name

            self.logger.info('Register SSH key at ' + self.config['host'])

            data = {
                'name': parsed_args.name,
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }

            try:
                response = aetros.api.http_request('server/ssh-key', json_body=data, method='post')
            except aetros.api.ApiError as e:
                if 'access_denied' in e.error:
                    print("error: Could not connect to " + self.config['url'] +
                          ': Access denied. --generate-ssh-key seems to be wrong. Incorrect host? See "aetros id"')
                    sys.exit(1)
                raise

            ssh_key_registered = response == True

        def delete_ssh_key():
            self.logger.info('Delete SSH key at ' + self.config['host'])

            data = {
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }
            response = aetros.api.http_request('server/ssh-key/delete', json_body=data)
            if not response:
                self.logger.error('Could not delete SSH key in AETROS Trainer.')

        if parsed_args.generate_ssh_key and ssh_key_registered:
            atexit.register(delete_ssh_key)

        if parsed_args.host:
            self.config['host'] = parsed_args.host

        if self.ssh_key_private:
            self.config['ssh_key_base64'] = self.ssh_key_private

        self.server = ServerClient(self.config, event_listener, self.logger)

        self.general_logger_stdout = GeneralLogger(job_backend=self, redirect_to=sys.__stdout__)
        self.general_logger_stderr = GeneralLogger(job_backend=self, redirect_to=sys.__stderr__)

        sys.stdout = self.general_logger_stdout
        sys.stderr = self.general_logger_stderr

        self.server.configure(parsed_args.name)
        self.logger.debug('Connecting to ' + self.config['host'])
        self.server.start()
        self.write_log("\n")

        try:
            while self.active:
                if self.registered:
                    self.server.send_message({'type': 'utilization', 'values': self.collect_system_utilization()}, '')
                    self.check_finished_jobs()

                time.sleep(1)
        except SystemExit:
            self.logger.warning('Killed')
            self.stop()
        except KeyboardInterrupt:
            self.stop()


    def on_signusr1(self, signal, frame):
        self.logger.info("ending=%s, active=%s, registered=%s, %d running, %d messages, %d connection_tries" % (
            str(self.ending),
            str(self.active),
            str(self.registered),
            len(self.job_processes),
            len(self.server.queue),
            self.server.connection_tries,
        ))

        for full_id in six.iterkeys(self.job_processes):
            self.logger.info("Running " + full_id)

    def on_client_close(self, params):
        self.active = False
        self.logger.warning('Closed')

    def write_log(self, message):
        self.server.send_message({'type': 'log', 'message': message}, '')
        return True

    def stop(self):
        self.active = False

        self.logger.warning('Killing %d jobs ' % (len(self.job_processes),))

        for p in six.itervalues(self.job_processes):
            p.send_signal(signal.SIGINT)

        for p in six.itervalues(self.job_processes):
            p.wait()

        self.general_logger_stdout.flush()
        self.general_logger_stderr.flush()
        self.server.close()

    def end(self):
        self.ending = True

        for p in self.job_processes:
            p.wait()

        self.check_finished_jobs()
        self.stop()

    def connection_failed(self, params):
        self.active = False
        sys.exit(1)

    def sync_jobs(self, jobs):
        self.lock.acquire()

        # make sure we started all ids from "jobs".
        # if we have still active jobs not in jobs_ids, stop them
        for full_id, resources_assigned in six.iteritems(jobs):
            started_id = full_id + '-' + str(resources_assigned['time'])

            if started_id in self.started_jobs:
                # we got the same job id + timestamp twice, just ignore it
                continue

            self.started_jobs[started_id] = True
            self.execute_job(full_id, resources_assigned)

        self.lock.release()

    def check_finished_jobs(self):
        self.lock.acquire()

        delete_finished = []

        for full_job_id, process in six.iteritems(self.job_processes):
            exit_code = process.poll()
            model, job_id = unpack_simple_job_id(full_job_id)

            if exit_code is not None:
                # command ended
                if exit_code == 0:
                    self.logger.info('Finished job %s. Exit status: %s' % (full_job_id, str(exit_code)))
                if exit_code > 0:
                    reason = 'Failed job %s. Exit status: %s' % (full_job_id, str(exit_code))
                    self.logger.error(reason)

                self.server.send_message({'type': 'job-finished', 'id': full_job_id}, '')
                delete_finished.append(full_job_id)

        for full_job_id in delete_finished:
            del self.job_processes[full_job_id]

        self.lock.release()

    def execute_job(self, full_id, resources_assigned):
        self.logger.info("Execute job %s ..." % (full_id, ))
        self.executed_jobs += 1

        with open(os.devnull, 'r+b', 0) as DEVNULL:
            my_env = os.environ.copy()
            my_env['AETROS_ATTY'] = '1'

            if self.ssh_key_private is not None:
                my_env['AETROS_SSH_KEY_BASE64'] = self.ssh_key_private

            args = [sys.executable, '-m', 'aetros', 'start']
            if resources_assigned['gpus']:
                for gpu_id in resources_assigned['gpus']:
                    args += ['--gpu-device', gpu_id]

            args += ['--cpu', str(int(resources_assigned['cpus']))]
            args += ['--memory', str(int(resources_assigned['memory']))]

            args += [full_id]
            self.logger.info('$ ' + ' '.join(args))
            self.server.send_message({'type': 'job-executed', 'id': full_id}, '')

            # Since JobBackend sends SIGINT to its current process group, wit sends also to its parents when same pg.
            # We need to change the process group of the process, so this won't happen.
            # If we don't this, the process of ServerCommand receives the SIGINT as well.
            kwargs = {}
            if os.name == 'nt':
                kwargs['creationflags'] = subprocess.CREATE_NEW_PROCESS_GROUP
            else:
                kwargs['preexec_fn'] = os.setsid

            stdout = subprocess.PIPE if self.show_stdout else open(os.devnull, 'r+b', 0)

            process = subprocess.Popen(args, bufsize=1, env=my_env, stdin=DEVNULL,
                stderr=stdout, stdout=stdout, **kwargs)

            if self.show_stdout:
                self.general_logger_stdout.attach(process.stdout, read_line=True)
                self.general_logger_stderr.attach(process.stderr, read_line=True)

            self.job_processes[full_id] = process

    def registration_complete(self, params):
        self.registered = True
        self.logger.info("Server connected to %s as %s under account %s registered." % (self.config['host'], params['server'], params['username']))
        self.server.send_message({'type': 'system', 'values': self.collect_system_information()}, '')

    def collect_system_information(self):
        values = {}
        mem = psutil.virtual_memory()
        values['memory_total'] = mem.total

        import cpuinfo
        cpu = cpuinfo.get_cpu_info()
        values['resources_limit'] = self.resources_limit
        values['cpu_name'] = cpu['brand']
        values['cpu'] = [cpu['hz_advertised_raw'][0], cpu['count']]
        values['nets'] = {}
        values['disks'] = {}
        values['gpus'] = {}
        values['boot_time'] = psutil.boot_time()

        try:
            for gpu_id, gpu in enumerate(aetros.cuda_gpu.get_ordered_devices()):
                gpu['available'] = gpu_id in self.enabled_gpus

                values['gpus'][gpu_id] = gpu
        except aetros.cuda_gpu.CudaNotImplementedException: pass

        for disk in psutil.disk_partitions():
            try:
                name = self.get_disk_name(disk[1])
                values['disks'][name] = psutil.disk_usage(disk[1]).total
            except Exception:
                # suppress Operation not permitted
                pass

        try:
            for id, net in psutil.net_if_stats().items():
                if 0 != id.find('lo') and net.isup:
                    self.nets.append(id)
                    values['nets'][id] = net.speed or 1000
        except Exception:
            # suppress Operation not permitted
            pass

        return values

    def get_disk_name(self, name):

        if 0 == name.find("/Volumes"):
            return os.path.basename(name)

        return name

    def collect_system_utilization(self):
        values = {}

        values['cpu'] = psutil.cpu_percent(interval=0.2, percpu=True)
        mem = psutil.virtual_memory()
        values['memory'] = mem.percent
        values['disks'] = {}
        values['jobs'] = {'running': len(self.job_processes)}
        values['nets'] = {}
        values['processes'] = []
        values['gpus'] = {}

        try:
            for gpu_id, gpu in enumerate(aetros.cuda_gpu.get_ordered_devices()):
                values['gpus'][gpu_id] = aetros.cuda_gpu.get_memory(gpu['device'])
        except aetros.cuda_gpu.CudaNotImplementedException: pass

        for disk in psutil.disk_partitions():
            try:
                name = self.get_disk_name(disk[1])
                values['disks'][name] = psutil.disk_usage(disk[1]).used
            except Exception: pass

        net_stats = psutil.net_io_counters(pernic=True)
        for id in self.nets:
            if id not in net_stats:
                continue
            net = net_stats[id]
            values['nets'][id] = {
                'recv': net.bytes_recv,
                'sent': net.bytes_sent,
                'upload': 0,
                'download': 0
            }

            if id in self.last_net and self.last_utilization:
                values['nets'][id]['upload'] = (net.bytes_sent - self.last_net[id]['sent']) / (
                    time.time() - self.last_utilization)
                values['nets'][id]['download'] = (net.bytes_recv - self.last_net[id]['recv']) / (
                    time.time() - self.last_utilization)

            self.last_net[id] = dict(values['nets'][id])

        for p in psutil.process_iter():
            try:
                cpu = p.cpu_percent()
                if cpu > 1 or p.memory_percent() > 1:
                    values['processes'].append([
                        p.pid,
                        p.name(),
                        p.username(),
                        p.create_time(),
                        p.status(),
                        p.num_threads(),
                        p.memory_percent(),
                        cpu
                    ])
            except OSError:
                pass
            except psutil.Error:
                pass

        try:
            if hasattr(os, 'getloadavg'):
                values['loadavg'] = os.getloadavg()
            else:
                values['loadavg'] = ''
        except OSError:
            pass

        self.last_utilization = time.time()
        return values
Exemple #7
0
    def main(self, args):
        import aetros.const

        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                         prog=aetros.const.__prog__ + ' server')
        parser.add_argument('name', nargs='?', help="Server name")
        parser.add_argument('--generate-ssh-key', help="Generates automatically a ssh key, register them in AETROS in "
                                                       "your account, and delete them when the server exits. "
                                                       "You should prefer 'aetros authenticate' command as its safer.")

        parser.add_argument('--allow-host-execution', action='store_true', help="Whether a job can run on this server "
            "directly, without a virtual (docker) container.\nSecurity risk and makes resource limitation useless.")

        parser.add_argument('--max-memory',
            help="How many RAM is available. In gigabyte. Per default all available memory.")
        parser.add_argument('--max-cpus',
            help="How many cores are available. Per default all available CPU cores.")
        parser.add_argument('--max-gpus',
            help="How many GPUs are available. Comma separate list of device ids."
                 "Per default all available GPU cards. Use 'aetros gpu' too see the ids.")

        parser.add_argument('--no-gpus', action='store_true', help="Disable all GPUs")

        parser.add_argument('--max-jobs', help="How many jobs are allowed to run in total until the process exists automatically.")
        parser.add_argument('--host', help="Default trainer.aetros.com. Read from the global configuration ~/aetros.yml.")
        parser.add_argument('--show-stdout', action='store_true', help="Show all stdout of all jobs. Only for debugging necessary.")

        parsed_args = parser.parse_args(args)

        if not parsed_args.name:
            parser.print_help()
            sys.exit()

        self.config = read_home_config()

        if parsed_args.max_jobs:
            self.max_jobs = int(parsed_args.max_jobs)

        if parsed_args.max_memory:
            self.resources_limit['memory'] = int(parsed_args.max_memory)

        if parsed_args.max_cpus:
            self.resources_limit['cpus'] = int(parsed_args.max_cpus)

        self.resources_limit['host_execution'] = parsed_args.allow_host_execution

        gpus = []
        try:
            gpus = aetros.cuda_gpu.get_ordered_devices()
            for i in range(len(gpus)):
                self.enabled_gpus.append(i)
        except aetros.cuda_gpu.CudaNotImplementedException: pass

        if parsed_args.max_gpus:
            self.enabled_gpus = []

            for i in parsed_args.max_gpus.split(','):
                i = int(i)
                if i < 0 or i >= len(gpus):
                    raise Exception('--max-gpus ' + str(i) + ' not available on the system. GPUs ' + str([i for i in range(len(gpus))])+ ' detected.')

                self.enabled_gpus.append(i)

        elif parsed_args.no_gpus:
            self.enabled_gpus = []

        if parsed_args.show_stdout:
            self.show_stdout = True

        event_listener = EventListener()

        event_listener.on('registration', self.registration_complete)
        event_listener.on('failed', self.connection_failed)
        event_listener.on('jobs', self.sync_jobs)
        event_listener.on('close', self.on_client_close)

        if hasattr(signal, 'SIGUSR1'):
            signal.signal(signal.SIGUSR1, self.on_signusr1)

        ssh_key_registered = False
        if parsed_args.generate_ssh_key:
            self.logger.info('Generate SSH key')

            ssh_key = paramiko.RSAKey.generate(4096)
            self.ssh_key_private = ssh_key.key.private_bytes(
                serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()
            ).decode()
            self.ssh_key_public = 'rsa ' + ssh_key.get_base64() + ' ' + parsed_args.name

            self.logger.info('Register SSH key at ' + self.config['host'])

            data = {
                'name': parsed_args.name,
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }

            try:
                response = aetros.api.http_request('server/ssh-key', json_body=data, method='post')
            except aetros.api.ApiError as e:
                if 'access_denied' in e.error:
                    print("error: Could not connect to " + self.config['url'] +
                          ': Access denied. --generate-ssh-key seems to be wrong. Incorrect host? See "aetros id"')
                    sys.exit(1)
                raise

            ssh_key_registered = response == True

        def delete_ssh_key():
            self.logger.info('Delete SSH key at ' + self.config['host'])

            data = {
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }
            response = aetros.api.http_request('server/ssh-key/delete', json_body=data)
            if not response:
                self.logger.error('Could not delete SSH key in AETROS Trainer.')

        if parsed_args.generate_ssh_key and ssh_key_registered:
            atexit.register(delete_ssh_key)

        if parsed_args.host:
            self.config['host'] = parsed_args.host

        if self.ssh_key_private:
            self.config['ssh_key_base64'] = self.ssh_key_private

        self.server = ServerClient(self.config, event_listener, self.logger)

        self.general_logger_stdout = GeneralLogger(job_backend=self, redirect_to=sys.__stdout__)
        self.general_logger_stderr = GeneralLogger(job_backend=self, redirect_to=sys.__stderr__)

        sys.stdout = self.general_logger_stdout
        sys.stderr = self.general_logger_stderr

        self.server.configure(parsed_args.name)
        self.logger.debug('Connecting to ' + self.config['host'])
        self.server.start()
        self.write_log("\n")

        try:
            while self.active:
                if self.registered:
                    self.server.send_message({'type': 'utilization', 'values': self.collect_system_utilization()}, '')
                    self.check_finished_jobs()

                time.sleep(1)
        except SystemExit:
            self.logger.warning('Killed')
            self.stop()
        except KeyboardInterrupt:
            self.stop()
class ServerCommand:
    model = None
    job_model = None

    def __init__(self, logger):
        self.logger = logger
        self.last_utilization = None
        self.last_net = {}
        self.nets = []
        self.server = None
        self.ending = False
        self.active = True
        self.config = {}
        self.lock = Lock()

        self.general_logger_stdout = None
        self.general_logger_stderr = None

        self.executed_jobs = 0
        self.started_jobs = {}
        self.max_jobs = 0
        self.ssh_key_private = None
        self.ssh_key_public = None

        self.resources_limit = {}
        self.enabled_gpus = []

        self.job_processes = {}
        self.registered = False
        self.show_stdout = False

    def main(self, args):
        import aetros.const

        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                         prog=aetros.const.__prog__ + ' server')
        parser.add_argument('name', nargs='?', help="Server name")
        parser.add_argument('--generate-ssh-key', help="Generates automatically a ssh key, register them in AETROS in "
                                                       "your account, and delete them when the server exits. "
                                                       "You should prefer 'aetros register' command as its safer.")

        parser.add_argument('--allow-host-execution', action='store_true', help="Whether a job can run on this server "
            "directly, without a virtual (docker) container.\nSecurity risk and makes resource limitation useless.")

        parser.add_argument('--max-memory',
            help="How many RAM is available. In gigabyte. Per default all available memory.")
        parser.add_argument('--max-cpus',
            help="How many cores are available. Per default all available CPU cores.")
        parser.add_argument('--max-gpus',
            help="How many GPUs are available. Comma separate list of device ids (pciBusId)."
                 "Per default all available GPU cards. Use 'aetros gpu' too see the ids.")

        parser.add_argument('--no-gpus', action='store_true', help="Disable all GPUs")

        parser.add_argument('--max-jobs', help="How many jobs are allowed to run in total until the process exists automatically.")
        parser.add_argument('--host', help="Default trainer.aetros.com. Read from the global configuration ~/aetros.yml.")
        parser.add_argument('--show-stdout', action='store_true', help="Show all stdout of all jobs. Only for debugging necessary.")

        parsed_args = parser.parse_args(args)

        if not parsed_args.name:
            parser.print_help()
            sys.exit()

        self.config = read_home_config()

        if parsed_args.max_jobs:
            self.max_jobs = int(parsed_args.max_jobs)

        if parsed_args.max_memory:
            self.resources_limit['memory'] = int(parsed_args.max_memory)

        if parsed_args.max_cpus:
            self.resources_limit['cpus'] = int(parsed_args.max_cpus)

        self.resources_limit['host_execution'] = parsed_args.allow_host_execution

        gpus = []
        try:
            gpus = aetros.cuda_gpu.get_ordered_devices()
            for i in range(len(gpus)):
                self.enabled_gpus.append(i)
        except Exception: pass

        if parsed_args.max_gpus:
            self.enabled_gpus = []

            for i in parsed_args.max_gpus.split(','):
                i = int(i)
                if i < 0 or i >= len(gpus):
                    raise Exception('--max-gpus ' + str(i) + ' not available on the system. GPUs ' + str([i for i in range(len(gpus))])+ ' detected.')

                self.enabled_gpus.append(i)

        elif parsed_args.no_gpus:
            self.enabled_gpus = []

        if parsed_args.show_stdout:
            self.show_stdout = True

        event_listener = EventListener()

        event_listener.on('registration', self.registration_complete)
        event_listener.on('failed', self.connection_failed)
        event_listener.on('jobs', self.sync_jobs)
        event_listener.on('close', self.on_client_close)

        if hasattr(signal, 'SIGUSR1'):
            signal.signal(signal.SIGUSR1, self.on_signusr1)

        ssh_key_registered = False
        if parsed_args.generate_ssh_key:
            self.logger.info('Generate SSH key')

            ssh_key = paramiko.RSAKey.generate(4096)
            self.ssh_key_private = ssh_key.key.private_bytes(
                serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()
            ).decode()
            self.ssh_key_public = 'rsa ' + ssh_key.get_base64() + ' ' + parsed_args.name

            self.logger.info('Register SSH key at ' + self.config['host'])

            data = {
                'name': parsed_args.name,
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }

            response = aetros.api.http_request('server/ssh-key', json_body=data, method='post')

            ssh_key_registered = response == True

        def delete_ssh_key():
            self.logger.info('Delete SSH key at ' + self.config['host'])

            data = {
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }
            response = aetros.api.http_request('server/ssh-key/delete', json_body=data)
            if not response:
                self.logger.error('Could not delete SSH key in AETROS Trainer.')

        if parsed_args.generate_ssh_key and ssh_key_registered:
            import atexit
            atexit.register(delete_ssh_key)

        if parsed_args.host:
            self.config['host'] = parsed_args.host

        if self.ssh_key_private:
            self.config['ssh_key_base64'] = self.ssh_key_private

        self.server = ServerClient(self.config, event_listener, self.logger)

        self.general_logger_stdout = GeneralLogger(job_backend=self, redirect_to=sys.__stdout__)
        self.general_logger_stderr = GeneralLogger(job_backend=self, redirect_to=sys.__stderr__)

        sys.stdout = self.general_logger_stdout
        sys.stderr = self.general_logger_stderr

        self.server.configure(parsed_args.name)
        self.logger.debug('Connecting to ' + self.config['host'])
        self.server.start()
        self.write_log("\n")

        try:
            while self.active:
                if self.registered:
                    self.server.send_message({'type': 'utilization', 'values': self.collect_system_utilization()})
                    self.check_finished_jobs()

                time.sleep(1)
        except KeyboardInterrupt:
            self.logger.warning('Aborted')
            self.stop()

    def on_signusr1(self, signal, frame):
        self.logger.info("ending=%s, active=%s, registered=%s, %d running, %d messages, %d connection_tries" % (
            str(self.ending),
            str(self.active),
            str(self.registered),
            len(self.job_processes),
            len(self.server.queue),
            self.server.connection_tries,
        ))

        for full_id in six.iterkeys(self.job_processes):
            self.logger.info("Running " + full_id)

    def on_client_close(self, params):
        self.active = False
        self.logger.warning('Closed')

    def write_log(self, message):
        self.server.send_message({'type': 'log', 'message': message})
        return True

    def stop(self):
        self.active = False

        for p in six.itervalues(self.job_processes):
            p.kill()
            time.sleep(0.1)
            p.terminate()

        self.general_logger_stdout.flush()
        self.general_logger_stderr.flush()
        self.server.close()

    def end(self):
        self.ending = True

        for p in self.job_processes:
            p.wait()

        self.check_finished_jobs()
        self.stop()

    def connection_failed(self, params):
        self.active = False
        sys.exit(1)

    def sync_jobs(self, jobs):
        self.lock.acquire()

        # make sure we started all ids from "jobs".
        # if we have still active jobs not in jobs_ids, stop them
        for full_id, resources_assigned in six.iteritems(jobs):
            started_id = full_id + '-' + str(resources_assigned['time'])

            if started_id in self.started_jobs:
                # we got the same job id + timestamp twice, just ignore it
                continue

            self.started_jobs[started_id] = True
            self.execute_job(full_id, resources_assigned)

        self.lock.release()

    def check_finished_jobs(self):
        self.lock.acquire()

        delete_finished = []

        for full_job_id, process in six.iteritems(self.job_processes):
            exit_code = process.poll()
            model, job_id = unpack_simple_job_id(full_job_id)

            if exit_code is not None:
                # command ended
                if exit_code == 0:
                    self.logger.info('Finished job %s. Exit status: %s' % (full_job_id, str(exit_code)))
                if exit_code > 0:
                    reason = 'Failed job %s. Exit status: %s' % (full_job_id, str(exit_code))
                    self.logger.error(reason)

                self.server.send_message({'type': 'job-finished', 'id': full_job_id})
                delete_finished.append(full_job_id)

        for full_job_id in delete_finished:
            del self.job_processes[full_job_id]

        self.lock.release()

    def execute_job(self, full_id, resources_assigned):
        self.logger.info("Execute job %s ..." % (full_id, ))
        self.executed_jobs += 1

        with open(os.devnull, 'r+b', 0) as DEVNULL:
            my_env = os.environ.copy()
            my_env['AETROS_ATTY'] = '1'

            if self.ssh_key_private is not None:
                my_env['AETROS_SSH_KEY_BASE64'] = self.ssh_key_private

            args = [sys.executable, '-m', 'aetros', 'start']
            if resources_assigned['gpus']:
                for gpu_id in resources_assigned['gpus']:
                    args += ['--gpu-device', gpu_id]

            args += [full_id]
            self.logger.info('$ ' + ' '.join(args))
            self.server.send_message({'type': 'job-executed', 'id': full_id})

            # Since JobBackend sends SIGINT to its current process group, wit sends also to its parents when same pg.
            # We need to change the process group of the process, so this won't happen.
            # If we don't this, the process of ServerCommand receives the SIGINT as well.
            kwargs = {}
            if os.name == 'nt':
                kwargs['creationflags'] = subprocess.CREATE_NEW_PROCESS_GROUP
            else:
                kwargs['preexec_fn'] = os.setsid

            process = subprocess.Popen(args, bufsize=1, env=my_env, stdin=DEVNULL,
                stderr=subprocess.PIPE, stdout=subprocess.PIPE, **kwargs)

            if self.show_stdout:
                self.general_logger_stdout.attach(process.stdout, read_line=True)
                self.general_logger_stderr.attach(process.stderr, read_line=True)

            self.job_processes[full_id] = process

    def registration_complete(self, params):
        self.registered = True
        self.logger.info("Server connected to %s as %s under account %s registered." % (self.config['host'], params['server'], params['username']))
        self.server.send_message({'type': 'system', 'values': self.collect_system_information()})

    def collect_system_information(self):
        values = {}
        mem = psutil.virtual_memory()
        values['memory_total'] = mem.total

        import cpuinfo
        cpu = cpuinfo.get_cpu_info()
        values['resources_limit'] = self.resources_limit
        values['cpu_name'] = cpu['brand']
        values['cpu'] = [cpu['hz_advertised_raw'][0], cpu['count']]
        values['nets'] = {}
        values['disks'] = {}
        values['gpus'] = {}
        values['boot_time'] = psutil.boot_time()

        try:
            for gpu_id, gpu in enumerate(aetros.cuda_gpu.get_ordered_devices()):
                gpu['available'] = gpu_id in self.enabled_gpus

                values['gpus'][gpu_id] = gpu
        except Exception: pass

        for disk in psutil.disk_partitions():
            try:
                name = self.get_disk_name(disk[1])
                values['disks'][name] = psutil.disk_usage(disk[1]).total
            except Exception:
                # suppress Operation not permitted
                pass

        try:
            for id, net in psutil.net_if_stats().items():
                if 0 != id.find('lo') and net.isup:
                    self.nets.append(id)
                    values['nets'][id] = net.speed or 1000
        except Exception:
            # suppress Operation not permitted
            pass

        return values

    def get_disk_name(self, name):

        if 0 == name.find("/Volumes"):
            return os.path.basename(name)

        return name

    def collect_system_utilization(self):
        values = {}

        values['cpu'] = psutil.cpu_percent(interval=0.2, percpu=True)
        mem = psutil.virtual_memory()
        values['memory'] = mem.percent
        values['disks'] = {}
        values['jobs'] = {'running': len(self.job_processes)}
        values['nets'] = {}
        values['processes'] = []
        values['gpus'] = {}

        try:
            for gpu_id, gpu in enumerate(aetros.cuda_gpu.get_ordered_devices()):
                values['gpus'][gpu_id] = aetros.cuda_gpu.get_memory(gpu['device'])
        except Exception: pass

        for disk in psutil.disk_partitions():
            try:
                name = self.get_disk_name(disk[1])
                values['disks'][name] = psutil.disk_usage(disk[1]).used
            except Exception: pass

        net_stats = psutil.net_io_counters(pernic=True)
        for id in self.nets:
            net = net_stats[id]
            values['nets'][id] = {
                'recv': net.bytes_recv,
                'sent': net.bytes_sent,
                'upload': 0,
                'download': 0
            }

            if id in self.last_net and self.last_utilization:
                values['nets'][id]['upload'] = (net.bytes_sent - self.last_net[id]['sent']) / (
                    time.time() - self.last_utilization)
                values['nets'][id]['download'] = (net.bytes_recv - self.last_net[id]['recv']) / (
                    time.time() - self.last_utilization)

            self.last_net[id] = dict(values['nets'][id])

        for p in psutil.process_iter():
            try:
                cpu = p.cpu_percent()
                if cpu > 1 or p.memory_percent() > 1:
                    values['processes'].append([
                        p.pid,
                        p.name(),
                        p.username(),
                        p.create_time(),
                        p.status(),
                        p.num_threads(),
                        p.memory_percent(),
                        cpu
                    ])
            except OSError:
                pass
            except psutil.Error:
                pass

        try:
            if hasattr(os, 'getloadavg'):
                values['loadavg'] = os.getloadavg()
            else:
                values['loadavg'] = ''
        except OSError:
            pass

        self.last_utilization = time.time()
        return values
    def main(self, args):
        import aetros.const

        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                         prog=aetros.const.__prog__ + ' server')
        parser.add_argument('name', nargs='?', help="Server name")
        parser.add_argument('--generate-ssh-key', help="Generates automatically a ssh key, register them in AETROS in "
                                                       "your account, and delete them when the server exits. "
                                                       "You should prefer 'aetros register' command as its safer.")

        parser.add_argument('--allow-host-execution', action='store_true', help="Whether a job can run on this server "
            "directly, without a virtual (docker) container.\nSecurity risk and makes resource limitation useless.")

        parser.add_argument('--max-memory',
            help="How many RAM is available. In gigabyte. Per default all available memory.")
        parser.add_argument('--max-cpus',
            help="How many cores are available. Per default all available CPU cores.")
        parser.add_argument('--max-gpus',
            help="How many GPUs are available. Comma separate list of device ids (pciBusId)."
                 "Per default all available GPU cards. Use 'aetros gpu' too see the ids.")

        parser.add_argument('--no-gpus', action='store_true', help="Disable all GPUs")

        parser.add_argument('--max-jobs', help="How many jobs are allowed to run in total until the process exists automatically.")
        parser.add_argument('--host', help="Default trainer.aetros.com. Read from the global configuration ~/aetros.yml.")
        parser.add_argument('--show-stdout', action='store_true', help="Show all stdout of all jobs. Only for debugging necessary.")

        parsed_args = parser.parse_args(args)

        if not parsed_args.name:
            parser.print_help()
            sys.exit()

        self.config = read_home_config()

        if parsed_args.max_jobs:
            self.max_jobs = int(parsed_args.max_jobs)

        if parsed_args.max_memory:
            self.resources_limit['memory'] = int(parsed_args.max_memory)

        if parsed_args.max_cpus:
            self.resources_limit['cpus'] = int(parsed_args.max_cpus)

        self.resources_limit['host_execution'] = parsed_args.allow_host_execution

        gpus = []
        try:
            gpus = aetros.cuda_gpu.get_ordered_devices()
            for i in range(len(gpus)):
                self.enabled_gpus.append(i)
        except Exception: pass

        if parsed_args.max_gpus:
            self.enabled_gpus = []

            for i in parsed_args.max_gpus.split(','):
                i = int(i)
                if i < 0 or i >= len(gpus):
                    raise Exception('--max-gpus ' + str(i) + ' not available on the system. GPUs ' + str([i for i in range(len(gpus))])+ ' detected.')

                self.enabled_gpus.append(i)

        elif parsed_args.no_gpus:
            self.enabled_gpus = []

        if parsed_args.show_stdout:
            self.show_stdout = True

        event_listener = EventListener()

        event_listener.on('registration', self.registration_complete)
        event_listener.on('failed', self.connection_failed)
        event_listener.on('jobs', self.sync_jobs)
        event_listener.on('close', self.on_client_close)

        if hasattr(signal, 'SIGUSR1'):
            signal.signal(signal.SIGUSR1, self.on_signusr1)

        ssh_key_registered = False
        if parsed_args.generate_ssh_key:
            self.logger.info('Generate SSH key')

            ssh_key = paramiko.RSAKey.generate(4096)
            self.ssh_key_private = ssh_key.key.private_bytes(
                serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()
            ).decode()
            self.ssh_key_public = 'rsa ' + ssh_key.get_base64() + ' ' + parsed_args.name

            self.logger.info('Register SSH key at ' + self.config['host'])

            data = {
                'name': parsed_args.name,
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }

            response = aetros.api.http_request('server/ssh-key', json_body=data, method='post')

            ssh_key_registered = response == True

        def delete_ssh_key():
            self.logger.info('Delete SSH key at ' + self.config['host'])

            data = {
                'secure_key': parsed_args.generate_ssh_key,
                'key': self.ssh_key_public,
            }
            response = aetros.api.http_request('server/ssh-key/delete', json_body=data)
            if not response:
                self.logger.error('Could not delete SSH key in AETROS Trainer.')

        if parsed_args.generate_ssh_key and ssh_key_registered:
            import atexit
            atexit.register(delete_ssh_key)

        if parsed_args.host:
            self.config['host'] = parsed_args.host

        if self.ssh_key_private:
            self.config['ssh_key_base64'] = self.ssh_key_private

        self.server = ServerClient(self.config, event_listener, self.logger)

        self.general_logger_stdout = GeneralLogger(job_backend=self, redirect_to=sys.__stdout__)
        self.general_logger_stderr = GeneralLogger(job_backend=self, redirect_to=sys.__stderr__)

        sys.stdout = self.general_logger_stdout
        sys.stderr = self.general_logger_stderr

        self.server.configure(parsed_args.name)
        self.logger.debug('Connecting to ' + self.config['host'])
        self.server.start()
        self.write_log("\n")

        try:
            while self.active:
                if self.registered:
                    self.server.send_message({'type': 'utilization', 'values': self.collect_system_utilization()})
                    self.check_finished_jobs()

                time.sleep(1)
        except KeyboardInterrupt:
            self.logger.warning('Aborted')
            self.stop()
    def main(self, args):

        from aetros import keras_model_utils

        import aetros.const
        from aetros.backend import JobBackend
        from aetros.logger import GeneralLogger
        from aetros.Trainer import Trainer

        parser = argparse.ArgumentParser(
            formatter_class=argparse.RawTextHelpFormatter,
            prog=aetros.const.__prog__ + ' upload-weights')
        parser.add_argument('id', nargs='?', help='model name or job id')
        parser.add_argument(
            '--secure-key',
            help="Secure key. Alternatively use API_KEY environment varibale.")
        parser.add_argument(
            '--weights',
            help=
            "Weights path. Per default we try to find it in the ./weights/ folder."
        )
        parser.add_argument(
            '--accuracy',
            help=
            "If you specified model name, you should also specify the accuracy this weights got."
        )
        parser.add_argument(
            '--latest',
            action="store_true",
            help="Instead of best epoch we upload latest weights.")

        parsed_args = parser.parse_args(args)
        job_backend = JobBackend(api_token=parsed_args.secure_key)

        if '/' in parsed_args.id and '@' not in parsed_args.id:
            job_backend.create(parsed_args.id)

        job_backend.load(parsed_args.id)

        if job_backend.job is None:
            raise Exception("Job not found")

        job_model = job_backend.get_job_model()

        weights_path = job_model.get_weights_filepath_best()

        if parsed_args.weights:
            weights_path = parsed_args.weights

        print(("Validate weights in %s ..." % (weights_path, )))

        keras_model_utils.job_prepare(job_model)

        general_logger = GeneralLogger()
        trainer = Trainer(job_backend, general_logger)

        job_model.set_input_shape(trainer)

        print("Loading model ...")
        model_provider = job_model.get_model_provider()
        model = model_provider.get_model(trainer)

        loss = model_provider.get_loss(trainer)
        optimizer = model_provider.get_optimizer(trainer)

        print("Compiling ...")
        model_provider.compile(trainer, model, loss, optimizer)

        print(("Validate weights %s ..." % (weights_path, )))
        job_model.load_weights(model, weights_path)
        print("Validated.")

        print("Uploading weights to %s of %s ..." %
              (job_backend.job_id, job_backend.model_id))

        job_backend.upload_weights(
            'best.hdf5', weights_path,
            float(parsed_args.accuracy) if parsed_args.accuracy else None)

        print("Done")