Example #1
0
def logging_wrapper(job, f, ip, port):
    """Wrapper to execute user passed functions remotely after
    setting up logging

    ip and port should specify somewhere we can push logging messages
    over zmq and have something useful happen to them
    """
    handler = NestedSetup([
        ZeroMQPushHandler("tcp://" + ip + ":" + port, level="DEBUG"),
        FileHandler(os.path.join(job["workdir"], job["description"]+".log"),
                    level="DEBUG", bubble=True)
    ])
    logger = Logger(job["description"])
    with handler.applicationbound():
        try:
            if job.get("tmpdir"):
                os.chdir(job["tmpdir"])
            else:
                os.chdir(job["workdir"])
            f(job, logger=logger)
        except:
            if job.get("tmpdir"):
                open(os.path.join(job["tmpdir"], ".error"), 'a').close()
            logger.exception("Task failed with traceback:")
            raise
        return job
Example #2
0
    def error(self, id_=None, error_code=None, error_msg=None):
        if isinstance(id_, Exception):
            # XXX: for an unknown reason 'log' is None in this branch,
            # therefore it needs to be instantiated before use
            global log
            if not log:
                log = Logger('IB Broker')
            log.exception(id_)

        if isinstance(error_code, EClientErrors.CodeMsgPair):
            error_msg = error_code.msg()
            error_code = error_code.code()

        if isinstance(error_code, int):
            if error_code in (502, 503, 326):
                # 502: Couldn't connect to TWS.
                # 503: The TWS is out of date and must be upgraded.
                # 326: Unable connect as the client id is already in use.
                self.unrecoverable_error = True

            if error_code < 1000:
                log.error("[{}] {} ({})".format(error_code, error_msg, id_))
            else:
                log.info("[{}] {} ({})".format(error_code, error_msg, id_))
        else:
            log.error("[{}] {} ({})".format(error_code, error_msg, id_))
def rpc_server(socket, protocol, dispatcher):
    log = Logger('rpc_server')
    log.debug('starting up...')
    while True:
        try:
            message = socket.recv_multipart()
        except Exception as e:
            log.warning('Failed to receive message from client, ignoring...')
            log.exception(e)
            continue

        log.debug('Received message %s from %r', message[-1], message[0])

        # assuming protocol is threadsafe and dispatcher is theadsafe, as long
        # as its immutable

        def handle_client(message):
            try:
                request = protocol.parse_request(message[-1])
            except RPCError as e:
                log.exception(e)
                response = e.error_respond()
            else:
                response = dispatcher.dispatch(request)
                log.debug('Response okay: %r', response)

            # send reply
            message[-1] = response.serialize()
            log.debug('Replying %s to %r', message[-1], message[0])
            socket.send_multipart(message)

        gevent.spawn(handle_client, message)
Example #4
0
    def error(self, id_=None, error_code=None, error_msg=None):
        if isinstance(id_, Exception):
            # XXX: for an unknown reason 'log' is None in this branch,
            # therefore it needs to be instantiated before use
            global log
            if not log:
                log = Logger('IB Broker')
            log.exception(id_)

        if isinstance(error_code, EClientErrors.CodeMsgPair):
            error_msg = error_code.msg()
            error_code = error_code.code()

        if isinstance(error_code, int):
            if error_code in (502, 503, 326):
                # 502: Couldn't connect to TWS.
                # 503: The TWS is out of date and must be upgraded.
                # 326: Unable connect as the client id is already in use.
                self.unrecoverable_error = True

            if error_code < 1000:
                log.error("[{}] {} ({})".format(error_code, error_msg, id_))
            else:
                log.info("[{}] {} ({})".format(error_code, error_msg, id_))
        else:
            log.error("[{}] {} ({})".format(error_code, error_msg, id_))
Example #5
0
def rpc_server(socket, protocol, dispatcher):
    log = Logger('rpc_server')
    log.debug('starting up...')
    while True:
        try:
            message = socket.recv_multipart()
        except Exception as e:
            log.warning('Failed to receive message from client, ignoring...')
            log.exception(e)
            continue

        log.debug('Received message %s from %r' % (message[-1], message[0]))

        # assuming protocol is threadsafe and dispatcher is theadsafe, as long
        # as its immutable

        def handle_client(message):
            try:
                request = protocol.parse_request(message[-1])
            except RPCError as e:
                log.exception(e)
                response = e.error_respond()
            else:
                response = dispatcher.dispatch(request)
                log.debug('Response okay: %r' % response)

            # send reply
            message[-1] = response.serialize()
            log.debug('Replying %s to %r' % (message[-1], message[0]))
            socket.send_multipart(message)

        gevent.spawn(handle_client, message)
Example #6
0
class Genome:
    path = attr.ib(default=Path(), converter=Path)

    def __attrs_post_init__(self):
        self.name = self.path.name
        self.log = Logger(self.name)
        try:
            self.accession_id = self.id_(self.name)
        except AttributeError:
            self.accession_id = "missing"
            self.log.exception("Invalid accession ID")

    @staticmethod
    def id_(name):
        return re.search("GCA_[0-9]*.[0-9]", name).group()

    def get_contigs(self):
        """Get a list of of Seq objects for genome and calculate
            the total the number of contigs.
            """
        try:
            with gzip.open(self.path, "rt") as handle:
                self.contigs = [
                    seq.seq for seq in SeqIO.parse(handle, "fasta")
                ]
            self.count_contigs = len(self.contigs)
        except UnicodeDecodeError:
            self.log.exception()

    def get_assembly_size(self):
        """Calculate the sum of all contig lengths"""
        self.assembly_size = sum((len(str(seq)) for seq in self.contigs))

    def get_unknowns(self):
        """Count the number of unknown bases, i.e. not [ATCG]"""
        # TODO: It might be useful to allow the user to define p.
        p = re.compile("[^ATCG]")
        self.unknowns = sum(
            (len(re.findall(p, str(seq))) for seq in self.contigs))

    def get_distance(self, dmx_mean):
        name = Path(self.path).name
        self.distance = dmx_mean.loc[name][1]

    def get_stats(self, dmx_mean):
        self.get_contigs()
        self.get_assembly_size()
        self.get_unknowns()
        self.get_distance(dmx_mean)
        data = {
            "contigs": self.count_contigs,
            "assembly_size": self.assembly_size,
            "unknowns": self.unknowns,
            "distance": self.distance,
        }
        self.stats = pd.DataFrame(data, index=[self.name])
        self.stats.to_csv(snakemake.input.fasta + ".csv")
Example #7
0
def main():
    logbook.concurrency.enable_gevent()

    global log
    StderrHandler().push_application()
    log = Logger('xbbs.worker')
    inst = XbbsWorker()

    XBBS_CFG_DIR = os.getenv("XBBS_CFG_DIR", "/etc/xbbs")
    with open(path.join(XBBS_CFG_DIR, "worker.toml"), "r") as fcfg:
        cfg = CONFIG_VALIDATOR.validate(toml.load(fcfg))

    job_request = msgs.JobRequest(capabilities=cfg["capabilities"]).pack()

    gevent.signal_handler(signal.SIGUSR1, gevent.util.print_run_info)

    log.info(cfg)
    while True:
        with inst.zmq.socket(zmq.REQ) as jobs:
            jobs.connect(cfg["job_endpoint"])

            while True:
                jobs.send(job_request)
                log.debug("waiting for job...")
                # the coordinator sends a heartbeat each minute, so 1.5 minutes
                # should be a sane duration to assume coordinator death on
                if jobs.poll(90000) == 0:
                    # breaking the inner loop will cause a reconnect
                    # since the coordinator is presumed dead, drop requests yet
                    # unsent to it
                    jobs.set(zmq.LINGER, 0)
                    log.debug("dropping socket after a heartbeat timeout")
                    break
                try:
                    msg = jobs.recv()
                    if len(msg) == 0:
                        # drop null msgs
                        continue
                    process_job_msg(inst, msg)
                except KeyboardInterrupt:
                    log.exception("interrupted")
                    return
                except Exception as e:
                    log.exception("job error", e)
Example #8
0
    def test_exception(self):
        logger = Logger('sentry.tests.test_contrib.test_logbook')
        handler = SentryHandler('INFO')
        with handler.applicationbound():
            try:
                raise ValueError('foo')
            except:
                logger.exception('foo bar')
        
        event = Event.objects.all()[0]

        self.assertEquals(event.type, 'sentry.events.Exception')
        self.assertEquals(event.time_spent, 0)
        self.assertTrue('sentry.interfaces.Exception' in event.data)
        event_data = event.data['sentry.interfaces.Exception']
        self.assertTrue('type' in event_data)
        self.assertEquals(event_data['type'], 'ValueError')
        self.assertTrue('value' in event_data)
        self.assertEquals(event_data['value'], 'foo')
        
        tags = dict(event.tags)
        self.assertTrue('level' in tags)
        self.assertEquals(tags['level'], 'error')
Example #9
0
    def test_exception(self):
        logger = Logger('sentry.tests.test_contrib.test_logbook')
        handler = SentryHandler('INFO')
        with handler.applicationbound():
            try:
                raise ValueError('foo')
            except:
                logger.exception('foo bar')

        event = Event.objects.all()[0]

        self.assertEquals(event.type, 'sentry.events.Exception')
        self.assertEquals(event.time_spent, 0)
        self.assertTrue('sentry.interfaces.Exception' in event.data)
        event_data = event.data['sentry.interfaces.Exception']
        self.assertTrue('type' in event_data)
        self.assertEquals(event_data['type'], 'ValueError')
        self.assertTrue('value' in event_data)
        self.assertEquals(event_data['value'], 'foo')

        tags = dict(event.tags)
        self.assertTrue('level' in tags)
        self.assertEquals(tags['level'], 'error')
Example #10
0
class Worker(object):
    redis_worker_namespace_prefix = 'rq:worker:'
    redis_workers_keys = 'rq:workers'

    @classmethod
    def all(cls):
        """Returns an iterable of all Workers.
        """
        conn = get_current_connection()
        reported_working = conn.smembers(cls.redis_workers_keys)
        return compact(map(cls.find_by_key, reported_working))

    @classmethod
    def find_by_key(cls, worker_key):
        """Returns a Worker instance, based on the naming conventions for
        naming the internal Redis keys.  Can be used to reverse-lookup Workers
        by their Redis keys.
        """
        prefix = cls.redis_worker_namespace_prefix
        name = worker_key[len(prefix):]
        if not worker_key.startswith(prefix):
            raise ValueError('Not a valid RQ worker key: %s' % (worker_key, ))

        conn = get_current_connection()
        if not conn.exists(worker_key):
            return None

        name = worker_key[len(prefix):]
        worker = cls([], name)
        queues = conn.hget(worker.key, 'queues')
        worker._state = conn.hget(worker.key, 'state') or '?'
        if queues:
            worker.queues = map(Queue, queues.split(','))
        return worker

    def __init__(self, queues, name=None, rv_ttl=500, connection=None):  # noqa
        if connection is None:
            connection = get_current_connection()
        self.connection = connection
        if isinstance(queues, Queue):
            queues = [queues]
        self._name = name
        self.queues = queues
        self.validate_queues()
        self.rv_ttl = rv_ttl
        self._state = 'starting'
        self._is_horse = False
        self._horse_pid = 0
        self._stopped = False
        self.log = Logger('worker')
        self.failed_queue = get_failed_queue(connection=self.connection)

    def validate_queues(self):  # noqa
        """Sanity check for the given queues."""
        if not iterable(self.queues):
            raise ValueError('Argument queues not iterable.')
        for queue in self.queues:
            if not isinstance(queue, Queue):
                raise NoQueueError('Give each worker at least one Queue.')

    def queue_names(self):
        """Returns the queue names of this worker's queues."""
        return map(lambda q: q.name, self.queues)

    def queue_keys(self):
        """Returns the Redis keys representing this worker's queues."""
        return map(lambda q: q.key, self.queues)

    @property  # noqa
    def name(self):
        """Returns the name of the worker, under which it is registered to the
        monitoring system.

        By default, the name of the worker is constructed from the current
        (short) host name and the current PID.
        """
        if self._name is None:
            hostname = socket.gethostname()
            shortname, _, _ = hostname.partition('.')
            self._name = '%s.%s' % (shortname, self.pid)
        return self._name

    @property
    def key(self):
        """Returns the worker's Redis hash key."""
        return self.redis_worker_namespace_prefix + self.name

    @property
    def pid(self):
        """The current process ID."""
        return os.getpid()

    @property
    def horse_pid(self):
        """The horse's process ID.  Only available in the worker.  Will return
        0 in the horse part of the fork.
        """
        return self._horse_pid

    @property
    def is_horse(self):
        """Returns whether or not this is the worker or the work horse."""
        return self._is_horse

    def procline(self, message):
        """Changes the current procname for the process.

        This can be used to make `ps -ef` output more readable.
        """
        procname.setprocname('rq: %s' % (message, ))

    def register_birth(self):  # noqa
        """Registers its own birth."""
        self.log.debug('Registering birth of worker %s' % (self.name, ))
        if self.connection.exists(self.key) and \
                not self.connection.hexists(self.key, 'death'):
            raise ValueError('There exists an active worker named \'%s\' '
                             'already.' % (self.name, ))
        key = self.key
        now = time.time()
        queues = ','.join(self.queue_names())
        with self.connection.pipeline() as p:
            p.delete(key)
            p.hset(key, 'birth', now)
            p.hset(key, 'queues', queues)
            p.sadd(self.redis_workers_keys, key)
            p.execute()

    def register_death(self):
        """Registers its own death."""
        self.log.debug('Registering death')
        with self.connection.pipeline() as p:
            # We cannot use self.state = 'dead' here, because that would
            # rollback the pipeline
            p.srem(self.redis_workers_keys, self.key)
            p.hset(self.key, 'death', time.time())
            p.expire(self.key, 60)
            p.execute()

    def set_state(self, new_state):
        self._state = new_state
        self.connection.hset(self.key, 'state', new_state)

    def get_state(self):
        return self._state

    state = property(get_state, set_state)

    @property
    def stopped(self):
        return self._stopped

    def _install_signal_handlers(self):
        """Installs signal handlers for handling SIGINT and SIGTERM
        gracefully.
        """
        def request_force_stop(signum, frame):
            """Terminates the application (cold shutdown).
            """
            self.log.warning('Cold shut down.')

            # Take down the horse with the worker
            if self.horse_pid:
                msg = 'Taking down horse %d with me.' % self.horse_pid
                self.log.debug(msg)
                try:
                    os.kill(self.horse_pid, signal.SIGKILL)
                except OSError as e:
                    # ESRCH ("No such process") is fine with us
                    if e.errno != errno.ESRCH:
                        self.log.debug('Horse already down.')
                        raise
            raise SystemExit()

        def request_stop(signum, frame):
            """Stops the current worker loop but waits for child processes to
            end gracefully (warm shutdown).
            """
            self.log.debug('Got %s signal.' % signal_name(signum))

            signal.signal(signal.SIGINT, request_force_stop)
            signal.signal(signal.SIGTERM, request_force_stop)

            if self.is_horse:
                self.log.debug('Ignoring signal %s.' % signal_name(signum))
                return

            msg = 'Warm shut down. Press Ctrl+C again for a cold shutdown.'
            self.log.warning(msg)
            self._stopped = True
            self.log.debug('Stopping after current horse is finished.')

        signal.signal(signal.SIGINT, request_stop)
        signal.signal(signal.SIGTERM, request_stop)

    def work(self, burst=False):  # noqa
        """Starts the work loop.

        Pops and performs all jobs on the current list of queues.  When all
        queues are empty, block and wait for new jobs to arrive on any of the
        queues, unless `burst` mode is enabled.

        The return value indicates whether any jobs were processed.
        """
        self._install_signal_handlers()

        did_perform_work = False
        self.register_birth()
        self.state = 'starting'
        try:
            while True:
                if self.stopped:
                    self.log.info('Stopping on request.')
                    break
                self.state = 'idle'
                qnames = self.queue_names()
                self.procline('Listening on %s' % ','.join(qnames))
                self.log.info('')
                self.log.info('*** Listening on %s...' % \
                        green(', '.join(qnames)))
                wait_for_job = not burst
                try:
                    result = Queue.dequeue_any(self.queues, wait_for_job, \
                            connection=self.connection)
                    if result is None:
                        break
                except UnpickleError as e:
                    msg = '*** Ignoring unpickleable data on %s.' % \
                            green(e.queue.name)
                    self.log.warning(msg)
                    self.log.debug('Data follows:')
                    self.log.debug(e.raw_data)
                    self.log.debug('End of unreadable data.')
                    self.failed_queue.push_job_id(e.job_id)
                    continue

                job, queue = result
                self.log.info(
                    '%s: %s (%s)' %
                    (green(queue.name), blue(job.description), job.id))

                self.state = 'busy'
                self.fork_and_perform_job(job)

                did_perform_work = True
        finally:
            if not self.is_horse:
                self.register_death()
        return did_perform_work

    def fork_and_perform_job(self, job):
        """Spawns a work horse to perform the actual work and passes it a job.
        The worker will wait for the work horse and make sure it executes
        within the given timeout bounds, or will end the work horse with
        SIGALRM.
        """
        child_pid = os.fork()
        if child_pid == 0:
            self.main_work_horse(job)
        else:
            self._horse_pid = child_pid
            self.procline('Forked %d at %d' % (child_pid, time.time()))
            while True:
                try:
                    os.waitpid(child_pid, 0)
                    break
                except OSError as e:
                    # In case we encountered an OSError due to EINTR (which is
                    # caused by a SIGINT or SIGTERM signal during
                    # os.waitpid()), we simply ignore it and enter the next
                    # iteration of the loop, waiting for the child to end.  In
                    # any other case, this is some other unexpected OS error,
                    # which we don't want to catch, so we re-raise those ones.
                    if e.errno != errno.EINTR:
                        raise

    def main_work_horse(self, job):
        """This is the entry point of the newly spawned work horse."""
        # After fork()'ing, always assure we are generating random sequences
        # that are different from the worker.
        random.seed()
        self._is_horse = True
        self.log = Logger('horse')

        success = self.perform_job(job)

        # os._exit() is the way to exit from childs after a fork(), in
        # constrast to the regular sys.exit()
        os._exit(int(not success))

    def perform_job(self, job):
        """Performs the actual work of a job.  Will/should only be called
        inside the work horse's process.
        """
        self.procline('Processing %s from %s since %s' %
                      (job.func_name, job.origin, time.time()))

        try:
            with death_pentalty_after(job.timeout or 180):
                rv = job.perform()
        except Exception as e:
            fq = self.failed_queue
            self.log.exception(red(str(e)))
            self.log.warning('Moving job to %s queue.' % fq.name)

            fq.quarantine(job, exc_info=traceback.format_exc())
            return False

        if rv is None:
            self.log.info('Job OK')
        else:
            self.log.info('Job OK, result = %s' % (yellow(unicode(rv)), ))

        if rv is not None:
            p = self.connection.pipeline()
            p.hset(job.key, 'result', dumps(rv))
            p.expire(job.key, self.rv_ttl)
            p.execute()
        else:
            # Cleanup immediately
            job.delete()

        return True
Example #11
0
class Genome:
    def __init__(self, genome, assembly_summary=None):
        """
        :param genome: Path to genome
        :returns: Path to genome and name of the genome
        :rtype:
        """
        self.path = os.path.abspath(genome)
        self.species_dir, self.fasta = os.path.split(self.path)
        self.name = os.path.splitext(self.fasta)[0]
        self.log = Logger(self.name)
        self.qc_dir = os.path.join(self.species_dir, "qc")
        self.stats_file = os.path.join(self.qc_dir, self.name + ".csv")
        self.sketch_file = os.path.join(self.qc_dir, self.name + ".msh")
        self.assembly_summary = assembly_summary
        self.metadata = defaultdict(lambda: "missing")
        self.xml = defaultdict(lambda: "missing")
        try:
            self.accession_id = re.search("GCA_[0-9]*.[0-9]",
                                          self.name).group()
            self.metadata["accession"] = self.accession_id
        except AttributeError:
            self.accession_id = "missing"
            self.log.exception("Invalid accession ID")
        # Don't do this here
        if isinstance(self.assembly_summary, pd.DataFrame):
            try:
                biosample = assembly_summary.loc[self.accession_id].biosample
                self.metadata["biosample_id"] = biosample
            except (AttributeError, KeyError):
                self.log.exception("Unable to get biosample ID")

    @staticmethod
    def id_(name):
        return re.search("GCA_[0-9]*.[0-9]", name).group()

    def get_contigs(self):
        """
        Return a list of of Bio.Seq.Seq objects for fasta and calculate
        the total the number of contigs.
        """
        try:
            self.contigs = [seq.seq for seq in SeqIO.parse(self.path, "fasta")]
            self.count_contigs = len(self.contigs)
        except UnicodeDecodeError:
            self.log.exception()

    def get_assembly_size(self):
        """Calculate the sum of all contig lengths"""
        # TODO: map or reduce might be more elegant here
        self.assembly_size = sum((len(str(seq)) for seq in self.contigs))

    def get_unknowns(self):
        """Count the number of unknown bases, i.e. not [ATCG]"""
        # TODO: Would it be useful to allow the user to define p?
        p = re.compile("[^ATCG]")
        self.unknowns = sum(
            (len(re.findall(p, str(seq))) for seq in self.contigs))

    def get_distance(self, dmx_mean):
        self.distance = dmx_mean.loc[self.name]

    def sketch(self):
        cmd = "mash sketch '{}' -o '{}'".format(self.path, self.sketch_file)
        if os.path.isfile(self.sketch_file):
            pass
        else:
            subprocess.Popen(cmd, shell="True",
                             stderr=subprocess.DEVNULL).wait()

    def get_stats(self, dmx_mean):
        if not os.path.isfile(self.stats_file):
            self.get_contigs()
            self.get_assembly_size()
            self.get_unknowns()
            self.get_distance(dmx_mean)
            data = {
                "contigs": self.count_contigs,
                "assembly_size": self.assembly_size,
                "unknowns": self.unknowns,
                "distance": self.distance,
            }
            self.stats = pd.DataFrame(data, index=[self.name])
            self.stats.to_csv(self.stats_file)

    @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
    def efetch(self, db):
        """
        Use NCBI's efetch tools to get xml for genome's biosample id or SRA id
        """
        if db == "biosample":
            db_id = db + "_id"
        elif db == "sra":
            db_id = db + "_id"
        cmd = "esearch -db {} -query {} | " "efetch -format docsum".format(
            db, self.metadata[db_id])
        # Make efetch timeout and retry after 30 seconds
        time_limit = 30
        if self.metadata[db_id] != "missing":
            try:
                p = subprocess.run(
                    cmd,
                    shell="True",
                    stdout=subprocess.PIPE,
                    stderr=subprocess.DEVNULL,
                    timeout=time_limit,
                )
                xml = p.stdout
                self.xml[db] = xml
            except subprocess.TimeoutExpired:
                self.log.error("Retrying efetch after timeout")
                raise subprocess.TimeoutExpired(cmd, time_limit)
            except Exception:
                self.log.exception(db)

    def parse_biosample(self):
        """
        Get what we need out of the xml returned by efetch("biosample")ncluding
        the SRA ID and fields of interest as defined in Metadata.biosample_fields
        """
        attributes = [
            "BioSample",
            "geo_loc_name",
            "collection_date",
            "strain",
            "isolation_source",
            "host",
            "collected_by",
            "sample_type",
            "sample_name",
            "host_disease",
            "isolate",
            "host_health_state",
            "serovar",
            "env_biome",
            "env_feature",
            "ref_biomaterial",
            "env_material",
            "isol_growth_condt",
            "num_replicons",
            "sub_species",
            "host_age",
            "genotype",
            "host_sex",
            "serotype",
            "host_disease_outcome",
        ]
        try:
            tree = ET.fromstring(self.xml["biosample"])
            sra = tree.find(
                'DocumentSummary/SampleData/BioSample/Ids/Id/[@db="SRA"]')
            try:
                self.metadata["sra_id"] = sra.text
            except AttributeError:
                self.metadata["sra_id"] = "missing"
            for name in attributes:
                xp = ("DocumentSummary/SampleData/BioSample/Attributes/"
                      'Attribute/[@harmonized_name="{}"]'.format(name))
                attrib = tree.find(xp)
                try:
                    self.metadata[name] = attrib.text
                except AttributeError:
                    self.metadata[name] = "missing"
        except ParseError:
            self.log.exception()

    def parse_sra(self):
        try:
            tree = ET.fromstring(self.xml["sra"])
            elements = tree.iterfind("DocumentSummary/Runs/Run/[@acc]")
            srr_accessions = []
            for el in elements:
                items = el.items()
                acc = [i[1] for i in items if i[0] == "acc"]
                acc = acc[0]
                srr_accessions.append(acc)
            self.metadata["srr_accessions"] = ",".join(srr_accessions)
        except ParseError:
            self.log.exception("Parse error for SRA XML")

    def get_metadata(self):
        self.efetch("biosample")
        self.parse_biosample()
        self.efetch("sra")
        self.parse_sra()
Example #12
0
File: worker.py Project: jgelens/rq
class Worker(object):
    redis_worker_namespace_prefix = 'rq:worker:'
    redis_workers_keys = 'rq:workers'

    @classmethod
    def all(cls, connection=None):
        """Returns an iterable of all Workers.
        """
        if connection is None:
            connection = get_current_connection()
        reported_working = connection.smembers(cls.redis_workers_keys)
        workers = [cls.find_by_key(key, connection) for key in
                reported_working]
        return compact(workers)

    @classmethod
    def find_by_key(cls, worker_key, connection=None):
        """Returns a Worker instance, based on the naming conventions for
        naming the internal Redis keys.  Can be used to reverse-lookup Workers
        by their Redis keys.
        """
        prefix = cls.redis_worker_namespace_prefix
        name = worker_key[len(prefix):]
        if not worker_key.startswith(prefix):
            raise ValueError('Not a valid RQ worker key: %s' % (worker_key,))

        if connection is None:
            connection = get_current_connection()
        if not connection.exists(worker_key):
            return None

        name = worker_key[len(prefix):]
        worker = cls([], name)
        queues = connection.hget(worker.key, 'queues')
        worker._state = connection.hget(worker.key, 'state') or '?'
        if queues:
            worker.queues = map(Queue, queues.split(','))
        return worker


    def __init__(self, queues, name=None, default_result_ttl=500,
            connection=None):  # noqa
        if connection is None:
            connection = get_current_connection()
        self.connection = connection
        if isinstance(queues, Queue):
            queues = [queues]
        self._name = name
        self.queues = queues
        self.validate_queues()
        self.default_result_ttl = default_result_ttl
        self._state = 'starting'
        self._is_horse = False
        self._horse_pid = 0
        self._stopped = False
        self.log = Logger('worker')
        self.failed_queue = get_failed_queue(connection=self.connection)


    def validate_queues(self):  # noqa
        """Sanity check for the given queues."""
        if not iterable(self.queues):
            raise ValueError('Argument queues not iterable.')
        for queue in self.queues:
            if not isinstance(queue, Queue):
                raise NoQueueError('Give each worker at least one Queue.')

    def queue_names(self):
        """Returns the queue names of this worker's queues."""
        return map(lambda q: q.name, self.queues)

    def queue_keys(self):
        """Returns the Redis keys representing this worker's queues."""
        return map(lambda q: q.key, self.queues)


    @property  # noqa
    def name(self):
        """Returns the name of the worker, under which it is registered to the
        monitoring system.

        By default, the name of the worker is constructed from the current
        (short) host name and the current PID.
        """
        if self._name is None:
            hostname = socket.gethostname()
            shortname, _, _ = hostname.partition('.')
            self._name = '%s.%s' % (shortname, self.pid)
        return self._name

    @property
    def key(self):
        """Returns the worker's Redis hash key."""
        return self.redis_worker_namespace_prefix + self.name

    @property
    def pid(self):
        """The current process ID."""
        return os.getpid()

    @property
    def horse_pid(self):
        """The horse's process ID.  Only available in the worker.  Will return
        0 in the horse part of the fork.
        """
        return self._horse_pid

    @property
    def is_horse(self):
        """Returns whether or not this is the worker or the work horse."""
        return self._is_horse

    def procline(self, message):
        """Changes the current procname for the process.

        This can be used to make `ps -ef` output more readable.
        """
        setprocname('rq: %s' % (message,))


    def register_birth(self):  # noqa
        """Registers its own birth."""
        self.log.debug('Registering birth of worker %s' % (self.name,))
        if self.connection.exists(self.key) and \
                not self.connection.hexists(self.key, 'death'):
            raise ValueError(
                    'There exists an active worker named \'%s\' '
                    'already.' % (self.name,))
        key = self.key
        now = time.time()
        queues = ','.join(self.queue_names())
        with self.connection.pipeline() as p:
            p.delete(key)
            p.hset(key, 'birth', now)
            p.hset(key, 'queues', queues)
            p.sadd(self.redis_workers_keys, key)
            p.execute()

    def register_death(self):
        """Registers its own death."""
        self.log.debug('Registering death')
        with self.connection.pipeline() as p:
            # We cannot use self.state = 'dead' here, because that would
            # rollback the pipeline
            p.srem(self.redis_workers_keys, self.key)
            p.hset(self.key, 'death', time.time())
            p.expire(self.key, 60)
            p.execute()

    def set_state(self, new_state):
        self._state = new_state
        self.connection.hset(self.key, 'state', new_state)

    def get_state(self):
        return self._state

    state = property(get_state, set_state)

    @property
    def stopped(self):
        return self._stopped

    def _install_signal_handlers(self):
        """Installs signal handlers for handling SIGINT and SIGTERM
        gracefully.
        """

        def request_force_stop(signum, frame):
            """Terminates the application (cold shutdown).
            """
            self.log.warning('Cold shut down.')

            # Take down the horse with the worker
            if self.horse_pid:
                msg = 'Taking down horse %d with me.' % self.horse_pid
                self.log.debug(msg)
                try:
                    os.kill(self.horse_pid, signal.SIGKILL)
                except OSError as e:
                    # ESRCH ("No such process") is fine with us
                    if e.errno != errno.ESRCH:
                        self.log.debug('Horse already down.')
                        raise
            raise SystemExit()

        def request_stop(signum, frame):
            """Stops the current worker loop but waits for child processes to
            end gracefully (warm shutdown).
            """
            self.log.debug('Got signal %s.' % signal_name(signum))

            signal.signal(signal.SIGINT, request_force_stop)
            signal.signal(signal.SIGTERM, request_force_stop)

            msg = 'Warm shut down requested.'
            self.log.warning(msg)

            # If shutdown is requested in the middle of a job, wait until
            # finish before shutting down
            if self.state == 'busy':
                self._stopped = True
                self.log.debug('Stopping after current horse is finished. '
                               'Press Ctrl+C again for a cold shutdown.')
            else:
                raise StopRequested()

        signal.signal(signal.SIGINT, request_stop)
        signal.signal(signal.SIGTERM, request_stop)


    def work(self, burst=False):  # noqa
        """Starts the work loop.

        Pops and performs all jobs on the current list of queues.  When all
        queues are empty, block and wait for new jobs to arrive on any of the
        queues, unless `burst` mode is enabled.

        The return value indicates whether any jobs were processed.
        """
        self._install_signal_handlers()

        did_perform_work = False
        self.register_birth()
        self.log.info('RQ worker started, version %s' % VERSION)
        self.state = 'starting'
        try:
            while True:
                if self.stopped:
                    self.log.info('Stopping on request.')
                    break
                self.state = 'idle'
                qnames = self.queue_names()
                self.procline('Listening on %s' % ','.join(qnames))
                self.log.info('')
                self.log.info('*** Listening on %s...' % \
                        green(', '.join(qnames)))
                wait_for_job = not burst
                try:
                    result = Queue.dequeue_any(self.queues, wait_for_job, \
                            connection=self.connection)
                    if result is None:
                        break
                except StopRequested:
                    break
                except UnpickleError as e:
                    msg = '*** Ignoring unpickleable data on %s.' % \
                            green(e.queue.name)
                    self.log.warning(msg)
                    self.log.debug('Data follows:')
                    self.log.debug(e.raw_data)
                    self.log.debug('End of unreadable data.')
                    self.failed_queue.push_job_id(e.job_id)
                    continue

                self.state = 'busy'

                job, queue = result
                self.log.info('%s: %s (%s)' % (green(queue.name),
                    blue(job.description), job.id))

                self.fork_and_perform_job(job)

                did_perform_work = True
        finally:
            if not self.is_horse:
                self.register_death()
        return did_perform_work

    def fork_and_perform_job(self, job):
        """Spawns a work horse to perform the actual work and passes it a job.
        The worker will wait for the work horse and make sure it executes
        within the given timeout bounds, or will end the work horse with
        SIGALRM.
        """
        child_pid = os.fork()
        if child_pid == 0:
            self.main_work_horse(job)
        else:
            self._horse_pid = child_pid
            self.procline('Forked %d at %d' % (child_pid, time.time()))
            while True:
                try:
                    os.waitpid(child_pid, 0)
                    break
                except OSError as e:
                    # In case we encountered an OSError due to EINTR (which is
                    # caused by a SIGINT or SIGTERM signal during
                    # os.waitpid()), we simply ignore it and enter the next
                    # iteration of the loop, waiting for the child to end.  In
                    # any other case, this is some other unexpected OS error,
                    # which we don't want to catch, so we re-raise those ones.
                    if e.errno != errno.EINTR:
                        raise

    def main_work_horse(self, job):
        """This is the entry point of the newly spawned work horse."""
        # After fork()'ing, always assure we are generating random sequences
        # that are different from the worker.
        random.seed()

        # Always ignore Ctrl+C in the work horse, as it might abort the
        # currently running job.
        # The main worker catches the Ctrl+C and requests graceful shutdown
        # after the current work is done.  When cold shutdown is requested, it
        # kills the current job anyway.
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        signal.signal(signal.SIGTERM, signal.SIG_DFL)

        self._is_horse = True
        self.log = Logger('horse')

        success = self.perform_job(job)

        # os._exit() is the way to exit from childs after a fork(), in
        # constrast to the regular sys.exit()
        os._exit(int(not success))

    def perform_job(self, job):
        """Performs the actual work of a job.  Will/should only be called
        inside the work horse's process.
        """
        self.procline('Processing %s from %s since %s' % (
            job.func_name,
            job.origin, time.time()))

        try:
            with death_penalty_after(job.timeout or 180):
                rv = job.perform()

            # Pickle the result in the same try-except block since we need to
            # use the same exc handling when pickling fails
            pickled_rv = dumps(rv)
        except Exception as e:
            fq = self.failed_queue
            self.log.exception(red(str(e)))
            self.log.warning('Moving job to %s queue.' % fq.name)

            fq.quarantine(job, exc_info=traceback.format_exc())
            return False

        if rv is None:
            self.log.info('Job OK')
        else:
            self.log.info('Job OK, result = %s' % (yellow(unicode(rv)),))

        # Expire results
        has_result = rv is not None
        explicit_ttl_requested = job.result_ttl is not None
        should_expire = has_result or explicit_ttl_requested
        if should_expire:
            p = self.connection.pipeline()
            p.hset(job.key, 'result', pickled_rv)

            if explicit_ttl_requested:
                ttl = job.result_ttl
            else:
                ttl = self.default_result_ttl
            if ttl >= 0:
                p.expire(job.key, ttl)
            p.execute()
        else:
            # Cleanup immediately
            job.delete()

        return True
Example #13
0
class BTgymBaseData:
    """
    Base BTgym data provider class.
    Provides core data loading, sampling, splitting  and converting functionality.
    Do not use directly.

    Enables Pipe::

        CSV[source data]-->pandas[for efficient sampling]-->bt.feeds

    """

    def __init__(
            self,
            filename=None,
            parsing_params=None,
            sampling_params=None,
            name='base_data',
            data_names=('default_asset',),
            task=0,
            frozen_time_split=None,
            log_level=WARNING,
            _config_stack=None,
            **kwargs
    ):
        """
        Args:

            filename:                       Str or list of str, should be given either here or when calling read_csv(),
                                            see `Notes`.

            specific_params CSV to Pandas parsing

            sep:                            ';'
            header:                         0
            index_col:                      0
            parse_dates:                    True
            names:                          ['open', 'high', 'low', 'close', 'volume']

            specific_params Pandas to BT.feeds conversion

            timeframe=1:                    1 minute.
            datetime:                       0
            open:                           1
            high:                           2
            low:                            3
            close:                          4
            volume:                         -1
            openinterest:                   -1

            specific_params Sampling

            sample_class_ref:               None - if not None, than sample() method will return instance of specified
                                            class, which itself must be subclass of BaseBTgymDataset,
                                            else returns instance of the base data class.

            start_weekdays:                 [0, 1, 2, 3, ] - Only weekdays from the list will be used for sample start.
            start_00:                       True - sample start time will be set to first record of the day
                                            (usually 00:00).
            sample_duration:                {'days': 1, 'hours': 23, 'minutes': 55} - Maximum sample time duration
                                            in days, hours, minutes
            time_gap:                       {''days': 0, hours': 5, 'minutes': 0} - Data omittance threshold:
                                            maximum no-data time gap allowed within sample in days, hours.
                                            Thereby, if set to be < 1 day, samples containing weekends and holidays gaps
                                            will be rejected.
            test_period:                    {'days': 0, 'hours': 0, 'minutes': 0} - setting this param to non-zero
                                            duration forces instance.data split to train / test subsets with test
                                            subset duration equal to `test_period` with `time_gap` tolerance. Train data
                                            always precedes test one:
                                            [0_record<-train_data->split_point_record<-test_data->last_record].
            sample_expanding:               None, reserved for child classes.

        Note:
            - CSV file can contain duplicate records, checks will be performed and all duplicates will be removed;

            - CSV file should be properly sorted by date_time in ascending order, no sorting checks performed.

            - When supplying list of file_names, all files should be also listed ascending by their time period,
              no correct sampling will be possible otherwise.

            - Default parameters are source-specific and made to correctly parse 1 minute Forex generic ASCII
              data files from www.HistData.com. Tune according to your data source.
        """
        self.filename = filename

        if parsing_params is None:
            self.parsing_params = dict(
                # Default parameters for source-specific CSV datafeed class,
                # correctly parses 1 minute Forex generic ASCII
                # data files from www.HistData.com:

                # CSV to Pandas params.
                sep=';',
                header=0,
                index_col=0,
                parse_dates=True,
                names=['open', 'high', 'low', 'close', 'volume'],

                # Pandas to BT.feeds params:
                timeframe=1,  # 1 minute.
                datetime=0,
                open=1,
                high=2,
                low=3,
                close=4,
                volume=-1,
                openinterest=-1,
            )
        else:
            self.parsing_params = parsing_params

        if sampling_params is None:
            self.sampling_params = dict(
                # Sampling params:
                start_weekdays=[],  # Only weekdays from the list will be used for episode start.
                start_00=False,  # Sample start time will be set to first record of the day (usually 00:00).
                sample_duration=dict(  # Maximum sample time duration in days, hours, minutes:
                    days=0,
                    hours=0,
                    minutes=0
                ),
                time_gap=dict(  # Maximum data time gap allowed within sample in days, hours. Thereby,
                    days=0,  # if set to be < 1 day, samples containing weekends and holidays gaps will be rejected.
                    hours=0,
                ),
                test_period=dict(  # Time period to take test samples from, in days, hours, minutes:
                    days=0,
                    hours=0,
                    minutes=0
                ),
                expanding=False,
            )
        else:
            self.sampling_params = sampling_params

        self.name = name
        # String will be used as key name for bt_feed data-line:

        self.task = task
        self.log_level = log_level
        self.data_names = data_names
        self.data_name = self.data_names[0]

        self.data = None  # Will hold actual data as pandas dataframe
        self.is_ready = False

        self.global_timestamp = 0
        self.start_timestamp = 0
        self.final_timestamp = 0

        self.data_stat = None  # Dataset descriptive statistic as pandas dataframe
        self.data_range_delta = None  # Dataset total duration timedelta
        self.max_time_gap = None
        self.time_gap = None
        self.max_sample_len_delta = None
        self.sample_duration = None
        self.sample_num_records = 0
        self.start_weekdays = {0, 1, 2, 3, 4, 5, 6}
        self.start_00 = False
        self.expanding = False

        self.sample_instance = None

        self.test_range_delta = None
        self.train_range_delta = None
        self.test_num_records = 0
        self.train_num_records = 0
        self.total_num_records = 0
        self.train_interval = [0, 0]
        self.test_interval = [0, 0]
        self.test_period = {'days': 0, 'hours': 0, 'minutes': 0}
        self.train_period = {'days': 0, 'hours': 0, 'minutes': 0}
        self._test_period_backshift_delta = datetime.timedelta(**{'days': 0, 'hours': 0, 'minutes': 0})
        self.sample_num = 0
        self.task = 0
        self.metadata = {'sample_num': 0, 'type': None}

        self.set_params(self.parsing_params)
        self.set_params(self.sampling_params)

        self._config_stack = copy.deepcopy(_config_stack)
        try:
            nested_config = self._config_stack.pop()

        except (IndexError, AttributeError) as e:
            # IF stack is empty, sample of this instance itself is not supposed to be sampled.
            nested_config = dict(
                class_ref=None,
                kwargs=dict(
                    parsing_params=self.parsing_params,
                    sample_params=None,
                    name='data_stream',
                    task=self.task,
                    log_level=self.log_level,
                    _config_stack=None,
                )
            )
        # Configure sample instance parameters:
        self.nested_class_ref = nested_config['class_ref']
        self.nested_params = nested_config['kwargs']
        self.sample_name = '{}_w_{}_'.format(self.nested_params['name'], self.task)
        self.nested_params['_config_stack'] = self._config_stack

        # Logging:
        StreamHandler(sys.stdout).push_application()
        self.log = Logger('{}_{}'.format(self.name, self.task), level=self.log_level)

        # Legacy parameter dictionary, left here for BTgym API_shell:
        self.params = {}
        self.params.update(self.parsing_params)
        self.params.update(self.sampling_params)

        if frozen_time_split is not None:
            self.frozen_time_split = datetime.datetime(**frozen_time_split)

        else:
            self.frozen_time_split = None

        self.frozen_split_timestamp = None

    def set_params(self, params_dict):
        """
        Batch attribute setter.

        Args:
            params_dict: dictionary of parameters to be set as instance attributes.
        """
        for key, value in params_dict.items():
            setattr(self, key, value)

    def set_logger(self, level=None, task=None):
        """
        Sets logbook logger.

        Args:
            level:  logbook.level, int
            task:   task id, int

        """
        if task is not None:
            self.task = task

        if level is not None:
            self.log = Logger('{}_{}'.format(self.name, self.task), level=level)

    def set_global_timestamp(self, timestamp):
        if self.data is not None:
            self.global_timestamp = self.data.index[0].timestamp()

    def reset(self, data_filename=None, **kwargs):
        """
        Gets instance ready.

        Args:
            data_filename:  [opt] string or list of strings.
            kwargs:         not used.

        """
        self._reset(data_filename=data_filename, **kwargs)

    def _reset(self, data_filename=None, timestamp=None, **kwargs):

        self.read_csv(data_filename)

        # Add global timepoints:
        self.start_timestamp = self.data.index[0].timestamp()
        self.final_timestamp = self.data.index[-1].timestamp()

        if self.frozen_time_split is not None:
            frozen_index = self.data.index.get_loc(self.frozen_time_split, method='ffill')
            self.frozen_split_timestamp = self.data.index[frozen_index].timestamp()
            self.set_global_timestamp(self.frozen_split_timestamp)

        else:
            self.frozen_split_timestamp = None
            self.set_global_timestamp(timestamp)

        self.log.debug(
            'time stamps start: {}, current: {} final: {}'.format(
                self.start_timestamp,
                self.global_timestamp,
                self.final_timestamp
            )
        )

        # Maximum data time gap allowed within sample as pydatetimedelta obj:
        self.max_time_gap = datetime.timedelta(**self.time_gap)

        # Max. gap number of records:
        self.max_gap_num_records = int(self.max_time_gap.total_seconds() / (60 * self.timeframe))

        # ... maximum episode time duration:
        self.max_sample_len_delta = datetime.timedelta(**self.sample_duration)

        # Maximum possible number of data records (rows) within episode:
        self.sample_num_records = int(self.max_sample_len_delta.total_seconds() / (60 * self.timeframe))

        self.backshift_num_records = round(self._test_period_backshift_delta.total_seconds() / (60 * self.timeframe))

        # Train/test timedeltas:
        if self.train_period is None or self.test_period == -1:
            # No train data assumed, test only:
            self.train_num_records = 0
            self.test_num_records = self.data.shape[0] - self.backshift_num_records
            break_point = self.backshift_num_records
            self.train_interval = [0, 0]
            self.test_interval = [self.backshift_num_records, self.data.shape[0]]

        else:
            # Train and maybe test data assumed:
            if self.test_period is not None:
                self.test_range_delta = datetime.timedelta(**self.test_period)
                self.test_num_records = round(self.test_range_delta.total_seconds() / (60 * self.timeframe))
                self.train_num_records = self.data.shape[0] - self.test_num_records
                break_point = self.train_num_records
                self.train_interval = [0, break_point]
                self.test_interval = [break_point - self.backshift_num_records, self.data.shape[0]]
            else:
                self.test_num_records = 0
                self.train_num_records = self.data.shape[0]
                break_point = self.train_num_records
                self.train_interval = [0, break_point]
                self.test_interval = [0, 0]

        if self.train_num_records > 0:
            try:
                assert self.train_num_records + self.max_gap_num_records >= self.sample_num_records

            except AssertionError:
                self.log.exception(
                    'Train subset should contain at least one sample, ' +
                    'got: train_set size: {} rows, sample_size: {} rows, tolerance: {} rows'.
                    format(self.train_num_records, self.sample_num_records, self.max_gap_num_records)
                )
                raise AssertionError

        if self.test_num_records > 0:
            try:
                assert self.test_num_records + self.max_gap_num_records >= self.sample_num_records

            except AssertionError:
                self.log.exception(
                    'Test subset should contain at least one sample, ' +
                    'got: test_set size: {} rows, sample_size: {} rows, tolerance: {} rows'.
                    format(self.test_num_records, self.sample_num_records, self.max_gap_num_records)
                )
                raise AssertionError

        self.sample_num = 0
        self.is_ready = True

    def read_csv(self, data_filename=None, force_reload=False):
        """
        Populates instance by loading data: CSV file --> pandas dataframe.

        Args:
            data_filename: [opt] csv data filename as string or list of such strings.
            force_reload:  ignore loaded data.
        """
        if self.data is not None and not force_reload:
            data_range = pd.to_datetime(self.data.index)
            self.total_num_records = self.data.shape[0]
            self.data_range_delta = (data_range[-1] - data_range[0]).to_pytimedelta()
            self.log.debug('data has been already loaded. Use `force_reload=True` to reload')
            return
        if data_filename:
            self.filename = data_filename  # override data source if one is given
        if type(self.filename) == str:
            self.filename = [self.filename]

        dataframes = []
        for filename in self.filename:
            try:
                assert filename and os.path.isfile(filename)
                current_dataframe = pd.read_csv(
                    filename,
                    sep=self.sep,
                    header=self.header,
                    index_col=self.index_col,
                    parse_dates=self.parse_dates,
                    names=self.names,
                )

                # Check and remove duplicate datetime indexes:
                duplicates = current_dataframe.index.duplicated(keep='first')
                how_bad = duplicates.sum()
                if how_bad > 0:
                    current_dataframe = current_dataframe[~duplicates]
                    self.log.warning('Found {} duplicated date_time records in <{}>.\
                     Removed all but first occurrences.'.format(how_bad, filename))

                dataframes += [current_dataframe]
                self.log.info('Loaded {} records from <{}>.'.format(dataframes[-1].shape[0], filename))

            except:
                msg = 'Data file <{}> not specified / not found / parser error.'.format(str(filename))
                self.log.error(msg)
                raise FileNotFoundError(msg)

        self.data = pd.concat(dataframes)
        data_range = pd.to_datetime(self.data.index)
        self.total_num_records = self.data.shape[0]
        self.data_range_delta = (data_range[-1] - data_range[0]).to_pytimedelta()

    def describe(self):
        """
        Returns summary dataset statistic as pandas dataframe:

            - records count,
            - data mean,
            - data std dev,
            - min value,
            - 25% percentile,
            - 50% percentile,
            - 75% percentile,
            - max value

        for every data column.
        """
        # Pretty straightforward, using standard pandas utility.
        # The only caveat here is that if actual data has not been loaded yet, need to load, describe and unload again,
        # thus avoiding passing big files to BT server:
        flush_data = False
        try:
            assert not self.data.empty
            pass

        except (AssertionError, AttributeError) as e:
            self.read_csv()
            flush_data = True

        self.data_stat = self.data.describe()
        self.log.info('Data summary:\n{}'.format(self.data_stat.to_string()))

        if flush_data:
            self.data = None
            self.log.info('Flushed data.')

        return self.data_stat

    def to_btfeed(self):
        """
        Performs BTgymData-->bt.feed conversion.

        Returns:
             dict of type: {data_line_name: bt.datafeed instance}.
        """
        def bt_timeframe(minutes):
            timeframe = TimeFrame.Minutes
            if minutes / 1440 == 1:
                timeframe = TimeFrame.Days
            return timeframe
        try:
            assert not self.data.empty
            btfeed = btfeeds.PandasDirectData(
                dataname=self.data,
                timeframe=bt_timeframe(self.timeframe),
                datetime=self.datetime,
                open=self.open,
                high=self.high,
                low=self.low,
                close=self.close,
                volume=self.volume,
                openinterest=self.openinterest
            )
            btfeed.numrecords = self.data.shape[0]
            return {self.data_name: btfeed}

        except (AssertionError, AttributeError) as e:
            msg = 'Instance holds no data. Hint: forgot to call .read_csv()?'
            self.log.error(msg)
            raise AssertionError(msg)

    def sample(self, **kwargs):
        return self._sample(**kwargs)

    def _sample(
            self,
            get_new=True,
            sample_type=0,
            b_alpha=1.0,
            b_beta=1.0,
            force_interval=False,
            interval=None,
            **kwargs
    ):
        """
        Samples continuous subset of data.

        Args:
            get_new (bool):                     sample new (True) or reuse (False) last made sample;
            sample_type (int or bool):          0 (train) or 1 (test) - get sample from train or test data subsets
                                                respectively.
            b_alpha (float):                    beta-distribution sampling alpha > 0, valid for train episodes.
            b_beta (float):                     beta-distribution sampling beta > 0, valid for train episodes.
            force_interval(bool):               use exact sampling interval (should be given)
            interval(iterable of int, len2):    exact interval to sample from when force_interval=True

        Returns:
        if no sample_class_ref param been set:
            BTgymDataset instance with number of records ~ max_episode_len,
            where `~` tolerance is set by `time_gap` param;
        else:
            `sample_class_ref` instance with same as above number of records.

        Note:
                Train sample start position within interval is drawn from beta-distribution
                with default parameters b_alpha=1, b_beta=1, i.e. uniform one.
                Beta-distribution makes skewed sampling possible , e.g.
                to give recent episodes higher probability of being sampled, e.g.:  b_alpha=10, b_beta=0.8.
                Test samples are always uniform one.

        """
        try:
            assert self.is_ready

        except AssertionError:
            msg = 'sampling attempt: data not ready. Hint: forgot to call data.reset()?'
            self.log.error(msg)
            raise RuntimeError(msg)

        try:
            assert sample_type in [0, 1]

        except AssertionError:
            msg = 'sampling attempt: expected sample type be in {}, got: {}'.format([0, 1], sample_type)
            self.log.error(msg)
            raise ValueError(msg)

        if force_interval:
            try:
                assert interval is not None and len(list(interval)) == 2

            except AssertionError:
                msg = 'sampling attempt: got force_interval=True, expected interval=[a,b], got: <{}>'.format(interval)
                self.log.error(msg)
                raise ValueError(msg)

        if self.sample_instance is None or get_new:
            if sample_type == 0:
                # Get beta_distributed sample in train interval:
                if force_interval:
                    sample_interval = interval
                else:
                    sample_interval = self.train_interval

                self.sample_instance = self._sample_interval(
                    sample_interval,
                    force_interval=force_interval,
                    b_alpha=b_alpha,
                    b_beta=b_beta,
                    name='train_' + self.sample_name,
                    **kwargs
                )

            else:
                # Get uniform sample in test interval:
                if force_interval:
                    sample_interval = interval
                else:
                    sample_interval = self.test_interval

                self.sample_instance = self._sample_interval(
                    sample_interval,
                    force_interval=force_interval,
                    b_alpha=1,
                    b_beta=1,
                    name='test_' + self.sample_name,
                    **kwargs
                )
            self.sample_instance.metadata['type'] = sample_type  # TODO: can move inside sample()
            self.sample_instance.metadata['sample_num'] = self.sample_num
            self.sample_instance.metadata['parent_sample_num'] = copy.deepcopy(self.metadata['sample_num'])
            self.sample_instance.metadata['parent_sample_type'] = copy.deepcopy(self.metadata['type'])
            self.sample_num += 1

        else:
            # Do nothing:
            self.log.debug('Reusing sample, id: {}'.format(self.sample_instance.filename))

        return self.sample_instance

    def _sample_random(
            self,
            sample_type=0,
            timestamp=None,
            name='random_sample_',
            interval=None,
            force_interval=False,
            **kwargs
    ):
        """
        Randomly samples continuous subset of data.

        Args:
            name:        str, sample filename id

        Returns:
             BTgymDataset instance with number of records ~ max_episode_len,
             where `~` tolerance is set by `time_gap` param.
        """
        try:
            assert not self.data.empty

        except (AssertionError, AttributeError) as e:
            self.log.exception('Instance holds no data. Hint: forgot to call .read_csv()?')
            raise AssertionError

        if force_interval:
            raise NotImplementedError('Force_interval for random sampling not implemented.')

        self.log.debug('Maximum sample time duration set to: {}.'.format(self.max_sample_len_delta))
        self.log.debug('Respective number of steps: {}.'.format(self.sample_num_records))
        self.log.debug('Maximum allowed data time gap set to: {}.\n'.format(self.max_time_gap))

        sampled_data = None
        sample_len = 0

        # Sanity check param:
        max_attempts = 100
        attempts = 0

        # # Keep sampling random enter points until all conditions are met:
        while attempts <= max_attempts:

            # Randomly sample record (row) from entire datafeed:
            first_row = int((self.data.shape[0] - self.sample_num_records - 1) * random.random())
            sample_first_day = self.data[first_row:first_row + 1].index[0]
            self.log.debug('Sample start: {}, weekday: {}.'.format(sample_first_day, sample_first_day.weekday()))

            # Keep sampling until good day:
            while not sample_first_day.weekday() in self.start_weekdays and attempts <= max_attempts:
                self.log.debug('Not a good day to start, resampling...')
                first_row = int((self.data.shape[0] - self.sample_num_records - 1) * random.random())
                sample_first_day = self.data[first_row:first_row + 1].index[0]
                self.log.debug('Sample start: {}, weekday: {}.'.format(sample_first_day, sample_first_day.weekday()))
                attempts +=1

            # Check if managed to get proper weekday:
            assert attempts <= max_attempts, \
                'Quitting after {} sampling attempts. Hint: check sampling params / dataset consistency.'. \
                format(attempts)

            # If 00 option set, get index of first record of that day:
            if self.start_00:
                adj_timedate = sample_first_day.date()
                self.log.debug('Start time adjusted to <00:00>')

            else:
                adj_timedate = sample_first_day

            first_row = self.data.index.get_loc(adj_timedate, method='nearest')

            # Easy part:
            last_row = first_row + self.sample_num_records  # + 1
            sampled_data = self.data[first_row: last_row]
            sample_len = (sampled_data.index[-1] - sampled_data.index[0]).to_pytimedelta()
            self.log.debug('Actual sample duration: {}.'.format(sample_len, ))
            self.log.debug('Total sample time gap: {}.'.format(self.max_sample_len_delta - sample_len))

            # Perform data gap check:
            if self.max_sample_len_delta - sample_len < self.max_time_gap:
                self.log.debug('Sample accepted.')
                # If sample OK - compose and return sample:
                new_instance = self.nested_class_ref(**self.nested_params)
                new_instance.filename = name + 'n{}_at_{}'.format(self.sample_num, adj_timedate)
                self.log.info('Sample id: <{}>.'.format(new_instance.filename))
                new_instance.data = sampled_data
                new_instance.metadata['type'] = 'random_sample'
                new_instance.metadata['first_row'] = first_row
                new_instance.metadata['last_row'] = last_row

                return new_instance

            else:
                self.log.debug('Duration too big, resampling...\n')
                attempts += 1

        # Got here -> sanity check failed:
        msg = (
            '\nQuitting after {} sampling attempts.\n' +
            'Full sample duration: {}\n' +
            'Total sample time gap: {}\n' +
            'Sample start time: {}\n' +
            'Sample finish time: {}\n' +
            'Hint: check sampling params / dataset consistency.'
        ).format(
            attempts,
            sample_len,
            sample_len - self.max_sample_len_delta,
            sampled_data.index[0],
            sampled_data.index[-1]

        )
        self.log.error(msg)
        raise RuntimeError(msg)

    def _sample_interval(
            self,
            interval,
            b_alpha=1.0,
            b_beta=1.0,
            name='interval_sample_',
            force_interval=False,
            **kwargs
    ):
        """
        Samples continuous subset of data,
        such as entire episode records lie within positions specified by interval.
        Episode start position within interval is drawn from beta-distribution parametrised by `b_alpha, b_beta`.
        By default distribution is uniform one.

        Args:
            interval:       tuple, list or 1d-array of integers of length 2: [lower_row_number, upper_row_number];
            b_alpha:        float > 0, sampling B-distribution alpha param, def=1;
            b_beta:         float > 0, sampling B-distribution beta param, def=1;
            name:           str, sample filename id
            force_interval: bool,  if true: force exact interval sampling


        Returns:
             - BTgymDataset instance such as:
                1. number of records ~ max_episode_len, subj. to `time_gap` param;
                2. actual episode start position is sampled from `interval`;
             - `False` if it is not possible to sample instance with set args.
        """
        try:
            assert not self.data.empty

        except (AssertionError, AttributeError) as e:
            self.log.exception('Instance holds no data. Hint: forgot to call .read_csv()?')
            raise AssertionError

        try:
            assert len(interval) == 2

        except AssertionError:
            self.log.exception(
                'Invalid interval arg: expected list or tuple of size 2, got: {}'.format(interval)
            )
            raise AssertionError

        if force_interval:
            return self._sample_exact_interval(interval, name)

        try:
            assert b_alpha > 0 and b_beta > 0

        except AssertionError:
            self.log.exception(
                'Expected positive B-distribution [alpha, beta] params, got: {}'.format([b_alpha, b_beta])
            )
            raise AssertionError

        if interval[-1] - interval[0] + self.max_gap_num_records > self.sample_num_records:
            sample_num_records = self.sample_num_records
        else:
            sample_num_records = interval[-1] - interval[0]

        self.log.debug('Sample interval: {}'.format(interval))
        self.log.debug('Maximum sample time duration set to: {}.'.format(self.max_sample_len_delta))
        self.log.debug('Sample number of steps (adjusted to interval): {}.'.format(sample_num_records))
        self.log.debug('Maximum allowed data time gap set to: {}.\n'.format(self.max_time_gap))

        sampled_data = None
        sample_len = 0

        # Sanity check param:
        max_attempts = 100
        attempts = 0

        # # Keep sampling random enter points until all conditions are met:
        while attempts <= max_attempts:

            first_row = interval[0] + int(
                (interval[-1] - interval[0] - sample_num_records) * random_beta(a=b_alpha, b=b_beta)
            )

            #print('_sample_interval_sample_num_records: ', sample_num_records)
            #print('_sample_interval_first_row: ', first_row)

            sample_first_day = self.data[first_row:first_row + 1].index[0]
            self.log.debug(
                'Sample start row: {}, day: {}, weekday: {}.'.
                format(first_row, sample_first_day, sample_first_day.weekday())
            )

            # Keep sampling until good day:
            while not sample_first_day.weekday() in self.start_weekdays and attempts <= max_attempts:
                self.log.debug('Not a good day to start, resampling...')
                first_row = interval[0] + round(
                    (interval[-1] - interval[0] - sample_num_records) * random_beta(a=b_alpha, b=b_beta)
                )
                #print('r_sample_interval_sample_num_records: ', sample_num_records)
                #print('r_sample_interval_first_row: ', first_row)
                sample_first_day = self.data[first_row:first_row + 1].index[0]
                self.log.debug(
                    'Sample start row: {}, day: {}, weekday: {}.'.
                    format(first_row, sample_first_day, sample_first_day.weekday())
                )
                attempts += 1

            # Check if managed to get proper weekday:
            try:
                assert attempts <= max_attempts

            except AssertionError:
                self.log.exception(
                    'Quitting after {} sampling attempts. Hint: check sampling params / dataset consistency.'.
                    format(attempts)
                )
                raise RuntimeError

            # If 00 option set, get index of first record of that day:
            if self.start_00:
                adj_timedate = sample_first_day.date()
                self.log.debug('Start time adjusted to <00:00>')
                first_row = self.data.index.get_loc(adj_timedate, method='nearest')

            else:
                adj_timedate = sample_first_day

            # first_row = self.data.index.get_loc(adj_timedate, method='nearest')

            # Easy part:
            last_row = first_row + sample_num_records  # + 1
            sampled_data = self.data[first_row: last_row]

            self.log.debug(
                'first_row: {}, last_row: {}, data_shape: {}'.format(
                    first_row,
                    last_row,
                    sampled_data.shape
                )
            )
            sample_len = (sampled_data.index[-1] - sampled_data.index[0]).to_pytimedelta()
            self.log.debug('Actual sample duration: {}.'.format(sample_len))
            self.log.debug('Total sample time gap: {}.'.format(self.max_sample_len_delta - sample_len))

            # Perform data gap check:
            if self.max_sample_len_delta - sample_len < self.max_time_gap:
                self.log.debug('Sample accepted.')
                # If sample OK - return new dataset:
                new_instance = self.nested_class_ref(**self.nested_params)
                new_instance.filename = name + 'num_{}_at_{}'.format(self.sample_num, adj_timedate)
                self.log.info('New sample id: <{}>.'.format(new_instance.filename))
                new_instance.data = sampled_data
                new_instance.metadata['type'] = 'interval_sample'
                new_instance.metadata['first_row'] = first_row
                new_instance.metadata['last_row'] = last_row

                return new_instance

            else:
                self.log.debug('Attempt {}: gap is too big, resampling, ...\n'.format(attempts))
                attempts += 1

        # Got here -> sanity check failed:
        msg = (
                '\nQuitting after {} sampling attempts.\n' +
                'Full sample duration: {}\n' +
                'Total sample time gap: {}\n' +
                'Sample start time: {}\n' +
                'Sample finish time: {}\n' +
                'Hint: check sampling params / dataset consistency.'
        ).format(
            attempts,
            sample_len,
            sample_len - self.max_sample_len_delta,
            sampled_data.index[0],
            sampled_data.index[-1]

        )
        self.log.error(msg)
        raise RuntimeError(msg)

    def _sample_aligned_interval(
            self,
            interval,
            align_left=False,
            b_alpha=1.0,
            b_beta=1.0,
            name='interval_sample_',
            force_interval=False,
            **kwargs
    ):
        """
        Samples continuous subset of data,
        such as entire episode records lie within positions specified by interval
        Episode start position within interval is drawn from beta-distribution parametrised by `b_alpha, b_beta`.
        By default distribution is uniform one.

        Args:
            interval:       tuple, list or 1d-array of integers of length 2: [lower_row_number, upper_row_number];
            align:          if True - try to align sample to beginning of interval;
            b_alpha:        float > 0, sampling B-distribution alpha param, def=1;
            b_beta:         float > 0, sampling B-distribution beta param, def=1;
            name:           str, sample filename id
            force_interval: bool,  if true: force exact interval sampling

        Returns:
             - BTgymDataset instance such as:
                1. number of records ~ max_episode_len, subj. to `time_gap` param;
                2. actual episode start position is sampled from `interval`;
             - `False` if it is not possible to sample instance with set args.
        """
        try:
            assert not self.data.empty

        except (AssertionError, AttributeError) as e:
            self.log.exception('Instance holds no data. Hint: forgot to call .read_csv()?')
            raise AssertionError

        try:
            assert len(interval) == 2

        except AssertionError:
            self.log.exception(
                'Invalid interval arg: expected list or tuple of size 2, got: {}'.format(interval)
            )
            raise AssertionError

        if force_interval:
            return self._sample_exact_interval(interval, name)

        try:
            assert b_alpha > 0 and b_beta > 0

        except AssertionError:
            self.log.exception(
                'Expected positive B-distribution [alpha, beta] params, got: {}'.format([b_alpha, b_beta])
            )
            raise AssertionError

        sample_num_records = self.sample_num_records

        self.log.debug('Maximum sample time duration set to: {}.'.format(self.max_sample_len_delta))
        self.log.debug('Respective number of steps: {}.'.format(sample_num_records))
        self.log.debug('Maximum allowed data time gap set to: {}.\n'.format(self.max_time_gap))

        # Sanity check param:
        if align_left:
            max_attempts = interval[-1] - interval[0]
        else:
            # Sanity check:
            max_attempts = 100

        attempts = 0
        align_shift = 0

        # Sample enter point as close to beginning  until all conditions are met:
        while attempts <= max_attempts:
            if align_left:
                first_row = interval[0] + align_shift

            else:
                first_row = interval[0] + int(
                    (interval[-1] - interval[0] - sample_num_records) * random_beta(a=b_alpha, b=b_beta)
                )

            #print('_sample_interval_sample_num_records: ', sample_num_records)
            self.log.debug('_sample_interval_first_row: {}'.format(first_row))

            sample_first_day = self.data[first_row:first_row + 1].index[0]
            self.log.debug('Sample start: {}, weekday: {}.'.format(sample_first_day, sample_first_day.weekday()))

            # Keep sampling until good day:
            while not sample_first_day.weekday() in self.start_weekdays and attempts <= max_attempts:
                align_shift += 1

                self.log.debug('Not a good day to start, resampling...')

                if align_left:
                    first_row = interval[0] + align_shift
                else:

                    first_row = interval[0] + int(
                        (interval[-1] - interval[0] - sample_num_records) * random_beta(a=b_alpha, b=b_beta)
                    )
                #print('r_sample_interval_sample_num_records: ', sample_num_records)
                self.log.debug('_sample_interval_first_row: {}'.format(first_row))

                sample_first_day = self.data[first_row:first_row + 1].index[0]

                self.log.debug('Sample start: {}, weekday: {}.'.format(sample_first_day, sample_first_day.weekday()))

                attempts += 1

            # Check if managed to get proper weekday:
            try:
                assert attempts <= max_attempts

            except AssertionError:
                self.log.exception(
                    'Quitting after {} sampling attempts. Hint: check sampling params / dataset consistency.'.
                    format(attempts)
                )
                raise RuntimeError

            # If 00 option set, get index of first record of that day:
            if self.start_00:
                adj_timedate = sample_first_day.date()
                self.log.debug('Start time adjusted to <00:00>')
                first_row = self.data.index.get_loc(adj_timedate, method='nearest')

            else:
                adj_timedate = sample_first_day

            # first_row = self.data.index.get_loc(adj_timedate, method='nearest')

            # Easy part:
            last_row = first_row + sample_num_records  # + 1
            sampled_data = self.data[first_row: last_row]
            sample_len = (sampled_data.index[-1] - sampled_data.index[0]).to_pytimedelta()
            self.log.debug('Actual sample duration: {}.'.format(sample_len))
            self.log.debug('Total sample time gap: {}.'.format(sample_len - self.max_sample_len_delta))

            # Perform data gap check:
            if sample_len - self.max_sample_len_delta < self.max_time_gap:
                self.log.debug('Sample accepted.')
                # If sample OK - return new dataset:
                new_instance = self.nested_class_ref(**self.nested_params)
                new_instance.filename = name + 'num_{}_at_{}'.format(self.sample_num, adj_timedate)
                self.log.info('New sample id: <{}>.'.format(new_instance.filename))
                new_instance.data = sampled_data
                new_instance.metadata['type'] = 'interval_sample'
                new_instance.metadata['first_row'] = first_row
                new_instance.metadata['last_row'] = last_row

                return new_instance

            else:
                self.log.debug('Attempt {}: duration too big, resampling, ...\n'.format(attempts))
                attempts += 1
                align_shift += 1

        # Got here -> sanity check failed:
        msg = ('Quitting after {} sampling attempts.' +
               'Hint: check sampling params / dataset consistency.').format(attempts)
        self.log.error(msg)
        raise RuntimeError(msg)

    def _sample_exact_interval(self, interval, name='interval_sample_', **kwargs):
        """
        Samples exactly defined interval.

        Args:
            interval:   tuple, list or 1d-array of integers of length 2: [lower_row_number, upper_row_number];
            name:       str, sample filename id

        Returns:
             BTgymDataset instance.

        """
        try:
            assert not self.data.empty

        except (AssertionError, AttributeError) as e:
            self.log.exception('Instance holds no data. Hint: forgot to call .read_csv()?')
            raise AssertionError

        try:
            assert len(interval) == 2

        except AssertionError:
            self.log.exception(
                'Invalid interval arg: expected list or tuple of size 2, got: {}'.format(interval)
            )
            raise AssertionError

        first_row = interval[0]
        last_row = interval[-1]
        sampled_data = self.data[first_row: last_row]

        sample_first_day = self.data[first_row:first_row + 1].index[0]

        new_instance = self.nested_class_ref(**self.nested_params)
        new_instance.filename = name + 'num_{}_at_{}'.format(self.sample_num, sample_first_day)
        self.log.info('New sample id: <{}>.'.format(new_instance.filename))
        new_instance.data = sampled_data
        new_instance.metadata['type'] = 'interval_sample'
        new_instance.metadata['first_row'] = first_row
        new_instance.metadata['last_row'] = last_row

        return new_instance
Example #14
0
class AMLDG():
    """
    Asynchronous implementation of MLDG algorithm (by Da Li et al.)
    for one-shot adaptation in dynamically changing environments.

    Papers:
        Da Li et al.,
         "Learning to Generalize: Meta-Learning for Domain Generalization"
         https://arxiv.org/abs/1710.03463

        Maruan Al-Shedivat et al.,
        "Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments"
        https://arxiv.org/abs/1710.03641

    """
    def __init__(
            self,
            env,
            task,
            log_level,
            aac_class_ref=SubAAC,
            runner_config=None,
            aac_lambda=1.0,
            guided_lambda=1.0,
            rollout_length=20,
            trial_source_target_cycle=(1, 0),
            num_episodes_per_trial=1,  # one-shot adaptation
            _aux_render_modes=('action_prob', 'value_fn', 'lstm_1_h',
                               'lstm_2_h'),
            name='AMLDG',
            **kwargs):
        try:
            self.aac_class_ref = aac_class_ref
            self.task = task
            self.name = name
            self.summary_writer = None

            StreamHandler(sys.stdout).push_application()
            self.log = Logger('{}_{}'.format(name, task), level=log_level)
            self.rollout_length = rollout_length

            if runner_config is None:
                self.runner_config = {
                    'class_ref': BaseSynchroRunner,
                    'kwargs': {},
                }
            else:
                self.runner_config = runner_config

            self.env_list = env

            assert isinstance(self.env_list, list) and len(self.env_list) == 2, \
                'Expected pair of environments, got: {}'.format(self.env_list)

            # Instantiate two sub-trainers: one for meta-test and one for meta-train environments:

            self.runner_config['kwargs']['data_sample_config'] = {
                'mode': 1
            }  # master
            self.runner_config['kwargs']['name'] = 'master'

            self.train_aac = aac_class_ref(
                env=self.env_list[0],  # train data will be master environment
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                rollout_length=self.rollout_length,
                trial_source_target_cycle=trial_source_target_cycle,
                num_episodes_per_trial=num_episodes_per_trial,
                _use_target_policy=False,
                _use_global_network=True,
                _aux_render_modes=_aux_render_modes,
                name=self.name + '/metaTrain',
                **kwargs)

            self.runner_config['kwargs']['data_sample_config'] = {
                'mode': 0
            }  # slave
            self.runner_config['kwargs']['name'] = 'slave'

            self.test_aac = aac_class_ref(
                env=self.env_list[-1],  # test data -> slave env.
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                rollout_length=self.rollout_length,
                trial_source_target_cycle=trial_source_target_cycle,
                num_episodes_per_trial=num_episodes_per_trial,
                _use_target_policy=False,
                _use_global_network=False,
                global_step_op=self.train_aac.global_step,
                global_episode_op=self.train_aac.global_episode,
                inc_episode_op=self.train_aac.inc_episode,
                _aux_render_modes=_aux_render_modes,
                name=self.name + '/metaTest',
                **kwargs)

            self.local_steps = self.train_aac.local_steps
            self.model_summary_freq = self.train_aac.model_summary_freq

            self._make_train_op()

            self.test_aac.model_summary_op = tf.summary.merge(
                [
                    self.test_aac.model_summary_op,
                    self._combine_meta_summaries()
                ],
                name='meta_model_summary')

        except:
            msg = 'AMLDG.__init()__ exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def _make_train_op(self):
        """
        Defines tensors holding training op graph for meta-train, meta-test and meta-optimisation.
        """
        # Handy aliases:
        pi = self.train_aac.local_network  # local meta-train policy
        pi_prime = self.test_aac.local_network  # local meta-test policy
        pi_global = self.train_aac.network  # global shared policy

        self.test_aac.sync = self.test_aac.sync_pi = tf.group(
            *[v1.assign(v2) for v1, v2 in zip(pi_prime.var_list, pi.var_list)])

        # Shared counters:
        self.global_step = self.train_aac.global_step
        self.global_episode = self.train_aac.global_episode

        self.test_aac.global_step = self.train_aac.global_step
        self.test_aac.global_episode = self.train_aac.global_episode
        self.test_aac.inc_episode = self.train_aac.inc_episode
        self.train_aac.inc_episode = None
        self.inc_step = self.train_aac.inc_step

        # Meta-opt. loss:
        self.loss = self.train_aac.loss + self.test_aac.loss

        # Clipped gradients:
        self.train_aac.grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.train_aac.loss, pi.var_list), 40.0)
        self.test_aac.grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.test_aac.loss, pi_prime.var_list), 40.0)
        # Aliases:
        pi.grads = self.train_aac.grads
        pi_prime.grads = self.test_aac.grads

        # # Learned meta opt. scaling (equivalent to learned meta-update step-size),
        # # conditioned on test input:
        # self.meta_grads_scale = tf.reduce_mean(pi_prime.meta_grads_scale)
        # meta_grads_scale_var_list = [var for var in pi_prime.var_list if 'meta_grads_scale' in var.name]
        # meta_grads_scale_var_list_global = [
        #     var for var in pi_global.var_list if 'meta_grads_scale' in var.name
        # ]

        # self.log.warning('meta_grads_scale_var_list: {}'.format(meta_grads_scale_var_list))

        # Meta_optimisation gradients as sum of meta-train and meta-test gradients:
        self.grads = []
        for g1, g2 in zip(pi.grads, pi_prime.grads):
            if g1 is not None and g2 is not None:
                meta_g = g1 + g2
                # meta_g = (1 - self.meta_grads_scale) * g1 + self.meta_grads_scale * g2
            else:
                meta_g = None  # need to map correctly to vars

            self.grads.append(meta_g)

        # # Second order grads for learned grad. scaling param:
        # meta_grads_scale_grads, _ = tf.clip_by_global_norm(
        #     tf.gradients([g for g in self.grads if g is not None], meta_grads_scale_var_list),
        #     40.0
        # )
        # # Second order grads wrt global variables:
        # meta_grads_scale_grads_and_vars = list(zip(meta_grads_scale_grads, meta_grads_scale_var_list_global))

        # self.log.warning('meta_grads_scale_grads:\n{}'.format(meta_grads_scale_grads))
        # self.log.warning('meta_grads_scale_grads_and_vars:\n{}'.format(meta_grads_scale_grads_and_vars))

        #self.log.warning('self.grads_len: {}'.format(len(list(self.grads))))

        # Gradients to update local meta-test policy (from train data):
        train_grads_and_vars = list(zip(pi.grads, pi_prime.var_list))

        # self.log.warning('train_grads_and_vars_len: {}'.format(len(train_grads_and_vars)))

        # Meta-gradients to be sent to parameter server:
        meta_grads_and_vars = list(
            zip(self.grads,
                pi_global.var_list))  #+ meta_grads_scale_grads_and_vars

        # Remove empty entries:
        meta_grads_and_vars = [(g, v) for (g, v) in meta_grads_and_vars
                               if g is not None]

        # for item in meta_grads_and_vars:
        #     self.log.warning('\nmeta_g_v: {}'.format(item))

        # Set global_step increment equal to observation space batch size:
        obs_space_keys = list(self.train_aac.local_network.on_state_in.keys())
        assert 'external' in obs_space_keys, \
            'Expected observation space to contain `external` mode, got: {}'.format(obs_space_keys)
        self.train_aac.inc_step = self.train_aac.global_step.assign_add(
            tf.shape(self.train_aac.local_network.on_state_in['external'])[0])

        # Pi to pi_prime local adaptation op:
        self.train_op = self.train_aac.optimizer.apply_gradients(
            train_grads_and_vars)

        # Optimizer for meta-update, sharing same learn rate (change?):
        self.optimizer = tf.train.AdamOptimizer(
            self.train_aac.train_learn_rate, epsilon=1e-5)

        # Global meta-optimisation op:
        self.meta_train_op = self.optimizer.apply_gradients(
            meta_grads_and_vars)

        self.log.debug('meta_train_op defined')

    def _combine_meta_summaries(self):
        """
        Additional summaries here.
        """
        meta_model_summaries = [
            tf.summary.scalar('meta_grad_global_norm',
                              tf.global_norm(self.grads)),
            tf.summary.scalar('total_meta_loss', self.loss),
            # tf.summary.scalar('meta_grad_scale', self.meta_grads_scale)
        ]
        return meta_model_summaries

    def start(self, sess, summary_writer, **kwargs):
        """
        Executes all initializing operations,
        starts environment runner[s].
        Supposed to be called by parent worker just before training loop starts.

        Args:
            sess:           tf session object.
            kwargs:         not used by default.
        """
        try:
            # Copy weights from global to local:
            sess.run(self.train_aac.sync_pi)
            sess.run(self.test_aac.sync_pi)

            # Start thread_runners:
            self.train_aac._start_runners(  # master first
                sess,
                summary_writer,
                init_context=None,
                data_sample_config=self.train_aac.get_sample_config(mode=1))
            self.test_aac._start_runners(
                sess,
                summary_writer,
                init_context=None,
                data_sample_config=self.test_aac.get_sample_config(mode=0))

            self.summary_writer = summary_writer
            self.log.notice('Runners started.')

        except:
            msg = 'start() exception occurred' + \
                '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def process(self, sess):
        """
        Meta-train/test procedure for one-shot learning.
        Single call runs single meta-test episode.

        Args:
            sess (tensorflow.Session):   tf session obj.

        """
        try:
            # Copy from parameter server:
            sess.run(self.train_aac.sync_pi)
            sess.run(self.test_aac.sync_pi)
            # self.log.warning('Init Sync ok.')

            # Get data configuration,
            # (want both data streams come from  same trial,
            # and trial type we got can be either from source or target domain);
            # note: data_config counters get updated once per .process() call
            train_data_config = self.train_aac.get_sample_config(
                mode=1)  # master env., samples trial
            test_data_config = self.train_aac.get_sample_config(
                mode=0)  # slave env, catches up with same trial

            # self.log.warning('train_data_config: {}'.format(train_data_config))
            # self.log.warning('test_data_config: {}'.format(test_data_config))

            # If this step data comes from source or target domain
            # (i.e. is it either meta-optimised or true test episode):
            is_target = train_data_config['trial_config']['sample_type']
            done = False

            # Collect initial meta-train trajectory rollout:
            train_data = self.train_aac.get_data(
                data_sample_config=train_data_config, force_new_episode=True)
            feed_dict = self.train_aac.process_data(sess,
                                                    train_data,
                                                    is_train=True)

            # self.log.warning('Init Train data ok.')

            # Disable possibility of master data runner acquiring new trials,
            # in case meta-train episode termintaes earlier than meta-test -
            # we than need to get additional meta-train trajectories from exactly same distribution (trial):
            train_data_config['trial_config']['get_new'] = 0

            roll_num = 0

            # Collect entire meta-test episode rollout by rollout:
            while not done:
                # self.log.warning('Roll #{}'.format(roll_num))

                wirte_model_summary = \
                    self.local_steps % self.model_summary_freq == 0

                # self.log.warning(
                #     'Train data trial_num: {}'.format(
                #         np.asarray(train_data['on_policy'][0]['state']['metadata']['trial_num'])
                #     )
                # )

                # Paranoid checks against data sampling logic faults to prevent possible cheating:
                train_trial_chksum = np.average(
                    train_data['on_policy'][0]['state']['metadata']
                    ['trial_num'])

                # Update pi_prime parameters wrt collected train data:
                if wirte_model_summary:
                    fetches = [self.train_op, self.train_aac.model_summary_op]
                else:
                    fetches = [self.train_op]

                fetched = sess.run(fetches, feed_dict=feed_dict)

                # self.log.warning('Train gradients ok.')

                # Collect test rollout using updated pi_prime policy:
                test_data = self.test_aac.get_data(
                    data_sample_config=test_data_config)

                # If meta-test episode has just ended?
                done = np.asarray(test_data['terminal']).any()

                # self.log.warning(
                #     'Test data trial_num: {}'.format(
                #         np.asarray(test_data['on_policy'][0]['state']['metadata']['trial_num'])
                #     )
                # )

                test_trial_chksum = np.average(
                    test_data['on_policy'][0]['state']['metadata']
                    ['trial_num'])

                # Ensure slave runner data consistency, can correct if episode just started:
                if roll_num == 0 and train_trial_chksum != test_trial_chksum:
                    test_data = self.test_aac.get_data(
                        data_sample_config=test_data_config,
                        force_new_episode=True)
                    done = np.asarray(test_data['terminal']).any()
                    faulty_chksum = test_trial_chksum
                    test_trial_chksum = np.average(
                        test_data['on_policy'][0]['state']['metadata']
                        ['trial_num'])

                    self.log.warning('Test trial corrected: {} -> {}'.format(
                        faulty_chksum, test_trial_chksum))

                # self.log.warning(
                #     'roll # {}: train_trial_chksum: {}, test_trial_chksum: {}'.
                #         format(roll_num, train_trial_chksum, test_trial_chksum)
                # )

                if train_trial_chksum != test_trial_chksum:
                    # Still got error? - highly probable algorithm logic fault. Issue warning.
                    msg = 'Train/test trials mismatch found!\nGot train trials: {},\nTest trials: {}'. \
                        format(
                        train_data['on_policy'][0]['state']['metadata']['trial_num'][0],
                        test_data['on_policy'][0]['state']['metadata']['trial_num'][0]
                        )
                    msg2 = 'Train data config: {}\n Test data config: {}'.format(
                        train_data_config, test_data_config)

                    self.log.warning(msg)
                    self.log.warning(msg2)

                # Check episode type for consistency; if failed - another data sampling logic fault, warn:
                try:
                    assert (np.asarray(test_data['on_policy'][0]['state']
                                       ['metadata']['type']) == 1).any()
                    assert (np.asarray(train_data['on_policy'][0]['state']
                                       ['metadata']['type']) == 0).any()
                except AssertionError:
                    msg = 'Train/test episodes types mismatch found!\nGot train ep. type: {},\nTest ep.type: {}'. \
                        format(
                        train_data['on_policy'][0]['state']['metadata']['type'],
                        test_data['on_policy'][0]['state']['metadata']['type']
                    )
                    self.log.warning(msg)

                # self.log.warning('Test data ok.')

                if not is_target:
                    # Process test data and perform meta-optimisation step:
                    feed_dict.update(
                        self.test_aac.process_data(sess,
                                                   test_data,
                                                   is_train=True))

                    if wirte_model_summary:
                        meta_fetches = [
                            self.meta_train_op, self.test_aac.model_summary_op,
                            self.inc_step
                        ]
                    else:
                        meta_fetches = [self.meta_train_op, self.inc_step]

                    meta_fetched = sess.run(meta_fetches, feed_dict=feed_dict)

                    # self.log.warning('Meta-gradients ok.')
                else:
                    # True test, no updates sent to parameter server:
                    meta_fetched = [None, None]

                    # self.log.warning('Meta-opt. rollout ok.')

                if wirte_model_summary:
                    meta_model_summary = meta_fetched[-2]
                    model_summary = fetched[-1]

                else:
                    meta_model_summary = None
                    model_summary = None

                # Next step housekeeping:
                # copy from parameter server:
                sess.run(self.train_aac.sync_pi)
                sess.run(self.test_aac.sync_pi)
                # self.log.warning('Sync ok.')

                # Collect next train trajectory rollout:
                train_data = self.train_aac.get_data(
                    data_sample_config=train_data_config)
                feed_dict = self.train_aac.process_data(sess,
                                                        train_data,
                                                        is_train=True)
                # self.log.warning('Train data ok.')

                # Write down summaries:
                self.test_aac.process_summary(sess, test_data,
                                              meta_model_summary)
                self.train_aac.process_summary(sess, train_data, model_summary)
                self.local_steps += 1
                roll_num += 1
        except:
            msg = 'process() exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)
Example #15
0
class Paxos(object):
    def __init__(self, transport, on_learn, on_prepare=None,
            on_stale=None, quorum_timeout=3,
            logger_group=None,
        ):
        self._logger = Logger('paxos')
        if logger_group is not None:
            logger_group.add_logger(self._logger)

        self.transport = transport
        self.on_learn = on_learn
        self.on_prepare = on_prepare
        self.on_stale = on_stale
        self.quorum_timeout = quorum_timeout

        self.id = 0
        self.max_seen_id = 0
        self.last_accepted_id = 0
        self._logger.debug('2 last_accepted_id=%(last_accepted_id)s' % self.__dict__)

        self.proposed_value = None
        self.deferred = None
        self.queue = deque() # queue of (value, deferred) to propose
        self._learn_queue = [] # sorted list with learn requests which come out of order

        # delayed calls for timeouts
        self._accepted_timeout = None
        self._acks_timeout = None
        self._waiting_to_learn_id = deque()

    def recv(self, message, client):
        message = shlex.split(message)
        command = getattr(self, message[0])
        command(client=client, *message[1:])

    def propose(self, value):
        deferred = Deferred()
        if self.proposed_value is None:
            self._start_paxos(value, deferred)
        else:
            self.queue.append((value, deferred))
            self._logger.debug('Request for %s was queued, queue size is %s (because we are proposing %s now)' % (
                    value,
                    len(self.queue),
                    self.proposed_value,
                )
            )
        return deferred

    def _start_paxos(self, value, deferred):
        """Starts paxos iteration proposing given value."""
        self.id = self.max_seen_id + 1
        self.proposed_value = value
        self.deferred = deferred

        self._num_acks_to_wait = self.transport.quorum_size

        def _timeout_callback():
            self._logger.info('+++ prepare timeout')
            # TODO sometimes self.deferred is None when this callbach is called
            self.deferred.errback(PrepareTimeout())
            self.deferred = None
            self.proposed_value = None

        self._acks_timeout = reactor.callLater(self.quorum_timeout, _timeout_callback)
        self.transport.broadcast('paxos_prepare %s %s' % (self.id, self.last_accepted_id))

    def paxos_prepare(self, num, last_accepted_id, client):
        num = int(num)
        last_accepted_id = int(last_accepted_id)

        if last_accepted_id > self.last_accepted_id:
            # Move to the "stale" state
            self._logger.debug('stale last_accepted_id(%s) < self.last_accepted_id(%s)' % (
                last_accepted_id, self.last_accepted_id
            ))
            self.on_stale(last_accepted_id)
        else:
            if num > self.max_seen_id:
                if self.on_prepare is not None:
                    self.on_prepare(num, client)

                self.max_seen_id = num
                self._send_to(client, 'paxos_ack %s' % num)

    def paxos_ack(self, num, client):
        num = int(num)
        if self.proposed_value is not None and num == self.id:
            self._num_acks_to_wait -= 1
            if self._num_acks_to_wait == 0:
                _stop_waiting(self._acks_timeout)

                self._num_accepts_to_wait = self.transport.quorum_size

                def _timeout_callback():
                    self._logger.info('+++ accept timeout')
                    self.deferred.errback(AcceptTimeout())
                    self.deferred = None
                    self.proposed_value = None

                self._accepted_timeout = reactor.callLater(
                    self.quorum_timeout,
                    _timeout_callback
                )
                self.transport.broadcast('paxos_accept %s "%s"' % (self.id, escape(self.proposed_value)))

    def paxos_accept(self, num, value, client):
        num = int(num)
        if num == self.max_seen_id:
            if self.id == num:
                # we have a deferred to return result in this round
                self._waiting_to_learn_id.append((num, self.deferred))
            else:
                # may be we have deferred but it is for another Paxos round
                self._waiting_to_learn_id.append((num, None))
            self._send_to(client, 'paxos_accepted %s' % num)

    def paxos_accepted(self, num, client):
        num = int(num)
        if self.proposed_value is not None and num == self.id:
            self._num_accepts_to_wait -= 1
            if self._num_accepts_to_wait == 0:
                _stop_waiting(self._accepted_timeout)
                self.transport.broadcast('paxos_learn %s "%s"' % (self.id, escape(self.proposed_value)))

    def paxos_learn(self, num, value, client):
        self._logger.info('paxos.learn %s' % value)

        num = int(num)
        if self._waiting_to_learn_id and num == self._waiting_to_learn_id[0][0]:
            num, deferred = self._waiting_to_learn_id.popleft()

            try:
                result = self.on_learn(num, value, client)
            except Exception, e:
                self._logger.exception('paxos.learn %s' % value)
                result = e

            self.last_accepted_id = num
            self._logger.debug('1 last_accepted_id=%(last_accepted_id)s' % self.__dict__)

            if deferred is not None and value == self.proposed_value:
                # this works for current round coordinator only
                # because it must return result to the client
                # and to start a new round for next request

                if isinstance(result, Exception):
                    self._logger.warning('returning error from paxos.learn %s, %s' % (value, result))
                    deferred.errback(result)
                else:
                    self._logger.warning('returning success from paxos.learn %s' % value)
                    deferred.callback(result)

                self._logger.debug('queue size: %s' % len(self.queue))
                if self.queue:
                    # start new Paxos instance
                    # for next value from the queue
                    next_value, deferred = self.queue.pop()
                    self._logger.debug('next value from the queue: %s' % next_value)
                    self._start_paxos(next_value, deferred)
                else:
                    self.proposed_value = None
                    self.deferred = None

            if self._learn_queue:
                self._logger.debug('relearning remembered values')
                # clear queue because it will be filled again if needed
                queue, self._learn_queue = self._learn_queue, []
                for args in queue:
                    self.paxos_learn(*args)

        else:
Example #16
0
class MetaAAC_1_0():
    """
    Meta-trainer class.
    INITIAL: Implementation of MLDG algorithm tuned
    for adaptation in dynamically changing environments

    Papers:
        Da Li et al.,
         "Learning to Generalize: Meta-Learning for Domain Generalization"
         https://arxiv.org/abs/1710.03463

        Maruan Al-Shedivat et al.,
        "Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments"
        https://arxiv.org/abs/1710.03641



    """
    def __init__(self,
                 env,
                 task,
                 log_level,
                 aac_class_ref=SubAAC,
                 runner_config=None,
                 aac_lambda=1.0,
                 guided_lambda=1.0,
                 trial_source_target_cycle=(1, 0),
                 num_episodes_per_trial=1,
                 _aux_render_modes=('action_prob', 'value_fn', 'lstm_1_h',
                                    'lstm_2_h'),
                 name='MetaAAC',
                 **kwargs):
        try:
            self.aac_class_ref = aac_class_ref
            self.task = task
            self.name = name
            StreamHandler(sys.stdout).push_application()
            self.log = Logger('{}_{}'.format(name, task), level=log_level)

            # with tf.variable_scope(self.name):
            if runner_config is None:
                self.runner_config = {
                    'class_ref': BaseSynchroRunner,
                    'kwargs': {},
                }
            else:
                self.runner_config = runner_config

            self.env_list = env

            assert isinstance(self.env_list, list) and len(self.env_list) == 2, \
                'Expected pair of environments, got: {}'.format(self.env_list)

            # Instantiate to sub-trainers: one for test and one for train environments:

            self.runner_config['kwargs']['data_sample_config'] = {
                'mode': 0
            }  # salve
            self.runner_config['kwargs']['name'] = 'slave'

            self.train_aac = aac_class_ref(
                env=self.env_list[-1],  # train data will be salve environment
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                trial_source_target_cycle=trial_source_target_cycle,
                num_episodes_per_trial=num_episodes_per_trial,
                _use_target_policy=False,
                _use_global_network=True,
                _aux_render_modes=_aux_render_modes,
                name=self.name + '_sub_Train',
                **kwargs)

            self.runner_config['kwargs']['data_sample_config'] = {
                'mode': 1
            }  # master
            self.runner_config['kwargs']['name'] = 'master'

            self.test_aac = aac_class_ref(
                env=self.env_list[0],  # test data - master env.
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                trial_source_target_cycle=trial_source_target_cycle,
                num_episodes_per_trial=num_episodes_per_trial,
                _use_target_policy=False,
                _use_global_network=False,
                global_step_op=self.train_aac.global_step,
                global_episode_op=self.train_aac.global_episode,
                inc_episode_op=self.train_aac.inc_episode,
                _aux_render_modes=_aux_render_modes,
                name=self.name + '_sub_Test',
                **kwargs)

            self.local_steps = self.train_aac.local_steps
            self.model_summary_freq = self.train_aac.model_summary_freq
            #self.model_summary_op = self.train_aac.model_summary_op

            self._make_train_op()
            self.test_aac.model_summary_op = tf.summary.merge(
                [
                    self.test_aac.model_summary_op,
                    self._combine_meta_summaries()
                ],
                name='meta_model_summary')

        except:
            msg = 'MetaAAC_0_1.__init()__ exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def _make_train_op(self):
        """

        Defines:
            tensors holding training op graph for sub trainers and self;
        """
        pi = self.train_aac.local_network
        pi_prime = self.test_aac.local_network

        self.test_aac.sync = self.test_aac.sync_pi = tf.group(
            *[v1.assign(v2) for v1, v2 in zip(pi_prime.var_list, pi.var_list)])

        self.global_step = self.train_aac.global_step
        self.global_episode = self.train_aac.global_episode

        self.test_aac.global_step = self.train_aac.global_step
        self.test_aac.global_episode = self.train_aac.global_episode
        self.test_aac.inc_episode = self.train_aac.inc_episode
        self.train_aac.inc_episode = None
        self.inc_step = self.train_aac.inc_step

        # Meta-loss:
        self.loss = 0.5 * self.train_aac.loss + 0.5 * self.test_aac.loss

        # Clipped gradients:
        self.train_aac.grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.train_aac.loss, pi.var_list), 40.0)
        self.log.warning('self.train_aac.grads: {}'.format(
            len(list(self.train_aac.grads))))

        # self.test_aac.grads, _ = tf.clip_by_global_norm(
        #     tf.gradients(self.test_aac.loss, pi_prime.var_list),
        #     40.0
        # )
        # Meta-gradient:
        grads_i, _ = tf.clip_by_global_norm(
            tf.gradients(self.train_aac.loss, pi.var_list), 40.0)

        grads_i_next, _ = tf.clip_by_global_norm(
            tf.gradients(self.test_aac.loss, pi_prime.var_list), 40.0)

        self.grads = []
        for g1, g2 in zip(grads_i, grads_i_next):
            if g1 is not None and g2 is not None:
                meta_g = 0.5 * g1 + 0.5 * g2
            else:
                meta_g = None

            self.grads.append(meta_g)

        #self.log.warning('self.grads_len: {}'.format(len(list(self.grads))))

        # Gradients to update local copy of pi_prime (from train data):
        train_grads_and_vars = list(
            zip(self.train_aac.grads, pi_prime.var_list))

        self.log.warning('train_grads_and_vars_len: {}'.format(
            len(train_grads_and_vars)))

        # Meta-gradients to be sent to parameter server:
        meta_grads_and_vars = list(
            zip(self.grads, self.train_aac.network.var_list))

        self.log.warning('meta_grads_and_vars_len: {}'.format(
            len(meta_grads_and_vars)))

        # Set global_step increment equal to observation space batch size:
        obs_space_keys = list(self.train_aac.local_network.on_state_in.keys())

        assert 'external' in obs_space_keys, \
            'Expected observation space to contain `external` mode, got: {}'.format(obs_space_keys)
        self.train_aac.inc_step = self.train_aac.global_step.assign_add(
            tf.shape(self.train_aac.local_network.on_state_in['external'])[0])

        self.train_op = self.train_aac.optimizer.apply_gradients(
            train_grads_and_vars)

        # Optimizer for meta-update:
        self.optimizer = tf.train.AdamOptimizer(
            self.train_aac.train_learn_rate, epsilon=1e-5)
        # TODO: own alpha-leran rate
        self.meta_train_op = self.optimizer.apply_gradients(
            meta_grads_and_vars)

        self.log.debug('meta_train_op defined')

    def _combine_meta_summaries(self):

        meta_model_summaries = [
            tf.summary.scalar("meta_grad_global_norm",
                              tf.global_norm(self.grads)),
            tf.summary.scalar("total_meta_loss", self.loss),
        ]

        return meta_model_summaries

    def start(self, sess, summary_writer, **kwargs):
        """
        Executes all initializing operations,
        starts environment runner[s].
        Supposed to be called by parent worker just before training loop starts.

        Args:
            sess:           tf session object.
            kwargs:         not used by default.
        """
        try:
            # Copy weights from global to local:
            sess.run(self.train_aac.sync_pi)
            sess.run(self.test_aac.sync_pi)

            # Start thread_runners:
            self.test_aac._start_runners(sess, summary_writer)  # master first
            self.train_aac._start_runners(sess, summary_writer)

            self.summary_writer = summary_writer
            self.log.notice('Runners started.')

        except:
            msg = 'start() exception occurred' + \
                '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def process(self, sess):
        """
        Meta-train step.

        Args:
            sess (tensorflow.Session):   tf session obj.

        """
        try:
            # Say `No` to redundant summaries:
            wirte_model_summary = \
                self.local_steps % self.model_summary_freq == 0

            # Copy from parameter server:
            sess.run(self.train_aac.sync_pi)
            sess.run(self.test_aac.sync_pi)

            #self.log.warning('Sync ok.')

            # Collect train trajectory:
            train_data = self.train_aac.get_data()
            feed_dict = self.train_aac.process_data(sess,
                                                    train_data,
                                                    is_train=True)

            #self.log.warning('Train data ok.')

            # Update pi_prime parameters wrt collected data:
            if wirte_model_summary:
                fetches = [self.train_op, self.train_aac.model_summary_op]
            else:
                fetches = [self.train_op]

            fetched = sess.run(fetches, feed_dict=feed_dict)

            #self.log.warning('Train gradients ok.')

            # Collect test trajectory wrt updated pi_prime parameters:
            test_data = self.test_aac.get_data()
            feed_dict.update(
                self.test_aac.process_data(sess, test_data, is_train=True))

            #self.log.warning('Test data ok.')

            # Perform meta-update:
            if wirte_model_summary:
                meta_fetches = [
                    self.meta_train_op, self.test_aac.model_summary_op,
                    self.inc_step
                ]
            else:
                meta_fetches = [self.meta_train_op, self.inc_step]

            meta_fetched = sess.run(meta_fetches, feed_dict=feed_dict)

            #self.log.warning('Meta-gradients ok.')

            if wirte_model_summary:
                meta_model_summary = meta_fetched[-2]
                model_summary = fetched[-1]

            else:
                meta_model_summary = None
                model_summary = None

            # Write down summaries:
            self.test_aac.process_summary(sess, test_data, meta_model_summary)
            self.train_aac.process_summary(sess, train_data, model_summary)
            self.local_steps += 1

            # TODO: ...what about sampling control?

        except:
            msg = 'process() exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)
Example #17
0
class Genome:
    def __init__(self, genome, assembly_summary=None):
        """
        :param genome: Path to genome
        :returns: Path to genome and name of the genome
        :rtype:
        """
        self.path = os.path.abspath(genome)
        self.species_dir, self.fasta = os.path.split(self.path)
        self.name = os.path.splitext(self.fasta)[0]
        self.log = Logger(self.name)
        self.qc_dir = os.path.join(self.species_dir, "qc")
        self.msh = os.path.join(self.qc_dir, self.name + ".msh")
        self.stats_path = os.path.join(self.qc_dir, self.name + '.csv')
        if os.path.isfile(self.stats_path):
            self.stats = pd.read_csv(self.stats_path, index_col=0)
        self.assembly_summary = assembly_summary
        self.metadata = defaultdict(lambda: 'missing')
        self.xml = defaultdict(lambda: 'missing')
        try:
            self.accession_id = re.match('GCA_.*\.\d', self.name).group()
            self.metadata["accession"] = self.accession_id
        except AttributeError:
            # Raise custom exception
            self.accession_id = "missing"
            self.log.error("Invalid accession ID")
            self.log.exception()
        if isinstance(self.assembly_summary, pd.DataFrame):
            try:
                biosample = assembly_summary.loc[self.accession_id].biosample
                self.metadata["biosample_id"] = biosample
            except (AttributeError, KeyError):
                self.log.info("Unable to get biosample ID")
        self.log.info("Instantiated")

    def get_contigs(self):
        """
        Return a list of of Bio.Seq.Seq objects for fasta and calculate
        the total the number of contigs.
        """
        try:
            self.contigs = [seq.seq for seq in SeqIO.parse(self.path, "fasta")]
            self.count_contigs = len(self.contigs)
            self.log.info("Contigs: {}".format(self.count_contigs))
        except UnicodeDecodeError:
            self.log.exception()

    def get_assembly_size(self):
        """Calculate the sum of all contig lengths"""
        # TODO: map or reduce might be more elegant here
        self.assembly_size = sum((len(str(seq)) for seq in self.contigs))
        self.log.info("Assembly Size: {}".format(self.assembly_size))

    def get_unknowns(self):
        """Count the number of unknown bases, i.e. not [ATCG]"""
        # TODO: Would it be useful to allow the user to define p?
        p = re.compile("[^ATCG]")
        self.unknowns = sum(
            (len(re.findall(p, str(seq))) for seq in self.contigs))
        self.log.info("Unknowns: {}".format(self.unknowns))

    def get_distance(self, dmx_mean):
        self.distance = dmx_mean.loc[self.name]
        self.log.info("Distance: {}".format(self.distance))

    def sketch(self):
        cmd = "mash sketch '{}' -o '{}'".format(self.path, self.msh)
        if os.path.isfile(self.msh):
            self.log.info("Sketch file already exists")
        else:
            subprocess.Popen(cmd, shell="True",
                             stderr=subprocess.DEVNULL).wait()
            self.log.info("Sketch file created")

    def get_stats(self, dmx_mean):
        if not os.path.isfile(self.stats_path):
            self.get_contigs()
            self.get_assembly_size()
            self.get_unknowns()
            self.get_distance(dmx_mean)
            data = {
                "contigs": self.count_contigs,
                "assembly_size": self.assembly_size,
                "unknowns": self.unknowns,
                "distance": self.distance
            }
            self.stats = pd.DataFrame(data, index=[self.name])
            self.stats.to_csv(self.stats_path)
            self.log.info("Generated stats and wrote to disk")

    one_minute = 60000

    # Retry 3 times over a period of 3 minutes max,
    # waiting five seconds in between retries
    @retry(stop_max_attempt_number=3, stop_max_delay=10000, wait_fixed=100)
    def efetch(self, db):
        """
        Use NCBI's efetch tools to get xml for genome's biosample id or SRA id
        """
        if db == "biosample":
            db_id = db + "_id"
        elif db == "sra":
            db_id = db + "_id"
        cmd = ("esearch -db {} -query {} | "
               "efetch -format docsum".format(db, self.metadata[db_id]))
        # Make efetch timeout and retry after 30 seconds
        time_limit = 30
        if self.metadata[db_id] is not 'missing':
            try:
                p = subprocess.run(cmd,
                                   shell="True",
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.DEVNULL,
                                   timeout=time_limit)
                xml = p.stdout
                self.xml[db] = xml
                self.log.info("{} XML downloaded".format(db))
            except subprocess.TimeoutExpired:
                self.log.error("Retrying efetch after timeout")
                raise subprocess.TimeoutExpired(cmd, time_limit)
            except Exception:
                self.log.error(db)
                self.log.exception()

    def parse_biosample(self):
        """
        Get what we need to get out of the xml returned by efetch("biosample")
        Including the SRA ID and fields of interest as defined in
        Metadata.biosample_fields
        """
        try:
            tree = ET.fromstring(self.xml["biosample"])
            sra = tree.find(
                'DocumentSummary/SampleData/BioSample/Ids/Id/[@db="SRA"]')
            self.log.info("Parsed biosample XML")
            try:
                self.metadata["sra_id"] = sra.text
            except AttributeError:
                self.metadata["sra_id"] = "missing"
            for name in Metadata.Metadata.biosample_fields:
                xp = ('DocumentSummary/SampleData/BioSample/Attributes/'
                      'Attribute/[@harmonized_name="{}"]'.format(name))
                attrib = tree.find(xp)
                try:
                    self.metadata[name] = attrib.text
                except AttributeError:
                    self.metadata[name] = "missing"

        except ParseError:
            self.log.error("Parse error for biosample XML")

    def parse_sra(self):
        try:
            tree = ET.fromstring(self.xml["sra"])
            elements = tree.iterfind("DocumentSummary/Runs/Run/[@acc]")
            self.log.info("Parsed SRA XML")
            srr_accessions = []
            for el in elements:
                items = el.items()
                acc = [i[1] for i in items if i[0] == 'acc']
                acc = acc[0]
                srr_accessions.append(acc)
            self.metadata["srr_accessions"] = ','.join(srr_accessions)
        except ParseError:
            self.log.error("Parse error for SRA XML")

    def get_metadata(self):
        self.efetch("biosample")
        self.parse_biosample()
        self.efetch("sra")
        self.parse_sra()
Example #18
0
class BaseDataGenerator():
    """
    Base synthetic data provider class.
    """
    def __init__(self,
                 episode_duration=None,
                 timeframe=1,
                 generator_fn=null_generator,
                 generator_params=None,
                 name='BaseSyntheticDataGenerator',
                 data_names=('default_asset', ),
                 global_time=None,
                 task=0,
                 log_level=WARNING,
                 _nested_class_ref=None,
                 _nested_params=None,
                 **kwargs):
        """

        Args:
            episode_duration:       dict, duration of episode in days/hours/mins
            generator_fn            callabale, should return generated data as 1D np.array
            generator_params        dict,
            timeframe:              int, data periodicity in minutes
            name:                   str
            data_names:             iterable of str
            global_time:            dict {y, m, d} to set custom global time (only for plotting)
            task:                   int
            log_level:              logbook.Logger level
            **kwargs:

        """
        # Logging:
        self.log_level = log_level
        self.task = task
        self.name = name
        self.filename = self.name + '_sample'

        self.data_names = data_names
        self.data_name = self.data_names[0]
        self.sample_instance = None
        self.metadata = {
            'sample_num': 0,
            'type': None,
            'parent_sample_type': None
        }

        self.data = None
        self.data_stat = None

        self.sample_num = 0
        self.is_ready = False

        if _nested_class_ref is None:
            self.nested_class_ref = BaseDataGenerator
        else:
            self.nested_class_ref = _nested_class_ref

        if _nested_params is None:
            self.nested_params = dict(
                episode_duration=episode_duration,
                timeframe=timeframe,
                generator_fn=generator_fn,
                generator_params=generator_params,
                name=name,
                data_names=data_names,
                task=task,
                log_level=log_level,
                _nested_class_ref=_nested_class_ref,
                _nested_params=_nested_params,
            )
        else:
            self.nested_params = _nested_params

        StreamHandler(sys.stdout).push_application()
        self.log = Logger('{}_{}'.format(self.name, self.task),
                          level=self.log_level)

        # Default sample time duration:
        if episode_duration is None:
            self.episode_duration = dict(
                days=0,
                hours=23,
                minutes=55,
            )
        else:
            self.episode_duration = episode_duration

        # Btfeed parsing setup:
        self.timeframe = timeframe
        self.names = ['open']
        self.datetime = 0
        self.open = 1
        self.high = -1
        self.low = -1
        self.close = -1
        self.volume = -1
        self.openinterest = -1

        # base data feed related:
        self.params = {}
        if global_time is None:
            self.global_time = datetime.datetime(year=2018, month=1, day=1)
        else:
            self.global_time = datetime.datetime(**global_time)

        self.global_timestamp = self.global_time.timestamp()

        # Infer time indexes and sample number of records:
        self.train_index = pd.timedelta_range(
            start=datetime.timedelta(days=0, hours=0, minutes=0),
            end=datetime.timedelta(**self.episode_duration),
            freq='{}min'.format(self.timeframe))
        self.test_index = pd.timedelta_range(
            start=self.train_index[-1] +
            datetime.timedelta(minutes=self.timeframe),
            periods=len(self.train_index),
            freq='{}min'.format(self.timeframe))
        self.train_index += self.global_time
        self.test_index += self.global_time
        self.episode_num_records = len(self.train_index)

        self.generator_fn = generator_fn

        if generator_params is None:
            self.generator_params = {}

        else:
            self.generator_params = generator_params

    def set_logger(self, level=None, task=None):
        """
        Sets logbook logger.

        Args:
            level:  logbook.level, int
            task:   task id, int

        """
        if task is not None:
            self.task = task

        if level is not None:
            self.log = Logger('{}_{}'.format(self.name, self.task),
                              level=level)

    def reset(self, **kwargs):
        self.sample_num = 0
        self.is_ready = True

    def read_csv(self, **kwargs):
        self.data = self.generate_data(self.generator_params)

    def generate_data(self, generator_params, type=0):
        """
        Generates data trajectory (episode)

        Args:
            generator_params:       dict, data generating parmeters
            type:                   0 - generate train data | 1 - generate test data

        Returns:
            data as pandas dataframe
        """
        assert type in [
            0, 1
        ], 'Expected sample type be either 0 (train), or 1 (test) got: {}'.format(
            type)
        # Generate datapoints:
        data_array = self.generator_fn(num_points=self.episode_num_records,
                                       **generator_params)
        assert len(data_array.shape) == 1 and data_array.shape[0] == self.episode_num_records,\
            'Expected generated data to be 1D array of length {},  got data shape: {}'.format(
                self.episode_num_records,
                data_array.shape
            )
        negs = data_array[data_array < 0]
        if negs.any():
            self.log.warning(
                ' Set to zero {} negative generated values'.format(
                    negs.shape[0]))
            data_array[data_array < 0] = 0.0
        # Make dataframe:
        if type:
            index = self.test_index
        else:
            index = self.train_index

        # data_dict = {name: data_array for name in self.names}
        # data_dict['hh:mm:ss'] = index
        df = pd.DataFrame(data={name: data_array
                                for name in self.names},
                          index=index)
        # df = df.set_index('hh:mm:ss')
        return df

    def sample(self, get_new=True, sample_type=0, **kwargs):
        """
        Samples continuous subset of data.

        Args:
            get_new (bool):                 not used;
            sample_type (int or bool):      0 (train) or 1 (test) - get sample from train or test data subsets
                                            respectively.

        Returns:
            Dataset instance with number of records ~ max_episode_len,

        """
        try:
            assert sample_type in [0, 1]

        except AssertionError:
            self.log.exception(
                'Sampling attempt: expected sample type be in {}, got: {}'.\
                format([0, 1], sample_type)
            )
            raise AssertionError

        if self.metadata['type'] is not None:
            if self.metadata['type'] != sample_type:
                self.log.warning(
                    'Attempted to sample type {} given current sample type {}, overriden.'
                    .format(self.metadata['type'], sample_type))
                sample_type = self.metadata['type']

        # Generate data:
        sampled_data = self.generate_data(self.generator_params,
                                          type=sample_type)
        self.sample_instance = self.nested_class_ref(**self.nested_params)
        self.sample_instance.filename += '_{}'.format(self.sample_num)
        self.log.info('New sample id: <{}>.'.format(
            self.sample_instance.filename))
        self.sample_instance.data = sampled_data

        # Add_metadata
        self.sample_instance.metadata['type'] = 'synthetic_data_sample'
        self.sample_instance.metadata['first_row'] = 0
        self.sample_instance.metadata['last_row'] = self.episode_num_records
        self.sample_instance.metadata['type'] = sample_type
        self.sample_instance.metadata['sample_num'] = self.sample_num
        self.sample_instance.metadata['parent_sample_num'] = self.metadata[
            'sample_num']
        self.sample_instance.metadata['parent_sample_type'] = self.metadata[
            'type']
        self.sample_num += 1

        return self.sample_instance

    def describe(self):
        """
        Returns summary dataset statistic as pandas dataframe:

            - records count,
            - data mean,
            - data std dev,
            - min value,
            - 25% percentile,
            - 50% percentile,
            - 75% percentile,
            - max value

        for every data column.
        """
        # Pretty straightforward, using standard pandas utility.
        # The only caveat here is that if actual data has not been loaded yet, need to load, describe and unload again,
        # thus avoiding passing big files to BT server:
        flush_data = False
        try:
            assert not self.data.empty
            pass

        except (AssertionError, AttributeError) as e:
            self.read_csv()
            flush_data = True

        self.data_stat = self.data.describe()
        self.log.info('Data summary:\n{}'.format(self.data_stat.to_string()))

        if flush_data:
            self.data = None
            self.log.info('Flushed data.')

        return self.data_stat

    def to_btfeed(self):
        """
        Performs BTgymData-->bt.feed conversion.

        Returns:
             dict of type: {data_line_name: bt.datafeed instance}.
        """
        try:
            assert not self.data.empty
            btfeed = btfeeds.PandasDirectData(dataname=self.data,
                                              timeframe=self.timeframe,
                                              datetime=self.datetime,
                                              open=self.open,
                                              high=self.high,
                                              low=self.low,
                                              close=self.close,
                                              volume=self.volume,
                                              openinterest=self.openinterest)
            btfeed.numrecords = self.data.shape[0]
            return {self.data_name: btfeed}

        except (AssertionError, AttributeError) as e:
            msg = 'Instance holds no data. Hint: forgot to call .read_csv()?'
            self.log.error(msg)
            raise AssertionError(msg)

    def set_global_timestamp(self, timestamp):
        pass
Example #19
0
class PortfolioEnv(BTgymEnv):
    """
        OpenAI Gym API shell for Backtrader backtesting/trading library with multiply assets support.
        Action space is dictionary of contionious  actions for every asset.
        This setup closely relates to continuous portfolio optimisation problem definition.

        Setup explanation:

            0. Problem definition.
            Consider setup with one riskless asset acting as broker account cash and K (by default - one) risky assets.
            For every risky asset there exists track of historic price records referred as `data-line`.
            Apart from assets data lines there possibly exists number of exogenous data lines holding some
            information and statistics, e.g. economic indexes, encoded news, macroeconomic indicators, weather forecasts
            etc. which are considered relevant and valuable for decision-making.
            It is supposed for this setup that:
            i. there is no interest rate for base (riskless) asset;
            ii. short selling is not permitted;
            iii. transaction costs are modelled via broker commission;
            iv. 'market liquidity' and 'capital impact' assumptions are met;
            v. time indexes match for all data lines provided;

            1. Assets and datalines.
            This environment expects Dataset to be instance of `btgym.datafeed.multi.BTgymMultiData`, which sets
            number,  specifications and sampling synchronisation for historic data for all assets and data lines.

            Namely, one should define data_config dictionary of `data lines` and list of `assets`.
            `data_config` specifies all data sources used by strategy, while `assets` defines subset of `data lines`
            which is supposed to hold historic data for risky portfolio assets.

            Internally every episodic asset data is converted to single bt.feed and added to environment strategy
            as separate named data_line (see backtrader docs for extensive explanation of data_lines concept).
            Every non-asset data line as also added as bt.feed with difference that it is not 'tradable' i.e. it is
            impossible to issue trade orders on such line.
            Strategy is expected to properly handle all received data-lines.

                Example::

                    1. Four data streams added via Dataset.data_config,
                       portfolio consists of four assets, added via strategy_params, cash is EUR:

                        data_config = {
                            'usd': {'filename': '.../DAT_ASCII_EURUSD_M1_2017.csv'},
                            'gbp': {'filename': '.../DAT_ASCII_EURGBP_M1_2017.csv'},
                            'jpy': {'filename': '.../DAT_ASCII_EURJPY_M1_2017.csv'},
                            'chf': {'filename': '.../DAT_ASCII_EURCHF_M1_2017.csv'},
                        }
                        cash_name = 'eur'
                        assets_names = ['usd', 'gbp', 'jpy', 'chf']

                    2. Three streams added, only two of them form portfolio; DXY stream is `decision-making` only:
                        data_config = {
                            'usd': {'filename': '.../DAT_ASCII_EURUSD_M1_2017.csv'},
                            'gbp': {'filename': '.../DAT_ASCII_EURGBP_M1_2017.csv'},
                            '​DXY': {'filename': '.../DAT_ASCII_DXY_M1_2017.csv'},
                        }
                        cash_name = 'eur'
                        assets_names = ['usd', 'gbp']


            2. btgym.spaces.ActionDictSpace and order execution.
            ActionDictSpace is an extension of OpenAI Gym DictSpace providing domain-specific functionality.
            Strategy expects to receive separate action for every K+1 asset in form of dictionary:
            `{cash_name: a[0], asset_name_1: a[1], ..., asset_name_K: a[K]}` for K risky assets added,
            where base actions are real numbers: `a[i] in [0,1], 0<=i<=K, SUM{a[i]} = 1`. Whole action should be
            interpreted as order to adjust portfolio to have share `a[i] * 100% for i-th  asset`.

            Therefore, base actions are gym.spaces.Box and for K assets environment action space will be a shallow
            DictSpace of K+1 continuous spaces: `{cash_name: gym.spaces.Box(low=0, high=1),
            asset_name_1: gym.spaces.Box(low=0, high=1), ..., asset_name_K: gym.spaces.Box(low=0, high=1)}`

            3. TODO: refine order execution control, see: https://community.backtrader.com/topic/152/multi-asset-ranking-and-rebalancing/2?page=1

                Example::

                    if cash asset is 'eur',
                    risky assets added are: ['chf', 'gbp', 'gpy', 'usd'],
                    and data lines added via BTgymMultiData are:
                    {
                        'chf': eurchf_hist_data_source,
                        'gbp', eurgbp_hist_data_source,
                        'jpy', eurgpy_hist_data_source,
                        'usd', eurusd_hist_data_source,
                    },
                    than:

                    env.action.space will be:
                        DictSpace(
                            {
                                'eur': gym.spaces.Box(low=0, high=1, dtype=np.float32),
                                'chf': gym.spaces.Box(low=0, high=1, dtype=np.float32),
                                'gbp': gym.spaces.Box(low=0, high=1, dtype=np.float32),
                                'jpy': gym.spaces.Box(low=0, high=1, dtype=np.float32),
                                'usd': gym.spaces.Box(low=0, high=1, dtype=np.float32),
                            }
                        )

                    single environment action instance (as seen inside strategy or passed to environment via .step()):
                        {
                            'eur': 0.3
                            'chf': 0.1,
                            'gbp': 0.1,
                            'jpy': 0.2,
                            'usd': 0.3,
                        }

                    or vector (unlike multi-asset discrete setup, there is no binary/one hot encoding):
                        (0.3, 0.1, 0.1, 0.2, 0.3)

                    which says to broker: "... adjust positions to get 30% in base EUR asset (cash), and amounts of
                    10%, 10%, 20% and 30% off current portfolio value in CHF, GBP, JPY respectively".

                    Note that under the hood broker uses `order_target_percent` for every risky asset and can issue
                    'sell', 'buy' or 'close' orders depending on positive/negative difference of current to desired
                    share of asset.

            3. Observation space: is nested DictSpace, where 'external' part part of space should hold specifications
            for every data line added (note that cash asset does not have it's own data line).

                Example::

                    if data lines added via BTgymMultiData are:
                        'chf', 'gbp', 'jpy', 'usd';

                    environment observation space can be DictSpace:
                     {
                        'external': DictSpace(
                            {
                                'usd': spaces.Box(low=-1000, high=1000, shape=(128, 1, num_features), dtype=np.float32),
                                'gbp': spaces.Box(low=-1000, high=1000, shape=(128, 1, num_features), dtype=np.float32),
                                'chf': spaces.Box(low=-1000, high=1000, shape=(128, 1, num_features), dtype=np.float32),
                                'jpy': spaces.Box(low=-1000, high=1000, shape=(128, 1, num_features), dtype=np.float32),
                            }
                        ),
                        'raw': spaces.Box(...),
                        'internal': spaces.Box(...),
                        'datetime': spaces.Box(...),
                        'metadata': DictSpace(...)
                    }

                    refer to strategies declarations for full code.

        """

    # Datafeed Server management:
    data_master = True
    data_network_address = 'tcp://127.0.0.1:'  # using localhost.
    data_port = 4999
    data_server = None
    data_server_pid = None
    data_context = None
    data_socket = None
    data_server_response = None

    # Dataset:
    dataset = None  # BTgymDataset instance.
    dataset_stat = None

    # Backtrader engine:
    engine = None  # bt.Cerbro subclass for server to execute.

    # Strategy:
    strategy = None  # strategy to use if no <engine> class been passed.

    # Server and network:
    server = None  # Server process.
    context = None  # ZMQ context.
    socket = None  # ZMQ socket, client side.
    port = 5500  # network port to use.
    network_address = 'tcp://127.0.0.1:'  # using localhost.
    ctrl_actions = ('_done', '_reset', '_stop', '_getstat', '_render'
                    )  # server control messages.
    server_response = None

    # Connection timeout:
    connect_timeout = 60  # server connection timeout in seconds.
    # connect_timeout_step = 0.01  # time between retries in seconds.

    # Rendering:
    render_enabled = True
    render_modes = [
        'human',
        'episode',
    ]
    # `episode` - plotted episode results.
    # `human` - raw_state observation in conventional human-readable format.
    #  <obs_space_key> - rendering of arbitrary state presented in observation_space with same key.

    renderer = None  # Rendering support.
    rendered_rgb = dict()  # Keep last rendered images for each mode.

    # Logging and id:
    log = None
    log_level = None  # logbook level: NOTICE, WARNING, INFO, DEBUG etc. or its integer equivalent;
    verbose = 0  # verbosity mode, valid only if no `log_level` arg has been provided:
    # 0 - WARNING, 1 - INFO, 2 - DEBUG.
    task = 0
    asset_names = ('default_asset', )
    data_lines_names = ('default_asset', )
    cash_name = 'default_cash'

    random_seed = None

    closed = True

    def __init__(self, engine, dataset=None, **kwargs):
        """
        This class requires dataset, strategy, engine instances to be passed explicitly.

        Args:
            dataset(btgym.datafeed):                        BTgymDataDomain instance;
            engine(bt.Cerebro):                             environment simulation engine, any bt.Cerebro subclass,

        Keyword Args:
            network_address=`tcp://127.0.0.1:` (str):       BTGym_server address.
            port=5500 (int):                                network port to use for server - API_shell communication.
            data_master=True (bool):                        let this environment control over data_server;
            data_network_address=`tcp://127.0.0.1:` (str):  data_server address.
            data_port=4999 (int):                           network port to use for server -- data_server communication.
            connect_timeout=60 (int):                       server connection timeout in seconds.
            render_enabled=True (bool):                     enable rendering for this environment;
            render_modes=['human', 'episode'] (list):       `episode` - plotted episode results;
                                                            `human` - raw_state observation.
            **render_args (any):                            any render-related args, passed through to renderer class.
            verbose=0 (int):                                verbosity mode, {0 - WARNING, 1 - INFO, 2 - DEBUG}
            log_level=None (int):                           logbook level {DEBUG=10, INFO=11, NOTICE=12, WARNING=13},
                                                            overrides `verbose` arg;
            log=None (logbook.Logger):                      external logbook logger,
                                                            overrides `log_level` and `verbose` args.
            task=0 (int):                                   environment id


        """
        self.dataset = dataset
        self.engine = engine
        # Parameters and default values:
        self.params = dict(
            engine={},
            dataset={},
            strategy={},
            render={},
        )
        # Update self attributes, remove used kwargs:
        for key in dir(self):
            if key in kwargs.keys():
                setattr(self, key, kwargs.pop(key))

        self.metadata = {'render.modes': self.render_modes}

        # Logging and verbosity control:
        if self.log is None:
            StreamHandler(sys.stdout).push_application()
            if self.log_level is None:
                log_levels = [(0, NOTICE), (1, INFO), (2, DEBUG)]
                self.log_level = WARNING
                for key, value in log_levels:
                    if key == self.verbose:
                        self.log_level = value
            self.log = Logger('BTgymPortfolioShell_{}'.format(self.task),
                              level=self.log_level)

        # Random seeding:
        np.random.seed(self.random_seed)

        # Network parameters:
        self.network_address += str(self.port)
        self.data_network_address += str(self.data_port)

        # Set server rendering:
        if self.render_enabled:
            self.renderer = BTgymRendering(self.metadata['render.modes'],
                                           log_level=self.log_level,
                                           **kwargs)

        else:
            self.renderer = BTgymNullRendering()
            self.log.info(
                'Rendering disabled. Call to render() will return null-plug image.'
            )

        # Append logging:
        self.renderer.log = self.log

        # Update params -1: pull from renderer, remove used kwargs:
        self.params['render'].update(self.renderer.params)
        for key in self.params['render'].keys():
            if key in kwargs.keys():
                _ = kwargs.pop(key)

        # self.assets = list(self.dataset.assets)

        if self.data_master:
            try:
                assert self.dataset is not None

            except AssertionError:
                msg = 'Dataset instance shoud be provided for data_master environment.'
                self.log.error(msg)
                raise ValueError(msg)

            # Append logging:
            self.dataset.set_logger(self.log_level, self.task)

            # Update params -2: pull from dataset, remove used kwargs:
            self.params['dataset'].update(self.dataset.params)
            for key in self.params['dataset'].keys():
                if key in kwargs.keys():
                    _ = kwargs.pop(key)

        # Connect/Start data server (and get dataset statistic):
        self.log.info('Connecting data_server...')
        self._start_data_server()
        self.log.info('...done.')
        # After starting data-server we have self.assets attribute, dataset statisitc etc. filled.

        # Define observation space shape, minimum / maximum values and agent action space.
        # Retrieve values from configured engine or...

        # ...Update params -4:
        # Pull strategy defaults to environment params dict :
        for t_key, t_value in self.engine.strats[0][0][0].params._gettuple():
            self.params['strategy'][t_key] = t_value

        # Update it with values from strategy 'passed-to params':
        for key, value in self.engine.strats[0][0][2].items():
            self.params['strategy'][key] = value

        self.asset_names = self.params['strategy']['asset_names']
        self.cash_name = self.params['strategy']['cash_name']

        self.params['strategy']['initial_action'] = self.get_initial_action()
        self.params['strategy'][
            'initial_portfolio_action'] = self.get_initial_action()

        self.server_actions = {
            name: self.params['strategy']['portfolio_actions']
            for name in self.asset_names
        }

        try:
            assert set(self.asset_names).issubset(set(self.data_lines_names))

        except AssertionError:
            msg = 'Assets names should be subset of data_lines names, but got: assets: {}, data_lines: {}'.format(
                set(self.asset_names), set(self.data_lines_names))
            self.log.error(msg)
            raise ValueError(msg)

        try:
            assert self.params['strategy']['portfolio_actions'] is None

        except AssertionError:
            self.log.debug(
                'For continious action space strategy.params[`portfolio_actions`] should be `None`, corrected.'
            )
            self.params['strategy']['portfolio_actions'] = None

        # ... Push it all back (don't ask):
        for key, value in self.params['strategy'].items():
            self.engine.strats[0][0][2][key] = value

        # For 'raw_state' min/max values,
        # the only way is to infer from raw Dataset price values (we already got those from data_server):
        if 'raw_state' in self.params['strategy']['state_shape'].keys():
            # Exclude 'volume' from columns we count:
            self.dataset_columns.remove('volume')

            # print(self.params['strategy'])
            # print('self.engine.strats[0][0][2]:', self.engine.strats[0][0][2])
            # print('self.engine.strats[0][0][0].params:', self.engine.strats[0][0][0].params._gettuple())

            # Override with absolute price min and max values:
            self.params['strategy']['state_shape']['raw_state'].low = \
                self.engine.strats[0][0][2]['state_shape']['raw_state'].low = \
                np.zeros(self.params['strategy']['state_shape']['raw_state'].shape) + \
                self.dataset_stat.loc['min', self.dataset_columns].min()

            self.params['strategy']['state_shape']['raw_state'].high = \
                self.engine.strats[0][0][2]['state_shape']['raw_state'].high = \
                np.zeros(self.params['strategy']['state_shape']['raw_state'].shape) + \
                self.dataset_stat.loc['max', self.dataset_columns].max()

            self.log.info(
                'Inferring `state_raw` high/low values form dataset: {:.6f} / {:.6f}.'
                .format(
                    self.dataset_stat.loc['min', self.dataset_columns].min(),
                    self.dataset_stat.loc['max', self.dataset_columns].max()))

        # Set observation space shape from engine/strategy parameters:
        self.observation_space = DictSpace(
            self.params['strategy']['state_shape'])

        self.log.debug('Obs. shape: {}'.format(self.observation_space.spaces))

        # Set action space and corresponding server messages:
        self.action_space = ActionDictSpace(
            base_actions=self.params['strategy']['portfolio_actions'],  # None
            assets=list(self.asset_names) + [self.cash_name])

        self.log.debug('Act. space shape: {}'.format(self.action_space.spaces))

        # Finally:
        self.server_response = None
        self.env_response = None

        # if not self.data_master:
        self._start_server()
        self.closed = False

        self.log.info('Environment is ready.')

    def get_initial_action(self):
        action = {asset: np.asarray([0.0]) for asset in self.asset_names}
        action[self.cash_name] = np.asarray([1.0])
        return action

    def step(self, action):
        """
        Implementation of OpenAI Gym env.step() method.
        Makes a step in the environment.

        Args:
            action:     int or dict, action compatible to env.action_space

        Returns:
            tuple (Observation, Reward, Info, Done)

        """
        # Are you in the list, ready to go and all that?
        if self.action_space.contains(action) \
                and not self._closed \
                and (self.socket is not None) \
                and not self.socket.closed:
            pass

        else:
            msg = ('\nAt least one of these is true:\n' +
                   'Action error: (space is {}, action sent is {}): {}\n' +
                   'Environment closed: {}\n' +
                   'Network error [socket doesnt exists or closed]: {}\n' +
                   'Hint: forgot to call reset()?').format(
                       self.action_space,
                       action,
                       not self.action_space.contains(action),
                       self._closed,
                       not self.socket or self.socket.closed,
                   )
            self.log.exception(msg)
            raise AssertionError(msg)

        # print('step: ', action, action_as_dict)
        env_response = self._comm_with_timeout(socket=self.socket,
                                               message={'action': action})
        if not env_response['status'] in 'ok':
            msg = '.step(): server unreachable with status: <{}>.'.format(
                env_response['status'])
            self.log.error(msg)
            raise ConnectionError(msg)

        self.env_response = env_response['message']

        return self.env_response
Example #20
0
class Casca(Bot):
    def __init__(self, *args, **kwargs):
        config_file = os.path.join(os.getcwd(), "config.yaml")

        with open(config_file) as f:
            self.config = yaml.load(f)

        super().__init__(*args, **kwargs)

        # Define the logging set up.
        redirect_logging()
        StreamHandler(sys.stderr).push_application()

        self.logger = Logger("Casca_Best_Bot")
        self.logger.level = getattr(logbook,
                                    self.config.get("log_level", "INFO"),
                                    logbook.INFO)

        # Set the root logger level, too.
        logging.root.setLevel(self.logger.level)

        self._loaded = False

    async def on_ready(self):
        if self._loaded:
            return

        self.logger.info(
            "LOADED Casca | LOGGED IN AS: {0.user.name}#{0.user.discriminator}.\n----------------------------------------------------------------------------------------------------"
            .format(self))

        for cog in extensions:
            try:
                self.load_extension(cog)
            except Exception as e:
                self.logger.critical(
                    "Could not load extension `{}` -> `{}`".format(cog, e))
                self.logger.exception()
            else:
                self.logger.info("Loaded extension {}.".format(cog))

        self._loaded = True

    async def on_message(self, message):
        if not message.server:
            return

        if message.server.id not in Whitelisted_Servers:
            return

        if message.channel.id not in Whitelisted_Channels:
            return

        self.logger.info("MESSAGE: {message.content}".format(
            message=message, bot=" [BOT]" if message.author.bot else ""))
        self.logger.info("FROM: {message.author.name}".format(message=message))

        if message.server is not None:
            self.logger.info(
                "CHANNEL: {message.channel.name}".format(message=message))
            self.logger.info(
                "SERVER: {0.server.name}\n----------------------------------------------------------------------------------------------------"
                .format(message))

        await super().on_message(message)

    async def on_command_error(self, e, ctx):
        if isinstance(e, (commands.errors.BadArgument,
                          commands.errors.MissingRequiredArgument)):
            await self.send_message(ctx.message.channel,
                                    "```ERROR: {}```".format(' '.join(e.args)))
            return

    async def on_command(self, command, ctx):
        await self.delete_message(ctx.message)

    def run(self):
        try:
            super().run(self.config["bot"]["token"], bot=True)
        except discord.errors.LoginFailure as e:
            self.logger.error("LOGIN FAILURE: {}".format(e.args[0]))
            sys.exit(2)
Example #21
0
File: twitter.py Project: m00n/mutc
class TwitterThread(QThread):
    newTweets = pyqtSignal(object, object)

    RATE_CHECK_INTERVAL = 60 * 10

    def __init__(self, parent, subscriptions, limit_config):
        QThread.__init__(self, parent)

        self.subscriptions = subscriptions
        self.limit_config = limit_config

        self.ticks = 1
        self.tick_count = 60

        self.running = True
        self.force_check = threading.Event()

        self.logger = Logger("twitter-thread")
        self.rate_logger = Logger("twitter-limits")

        #
        self.last_rate_check = time()
        self.ticks_for_account = {}
        self.tick_counter = {}

    def run(self):
        while self.running:
            self.check_intervals()

            if time() - self.last_rate_check > self.RATE_CHECK_INTERVAL:
                self.rate_logger.info("Recalculating ticks")
                self.calc_rates()

            sleep(self.ticks)

    def check_intervals(self):
        subscriptions = self.get_subscriptions()
        accounts = set(subscription.account for subscription in subscriptions
                        if subscription.account.me)

        for account in accounts:
            #__rticks = self.tick_counter.get(account)
            #if __rticks and __rticks % 5 == 0:
                #print >>sys.stderr, account, self.tick_counter.get(account)

            if account not in self.tick_counter:
                self.calc_rates()
                self.tick_counter[account] = 1 # force checking

            #print account, self.tick_counter[account]
            self.tick_counter[account] -= 1
            if self.tick_counter[account] == 0:
                self.tick_counter[account] = self.ticks_for_account[account]
                self.check_subscriptions(account)

    def get_subscriptions(self):
        with self.subscriptions:
            return self.subscriptions.values()

    def calc_rates(self):
        subscriptions = self.get_subscriptions()
        calls_per_account = defaultdict(int)

        for subscription in subscriptions:
            calls_per_account[subscription.account] += subscription.calls

        for account, calls in calls_per_account.iteritems():
            if account.me:
                rate_info = safe_api_request(account.api.rate_limit_status)
                ticks = calc_ticks(rate_info, calls, **self.limit_config)
                self.ticks_for_account[account] = ticks
                self.tick_counter[account] = ticks

                self.rate_logger.debug(
                    "{0}; calls: {1}({2}); ticks: {3}",
                    repr(account.me.screen_name),
                    calls,
                    rate_info["remaining_hits"],
                    ticks
                )

        self.last_rate_check = time()

    def check_subscriptions(self, account=None):
        subscriptions = self.get_subscriptions()

        self.logger.debug("Checking {0} subscriptions", len(subscriptions))

        for subscription in subscriptions:
            if account and subscription.account != account:
                continue

            if subscription.account.api:
                try:
                    tweets = subscription.update()
                except tweepy.TweepError as error:
                    self.logger.exception("Error while fetching tweets")
                except Exception as exc:
                    self.logger.exception("Unexpected exception")
                else:
                    if tweets:
                        self.logger.debug("{0} new tweets for {1}/{2}",
                            len(tweets),
                            subscription.account,
                            subscription.subscription_type
                        )
                        self.newTweets.emit(subscription, tweets)

    def stepped_sleep(self):
        for x in xrange(self.tick_count):
            sleep(self.ticks)
            if self.force_check.is_set():
                self.force_check.clear()
                break
Example #22
0
                if yield_trace:
                    yield line
                if not line and f.tell() == num_lines:
                    break

    def _parse_data(self, tracepoint, data):
        """
        Parse payload(data) for tracepoint - if we have it.
        """
        rv = data
        try:
            rv = PARSERS[tracepoint](data)
        except Exception, e:
            rv = PARSERS[tracepoint](data)
        except ParserError, e:
            log.exception(e)
            log.warn('Error parsing {tp} with {data}'.format(tp=tracepoint, data=data))
        finally:
            return rv if rv else data

    def _check_tracer(self, line):
        """
        Return tracer (typically 'nop')
        """
        match = re.match(self._TRACER_PATTERN, line.strip())
        if match:
            return match.groupdict()['tracer']
        return None

    def _check_buffer_entries(self, line):
        """
Example #23
0
    :return int: The port in which the slave will listen.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--file", dest='ips_source_file_name', type=str,
                        help='The name of the file which will '
                             'contain the list of ips (can be relative or absolute path)', required=True)
    parser.add_argument("--divide", dest='divide_ips', action='store_true',
                        help='A flag which indicates whether an ips range should be a single job for one'
                             'slave or should the master divide the range to single ips and each ip will'
                             'be considered a job for a slave', default=False)
    parser.add_argument("--flags", dest='flags', type=str,
                        help='The flags of the script (see the README.md file)', required=True)

    parser.add_argument("-p", "--ports", dest='ports', type=str,
                        help='The ports which will be scanned for the ips.', required=True)

    return parser.parse_args()


if __name__ == '__main__':
    try:
        arguments = parse_arguments()
        ips_to_scan = _retrieve_ips_to_scan(arguments.ips_source_file_name, divide_ips=arguments.divide_ips)
        start_master(ips_to_scan=ips_to_scan, flags=arguments.flags, ports=arguments.ports)
    except (KeyboardInterrupt, SystemExit):
        # We want to be able to abort the running of the code without a strange log :)
        raise
    except Exception:
        logger.exception()
        raise
Example #24
0
class MLDG():
    """
    Asynchronous implementation of MLDG algorithm
    for continuous adaptation in dynamically changing environments.

    Papers:
        Da Li et al.,
         "Learning to Generalize: Meta-Learning for Domain Generalization"
         https://arxiv.org/abs/1710.03463

        Maruan Al-Shedivat et al.,
        "Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments"
        https://arxiv.org/abs/1710.03641

    """
    def __init__(
            self,
            env,
            task,
            log_level,
            aac_class_ref=SubAAC,
            runner_config=None,
            aac_lambda=1.0,
            guided_lambda=1.0,
            rollout_length=20,
            train_support=300,
            fast_adapt_num_steps=10,
            fast_adapt_batch_size=32,
            trial_source_target_cycle=(1, 0),
            num_episodes_per_trial=1,  # one-shot adaptation
            _aux_render_modes=('action_prob', 'value_fn', 'lstm_1_h', 'lstm_2_h'),
            name='MLDG',
            **kwargs
    ):
        try:
            self.aac_class_ref = aac_class_ref
            self.task = task
            self.name = name
            self.summary_writer = None

            StreamHandler(sys.stdout).push_application()
            self.log = Logger('{}_{}'.format(name, task), level=log_level)

            self.rollout_length = rollout_length
            self.train_support = train_support  # number of train experiences to collect
            self.train_batch_size = int(self.train_support / self.rollout_length)
            self.fast_adapt_num_steps = fast_adapt_num_steps
            self.fast_adapt_batch_size = fast_adapt_batch_size

            if runner_config is None:
                self.runner_config = {
                    'class_ref': BaseSynchroRunner,
                    'kwargs': {},
                }
            else:
                self.runner_config = runner_config

            self.env_list = env

            assert isinstance(self.env_list, list) and len(self.env_list) == 2, \
                'Expected pair of environments, got: {}'.format(self.env_list)

            # Instantiate two sub-trainers: one for test and one for train environments:

            self.runner_config['kwargs']['data_sample_config'] = {'mode': 1}  # master
            self.runner_config['kwargs']['name'] = 'master'

            self.train_aac = aac_class_ref(
                env=self.env_list[0],  # train data will be master environment TODO: really dumb data control. improve.
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                rollout_length=self.rollout_length,
                trial_source_target_cycle=trial_source_target_cycle,
                num_episodes_per_trial=num_episodes_per_trial,
                _use_target_policy=False,
                _use_global_network=True,
                _aux_render_modes=_aux_render_modes,
                name=self.name + '_sub_Train',
                **kwargs
            )

            self.runner_config['kwargs']['data_sample_config'] = {'mode': 0}  # master
            self.runner_config['kwargs']['name'] = 'slave'

            self.test_aac = aac_class_ref(
                env=self.env_list[-1],  # test data -> slave env.
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                rollout_length=self.rollout_length,
                trial_source_target_cycle=trial_source_target_cycle,
                num_episodes_per_trial=num_episodes_per_trial,
                _use_target_policy=False,
                _use_global_network=False,
                global_step_op=self.train_aac.global_step,
                global_episode_op=self.train_aac.global_episode,
                inc_episode_op=self.train_aac.inc_episode,
                _aux_render_modes=_aux_render_modes,
                name=self.name + '_sub_Test',
                **kwargs
            )

            self.local_steps = self.train_aac.local_steps
            self.model_summary_freq = self.train_aac.model_summary_freq
            #self.model_summary_op = self.train_aac.model_summary_op

            self._make_train_op()
            self.test_aac.model_summary_op = tf.summary.merge(
                [self.test_aac.model_summary_op, self._combine_meta_summaries()],
                name='meta_model_summary'
            )

        except:
            msg = 'MLDG.__init()__ exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def _make_train_op(self):
        """

        Defines:
            tensors holding training op graph for sub trainers and self;
        """
        pi = self.train_aac.local_network
        pi_prime = self.test_aac.local_network

        self.test_aac.sync = self.test_aac.sync_pi = tf.group(
            *[v1.assign(v2) for v1, v2 in zip(pi_prime.var_list, pi.var_list)]
        )

        self.global_step = self.train_aac.global_step
        self.global_episode = self.train_aac.global_episode

        self.test_aac.global_step = self.train_aac.global_step
        self.test_aac.global_episode = self.train_aac.global_episode
        self.test_aac.inc_episode = self.train_aac.inc_episode
        self.train_aac.inc_episode = None
        self.inc_step = self.train_aac.inc_step

        # Meta-loss:
        self.loss = 0.5 * self.train_aac.loss + 0.5 * self.test_aac.loss

        # Clipped gradients:
        self.train_aac.grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.train_aac.loss, pi.var_list),
            40.0
        )
        self.log.warning('self.train_aac.grads: {}'.format(len(list(self.train_aac.grads))))

        # self.test_aac.grads, _ = tf.clip_by_global_norm(
        #     tf.gradients(self.test_aac.loss, pi_prime.var_list),
        #     40.0
        # )
        # Meta-gradient:
        grads_i, _ = tf.clip_by_global_norm(
            tf.gradients(self.train_aac.loss, pi.var_list),
            40.0
        )

        grads_i_next, _ = tf.clip_by_global_norm(
            tf.gradients(self.test_aac.loss, pi_prime.var_list),
            40.0
        )

        self.grads = []
        for g1, g2 in zip(grads_i, grads_i_next):
            if g1 is not None and g2 is not None:
                meta_g = 0.5 * g1 + 0.5 * g2
            else:
                meta_g = None

            self.grads.append(meta_g)

        #self.log.warning('self.grads_len: {}'.format(len(list(self.grads))))

        # Gradients to update local copy of pi_prime (from train data):
        train_grads_and_vars = list(zip(self.train_aac.grads, pi_prime.var_list))

        # self.log.warning('train_grads_and_vars_len: {}'.format(len(train_grads_and_vars)))

        # Meta-gradients to be sent to parameter server:
        meta_grads_and_vars = list(zip(self.grads, self.train_aac.network.var_list))

        # self.log.warning('meta_grads_and_vars_len: {}'.format(len(meta_grads_and_vars)))

        # Set global_step increment equal to observation space batch size:
        obs_space_keys = list(self.train_aac.local_network.on_state_in.keys())

        assert 'external' in obs_space_keys, \
            'Expected observation space to contain `external` mode, got: {}'.format(obs_space_keys)
        self.train_aac.inc_step = self.train_aac.global_step.assign_add(
            tf.shape(self.train_aac.local_network.on_state_in['external'])[0]
        )

        self.train_op = self.train_aac.optimizer.apply_gradients(train_grads_and_vars)

        # Optimizer for meta-update:
        self.optimizer = tf.train.AdamOptimizer(self.train_aac.train_learn_rate, epsilon=1e-5)
        # TODO: own alpha-leran rate
        self.meta_train_op = self.optimizer.apply_gradients(meta_grads_and_vars)

        self.log.debug('meta_train_op defined')

    def _combine_meta_summaries(self):

        meta_model_summaries = [
            tf.summary.scalar("meta_grad_global_norm", tf.global_norm(self.grads)),
            tf.summary.scalar("total_meta_loss", self.loss),
        ]

        return meta_model_summaries

    def start(self, sess, summary_writer, **kwargs):
        """
        Executes all initializing operations,
        starts environment runner[s].
        Supposed to be called by parent worker just before training loop starts.

        Args:
            sess:           tf session object.
            kwargs:         not used by default.
        """
        try:
            # Copy weights from global to local:
            sess.run(self.train_aac.sync_pi)
            sess.run(self.test_aac.sync_pi)

            # Start thread_runners:
            self.train_aac._start_runners(   # master first
                sess,
                summary_writer,
                init_context=None,
                data_sample_config=self.train_aac.get_sample_config(mode=1)
            )
            self.test_aac._start_runners(
                sess,
                summary_writer,
                init_context=None,
                data_sample_config=self.test_aac.get_sample_config(mode=0)
            )

            self.summary_writer = summary_writer
            self.log.notice('Runners started.')

        except:
            msg = 'start() exception occurred' + \
                '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def fast_adapt_step(self, sess, batch_size, on_policy_batch, off_policy_batch, rp_batch, make_summary=False):
        """
        One step of test_policy adaptation.

        Args:
            sess:                   tensorflow.Session obj.
            batch_size:             train mini-batch size
            on_policy_batch:        `on_policy` train data
            off_policy_batch:       `off_policy` train data or None
            rp_batch:               'reward_prediction` train data or None
            make_summary:           bool, if True - compute model summary

        Returns:
            model summary or None
        """
        # Sample from train distribution:
        on_mini_batch = self.train_aac.sample_batch(on_policy_batch, batch_size)
        off_mini_batch = self.train_aac.sample_batch(off_policy_batch, batch_size)
        rp_mini_batch = self.train_aac.sample_batch(rp_batch, batch_size)

        feed_dict = self.train_aac._get_main_feeder(sess, on_mini_batch, off_mini_batch, rp_mini_batch, True)

        if make_summary:
            fetches = [self.train_op, self.train_aac.model_summary_op]
        else:
            fetches = [self.train_op]

        # Update pi_prime parameters wrt sampled data:
        fetched = sess.run(fetches, feed_dict=feed_dict)

        # self.log.warning('Train gradients ok.')

        if make_summary:
            summary =  fetched[-1]

        else:
            summary = None

        return summary

    def train_step(self, sess, data_config):
        """
        Collects train task data and updates test policy parameters (fast adaptation).

        Args:
            sess:                   tensorflow.Session obj.
            data_config:            configuration dictionary of type `btgym.datafeed.base.EnvResetConfig`

        Returns:
            batched train data

        """
        # Collect train distribution:
        train_batch = self.train_aac.get_batch(
            size=self.train_batch_size,
            require_terminal=True,
            same_trial=True,
            data_sample_config=data_config
        )

        # for rollout in train_batch['on_policy']:
        #     self.log.warning(
        #         'Train data trial_num: {}'.format(
        #             np.asarray(rollout['state']['metadata']['trial_num'])
        #         )
        #     )

        # Process time-flat-alike (~iid) to treat as empirical data distribution over train task:
        on_policy_batch, off_policy_batch, rp_batch = self.train_aac.process_batch(sess, train_batch)

        # self.log.warning('Train data ok.')

        local_step = sess.run(self.global_step)
        local_episode = sess.run(self.global_episode)
        model_summary = None

        # Extract all non-empty summaries:
        ep_summary = [summary for summary in train_batch['ep_summary'] if summary is not None]

        # Perform number of test policy updates wrt. collected train data:
        for i in range(self.fast_adapt_num_steps):
            model_summary = self.fast_adapt_step(
                sess,
                batch_size=self.fast_adapt_batch_size,
                on_policy_batch=on_policy_batch,
                off_policy_batch=off_policy_batch,
                rp_batch=rp_batch,
                make_summary=(local_step + i) % self.model_summary_freq == 0
            )
            # self.log.warning('Batch {} Train gradients ok.'.format(i))

            # Write down summaries:
            train_summary = dict(
                render_summary=[None],
                test_ep_summary=[None],
                ep_summary=[ep_summary.pop() if len(ep_summary) > 0 else None]
            )
            self.train_aac.process_summary(
                sess,
                train_summary,
                model_summary,
                step=local_step + i,
                episode=local_episode + i
            )

        return on_policy_batch, off_policy_batch, rp_batch

    def meta_train_step(self, sess, data_config, on_policy_batch, off_policy_batch, rp_batch):
        """
        Collects data from source domain test task and performs meta-update to shared parameters vector.
        Writes down relevant summaries.

        Args:
            sess:                   tensorflow.Session obj.
            data_config:            configuration dictionary of type `btgym.datafeed.base.EnvResetConfig`
            on_policy_batch:        `on_policy` train data
            off_policy_batch:       `off_policy` train data or None
            rp_batch:               'reward_prediction` train data or None

        """
        done = False
        while not done:
            # Say `No` to redundant summaries:
            wirte_model_summary = \
                self.local_steps % self.model_summary_freq == 0

            # Collect test trajectory wrt updated test_policy parameters:
            test_data = self.test_aac.get_data(
                init_context=None,
                data_sample_config=data_config
            )
            test_batch_size = 0  # TODO: adjust on/off/rp sizes
            for rollout in test_data['on_policy']:
                test_batch_size += len(rollout['position'])

            test_feed_dict = self.test_aac.process_data(sess, test_data, is_train=True)

            # self.log.warning('Test data rollout for step {} ok.'.format(self.local_steps))
            #
            # self.log.warning(
            #     'Test data trial_num: {}'.format(
            #         np.asarray(test_data['on_policy'][0]['state']['metadata']['trial_num'])
            #     )
            # )

            # Sample train data of same size:
            feed_dict = self.train_aac._get_main_feeder(
                sess,
                self.train_aac.sample_batch(on_policy_batch, test_batch_size),
                self.train_aac.sample_batch(off_policy_batch, test_batch_size),
                self.train_aac.sample_batch(rp_batch, test_batch_size),
                True
            )
            # Add test trajectory:
            feed_dict.update(test_feed_dict)

            # Perform meta-update:
            if wirte_model_summary:
                meta_fetches = [self.meta_train_op, self.test_aac.model_summary_op, self.inc_step]
            else:
                meta_fetches = [self.meta_train_op, self.inc_step]

            meta_fetched = sess.run(meta_fetches, feed_dict=feed_dict)

            # self.log.warning('Meta-gradients ok.')

            if wirte_model_summary:
                meta_model_summary = meta_fetched[-2]

            else:
                meta_model_summary = None

            # Write down summaries:
            self.test_aac.process_summary(sess, test_data, meta_model_summary)
            self.local_steps += 1

            # If test episode ended?
            done = np.asarray(test_data['terminal']).any()

    def meta_test_step(self, sess, data_config, on_policy_batch, off_policy_batch, rp_batch):
        """
        Validates adapted policy on data from target domain test task.
        Writes down relevant summaries.

        Args:
            sess:                   tensorflow.Session obj.
            data_config:            configuration dictionary of type `btgym.datafeed.base.EnvResetConfig`
            on_policy_batch:        `on_policy` train data
            off_policy_batch:       `off_policy` train data or None
            rp_batch:               'reward_prediction` train data or None

        """
        done = False
        while not done:
            # Collect test trajectory:
            test_data = self.test_aac.get_data(
                init_context=None,
                data_sample_config=data_config
            )

            # self.log.warning('Target test rollout ok.')
            # self.log.warning(
            #     'Test data target trial_num: {}'.format(
            #         np.asarray(test_data['on_policy'][0]['state']['metadata']['trial_num'])
            #     )
            # )
            # self.log.warning('target_render_ep_summary: {}'.format(test_data['render_summary']))

            # Write down summaries:
            self.test_aac.process_summary(sess, test_data)

            # If test episode ended?
            done = np.asarray(test_data['terminal']).any()

    def process(self, sess):
        """
        Meta-train procedure for one-shot learning/

        Args:
            sess (tensorflow.Session):   tf session obj.

        """
        try:
            # Copy from parameter server:
            sess.run(self.train_aac.sync_pi)
            sess.run(self.test_aac.sync_pi)

            #self.log.warning('Sync ok.')

            # Decide on data configuration for train/test trajectories,
            # such as all data will come from same trial (maybe different episodes)
            # and trial type as well (~from source or target domain):
            # note: data_config counters get updated once per process() call
            train_data_config = self.train_aac.get_sample_config(mode=1)  # master env., draws trial
            test_data_config = self.train_aac.get_sample_config(mode=0)   # slave env, catches up with same trial

            # If data comes from source or target domain:
            is_target = train_data_config['trial_config']['sample_type']

            # self.log.warning('PROCESS_train_data_config: {}'.format(train_data_config))
            # self.log.warning('PROCESS_test_data_config: {}'.format(test_data_config))

            # Fast adaptation step:
            # collect train trajectories, process time-flat-alike (~iid) to treat as empirical data distribution
            # over train task and adapt test_policy wrt. train experience:
            on_policy_batch, off_policy_batch, rp_batch = self.train_step(sess, train_data_config)

            # Slow adaptation step:
            if is_target:
                # Meta-test:
                # self.log.warning('Running meta-test episode...')
                self.meta_test_step(sess,test_data_config, on_policy_batch, off_policy_batch, rp_batch)

            else:
                # Meta-train:
                # self.log.warning('Running meta-train episode...')
                self.meta_train_step(sess,test_data_config, on_policy_batch, off_policy_batch, rp_batch)

        except:
            msg = 'process() exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)
Example #25
0
class Worker(multiprocessing.Process):
    """
    Distributed tf worker class.

    Sets up environment, trainer and starts training process in supervised session.
    """
    env_list = None

    def __init__(self,
                 env_config,
                 policy_config,
                 trainer_config,
                 cluster_spec,
                 job_name,
                 task,
                 log_dir,
                 log_level,
                 max_env_steps,
                 random_seed=None,
                 render_last_env=False,
                 test_mode=False):
        """

        Args:
            env_config:         environment class_config_dict.
            policy_config:      model policy estimator class_config_dict.
            trainer_config:     algorithm class_config_dict.
            cluster_spec:       tf.cluster specification.
            job_name:           worker or parameter server.
            task:               integer number, 0 is chief worker.
            log_dir:            for tb summaries and checkpoints.
            log_level:          int, logbook.level
            max_env_steps:      number of environment steps to run training on
            random_seed:        int or None
            render_last_env:    bool, if True - render enabled for last environment in a list; first otherwise
            test_mode:          if True - use Atari mode, BTGym otherwise.

            Note:
                - Conventional `self.global_step` refers to number of environment steps,
                    summarized over all environment instances, not to number of policy optimizer train steps.

                - Every worker can run several environments in parralell, as specified by `cluster_config'['num_envs'].
                    If use 4 forkers and num_envs=4 => total number of environments is 16. Every env instance has
                    it's own ThreadRunner process.

                - When using replay memory, keep in mind that every ThreadRunner is keeping it's own replay memory,
                    If memory_size = 2000, num_workers=4, num_envs=4 => total replay memory size equals 32 000 frames.
        """
        super(Worker, self).__init__()
        self.env_class = env_config['class_ref']
        self.env_kwargs = env_config['kwargs']
        self.policy_config = policy_config
        self.trainer_class = trainer_config['class_ref']
        self.trainer_kwargs = trainer_config['kwargs']
        self.cluster_spec = cluster_spec
        self.job_name = job_name
        self.task = task
        self.log_dir = log_dir
        self.max_env_steps = max_env_steps
        self.log_level = log_level
        self.log = None
        self.test_mode = test_mode
        self.random_seed = random_seed
        self.render_last_env = render_last_env

    def run(self):
        """Worker runtime body.
        """
        # Logging:
        StreamHandler(sys.stdout).push_application()
        self.log = Logger('Worker_{}'.format(self.task), level=self.log_level)

        tf.reset_default_graph()

        if self.test_mode:
            import gym

        # Define cluster:
        cluster = tf.train.ClusterSpec(self.cluster_spec).as_cluster_def()

        # Start tf.server:
        if self.job_name in 'ps':
            server = tf.train.Server(
                cluster,
                job_name=self.job_name,
                task_index=self.task,
                config=tf.ConfigProto(device_filters=["/job:ps"])
            )
            self.log.debug('parameters_server started.')
            # Just block here:
            server.join()

        else:
            server = tf.train.Server(
                cluster,
                job_name='worker',
                task_index=self.task,
                config=tf.ConfigProto(
                    intra_op_parallelism_threads=1,  # original was: 1
                    inter_op_parallelism_threads=2  # original was: 2
                )
            )
            self.log.debug('tf.server started.')

            self.log.debug('making environments:')
            # Making as many environments as many entries in env_config `port` list:
            # TODO: Hacky-II: only one example over all parallel environments can be data-master [and renderer]
            # TODO: measure data_server lags, maybe launch several instances
            self.env_list = []
            env_kwargs = self.env_kwargs.copy()
            env_kwargs['log_level'] = self.log_level
            port_list = env_kwargs.pop('port')
            data_port_list = env_kwargs.pop('data_port')
            data_master = env_kwargs.pop('data_master')
            render_enabled = env_kwargs.pop('render_enabled')

            render_list = [False for entry in port_list]
            if render_enabled:
                if self.render_last_env:
                    render_list[-1] = True
                else:
                    render_list[0] = True

            data_master_list = [False for entry in port_list]
            if data_master:
                data_master_list[0] = True

            # Parallel envs. numbering:
            if len(port_list) > 1:
                task_id = 0.0
            else:
                task_id = 0

            for port, data_port, is_render, is_master in zip(port_list, data_port_list, render_list, data_master_list):
                # Get random seed for environments:
                env_kwargs['random_seed'] = random.randint(0, 2 ** 30)

                if not self.test_mode:
                    # Assume BTgym env. class:
                    self.log.debug('setting env at port_{} is data_master: {}'.format(port, data_master))
                    self.log.debug('env_kwargs:')
                    for k, v in env_kwargs.items():
                        self.log.debug('{}: {}'.format(k, v))
                    try:
                        self.env_list.append(
                            self.env_class(
                                port=port,
                                data_port=data_port,
                                data_master=is_master,
                                render_enabled=is_render,
                                task=self.task + task_id,
                                **env_kwargs
                            )
                        )
                        data_master = False
                        self.log.info('set BTGym environment {} @ port:{}, data_port:{}'.
                                      format(self.task + task_id, port, data_port))
                        task_id += 0.01

                    except:
                        self.log.exception(
                            'failed to make BTGym environment at port_{}.'.format(port)
                        )
                        raise RuntimeError

                else:
                    # Assume atari testing:
                    try:
                        self.env_list.append(self.env_class(env_kwargs['gym_id']))
                        self.log.debug('set Gyn/Atari environment.')

                    except:
                        self.log.exception('failed to make Gym/Atari environment')
                        raise RuntimeError

            self.log.debug('Defining trainer...')

            # Define trainer:
            trainer = self.trainer_class(
                env=self.env_list,
                task=self.task,
                policy_config=self.policy_config,
                log_level=self.log_level,
                cluster_spec=self.cluster_spec,
                random_seed=self.random_seed,
                **self.trainer_kwargs,
            )

            self.log.debug('trainer ok.')

            # Saver-related:
            variables_to_save = [v for v in tf.global_variables() if not 'local' in v.name]
            local_variables = [v for v in tf.global_variables() if 'local' in v.name] + tf.local_variables()
            init_op = tf.variables_initializer(variables_to_save)
            local_init_op = tf.variables_initializer(local_variables)
            init_all_op = tf.global_variables_initializer()

            saver = _FastSaver(variables_to_save)

            self.log.debug('VARIABLES TO SAVE:')
            for v in variables_to_save:
                self.log.debug('{}: {}'.format(v.name, v.get_shape()))

            def init_fn(ses):
                self.log.info("initializing all parameters.")
                ses.run(init_all_op)

            config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)])
            logdir = os.path.join(self.log_dir, 'train')
            summary_dir = logdir + "_{}".format(self.task)

            summary_writer = tf.summary.FileWriter(summary_dir)

            self.log.debug('before tf.train.Supervisor... ')

            # TODO: switch to tf.train.MonitoredTrainingSession
            sv = tf.train.Supervisor(
                is_chief=(self.task == 0),
                logdir=logdir,
                saver=saver,
                summary_op=None,
                init_op=init_op,
                local_init_op=local_init_op,
                init_fn=init_fn,
                #ready_op=tf.report_uninitialized_variables(variables_to_save),
                ready_op=tf.report_uninitialized_variables(),
                global_step=trainer.global_step,
                save_model_secs=300,
            )
            self.log.info("connecting to the parameter server... ")

            with sv.managed_session(server.target, config=config) as sess, sess.as_default():
                #sess.run(trainer.sync)
                trainer.start(sess, summary_writer)

                # Note: `self.global_step` refers to number of environment steps
                # summarized over all environment instances, not to number of policy optimizer train steps.
                global_step = sess.run(trainer.global_step)
                self.log.notice("started training at step: {}".format(global_step))

                while not sv.should_stop() and global_step < self.max_env_steps:
                    trainer.process(sess)
                    global_step = sess.run(trainer.global_step)

                # Ask for all the services to stop:
                for env in self.env_list:
                    env.close()

                sv.stop()
            self.log.notice('reached {} steps, exiting.'.format(global_step))
Example #26
0
class RunnerThread(threading.Thread):
    """
    Async. framework code comes from OpenAI repository under MIT licence:
    https://github.com/openai/universe-starter-agent

    Despite the fact BTgym is not real-time environment [yet], thread-runner approach is still here. From
    original `universe-starter-agent`:
    `...One of the key distinctions between a normal environment and a universe environment
    is that a universe environment is _real time_.  This means that there should be a thread
    that would constantly interact with the environment and tell it what to do.  This thread is here.`

    Another idea is to see ThreadRunner as all-in-one data provider, thus shaping data distribution
    fed to estimator from single place.
    So, replay memory is also here, as well as some service functions (collecting summary data).
    """
    def __init__(self,
                 env,
                 policy,
                 task,
                 rollout_length,
                 episode_summary_freq,
                 env_render_freq,
                 test,
                 ep_summary,
                 runner_fn_ref=BaseEnvRunnerFn,
                 memory_config=None,
                 log_level=WARNING,
                 **kwargs):
        """

        Args:
            env:                    environment instance
            policy:                 policy instance
            task:                   int
            rollout_length:         int
            episode_summary_freq:   int
            env_render_freq:        int
            test:                   Atari or BTGyn
            ep_summary:             tf.summary
            runner_fn_ref:          callable defining runner execution logic
            memory_config:          replay memory configuration dictionary
            log_level:              int, logbook.level
        """
        threading.Thread.__init__(self)
        self.queue = queue.Queue(5)
        self.rollout_length = rollout_length
        self.env = env
        self.last_features = None
        self.policy = policy
        self.runner_fn_ref = runner_fn_ref
        self.daemon = True
        self.sess = None
        self.summary_writer = None
        self.episode_summary_freq = episode_summary_freq
        self.env_render_freq = env_render_freq
        self.task = task
        self.test = test
        self.ep_summary = ep_summary
        self.memory_config = memory_config
        self.log_level = log_level
        StreamHandler(sys.stdout).push_application()
        self.log = Logger('ThreadRunner_{}'.format(self.task),
                          level=self.log_level)

    def start_runner(self, sess, summary_writer, **kwargs):
        try:
            self.sess = sess
            self.summary_writer = summary_writer
            self.start()

        except:
            msg = 'start() exception occurred.\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError

    def run(self):
        """Just keep running."""
        try:
            with self.sess.as_default():
                self._run()

        except:
            msg = 'RunTime exception occurred.\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError

    def _run(self):
        rollout_provider = self.runner_fn_ref(
            self.sess, self.env, self.policy, self.task, self.rollout_length,
            self.summary_writer, self.episode_summary_freq,
            self.env_render_freq, self.test, self.ep_summary,
            self.memory_config, self.log)
        while True:
            # the timeout variable exists because apparently, if one worker dies, the other workers
            # won't die with it, unless the timeout is set to some large number.  This is an empirical
            # observation.

            self.queue.put(next(rollout_provider), timeout=600.0)
Example #27
0
                if yield_trace:
                    yield line
                if not line and f.tell() == num_lines:
                    break

    def _parse_data(self, tracepoint, data):
        """
        Parse payload(data) for tracepoint - if we have it.
        """
        rv = data
        try:
            rv = PARSERS[tracepoint](data)
        except Exception, e:
            rv = PARSERS[tracepoint](data)
        except ParserError, e:
            log.exception(e)
            log.warn('Error parsing {tp} with {data}'.format(tp=tracepoint,
                                                             data=data))
        finally:
            return rv if rv else data

    def _check_tracer(self, line):
        """
        Return tracer (typically 'nop')
        """
        match = re.match(self._TRACER_PATTERN, line.strip())
        if match:
            return match.groupdict()['tracer']
        return None

    def _check_buffer_entries(self, line):
Example #28
0
class BTgymEnv(gym.Env):
    """
    OpenAI Gym API shell for Backtrader backtesting/trading library.
    """
    # Datafeed Server management:
    data_master = True
    data_network_address = 'tcp://127.0.0.1:'  # using localhost.
    data_port = 4999
    data_server = None
    data_server_pid = None
    data_context = None
    data_socket = None
    data_server_response = None

    # Dataset:
    dataset = None  # BTgymDataset instance.
    dataset_stat = None

    # Backtrader engine:
    engine = None  # bt.Cerbro subclass for server to execute.

    # Strategy:
    strategy = None  # strategy to use if no <engine> class been passed.

    # Server and network:
    server = None  # Server process.
    context = None  # ZMQ context.
    socket = None  # ZMQ socket, client side.
    port = 5500  # network port to use.
    network_address = 'tcp://127.0.0.1:'  # using localhost.
    ctrl_actions = ('_done', '_reset', '_stop', '_getstat', '_render'
                    )  # server control messages.
    server_response = None

    # Connection timeout:
    connect_timeout = 60  # server connection timeout in seconds.
    #connect_timeout_step = 0.01  # time between retries in seconds.

    # Rendering:
    render_enabled = True
    render_modes = [
        'human',
        'episode',
    ]
    # `episode` - plotted episode results.
    # `human` - raw_state observation in conventional human-readable format.
    #  <obs_space_key> - rendering of arbitrary state presented in observation_space with same key.

    renderer = None  # Rendering support.
    rendered_rgb = dict()  # Keep last rendered images for each mode.

    # Logging and id:
    log = None
    log_level = None  # logbook level: NOTICE, WARNING, INFO, DEBUG etc. or its integer equivalent;
    verbose = 0  # verbosity mode, valid only if no `log_level` arg has been provided:
    # 0 - WARNING, 1 - INFO, 2 - DEBUG.
    task = 0

    closed = True

    def __init__(self, **kwargs):
        """
        Keyword Args:

            filename=None (str, list):                      csv data file.
            **datafeed_args (any):                          any datafeed-related args, passed through to
                                                            default btgym.datafeed class.
            dataset=None (btgym.datafeed):                  BTgymDataDomain instance,
                                                            overrides `filename` or any other datafeed-related args.
            strategy=None (btgym.startegy):                 strategy to be used by `engine`, any subclass of
                                                            btgym.strategy.base.BTgymBaseStrateg
            engine=None (bt.Cerebro):                       environment simulation engine, any bt.Cerebro subclass,
                                                            overrides `strategy` arg.
            network_address=`tcp://127.0.0.1:` (str):       BTGym_server address.
            port=5500 (int):                                network port to use for server - API_shell communication.
            data_master=True (bool):                        let this environment control over data_server;
            data_network_address=`tcp://127.0.0.1:` (str):  data_server address.
            data_port=4999 (int):                           network port to use for server -- data_server communication.
            connect_timeout=60 (int):                       server connection timeout in seconds.
            render_enabled=True (bool):                     enable rendering for this environment;
            render_modes=['human', 'episode'] (list):       `episode` - plotted episode results;
                                                            `human` - raw_state observation.
            **render_args (any):                            any render-related args, passed through to renderer class.
            verbose=0 (int):                                verbosity mode, {0 - WARNING, 1 - INFO, 2 - DEBUG}
            log_level=None (int):                           logbook level {DEBUG=10, INFO=11, NOTICE=12, WARNING=13},
                                                            overrides `verbose` arg;
            log=None (logbook.Logger):                      external logbook logger,
                                                            overrides `log_level` and `verbose` args.
            task=0 (int):                                   environment id

        Environment kwargs applying logic::

            if <engine> kwarg is given:
                do not use default engine and strategy parameters;
                ignore <strategy> kwarg and all strategy and engine-related kwargs.

            else (no <engine>):
                use default engine parameters;
                if any engine-related kwarg is given:
                    override corresponding default parameter;

                if <strategy> is given:
                    do not use default strategy parameters;
                    if any strategy related kwarg is given:
                        override corresponding strategy parameter;

                else (no <strategy>):
                    use default strategy parameters;
                    if any strategy related kwarg is given:
                        override corresponding strategy parameter;

            if <dataset> kwarg is given:
                do not use default dataset parameters;
                ignore dataset related kwargs;

            else (no <dataset>):
                use default dataset parameters;
                    if  any dataset related kwarg is given:
                        override corresponding dataset parameter;

            If any <other> kwarg is given:
                override corresponding default parameter.
        """
        # Parameters and default values:
        self.params = dict(

            # Backtrader engine mandatory parameters:
            engine=dict(
                start_cash=10.0,  # initial trading capital.
                broker_commission=
                0.001,  # trade execution commission, default is 0.1% of operation value.
                fixed_stake=10,  # single trade stake is fixed type by def.
            ),
            # Dataset mandatory parameters:
            dataset=dict(filename=None, ),
            strategy=dict(state_shape=dict(), ),
            render=dict(),
        )
        p2 = dict(  # IS HERE FOR REFERENCE ONLY
            # Strategy related parameters:
            # Observation state shape is dictionary of Gym spaces,
            # at least should contain `raw_state` field.
            # By convention first dimension of every Gym Box space is time embedding one;
            # one can define any shape; should match env.observation_space.shape.
            # observation space state min/max values,
            # For `raw_state' - absolute min/max values from BTgymDataset will be used.
            state_shape=dict(raw_state=spaces.Box(
                shape=(10, 4), low=-100, high=100, dtype=np.float32)),
            drawdown_call=
            None,  # episode maximum drawdown threshold, default is 90% of initial value.
            portfolio_actions=None,
            # agent actions,
            # should consist with BTgymStrategy order execution logic;
            # defaults are: 0 - 'do nothing', 1 - 'buy', 2 - 'sell', 3 - 'close position'.
            skip_frame=None,
            # Number of environment steps to skip before returning next response,
            # e.g. if set to 10 -- agent will interact with environment every 10th episode step;
            # Every other step agent's action is assumed to be 'hold'.
            # Note: INFO part of environment response is a list of all skipped frame's info's,
            #       i.e. [info[-9], info[-8], ..., info[0].
        )
        # Update self attributes, remove used kwargs:
        for key in dir(self):
            if key in kwargs.keys():
                setattr(self, key, kwargs.pop(key))

        self.metadata = {'render.modes': self.render_modes}

        # Logging and verbosity control:
        if self.log is None:
            StreamHandler(sys.stdout).push_application()
            if self.log_level is None:
                log_levels = [(0, NOTICE), (1, INFO), (2, DEBUG)]
                self.log_level = WARNING
                for key, value in log_levels:
                    if key == self.verbose:
                        self.log_level = value
            self.log = Logger('BTgymAPIshell_{}'.format(self.task),
                              level=self.log_level)

        # Network parameters:
        self.network_address += str(self.port)
        self.data_network_address += str(self.data_port)

        # Set server rendering:
        if self.render_enabled:
            self.renderer = BTgymRendering(self.metadata['render.modes'],
                                           log_level=self.log_level,
                                           **kwargs)

        else:
            self.renderer = BTgymNullRendering()
            self.log.info(
                'Rendering disabled. Call to render() will return null-plug image.'
            )

        # Append logging:
        self.renderer.log = self.log

        # Update params -1: pull from renderer, remove used kwargs:
        self.params['render'].update(self.renderer.params)
        for key in self.params['render'].keys():
            if key in kwargs.keys():
                _ = kwargs.pop(key)

        if self.data_master:
            # DATASET preparation, only data_master executes this:
            #
            if self.dataset is not None:
                # If BTgymDataset instance has been passed:
                # do nothing.
                msg = 'Custom Dataset class used.'

            else:
                # If no BTgymDataset has been passed,
                # Make default dataset with given CSV file:
                try:
                    os.path.isfile(str(self.params['dataset']['filename']))

                except:
                    raise FileNotFoundError(
                        'Dataset source data file not specified/not found')

                # Use kwargs to instantiate dataset:
                self.dataset = BTgymDataset(**kwargs)
                msg = 'Base Dataset class used.'

            # Append logging:
            self.dataset.set_logger(self.log_level, self.task)

            # Update params -2: pull from dataset, remove used kwargs:
            self.params['dataset'].update(self.dataset.params)
            for key in self.params['dataset'].keys():
                if key in kwargs.keys():
                    _ = kwargs.pop(key)

            self.log.info(msg)

        # Connect/Start data server (and get dataset statistic):
        self.log.info('Connecting data_server...')
        self._start_data_server()
        self.log.info('...done.')
        # ENGINE preparation:

        # Update params -3: pull engine-related kwargs, remove used:
        for key in self.params['engine'].keys():
            if key in kwargs.keys():
                self.params['engine'][key] = kwargs.pop(key)

        if self.engine is not None:
            # If full-blown bt.Cerebro() subclass has been passed:
            # Update info:
            msg = 'Custom Cerebro class used.'
            self.strategy = msg
            for key in self.params['engine'].keys():
                self.params['engine'][key] = msg

        # Note: either way, bt.observers.DrawDown observer [and logger] will be added to any BTgymStrategy instance
        # by BTgymServer process at runtime.

        else:
            # Default configuration for Backtrader computational engine (Cerebro),
            # if no bt.Cerebro() custom subclass has been passed,
            # get base class Cerebro(), using kwargs on top of defaults:
            self.engine = bt.Cerebro()
            msg = 'Base Cerebro class used.'

            # First, set STRATEGY configuration:
            if self.strategy is not None:
                # If custom strategy has been passed:
                msg2 = 'Custom Strategy class used.'

            else:
                # Base class strategy :
                self.strategy = BTgymBaseStrategy
                msg2 = 'Base Strategy class used.'

            # Add, using kwargs on top of defaults:
            #self.log.debug('kwargs for strategy: {}'.format(kwargs))
            strat_idx = self.engine.addstrategy(self.strategy, **kwargs)

            msg += ' ' + msg2

            # Second, set Cerebro-level configuration:
            self.engine.broker.setcash(self.params['engine']['start_cash'])
            self.engine.broker.setcommission(
                self.params['engine']['broker_commission'])
            self.engine.addsizer(bt.sizers.SizerFix,
                                 stake=self.params['engine']['fixed_stake'])

        self.log.info(msg)

        # Define observation space shape, minimum / maximum values and agent action space.
        # Retrieve values from configured engine or...

        # ...Update params -4:
        # Pull strategy defaults to environment params dict :
        for t_key, t_value in self.engine.strats[0][0][0].params._gettuple():
            self.params['strategy'][t_key] = t_value

        # Update it with values from strategy 'passed-to params':
        for key, value in self.engine.strats[0][0][2].items():
            self.params['strategy'][key] = value

        # ... Push it all back (don't ask):
        for key, value in self.params['strategy'].items():
            self.engine.strats[0][0][2][key] = value

        # For 'raw_state' min/max values,
        # the only way is to infer from raw Dataset price values (we already got those from data_server):
        if 'raw_state' in self.params['strategy']['state_shape'].keys():
            # Exclude 'volume' from columns we count:
            self.dataset_columns.remove('volume')

            #print(self.params['strategy'])
            #print('self.engine.strats[0][0][2]:', self.engine.strats[0][0][2])
            #print('self.engine.strats[0][0][0].params:', self.engine.strats[0][0][0].params._gettuple())

            # Override with absolute price min and max values:
            self.params['strategy']['state_shape']['raw_state'].low =\
                self.engine.strats[0][0][2]['state_shape']['raw_state'].low =\
                np.zeros(self.params['strategy']['state_shape']['raw_state'].shape) +\
                self.dataset_stat.loc['min', self.dataset_columns].min()

            self.params['strategy']['state_shape']['raw_state'].high = \
                self.engine.strats[0][0][2]['state_shape']['raw_state'].high = \
                np.zeros(self.params['strategy']['state_shape']['raw_state'].shape) + \
                self.dataset_stat.loc['max', self.dataset_columns].max()

            self.log.info(
                'Inferring `state_raw` high/low values form dataset: {:.6f} / {:.6f}.'
                .format(
                    self.dataset_stat.loc['min', self.dataset_columns].min(),
                    self.dataset_stat.loc['max', self.dataset_columns].max()))

        # Set observation space shape from engine/strategy parameters:
        self.observation_space = DictSpace(
            self.params['strategy']['state_shape'])

        self.log.debug('Obs. shape: {}'.format(self.observation_space.spaces))
        #self.log.debug('Obs. min:\n{}\nmax:\n{}'.format(self.observation_space.low, self.observation_space.high))

        # Set action space and corresponding server messages:
        self.action_space = spaces.Discrete(
            len(self.params['strategy']['portfolio_actions']))
        self.server_actions = self.params['strategy']['portfolio_actions']

        # Finally:
        self.server_response = None
        self.env_response = None

        #if not self.data_master:
        self._start_server()
        self.closed = False

        self.log.info('Environment is ready.')

    def _seed(self, seed=None):
        """
        Sets env. random seed.

        Args:
            seed:   int or None
        """
        np.random.seed(seed)

    @staticmethod
    def _comm_with_timeout(
        socket,
        message,
    ):
        """
        Exchanges messages via socket, timeout sensitive.

        Args:
            socket: zmq connected socket to communicate via;
            message: message to send;

        Note:
            socket zmq.RCVTIMEO and zmq.SNDTIMEO should be set to some finite number of milliseconds.

        Returns:
            dictionary:
                `status`: communication result;
                `message`: received message if status == `ok` or None;
                `time`: remote side response time.
        """
        response = dict(
            status='ok',
            message=None,
        )
        try:
            socket.send_pyobj(message)

        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                response['status'] = 'send_failed_due_to_connect_timeout'

            else:
                response['status'] = 'send_failed_for_unknown_reason'
            return response

        start = time.time()
        try:
            response['message'] = socket.recv_pyobj()
            response['time'] = time.time() - start

        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                response['status'] = 'receive_failed_due_to_connect_timeout'

            else:
                response['status'] = 'receive_failed_for_unknown_reason'
            return response

        return response

    def _start_server(self):
        """
        Configures backtrader REQ/REP server instance and starts server process.
        """

        # Ensure network resources:
        # 1. Release client-side, if any:
        if self.context:
            self.context.destroy()
            self.socket = None

        # 2. Kill any process using server port:
        cmd = "kill $( lsof -i:{} -t ) > /dev/null 2>&1".format(self.port)
        os.system(cmd)

        # Set up client channel:
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REQ)
        self.socket.setsockopt(zmq.RCVTIMEO, self.connect_timeout * 1000)
        self.socket.setsockopt(zmq.SNDTIMEO, self.connect_timeout * 1000)
        self.socket.connect(self.network_address)

        # Configure and start server:
        self.server = BTgymServer(
            cerebro=self.engine,
            render=self.renderer,
            network_address=self.network_address,
            data_network_address=self.data_network_address,
            connect_timeout=self.connect_timeout,
            log_level=self.log_level,
            task=self.task,
        )
        self.server.daemon = False
        self.server.start()
        # Wait for server to startup:
        time.sleep(1)

        # Check connection:
        self.log.info('Server started, pinging {} ...'.format(
            self.network_address))

        self.server_response = self._comm_with_timeout(
            socket=self.socket, message={'ctrl': 'ping!'})
        if self.server_response['status'] in 'ok':
            self.log.info('Server seems ready with response: <{}>'.format(
                self.server_response['message']))

        else:
            msg = 'Server unreachable with status: <{}>.'.format(
                self.server_response['status'])
            self.log.error(msg)
            raise ConnectionError(msg)

        self._closed = False

    def _stop_server(self):
        """
        Stops BT server process, releases network resources.
        """
        if self.server:

            if self._force_control_mode():
                # In case server is running and client side is ok:
                self.socket.send_pyobj({'ctrl': '_stop'})
                self.server_response = self.socket.recv_pyobj()

            else:
                self.server.terminate()
                self.server.join()
                self.server_response = 'Server process terminated.'

            self.log.info('{} Exit code: {}'.format(self.server_response,
                                                    self.server.exitcode))

        # Release client-side, if any:
        if self.context:
            self.context.destroy()
            self.socket = None

    def _force_control_mode(self):
        """Puts BT server to control mode.
        """
        # Check is there any faults with server process and connection?
        network_error = [
            (not self.server or not self.server.is_alive(),
             'No running server found. Hint: forgot to call reset()?'),
            (not self.context
             or self.context.closed, 'No network connection found.'),
        ]
        for (err, msg) in network_error:
            if err:
                self.log.info(msg)
                self.server_response = msg
                return False

            # If everything works, insist to go 'control':
            self.server_response = {}
            attempt = 0

            while 'ctrl' not in self.server_response:
                self.socket.send_pyobj({'ctrl': '_done'})
                self.server_response = self.socket.recv_pyobj()
                attempt += 1
                self.log.debug(
                    'FORCE CONTROL MODE attempt: {}.\nResponse: {}'.format(
                        attempt, self.server_response))

            return True

    def _assert_response(self, response):
        """
        Simple watcher:
        roughly checks if we really talking to environment (== episode is running).
        Rises exception if response given is not as expected.
        """
        try:
            assert type(response) == tuple and len(response) == 4

        except AssertionError:
            msg = 'Unexpected environment response: {}\nHint: Forgot to call reset() or reset_data()?'.format(
                response)
            self.log.exception(msg)
            raise AssertionError(msg)

        self.log.debug('Response checker received:\n{}\nas type: {}'.format(
            response, type(response)))

    def _print_space(self, space, _tab=''):
        """
        Parses observation space shape or response.

        Args:
            space: gym observation space or state.

        Returns:
            description as string.
        """
        response = ''
        if type(space) in [dict, OrderedDict]:
            for key, value in space.items():
                response += '\n{}{}:{}\n'.format(
                    _tab, key, self._print_space(value, '   '))

        elif type(space) in [spaces.Dict, DictSpace]:
            for s in space.spaces:
                response += self._print_space(s, '   ')

        elif type(space) in [tuple, list]:
            for i in space:
                response += self._print_space(i, '   ')

        elif type(space) == np.ndarray:
            response += '\n{}array of shape: {}, low: {}, high: {}'.format(
                _tab, space.shape, space.min(), space.max())

        else:
            response += '\n{}{}, '.format(_tab, space)
            try:
                response += 'low: {}, high: {}'.format(space.low.min(),
                                                       space.high.max())

            except (KeyError, AttributeError, ArithmeticError,
                    ValueError) as e:
                pass
                #response += '\n{}'.format(e)

        return response

    def reset(self, **kwargs):
        """
        Implementation of OpenAI Gym env.reset method. Starts new episode. Episode data are sampled
        according to data provider class logic, controlled via kwargs. Refer `BTgym_Server` and data provider
        classes for details.

        Args:
            kwargs:         any kwargs; this dictionary is passed through to BTgym_server side without any checks and
                            modifications; currently used for data sampling control;

        Returns:
            observation space state

        Notes:
            Current kwargs accepted is::


                episode_config=dict(
                    get_new=True,
                    sample_type=0,
                    b_alpha=1,
                    b_beta=1
                ),
                trial_config=dict(
                    get_new=True,
                    sample_type=0,
                    b_alpha=1,
                    b_beta=1
                )

        """
        # Data Server check:
        if self.data_master:
            if not self.data_server or not self.data_server.is_alive():
                self.log.info('No running data_server found, starting...')
                self._start_data_server()

            # Domain dataset status check:
            self.data_server_response = self._comm_with_timeout(
                socket=self.data_socket, message={'ctrl': '_get_info'})
            if not self.data_server_response['message']['dataset_is_ready']:
                self.log.info(
                    'Data domain `reset()` called prior to `reset_data()` with [possibly inconsistent] defaults.'
                )
                self.reset_data()

        # Server process check:
        if not self.server or not self.server.is_alive():
            self.log.info('No running server found, starting...')
            self._start_server()

        if self._force_control_mode():
            self.server_response = self._comm_with_timeout(socket=self.socket,
                                                           message={
                                                               'ctrl':
                                                               '_reset',
                                                               'kwargs': kwargs
                                                           })
            # Get initial environment response:
            self.env_response = self.step(0)

            # Check (once) if it is really (o,r,d,i) tuple:
            self._assert_response(self.env_response)

            # Check (once) if state_space is as expected:
            try:
                #assert self.observation_space.contains(self.env_response[0])
                pass

            except (AssertionError, AttributeError) as e:
                msg1 = self._print_space(self.observation_space.spaces)
                msg2 = self._print_space(self.env_response[0])
                msg3 = ''
                for step_info in self.env_response[-1]:
                    msg3 += '{}\n'.format(step_info)
                msg = ('\nState observation shape/range mismatch!\n' +
                       'Space set by env: \n{}\n' +
                       'Space returned by server: \n{}\n' +
                       'Full response:\n{}\n' + 'Reward: {}\n' + 'Done: {}\n' +
                       'Info:\n{}\n' +
                       'Hint: Wrong Strategy.get_state() parameters?').format(
                           msg1,
                           msg2,
                           self.env_response[0],
                           self.env_response[1],
                           self.env_response[2],
                           msg3,
                       )
                self.log.exception(msg)
                self._stop_server()
                raise AssertionError(msg)

            return self.env_response[0]  #["raw_state"][np.newaxis]

        else:
            msg = 'Something went wrong. env.reset() can not get response from server.'
            self.log.exception(msg)
            raise ChildProcessError(msg)

    def step(self, action):
        """
        Implementation of OpenAI Gym env.step() method.
        Makes a step in the environment.

        Args:
            action:     int, number representing action from env.action_space

        Returns:
            tuple (Observation, Reward, Info, Done)

        """
        # Are you in the list, ready to go and all that?
        if self.action_space.contains(action)\
            and not self._closed\
            and (self.socket is not None)\
            and not self.socket.closed:
            pass

        else:
            msg = ('\nAt least one of these is true:\n' +
                   'Action error: (space is {}, action sent is {}): {}\n' +
                   'Environment closed: {}\n' +
                   'Network error [socket doesnt exists or closed]: {}\n' +
                   'Hint: forgot to call reset()?').format(
                       self.action_space,
                       action,
                       not self.action_space.contains(action),
                       self._closed,
                       not self.socket or self.socket.closed,
                   )
            self.log.exception(msg)
            raise AssertionError(msg)

        # Send action to backtrader engine, receive environment response
        env_response = self._comm_with_timeout(
            socket=self.socket,
            message={'action': self.server_actions[action]})
        if not env_response['status'] in 'ok':
            msg = '.step(): server unreachable with status: <{}>.'.format(
                env_response['status'])
            self.log.error(msg)
            raise ConnectionError(msg)

        # self.env_response = env_response ['message']
        tempNew_state, tempReward, tempDone, tempInfo = env_response['message']
        tempNew_state = tempNew_state["raw_state"][np.newaxis]
        self.env_response = tempNew_state, tempReward, tempDone, tempInfo

        return self.env_response

    def close(self):
        """
        Implementation of OpenAI Gym env.close method.
        Puts BTgym server in Control Mode.
        """
        self.log.debug('close.call()')
        self._stop_server()
        self._stop_data_server()
        self.log.info('Environment closed.')

    def get_stat(self):
        """
        Returns last run episode statistics.

        Note:
            when invoked, forces running episode to terminate.
        """
        if self._force_control_mode():
            self.socket.send_pyobj({'ctrl': '_getstat'})
            return self.socket.recv_pyobj()

        else:
            return self.server_response

    def render(self, mode='other_mode', close=False):
        """
        Implementation of OpenAI Gym env.render method.
        Visualises current environment state.

        Args:
            `mode`:     str, any of these::

                            `human` - current state observation as price lines;
                            `episode` - plotted results of last completed episode.
                            [other_key] - corresponding to any custom observation space key
        """
        if close:
            return None

        if not self._closed\
            and self.socket\
            and not self.socket.closed:
            pass

        else:
            msg = ('\nCan'
                   't get renderings.'
                   '\nAt least one of these is true:\n' +
                   'Environment closed: {}\n' +
                   'Network error [socket doesnt exists or closed]: {}\n' +
                   'Hint: forgot to call reset()?').format(
                       self._closed,
                       not self.socket or self.socket.closed,
                   )
            self.log.warning(msg)
            return None
        if mode not in self.render_modes:
            raise ValueError('Unexpected render mode {}'.format(mode))
        self.socket.send_pyobj({'ctrl': '_render', 'mode': mode})

        rgb_array_dict = self.socket.recv_pyobj()

        self.rendered_rgb.update(rgb_array_dict)

        return self.rendered_rgb[mode]

    def _stop(self):
        """
        Finishes current episode if any, does nothing otherwise. Leaves server running.
        """
        if self._force_control_mode():
            self.log.info('Episode stop forced.')

    def _restart_server(self):
        """Restarts server.
        """
        self._stop_server()
        self._start_server()
        self.log.info('Server restarted.')

    def _start_data_server(self):
        """
        For data_master environment:
            - configures backtrader REQ/REP server instance and starts server process.

        For others:
            - establishes network connection to existing data_server.
        """
        self.data_server = None

        # Ensure network resources:
        # 1. Release client-side, if any:
        if self.data_context:
            self.data_context.destroy()
            self.data_socket = None

        # Only data_master launches/stops data_server process:
        if self.data_master:
            # 2. Kill any process using server port:
            cmd = "kill $( lsof -i:{} -t ) > /dev/null 2>&1".format(
                self.data_port)
            os.system(cmd)

            # Configure and start server:
            self.data_server = BTgymDataFeedServer(
                dataset=self.dataset,
                network_address=self.data_network_address,
                log_level=self.log_level,
                task=self.task)
            self.data_server.daemon = False
            self.data_server.start()
            # Wait for server to startup
            time.sleep(1)

        # Set up client channel:
        self.data_context = zmq.Context()
        self.data_socket = self.data_context.socket(zmq.REQ)
        self.data_socket.setsockopt(zmq.RCVTIMEO, self.connect_timeout * 1000)
        self.data_socket.setsockopt(zmq.SNDTIMEO, self.connect_timeout * 1000)
        self.data_socket.connect(self.data_network_address)

        # Check connection:
        self.log.debug('Pinging data_server at: {} ...'.format(
            self.data_network_address))

        self.data_server_response = self._comm_with_timeout(
            socket=self.data_socket, message={'ctrl': 'ping!'})
        if self.data_server_response['status'] in 'ok':
            self.log.debug(
                'Data_server seems ready with response: <{}>'.format(
                    self.data_server_response['message']))

        else:
            msg = 'Data_server unreachable with status: <{}>.'.\
                format(self.data_server_response['status'])
            self.log.error(msg)
            raise ConnectionError(msg)

        # Get info and statistic:
        self.dataset_stat, self.dataset_columns, self.data_server_pid = self._get_dataset_info(
        )

    def _stop_data_server(self):
        """
        For data_master:
            - stops BT server process, releases network resources.
        """
        if self.data_master:
            if self.data_server is not None and self.data_server.is_alive():
                # In case server is running and is ok:
                self.data_socket.send_pyobj({'ctrl': '_stop'})
                self.data_server_response = self.data_socket.recv_pyobj()

            else:
                self.data_server.terminate()
                self.data_server.join()
                self.data_server_response = 'Data_server process terminated.'

            self.log.info('{} Exit code: {}'.format(self.data_server_response,
                                                    self.data_server.exitcode))

        if self.data_context:
            self.data_context.destroy()
            self.data_socket = None

    def _restart_data_server(self):
        """
        Restarts data_server.
        """
        if self.data_master:
            self._stop_data_server()
            self._start_data_server()

    def _get_dataset_info(self):
        """
        Retrieves dataset descriptive statistic.
        """
        self.data_socket.send_pyobj({'ctrl': '_get_info'})
        self.data_server_response = self.data_socket.recv_pyobj()

        return self.data_server_response['dataset_stat'],\
               self.data_server_response['dataset_columns'],\
               self.data_server_response['pid']

    def reset_data(self, **kwargs):
        """
        Resets data provider class used, whatever it means for that class. Gets data_server ready to provide data.
        Supposed to be called before first env.reset().

        Note:
            when invoked, forces running episode to terminate.

        Args:
            **kwargs:   data provider class .reset() method specific.
        """
        if self.closed:
            self._start_server()
            if self.data_master:
                self._start_data_server()
            self.closed = False

        else:
            _ = self._force_control_mode()

        if self.data_master:
            if self.data_server is None or not self.data_server.is_alive():
                self._restart_data_server()

            self.data_server_response = self._comm_with_timeout(
                socket=self.data_socket,
                message={
                    'ctrl': '_reset_data',
                    'kwargs': kwargs
                })
            if self.data_server_response['status'] in 'ok':
                self.log.debug(
                    'Dataset seems ready with response: <{}>'.format(
                        self.data_server_response['message']))

            else:
                msg = 'Data_server unreachable with status: <{}>.'. \
                    format(self.data_server_response['status'])
                self.log.error(msg)
                raise SystemExit(msg)

        else:
            pass
Example #29
0
class Worker(multiprocessing.Process):
    """
    Distributed tf worker class.

    Sets up environment, trainer and starts training process in supervised session.
    """
    env_list = None

    def __init__(self,
                 env_config,
                 policy_config,
                 trainer_config,
                 cluster_spec,
                 job_name,
                 task,
                 log_dir,
                 log_ckpt_subdir,
                 initial_ckpt_dir,
                 save_secs,
                 log_level,
                 max_env_steps,
                 random_seed=None,
                 render_last_env=True,
                 test_mode=False):
        """

        Args:
            env_config:             environment class_config_dict.
            policy_config:          model policy estimator class_config_dict.
            trainer_config:         algorithm class_config_dict.
            cluster_spec:           tf.cluster specification.
            job_name:               worker or parameter server.
            task:                   integer number, 0 is chief worker.
            log_dir:                path for tb summaries and current checkpoints.
            log_ckpt_subdir:        log_dir subdirectory to store current checkpoints
            initial_ckpt_dir:       path for checkpoint to load as pre-trained model.
            save_secs:              int, save model checkpoint every N secs.
            log_level:              int, logbook.level
            max_env_steps:          number of environment steps to run training on
            random_seed:            int or None
            render_last_env:        bool, if True and there is more than one environment specified for each worker,
                                    only allows rendering for last environment in a list;
                                    allows rendering for all environments of a chief worker otherwise;
            test_mode:              if True - use Atari mode, BTGym otherwise.

            Note:
                - Conventional `self.global_step` refers to number of environment steps,
                    summarized over all environment instances, not to number of policy optimizer train steps.

                - Every worker can run several environments in parralell, as specified by `cluster_config'['num_envs'].
                    If use 4 forkers and num_envs=4 => total number of environments is 16. Every env instance has
                    it's own ThreadRunner process.

                - When using replay memory, keep in mind that every ThreadRunner is keeping it's own replay memory,
                    If memory_size = 2000, num_workers=4, num_envs=4 => total replay memory size equals 32 000 frames.
        """
        super(Worker, self).__init__()
        self.env_class = env_config['class_ref']
        self.env_kwargs = env_config['kwargs']
        self.policy_config = policy_config
        self.trainer_class = trainer_config['class_ref']
        self.trainer_kwargs = trainer_config['kwargs']
        self.cluster_spec = cluster_spec
        self.job_name = job_name
        self.task = task
        self.is_chief = (self.task == 0)
        self.log_dir = log_dir
        self.save_secs = save_secs
        self.max_env_steps = max_env_steps
        self.log_level = log_level
        self.log = None
        self.test_mode = test_mode
        self.random_seed = random_seed
        self.render_last_env = render_last_env

        # Saver and summaries path:
        self.current_ckpt_dir = self.log_dir + log_ckpt_subdir
        self.initial_ckpt_dir = initial_ckpt_dir
        self.summary_dir = self.log_dir + '/worker_{}'.format(self.task)

        # print(log_ckpt_subdir)
        # print(self.log_dir)
        # print(self.current_ckpt_dir)
        # print(self.initial_ckpt_dir)
        # print(self.summary_dir)

        self.summary_writer = None
        self.config = None
        self.saver = None

    def _restore_model_params(self, sess, save_path):
        """
        Restores model parameters from specified location.

        Args:
            sess:       tf.Session obj.
            save_path:  path where parameters were previously saved.

        Returns: True if model has been successfully loaded, False otherwise.
        """
        if save_path is None:
            return False

        assert self.saver is not None, 'FastSaver has not been configured.'

        try:
            # Look for valid checkpoint:
            ckpt_state = tf.train.get_checkpoint_state(save_path)
            if ckpt_state is not None and ckpt_state.model_checkpoint_path:
                self.saver.restore(sess, ckpt_state.model_checkpoint_path)

            else:
                self.log.notice(
                    'no saved model parameters found in:\n{}'.format(
                        save_path))
                return False

        except (ValueError, tf.errors.NotFoundError,
                tf.errors.InvalidArgumentError) as e:
            self.log.notice(
                'failed to restore model parameters from:\n{}'.format(
                    save_path))
            return False

        return True

    def _save_model_params(self, sess, global_step):
        """
        Saves model checkpoint to predefined location.

        Args:
            sess:           tf.Session obj.
            global_step:    global step number is appended to save_path to create the checkpoint filenames
        """
        assert self.saver is not None, 'FastSaver has not been configured.'
        self.saver.save(sess,
                        save_path=self.current_ckpt_dir + '/model_parameters',
                        global_step=global_step)

    def run(self):
        """Worker runtime body.
        """
        # Logging:
        StreamHandler(sys.stdout).push_application()
        self.log = Logger('Worker_{}'.format(self.task), level=self.log_level)
        try:
            tf.reset_default_graph()

            if self.test_mode:
                import gym

            # Define cluster:
            cluster = tf.train.ClusterSpec(self.cluster_spec).as_cluster_def()

            # Start tf.server:
            if self.job_name in 'ps':
                server = tf.train.Server(
                    cluster,
                    job_name=self.job_name,
                    task_index=self.task,
                    config=tf.ConfigProto(device_filters=["/job:ps"]))
                self.log.debug('parameters_server started.')
                # Just block here:
                server.join()

            else:
                server = tf.train.Server(
                    cluster,
                    job_name='worker',
                    task_index=self.task,
                    config=tf.ConfigProto(
                        intra_op_parallelism_threads=4,  # original was: 1
                        inter_op_parallelism_threads=4,  # original was: 2
                    ))
                self.log.debug('tf.server started.')

                self.log.debug('making environments:')
                # Making as many environments as many entries in env_config `port` list:
                # TODO: Hacky-II: only one example over all parallel environments can be data-master [and renderer]
                # TODO: measure data_server lags, maybe launch several instances
                self.env_list = []
                env_kwargs = self.env_kwargs.copy()
                env_kwargs['log_level'] = self.log_level
                port_list = env_kwargs.pop('port')
                data_port_list = env_kwargs.pop('data_port')
                data_master = env_kwargs.pop('data_master')
                render_enabled = env_kwargs.pop('render_enabled')

                render_list = [False for entry in port_list]
                if render_enabled:
                    if self.render_last_env:
                        render_list[-1] = True
                    else:
                        render_list = [True for entry in port_list]
                        # render_list[0] = True

                data_master_list = [False for entry in port_list]
                if data_master:
                    data_master_list[0] = True

                # Parallel envs. numbering:
                if len(port_list) > 1:
                    task_id = 0.0
                else:
                    task_id = 0

                for port, data_port, is_render, is_master in zip(
                        port_list, data_port_list, render_list,
                        data_master_list):
                    # Get random seed for environments:
                    env_kwargs['random_seed'] = random.randint(0, 2**30)

                    if not self.test_mode:
                        # Assume BTgym env. class:
                        self.log.debug(
                            'setting env at port_{} is data_master: {}'.format(
                                port, data_master))
                        self.log.debug('env_kwargs:')
                        for k, v in env_kwargs.items():
                            self.log.debug('{}: {}'.format(k, v))
                        try:
                            self.env_list.append(
                                self.env_class(port=port,
                                               data_port=data_port,
                                               data_master=is_master,
                                               render_enabled=is_render,
                                               task=self.task + task_id,
                                               **env_kwargs))
                            data_master = False
                            self.log.info(
                                'set BTGym environment {} @ port:{}, data_port:{}'
                                .format(self.task + task_id, port, data_port))
                            task_id += 0.01

                        except Exception as e:
                            self.log.exception(
                                'failed to make BTGym environment at port_{}.'.
                                format(port))
                            raise e

                    else:
                        # Assume atari testing:
                        try:
                            self.env_list.append(
                                self.env_class(env_kwargs['gym_id']))
                            self.log.debug('set Gyn/Atari environment.')

                        except Exception as e:
                            self.log.exception(
                                'failed to make Gym/Atari environment')
                            raise e

                self.log.debug('Defining trainer...')

                # Define trainer:
                trainer = self.trainer_class(
                    env=self.env_list,
                    task=self.task,
                    policy_config=self.policy_config,
                    log_level=self.log_level,
                    cluster_spec=self.cluster_spec,
                    random_seed=self.random_seed,
                    **self.trainer_kwargs,
                )

                self.log.debug('trainer ok.')

                # Saver-related:
                variables_to_save = [
                    v for v in tf.global_variables() if not 'local' in v.name
                ]
                local_variables = [
                    v for v in tf.global_variables() if 'local' in v.name
                ] + tf.local_variables()
                init_op = tf.initializers.variables(variables_to_save)
                local_init_op = tf.initializers.variables(local_variables)
                init_all_op = tf.global_variables_initializer()

                def init_fn(_sess):
                    self.log.notice("initializing all parameters...")
                    _sess.run(init_all_op)

                # def init_fn_scaff(scaffold, _sess):
                #     self.log.notice("initializing all parameters...")
                #     _sess.run(init_all_op)

                # self.log.warning('VARIABLES TO SAVE:')
                # for v in variables_to_save:
                #     self.log.warning(v)
                #
                # self.log.warning('LOCAL VARS:')
                # for v in local_variables:
                #     self.log.warning(v)

                self.saver = FastSaver(var_list=variables_to_save,
                                       max_to_keep=1,
                                       save_relative_paths=True)

                self.config = tf.ConfigProto(device_filters=[
                    "/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)
                ])

                sess_manager = tf.train.SessionManager(
                    local_init_op=local_init_op,
                    ready_op=None,
                    ready_for_local_init_op=tf.report_uninitialized_variables(
                        variables_to_save),
                    graph=None,
                    recovery_wait_secs=90,
                )
                with sess_manager.prepare_session(
                        master=server.target,
                        init_op=init_op,
                        config=self.config,
                        init_fn=init_fn,
                ) as sess:

                    # Try to restore pre-trained model
                    pre_trained_restored = self._restore_model_params(
                        sess, self.initial_ckpt_dir)
                    _ = sess.run(trainer.reset_global_step)

                    if not pre_trained_restored:
                        # If not - try to recover current checkpoint:
                        current_restored = self._restore_model_params(
                            sess, self.current_ckpt_dir)

                    else:
                        current_restored = False

                    if not pre_trained_restored and not current_restored:
                        self.log.notice('training from scratch...')

                    self.log.info("connecting to the parameter server... ")

                    self.summary_writer = tf.summary.FileWriter(
                        self.summary_dir, sess.graph)
                    trainer.start(sess, self.summary_writer)

                    # Note: `self.global_step` refers to number of environment steps
                    # summarized over all environment instances, not to number of policy optimizer train steps.
                    global_step = sess.run(trainer.global_step)
                    self.log.notice(
                        "started training at step: {}".format(global_step))

                    last_saved_time = datetime.datetime.now()
                    last_saved_step = global_step

                    while global_step < self.max_env_steps:
                        trainer.process(sess)
                        global_step = sess.run(trainer.global_step)

                        time_delta = datetime.datetime.now() - last_saved_time
                        if self.is_chief and time_delta.total_seconds(
                        ) > self.save_secs:
                            self._save_model_params(sess, global_step)
                            train_speed = (global_step - last_saved_step) / (
                                time_delta.total_seconds() + 1)
                            self.log.notice(
                                'env. step: {}; cluster speed: {:.0f} step/sec; checkpoint saved.'
                                .format(global_step, train_speed))
                            last_saved_time = datetime.datetime.now()
                            last_saved_step = global_step

                # Ask for all the services to stop:
                for env in self.env_list:
                    env.close()

                self.log.notice(
                    'reached {} steps, exiting.'.format(global_step))

        except Exception as e:
            self.log.exception(e)
            raise e
Example #30
0
class AMLDG:
    """
    Train framework for combined model-based/model-free setup with non-parametric data model.
    Compensates model bias by jointly learning optimal policy for modelled data (generated trajectories) and
    real data model is based upon.
    Based on objective identical to one of MLDG algorithm (by Da Li et al.).

    This class is basically an AAC wrapper: it relies on two sub-AAC classes to make separate policy networks
    and training loops.

    Note that 'actor' and 'critic' names used here are not related to same named entities used in A3C and
    other actor-critic RL algorithms; it rather relevant to 'generator' and 'discriminator' terms used
    in adversarial training and mean that 'actor trainer' is optimising RL objective on synthetic data,
    generated by some model while 'critic trainer' tries to compensate model bias by optimizing same objective on
    real data model has been fitted with.

    Papers:
        Da Li et al.,
         "Learning to Generalize: Meta-Learning for Domain Generalization"
         https://arxiv.org/abs/1710.03463

        Maruan Al-Shedivat et al.,
        "Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments"
        https://arxiv.org/abs/1710.03641

    """
    def __init__(self,
                 env,
                 task,
                 log_level,
                 aac_class_ref=OUpAAC,
                 runner_config=None,
                 opt_decay_steps=None,
                 opt_end_learn_rate=None,
                 opt_learn_rate=1e-4,
                 opt_max_env_steps=10**7,
                 aac_lambda=1.0,
                 guided_lambda=0.0,
                 tau=.5,
                 rollout_length=20,
                 train_phase=True,
                 name='TrainAMLDG',
                 **kwargs):
        try:
            self.aac_class_ref = aac_class_ref
            self.task = task
            self.name = name
            self.summary_writer = None
            self.train_phase = train_phase
            self.tau = tau

            self.opt_learn_rate = opt_learn_rate
            self.opt_max_env_steps = opt_max_env_steps

            if opt_end_learn_rate is None:
                self.opt_end_learn_rate = self.opt_learn_rate
            else:
                self.opt_end_learn_rate = opt_end_learn_rate

            if opt_decay_steps is None:
                self.opt_decay_steps = self.opt_max_env_steps
            else:
                self.opt_decay_steps = opt_decay_steps

            StreamHandler(sys.stdout).push_application()
            self.log = Logger('{}_{}'.format(name, task), level=log_level)
            self.rollout_length = rollout_length

            if runner_config is None:
                self.runner_config = {
                    'class_ref': OUpRunner,
                    'kwargs': {},
                }
            else:
                self.runner_config = runner_config

            self.env_list = env

            assert isinstance(self.env_list, list) and len(self.env_list) == 2, \
                'Expected pair of environments, got: {}'.format(self.env_list)

            # Instantiate two sub-trainers: one for training on modeled data (actor, or generator) and one
            # for training on real data (critic, or discriminator):
            self.runner_config['kwargs'] = {
                'data_sample_config': {
                    'mode': 0
                },  # synthetic train data
                'name': 'actor',
                'test_deterministic': not self.train_phase,
            }
            self.actor_aac = aac_class_ref(
                env=self.env_list[-1],  # test data -> slave env.
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                opt_learn_rate=self.opt_learn_rate,
                opt_max_env_steps=self.opt_max_env_steps,
                opt_end_learn_rate=self.opt_end_learn_rate,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                rollout_length=self.rollout_length,
                episode_train_test_cycle=(1, 0) if self.train_phase else
                (0, 1),
                _use_target_policy=False,
                _use_global_network=True,
                name=self.name + '/actor',
                **kwargs)
            # Change for critic:
            self.runner_config['kwargs'] = {
                'data_sample_config': {
                    'mode': 1
                },  # real train data
                'name': 'critic',
                'test_deterministic': not self.
                train_phase,  # enable train exploration on [formally] test data
            }
            self.critic_aac = aac_class_ref(
                env=self.
                env_list[0],  # real train data will be master environment
                task=self.task,
                log_level=log_level,
                runner_config=self.runner_config,
                opt_learn_rate=self.opt_learn_rate,
                opt_max_env_steps=self.opt_max_env_steps,
                opt_end_learn_rate=self.opt_end_learn_rate,
                aac_lambda=aac_lambda,
                guided_lambda=guided_lambda,
                rollout_length=self.rollout_length,
                episode_train_test_cycle=(0, 1),  # always real
                _use_target_policy=False,
                _use_global_network=False,
                global_step_op=self.actor_aac.global_step,
                global_episode_op=self.actor_aac.global_episode,
                inc_episode_op=self.actor_aac.inc_episode,
                name=self.name + '/critic',
                **kwargs)

            self.local_steps = self.critic_aac.local_steps
            self.model_summary_freq = self.critic_aac.model_summary_freq

            self._make_train_op()

        except Exception as e:
            msg = 'AMLDG.__init()__ exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise e

    def _make_train_op(self):
        """
        Defines tensors holding training ops.
        """
        # Handy aliases:
        pi_critic = self.critic_aac.local_network  # local critic policy
        pi_actor = self.actor_aac.local_network  # local actor policy
        pi_global = self.actor_aac.network  # global shared policy

        # From local actor to local critic:
        self.critic_aac.sync_pi = tf.group(*[
            v1.assign(v2)
            for v1, v2 in zip(pi_critic.var_list, pi_actor.var_list)
        ])

        # Inherited counters:
        self.global_step = self.actor_aac.global_step
        self.global_episode = self.actor_aac.global_episode
        self.inc_episode = self.actor_aac.inc_episode
        self.reset_global_step = self.actor_aac.reset_global_step

        # Clipped gradients for critic (critic's train op is disabled by `_use_global_network=False`
        # to avoid actor's name scope violation):
        self.critic_aac.grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.critic_aac.loss, pi_critic.var_list), 40.0)
        # Placeholders for stored gradients values, include None's to correctly map Vars:
        self.actor_aac.grads_placeholders = [
            tf.placeholder(shape=grad.shape, dtype=grad.dtype)
            if grad is not None else None for grad in self.actor_aac.grads
        ]
        self.critic_aac.grads_placeholders = [
            tf.placeholder(shape=grad.shape, dtype=grad.dtype)
            if grad is not None else None for grad in self.critic_aac.grads
        ]

        # Gradients to update local critic policy with stored actor's gradients:
        critic_grads_and_vars = list(
            zip(self.actor_aac.grads_placeholders, pi_critic.var_list))

        # Final gradients to be sent to parameter server:
        self.grads = [
            self.tau * g1 +
            (1 - self.tau) * g2 if g1 is not None and g2 is not None else None
            for g1, g2 in zip(self.actor_aac.grads_placeholders,
                              self.critic_aac.grads_placeholders)
        ]
        global_grads_and_vars = list(zip(self.grads, pi_global.var_list))

        # debug_global_grads_and_vars = list(zip(self.actor_aac.grads_placeholders, pi_global.var_list))
        # debug_global_grads_and_vars = [(g, v) for (g, v) in debug_global_grads_and_vars if g is not None]

        # Remove None entries:
        global_grads_and_vars = [(g, v) for (g, v) in global_grads_and_vars
                                 if g is not None]
        critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
                                 if g is not None]
        self.actor_aac.grads = [
            g for g in self.actor_aac.grads if g is not None
        ]
        self.critic_aac.grads = [
            g for g in self.critic_aac.grads if g is not None
        ]
        self.actor_aac.grads_placeholders = [
            pl for pl in self.actor_aac.grads_placeholders if pl is not None
        ]
        self.critic_aac.grads_placeholders = [
            pl for pl in self.critic_aac.grads_placeholders if pl is not None
        ]

        self.inc_step = self.actor_aac.inc_step

        # Op to update critic with gradients from actor:
        self.critic_aac.optimizer = tf.train.AdamOptimizer(
            self.actor_aac.learn_rate_decayed, epsilon=1e-5)
        self.update_critic_op = self.critic_aac.optimizer.apply_gradients(
            critic_grads_and_vars)

        # Use actor optimizer to update global policy instance:
        self.train_op = self.actor_aac.optimizer.apply_gradients(
            global_grads_and_vars)

        self.log.debug('all_train_ops defined')

    def start(self, sess, summary_writer, **kwargs):
        """
        Executes all initializing operations,
        starts environment runner[s].
        Supposed to be called by parent worker just before training loop starts.

        Args:
            sess:           tf session object.
            kwargs:         not used by default.
        """
        try:
            # Copy weights from global to local:
            sess.run(self.critic_aac.sync_pi)
            sess.run(self.actor_aac.sync_pi)

            # Start thread_runners:
            self.critic_aac._start_runners(  # master first
                sess,
                summary_writer,
                init_context=None,
                data_sample_config=self.critic_aac.get_sample_config(mode=1))
            self.actor_aac._start_runners(
                sess,
                summary_writer,
                init_context=None,
                data_sample_config=self.actor_aac.get_sample_config(mode=0))

            self.summary_writer = summary_writer
            self.log.notice('Runners started.')

        except:
            msg = 'start() exception occurred' + \
                '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise RuntimeError(msg)

    def process(self, sess):
        if self.train_phase:
            self.process_train(sess)

        else:
            self.process_test(sess)

    def process_test(self, sess):
        """
        Evaluation loop.
        Args:
            sess (tensorflow.Session):   tf session obj.
        """
        try:
            # sess.run(self.critic_aac.sync_pi)
            # sess.run(self.actor_aac.sync_pi)

            actor_data = self.actor_aac.get_data()
            critic_data = self.critic_aac.get_data()

            # Write down summaries:
            self.actor_aac.process_summary(sess, actor_data)
            self.critic_aac.process_summary(sess, critic_data)
            self.local_steps += 1

        except Exception as e:
            msg = 'process_test() exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise e

    def process_train(self, sess):
        """
        Train procedure.

        Args:
            sess (tensorflow.Session):   tf session obj.

        """
        try:
            # Copy from parameter server:
            sess.run(self.critic_aac.sync_pi)
            sess.run(self.actor_aac.sync_pi)
            # self.log.warning('Train Sync ok.')

            # Get data configuration (redundant):
            actor_data_config = {
                'episode_config': {
                    'get_new': 1,
                    'sample_type': 0,
                    'b_alpha': 1.0,
                    'b_beta': 1.0
                },
                'trial_config': {
                    'get_new': 1,
                    'sample_type': 0,
                    'b_alpha': 1.0,
                    'b_beta': 1.0
                }
            }
            critic_data_config = {
                'episode_config': {
                    'get_new': 1,
                    'sample_type': 1,
                    'b_alpha': 1.0,
                    'b_beta': 1.0
                },
                'trial_config': {
                    'get_new': 1,
                    'sample_type': 1,
                    'b_alpha': 1.0,
                    'b_beta': 1.0
                }
            }

            # self.log.warning('actor_data_config: {}'.format(actor_data_config))
            # self.log.warning('critic_data_config: {}'.format(critic_data_config))

            # Collect synthetic train trajectory rollout:
            actor_data = self.actor_aac.get_data(
                data_sample_config=actor_data_config)
            actor_feed_dict = self.actor_aac.process_data(
                sess,
                actor_data,
                is_train=True,
                pi=self.actor_aac.local_network)

            # self.log.warning('Actor data ok.')

            wirte_model_summary = \
                self.local_steps % self.model_summary_freq == 0

            # Get gradients from actor:
            if wirte_model_summary:
                actor_fetches = [
                    self.actor_aac.grads, self.inc_step,
                    self.actor_aac.model_summary_op
                ]

            else:
                actor_fetches = [self.actor_aac.grads, self.inc_step]

            # self.log.warning('self.actor_aac.grads: \n{}'.format(self.actor_aac.grads))
            # self.log.warning('self.actor_aac.model_summary_op: \n{}'.format(self.actor_aac.model_summary_op))

            actor_fetched = sess.run(actor_fetches, feed_dict=actor_feed_dict)
            actor_grads_values = actor_fetched[0]
            # self.log.warning('Actor gradients ok.')

            # Start preparing gradients feeder:
            grads_feed_dict = {
                self.actor_aac.local_network.train_phase: True,
                self.critic_aac.local_network.train_phase: True,
            }
            grads_feed_dict.update({
                pl: value
                for pl, value in zip(self.actor_aac.grads_placeholders,
                                     actor_grads_values)
            })

            # Update critic with gradients collected from generated data:
            sess.run(self.update_critic_op, feed_dict=grads_feed_dict)
            # self.log.warning('Critic update ok.')

            # Collect real train trajectory rollout using updated critic policy:
            critic_data = self.critic_aac.get_data(
                data_sample_config=critic_data_config)
            critic_feed_dict = self.critic_aac.process_data(
                sess,
                critic_data,
                is_train=True,
                pi=self.critic_aac.local_network)
            # self.log.warning('Critic data ok.')

            # Get gradients from critic:
            if wirte_model_summary:
                critic_fetches = [
                    self.critic_aac.grads, self.critic_aac.model_summary_op
                ]

            else:
                critic_fetches = [self.critic_aac.grads]

            critic_fetched = sess.run(critic_fetches,
                                      feed_dict=critic_feed_dict)
            critic_grads_values = critic_fetched[0]
            # self.log.warning('Critic gradients ok.')

            # Update gradients feeder with critic's:
            grads_feed_dict.update({
                pl: value
                for pl, value in zip(self.critic_aac.grads_placeholders,
                                     critic_grads_values)
            })

            # Finally send combined gradients update to parameters server:
            sess.run([self.train_op], feed_dict=grads_feed_dict)
            # sess.run([self.actor_aac.train_op], feed_dict=actor_feed_dict)

            # self.log.warning('Final gradients ok.')

            if wirte_model_summary:
                critic_model_summary = critic_fetched[-1]
                actor_model_summary = actor_fetched[-1]

            else:
                critic_model_summary = None
                actor_model_summary = None

            # Write down summaries:
            self.actor_aac.process_summary(sess, actor_data,
                                           actor_model_summary)
            self.critic_aac.process_summary(sess, critic_data,
                                            critic_model_summary)
            self.local_steps += 1

        except Exception as e:
            msg = 'process_train() exception occurred' + \
                  '\n\nPress `Ctrl-C` or jupyter:[Kernel]->[Interrupt] for clean exit.\n'
            self.log.exception(msg)
            raise e
Example #31
0
import configparser
import os
import sys

# initialise the logger
from logbook import Logger
log = Logger("CoreConfig")

# try loading in the configuration
CONFIG_PATH = os.path.expanduser("~/.propagator/propagator.cfg")
try:
    CONFIG_CFGP = configparser.ConfigParser()
    CONFIG_CFGP.read(CONFIG_PATH)
except:
    log.exception()
    sys.exit(1)

# exported log sections
try:
    config_general = CONFIG_CFGP["general"]
except KeyError:
    config_general = {}

try:
    config_smtp = CONFIG_CFGP["smtp"]
except KeyError:
    config_smtp = {}

try:
    config_amqp = CONFIG_CFGP["amqp"]
Example #32
0
    _thread = QThread()
    _vpl = VNGameProcessListener()

    _vpl.vngame_exe_finished.connect(_app.exit)
    _thread.started.connect(_vpl.listen)

    _vpl.moveToThread(_thread)
    _thread.start()

    _gui.show()
    logger.info("GUI init done")

    try:
        logger.info("cwd='{cwd}'", cwd=os.getcwd())
        main()
    except Exception as _e:
        _extra = (f"\r\nCheck '{SCRIPT_LOG_PATH}' for more information."
                  if SCRIPT_LOG_PATH.exists() else "")
        _msg = (f"Error launching Winter War!\r\n\r\n"
                f"{type(_e).__name__}: {_e}\r\n"
                f"{_extra}\r\n")
        _gui.warn("Error", _msg)
        # noinspection PyBroadException
        try:
            logger.exception("error running script")
            _vpl.stop()
        except Exception:
            pass

    sys.exit(_app.exec_())
Example #33
0
class AdLoader(object):
    def __init__(self,
                 index,
                 hosts=None,
                 distance_cutoff=0.38,
                 logger=None,
                 exceptions_to_reraise=None):
        self._iss = ImageSignatureService()
        self._es = Elasticsearch(hosts=hosts)
        self._aes = AdES(self._es,
                         index=index,
                         distance_cutoff=distance_cutoff)

        if not logger:
            self.logger = Logger(self.__class__.__name__)

        else:
            self.logger = logger

        if not exceptions_to_reraise:
            self.exceptions_to_reraise = tuple()

        else:
            self.exceptions_to_reraise = tuple(exceptions_to_reraise)

        # Ensure the index to be used exists.
        self.create_index()

        self.num_images_inserted = 0
        self.num_images_updated = 0
        self.num_images_errored = 0

    def create_index(self):
        self._es.indices.create(self._aes.index, ignore=400)

    def delete_index(self):
        self._es.indices.delete(index=self._aes.index)

    def wipe_index(self):
        self.delete_index()
        self.create_index()

    def refresh_index(self):
        self._es.indices.refresh(index=self._aes.index)

    def _add_image_to_index(self, image_signature, image, image_url,
                            source_url, email, age, gender, interests):
        source = {
            SOURCE_URL_KEY: source_url,
            SOURCE_DOMAIN_KEY: furl(source_url).netloc,
            SOURCE_EMAIL_KEY: email,
            SOURCE_AGE_KEY: age,
            SOURCE_GENDER_KEY: gender,
            SOURCE_INTERESTS_KEY: interests
        }
        _id = hashlib.sha512(image_signature.encode('utf-8')).hexdigest()
        existing_document = None
        try:
            get_result = self._es.get(index=self._aes.index,
                                      doc_type=self._aes.doc_type,
                                      id=_id)
            existing_document = get_result['_source']

        except NotFoundError:
            pass

        if existing_document:
            existing_sources = existing_document['metadata'][
                METADATA_SOURCES_KEY]
            existing_sources.append(source)
            self._es.update(index=self._aes.index,
                            doc_type=self._aes.doc_type,
                            id=_id,
                            body={'doc': existing_document})
            self.num_images_updated += 1

        else:
            metadata = {
                METADATA_SOURCES_KEY: [source],
                METADATA_IMAGE_KEY: image
            }
            self._aes.add_image_signature_base64(_id,
                                                 image_signature,
                                                 path=image_url,
                                                 metadata=metadata)
            self.num_images_inserted += 1

    def _add_image(self,
                   image_signature,
                   image,
                   image_url,
                   source_url,
                   email,
                   age,
                   gender,
                   interests,
                   retry_num=0):
        try:
            self._add_image_to_index(image_signature, image, image_url,
                                     source_url, email, age, gender, interests)

        except ConflictError:
            if retry_num >= MAX_CONFLICT_RETRIES:
                self.logger.exception()
                self.num_images_errored += 1

            else:
                self.logger.warning(
                    ADD_IMAGE_URL_CONFLICT_MSG_FORMAT.format(image_url))
                retry_num += 1
                self._add_image(image_signature,
                                image,
                                image_url,
                                source_url,
                                email,
                                age,
                                gender,
                                interests,
                                retry_num=retry_num)

    def add_image_url(self, image_url, source_url, email, age, gender,
                      interests):
        try:
            self.logger.debug(ADD_IMAGE_URL_MSG_FORMAT.format(image_url))
            image_signature, image = self._iss.get_image_signature_from_url(
                image_url)
            if image_signature:
                self._add_image(image_signature, image, image_url, source_url,
                                email, age, gender, interests)

            else:
                self.logger.warning(
                    ADD_IMAGE_URL_NO_SIGNATURE_MSG_FORMAT.format(image_url))

        except self.exceptions_to_reraise:
            raise

        except Exception:
            self.logger.exception()
            self.num_images_errored += 1

    def add_image_bytes(self, image_bytes, source_url, email, age, gender,
                        interests):
        try:
            self.logger.debug(ADD_IMAGE_URL_MSG_FORMAT.format('from bytes'))
            image_signature, image = self._iss.get_image_signature_from_bytes(
                image_bytes)
            if image_signature:
                self._add_image(image_signature, image, image_signature,
                                source_url, email, age, gender, interests)

            else:
                self.logger.warning(
                    ADD_IMAGE_URL_NO_SIGNATURE_MSG_FORMAT.format('from bytes'))

        except self.exceptions_to_reraise:
            raise

        except Exception:
            self.logger.exception()
            self.num_images_errored += 1

    def get_image_match_by_image_url(self, image_url):
        return self._aes.search_image(image_url)

    def get_image_match_by_image_signature_base64(self, image_signature):
        return self._aes.search_image_signature_base64(image_signature)