コード例 #1
0
    def __init__(self,
                 name,
                 address,
                 label,
                 password,
                 env_id=None,
                 seed=None,
                 fps=60,
                 start_timeout=None,
                 observer=False,
                 skip_network_calibration=False):
        self.name = name
        self.address = address
        self.label = label
        self.password = password
        self.env_id = env_id
        self.seed = None
        self.fps = fps
        self.start_timeout = start_timeout
        self.observer = observer
        self.skip_network_calibration = skip_network_calibration

        self.reward_buffer = reward_buffer.RewardBuffer(label=self.label)

        factory = websocket.WebSocketClientFactory('ws://' + address)
        factory.reward_buffer = self.reward_buffer
        factory.label = self.label
        self.rewarder_client = rewarder_client.RewarderClient()
        self.rewarder_client.factory = factory
        self.rewarder_client.onConnect(None)
コード例 #2
0
    def connect_upstream(self, tries=1, max_attempts=7):
        if self._closed:
            logger.info(
                "[RewardProxyServer] [%d] Attempted to connect upstream although client connection is already closed. Aborting",
                self.id)
            return

        remote = "localhost:15900"
        endpoint = endpoints.clientFromString(reactor, 'tcp:' + remote)
        client_factory = websocket.WebSocketClientFactory('ws://' + remote)
        headers = {'authorization': self._request.headers['authorization']}
        if self._request.headers.get('openai-observer'):
            headers['openai-observer'] = self._request.headers.get(
                'openai-observer')
        client_factory.headers = headers
        client_factory.protocol = RewardServerClient
        client_factory.proxy_server = self
        client_factory.endpoint = endpoint

        logger.info(
            "[RewardProxyServer] [%d] Connecting to upstream %s (try %d/%d)",
            self.id, remote, tries, max_attempts)

        def _connect_callback(client):
            logger.info(
                '[RewardProxyServer] [%d] Upstream connection %s established',
                self.id, remote)
            self.client = client
            if self.factory.logfile_dir:
                self.begin_recording()

        def _connect_errback(reason):
            if tries < max_attempts:
                # Somewhat arbitrary exponential backoff: should be
                # pretty rare, and indicate that we're just starting
                # up.
                delay = 1.5**tries
                logger.info(
                    '[RewardProxyServer] [%d] Connection to %s failed: %s. Try %d/%d; going to retry in %fs',
                    self.id, remote, reason, tries, max_attempts, delay)
                reactor.callLater(delay,
                                  self.connect_upstream,
                                  tries=tries + 1,
                                  max_attempts=max_attempts)
            else:
                logger.error(
                    '[RewardProxyServer] [%d] Connection to %s failed: %s. Completed %d/%d atttempts; disconnecting.',
                    self.id, remote, reason, tries, max_attempts)
                self.transport.loseConnection()

        endpoint.connect(client_factory).addCallbacks(_connect_callback,
                                                      _connect_errback)
コード例 #3
0
    def _connect(
        self,
        name,
        address,
        env_id,
        seed,
        fps,
        i,
        network,
        env_status,
        reward_buffer,
        label,
        password,
        start_timeout,
        observer,
        skip_network_calibration,
        attempt=0,
        elapsed_sleep_time=0,
    ):
        endpoint = endpoints.clientFromString(reactor, 'tcp:' + address)
        factory = websocket.WebSocketClientFactory('ws://' + address)
        factory.protocol = rewarder_client.RewarderClient

        assert password, "Missing password: {} for rewarder session".format(
            password)
        factory.headers = {
            'authorization': utils.basic_auth_encode(password),
            'openai-observer': 'true' if observer else 'false'
        }
        factory.i = i

        # Various important objects
        factory.endpoint = endpoint
        factory.env_status = env_status
        factory.reward_buffer = reward_buffer

        # Helpful strings
        factory.label = label
        factory.address = address

        # Arguments to always send to the remote reset call
        factory.arg_env_id = env_id
        factory.arg_fps = fps

        def record_error(e):
            if isinstance(e, failure.Failure):
                e = e.value

            # logger.error('[%s] Recording rewarder error: %s', factory.label, e)
            with self.lock:
                # drop error on the floor if we're already closed
                if self._already_closed(factory.i):
                    extra_logger.info(
                        '[%s] Ignoring error for already closed connection: %s',
                        label, e)
                elif factory.i not in self.clients:
                    extra_logger.info(
                        '[%s] Received error for connection which has not been fully initialized: %s',
                        label, e)
                    # We could handle this better, but right now we
                    # just mark this as a fatal error for the
                    # backend. Often it actually is.
                    self.errors[factory.i] = e
                else:
                    extra_logger.info(
                        '[%s] Recording fatal error for connection: %s', label,
                        e)
                    self.errors[factory.i] = e

        def retriable_error(e, error_message):
            if isinstance(e, failure.Failure):
                e = e.value

            if self._already_closed(factory.i):
                logger.error(
                    '[%s] Got error, but giving up on reconnecting, since %d already disconnected',
                    factory.label, factory.i)
                return

            # Also need to handle DNS errors, so let's just handle everything for now.
            #
            # reason.trap(twisted.internet.error.ConnectError, error.ConnectionError)
            if elapsed_sleep_time < start_timeout:
                sleep = min((2 * attempt + 1), 10)
                logger.error(
                    '[%s] Waiting on rewarder: %s. Retry in %ds (slept %ds/%ds): %s',
                    factory.label, error_message, sleep, elapsed_sleep_time,
                    start_timeout, e)
                reactor.callLater(
                    sleep,
                    self._connect,
                    name=name,
                    address=address,
                    env_id=env_id,
                    seed=seed,
                    fps=fps,
                    i=i,
                    network=network,
                    env_status=env_status,
                    reward_buffer=reward_buffer,
                    label=label,
                    attempt=attempt + 1,
                    elapsed_sleep_time=elapsed_sleep_time + sleep,
                    start_timeout=start_timeout,
                    password=password,
                    observer=observer,
                    skip_network_calibration=skip_network_calibration,
                )
            else:
                logger.error('[%s] %s. Retries exceeded (slept %ds/%ds): %s',
                             factory.label, error_message, elapsed_sleep_time,
                             start_timeout, e)
                record_error(e)

        factory.record_error = record_error

        try:
            retry_msg = 'establish rewarder TCP connection'
            client = yield endpoint.connect(factory)
            extra_logger.info('[%s] Rewarder TCP connection established',
                              factory.label)

            retry_msg = 'complete WebSocket handshake'
            yield client.waitForWebsocketConnection()
            extra_logger.info('[%s] Websocket client successfully connected',
                              factory.label)

            if not skip_network_calibration:
                retry_msg = 'run network calibration'
                yield network.calibrate(client)
                extra_logger.info('[%s] Network calibration complete',
                                  factory.label)

            retry_msg = ''

            if factory.arg_env_id is not None:
                # We aren't picky about episode ID: we may have
                # already receieved an env.describe message
                # telling us about a resetting environment, which
                # we don't need to bump post.
                #
                # tl;dr hardcoding 0.0 here avoids a double reset.
                reply = yield self._send_env_reset(client,
                                                   seed=seed,
                                                   episode_id='0')
            else:
                # No env_id requested, so we just proceed without a reset
                reply = None
            # We're connected and have measured the
            # network. Mark everything as ready to go.
            with self.lock:
                if factory.i not in self.names_by_id:
                    # ID has been popped!
                    logger.info(
                        '[%s] Rewarder %d started, but has already been closed',
                        factory.label, factory.i)
                    client.close()
                elif reply is None:
                    logger.info(
                        '[%s] Attached to running environment without reset',
                        factory.label)
                else:
                    context, req, rep = reply
                    logger.info('[%s] Initial reset complete: episode_id=%s',
                                factory.label, rep['headers']['episode_id'])
                self.clients[factory.i] = client
        except Exception as e:
            if retry_msg:
                retriable_error(e, 'failed to ' + retry_msg)
            else:
                record_error(e)
コード例 #4
0
ファイル: rewarder_session.py プロジェクト: zhp0260/universe
    def _connect(
        self,
        name,
        address,
        env_id,
        seed,
        fps,
        i,
        network,
        env_status,
        reward_buffer,
        label,
        password,
        start_timeout,
        observer,
        skip_network_calibration,
        attempt=0,
        elapsed_sleep_time=0,
    ):
        endpoint = endpoints.clientFromString(reactor, 'tcp:' + address)
        factory = websocket.WebSocketClientFactory('ws://' + address)
        factory.protocol = rewarder_client.RewarderClient

        assert password, "Missing password: {} for rewarder session".format(
            password)
        factory.headers = {
            'authorization': utils.basic_auth_encode(password),
            'openai-observer': 'true' if observer else 'false'
        }
        factory.i = i

        # Various important objects
        factory.endpoint = endpoint
        factory.env_status = env_status
        factory.reward_buffer = reward_buffer

        # Helpful strings
        factory.label = label
        factory.address = address

        # Arguments to always send to the remote reset call
        factory.arg_env_id = env_id
        factory.arg_fps = fps

        def record_error(e):
            if isinstance(e, failure.Failure):
                e = e.value

            # logger.error('[%s] Recording rewarder error: %s', factory.label, e)
            with self.lock:
                # drop error on the floor if we're already closed
                if self._already_closed(factory.i):
                    extra_logger.info(
                        '[%s] Ignoring error for already closed connection: %s',
                        label, e)
                else:
                    extra_logger.info(
                        '[%s] Recording fatal error for connection: %s', label,
                        e)
                    self.errors[factory.i] = e

        def websocket_failed(e):
            if isinstance(e, failure.Failure):
                e = e.value

            if self._already_closed(factory.i):
                logger.error(
                    '[%s] Giving up on reconnecting, since %d already disconnected',
                    factory.label, factory.i)
                return

            # Also need to handle DNS errors, so let's just handle everything for now.
            #
            # reason.trap(twisted.internet.error.ConnectError, error.ConnectionError)
            if elapsed_sleep_time < start_timeout:
                sleep = min((2 * attempt + 1), 10)
                logger.error(
                    '[%s] Waiting on rewarder: %s. Retry in %ds (slept %ds/%ds): %s',
                    factory.label, websocket_failed.error_message, sleep,
                    elapsed_sleep_time, start_timeout, e)
                reactor.callLater(
                    sleep,
                    self._connect,
                    name=name,
                    address=address,
                    env_id=env_id,
                    seed=seed,
                    fps=fps,
                    i=i,
                    network=network,
                    env_status=env_status,
                    reward_buffer=reward_buffer,
                    label=label,
                    attempt=attempt + 1,
                    elapsed_sleep_time=elapsed_sleep_time + sleep,
                    start_timeout=start_timeout,
                    password=password,
                    observer=observer,
                    skip_network_calibration=skip_network_calibration,
                )
            else:
                logger.error('[%s] %s. Retries exceeded (slept %ds/%ds): %s',
                             factory.label, websocket_failed.error_message,
                             elapsed_sleep_time, start_timeout, e)
                record_error(e)

        def retriable_record_error(e):
            """Record an error, unless our connection is still establishing"""
            if isinstance(e, failure.Failure):
                e = e.value

            # logger.error('[%s] Recording rewarder error: %s', factory.label, e)
            with self.lock:
                # drop error on the floor if we're already closed
                if factory.i not in self.names_by_id:
                    record_error(e)
                elif factory.i not in self.clients:
                    extra_logger.info(
                        '[%s] Received error for connection which has not been fully initialized: %s',
                        label, e)
                    # We could handle this better, but right now we
                    # just mark this as a fatal error for the
                    # backend. Often it actually is.
                    #
                    # If we break again, don't recurse; just skip to
                    # the direct error recording.
                    record_error(e)
                else:
                    record_error(e)

        factory.record_error = retriable_record_error

        def fail(reason):
            factory.record_error(reason)

        def connected(client):
            extra_logger.info('[%s] Websocket client successfully connected',
                              factory.label)

            # Websocket client has come up fully. Time to start on the
            # next level of our callback chain. (There must be a
            # better way to write this.)
            def calibrate_success(network):
                extra_logger.info('[%s] Network calibration complete',
                                  factory.label)

                def reset_success(reply):
                    # We're connected and have measured the
                    # network. Mark everything as ready to go.
                    with self.lock:
                        if factory.i not in self.names_by_id:
                            # ID has been popped!
                            logger.info(
                                '[%s] Rewarder %d started, but has already been closed',
                                factory.label, factory.i)
                            client.close()
                        elif reply is None:
                            logger.info(
                                '[%s] Attached to running environment without reset',
                                factory.label)
                        else:
                            context, req, rep = reply
                            logger.info(
                                '[%s] Initial reset complete: episode_id=%s',
                                factory.label, rep['headers']['episode_id'])
                        self.clients[factory.i] = client

                if factory.arg_env_id is not None:
                    # We aren't picky about episode ID: we may have
                    # already receieved an env.describe message
                    # telling us about a resetting environment, which
                    # we don't need to bump post.
                    #
                    # tl;dr hardcoding 0.0 here avoids a double reset.
                    d = self._send_env_reset(client, seed=seed, episode_id='0')
                    d.addCallback(reset_success)
                    d.addErrback(fail)
                else:
                    # No env_id requested, so we just proceed without a reset
                    reset_success(None)

            if skip_network_calibration:
                calibrate_success(network)
            else:
                d = network.calibrate(client)
                d.addCallback(calibrate_success)
                websocket_failed.error_message = 'WebSocket handshake established but calibration failed'
                d.addErrback(websocket_failed)
                d.addErrback(fail)

        d = defer.Deferred()
        d.addCallbacks(connected)
        websocket_failed.error_message = 'TCP connection established but WebSocket handshake failed'
        d.addErrback(websocket_failed)
        d.addErrback(fail)
        factory.deferred = d

        def connection_succeeded(conn):
            extra_logger.info('[%s] Rewarder TCP connection established',
                              factory.label)

        def connection_failed(reason):
            reason = error.Error('[{}] Connection failed: {}'.format(
                factory.label, reason.value))

            try:
                d.errback(utils.format_error(reason))
            except defer.AlreadyCalledError:
                raise

        res = endpoint.connect(factory)
        res.addCallback(connection_succeeded)
        websocket_failed.error_message = 'Could not establish rewarder TCP connection'
        res.addErrback(websocket_failed)
        res.addErrback(connection_failed)