def test_message_string(): val = "mlagents!" msg_out = OutgoingMessage() msg_out.write_string(val) msg_in = IncomingMessage(msg_out.buffer) read_val = msg_in.read_string() assert val == read_val # Test reading with defaults assert "" == msg_in.read_string() assert val == msg_in.read_string(default_value=val)
def test_message_int32(): val = 1337 msg_out = OutgoingMessage() msg_out.write_int32(val) msg_in = IncomingMessage(msg_out.buffer) read_val = msg_in.read_int32() assert val == read_val # Test reading with defaults assert 0 == msg_in.read_int32() assert val == msg_in.read_int32(default_value=val)
def test_message_float_list(): val = [1.0, 3.0, 9.0] msg_out = OutgoingMessage() msg_out.write_float32_list(val) msg_in = IncomingMessage(msg_out.buffer) read_val = msg_in.read_float32_list() # These won't be exactly equal in general, since python floats are 64-bit. assert val == read_val # Test reading with defaults assert [] == msg_in.read_float32_list() assert val == msg_in.read_float32_list(default_value=val)
def test_message_bool(): vals = [True, False] msg_out = OutgoingMessage() for v in vals: msg_out.write_bool(v) msg_in = IncomingMessage(msg_out.buffer) read_vals = [] for _ in range(len(vals)): read_vals.append(msg_in.read_bool()) assert vals == read_vals # Test reading with defaults assert msg_in.read_bool() is False assert msg_in.read_bool(default_value=True) is True
def process_side_channel_message(self, data: bytes) -> None: """ Separates the data received from Python into individual messages for each registered side channel and calls on_message_received on them. :param data: The packed message sent by Unity """ offset = 0 while offset < len(data): try: channel_id = uuid.UUID(bytes_le=bytes(data[offset:offset + 16])) offset += 16 message_len, = struct.unpack_from("<i", data, offset) offset = offset + 4 message_data = data[offset:offset + message_len] offset = offset + message_len except (struct.error, ValueError, IndexError): raise UnityEnvironmentException( "There was a problem reading a message in a SideChannel. " "Please make sure the version of MLAgents in Unity is " "compatible with the Python version.") if len(message_data) != message_len: raise UnityEnvironmentException( "The message received by the side channel {0} was " "unexpectedly short. Make sure your Unity Environment " "sending side channel data properly.".format(channel_id)) if channel_id in self._side_channels_dict: incoming_message = IncomingMessage(message_data) self._side_channels_dict[channel_id].on_message_received( incoming_message) else: get_logger(__name__).warning( f"Unknown side channel data received. Channel type: {channel_id}." )
def on_message_received(self, msg: IncomingMessage) -> None: """ Is called by the environment to the side channel. Can be called multiple times per step if multiple messages are meant for that SideChannel. """ self._received_messages.append(msg.get_raw_bytes())
def on_message_received(self, msg: IncomingMessage) -> None: """ Is called by the environment to the side channel. Can be called multiple times per step if multiple messages are meant for that SideChannel. """ k, v = self.deserialize_int_list_prop(msg.get_raw_bytes()) self._int_list_properties[k] = v
def test_environment_parameters(): sender = EnvironmentParametersChannel() # We use a raw bytes channel to interpred the data receiver = RawBytesChannel(sender.channel_id) sender.set_float_parameter("param-1", 0.1) data = UnityEnvironment._generate_side_channel_data( {sender.channel_id: sender}) UnityEnvironment._parse_side_channel_message( {receiver.channel_id: receiver}, data) message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) key = message.read_string() dtype = message.read_int32() value = message.read_float32() assert key == "param-1" assert dtype == EnvironmentParametersChannel.EnvironmentDataTypes.FLOAT assert value - 0.1 < 1e-8 sender.set_float_parameter("param-1", 0.1) sender.set_float_parameter("param-2", 0.1) sender.set_float_parameter("param-3", 0.1) data = UnityEnvironment._generate_side_channel_data( {sender.channel_id: sender}) UnityEnvironment._parse_side_channel_message( {receiver.channel_id: receiver}, data) assert len(receiver.get_and_clear_received_messages()) == 3 with pytest.raises(UnityCommunicationException): # try to send data to the EngineConfigurationChannel sender.set_float_parameter("param-1", 0.1) data = UnityEnvironment._generate_side_channel_data( {sender.channel_id: sender}) UnityEnvironment._parse_side_channel_message( {receiver.channel_id: sender}, data)
def test_engine_configuration(): sender = EngineConfigurationChannel() # We use a raw bytes channel to interpred the data receiver = RawBytesChannel(sender.channel_id) config = EngineConfig.default_config() sender.set_configuration(config) data = UnityEnvironment._generate_side_channel_data( {sender.channel_id: sender}) UnityEnvironment._parse_side_channel_message( {receiver.channel_id: receiver}, data) received_data = receiver.get_and_clear_received_messages() assert len(received_data) == 5 # 5 different messages one for each setting sent_time_scale = 4.5 sender.set_configuration_parameters(time_scale=sent_time_scale) data = UnityEnvironment._generate_side_channel_data( {sender.channel_id: sender}) UnityEnvironment._parse_side_channel_message( {receiver.channel_id: receiver}, data) message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) message.read_int32() time_scale = message.read_float32() assert time_scale == sent_time_scale with pytest.raises(UnitySideChannelException): sender.set_configuration_parameters(width=None, height=42) with pytest.raises(UnityCommunicationException): # try to send data to the EngineConfigurationChannel sender.set_configuration_parameters(time_scale=sent_time_scale) data = UnityEnvironment._generate_side_channel_data( {sender.channel_id: sender}) UnityEnvironment._parse_side_channel_message( {receiver.channel_id: sender}, data)
def test_stats_channel(): receiver = StatsSideChannel() message = OutgoingMessage() message.write_string("stats-1") message.write_float32(42.0) message.write_int32(1) # corresponds to StatsAggregationMethod.MOST_RECENT receiver.on_message_received(IncomingMessage(message.buffer)) stats = receiver.get_and_reset_stats() assert len(stats) == 1 val, method = stats["stats-1"][0] assert val - 42.0 < 1e-8 assert method == StatsAggregationMethod.MOST_RECENT
def on_message_received(self, msg: IncomingMessage) -> None: val = msg.read_int32() self.list_int += [val]