def test_disconnect(self, monkeypatch): skt = MockSocket.create() monkeypatch.setattr(socket, 'socket', lambda: skt) client = AgentProxy('localhost:7000') client.connect() client.disconnect() assert client.skt is None
def test_exchange(self, monkeypatch): skt = MockSocket.create() monkeypatch.setattr(socket, 'socket', lambda: skt) client = AgentProxy('localhost:7000') client.connect() ret = client._exchange({'response': 'action', 'data': [1, 2, 3]}) assert ret == [1, 2, 3]
def test_wrong_protocol_response(self, monkeypatch): skt = MockSocket.create() monkeypatch.setattr(socket, 'socket', lambda: skt) client = AgentProxy('localhost:7000') client.connect() try: client._exchange({'wrong': 'response'}) assert False except AgentProxyException as e: assert str(e) == 'wring message format'
class TrainingBase(object): def __init__(self): self.exploit = options.get('exploit', False) self.max_episodes = options.get('environment/max_episodes', 1) self.infinite_run = options.get('environment/infinite_run', False) rlx_address = options.get('rlx_server_address', None) if rlx_address is None: rlx_address = options.get('relaax_rlx_server/bind', 'localhost:7001') self.agent = AgentProxy(rlx_address) def initialize_agent(self, retry=6): # connect to the server self.agent.connect(retry) # give agent a moment to load and initialize self.agent.init(self.exploit) def run(self): try: self.initialize_agent() number = 0 while (number < self.max_episodes) or self.infinite_run: try: episode_reward = self.episode(number) if episode_reward is not None: self.agent.metrics.scalar('episode_reward', episode_reward) number += 1 except AgentProxyException as e: log.error('Agent connection lost: %s' % str(e)) log.error( 'Reconnecting to another Agent, retrying to connect 10 times...' ) try: self.initialize_agent(retry=10) continue except: raise Exception('Can\'t reconnect, exiting...') except Exception as e: log.error("Error while running agent: %s" % str(e)) log.debug(traceback.format_exc()) finally: # disconnect from the server self.agent.disconnect() def episode(self, number): pass
def test_error_response(self, monkeypatch): skt = MockSocket.create() monkeypatch.setattr(socket, 'socket', lambda: skt) client = AgentProxy('localhost:7000') client.connect() try: client._exchange({'response': 'error'}) assert False except AgentProxyException as e: assert str(e) == 'unknown error' try: client._exchange({'response': 'error', 'message': 'some error'}) assert False except AgentProxyException as e: assert str(e) == 'some error'
def test_connect_with_address_as_tuple(self, monkeypatch): skt = MockSocket.create() monkeypatch.setattr(socket, 'socket', lambda: skt) client = AgentProxy(('localhost', 7000)) client.connect() assert client.skt == skt