def test_message_string():
    val = "mlagents!"
    msg_out = OutgoingMessage()
    msg_out.write_string(val)

    msg_in = IncomingMessage(msg_out.buffer)
    read_val = msg_in.read_string()
    assert val == read_val
 def set_float_parameter(self, key: str, value: float) -> None:
     """
     Sets a float environment parameter in the Unity Environment.
     :param key: The string identifier of the parameter.
     :param value: The float value of the parameter.
     """
     msg = OutgoingMessage()
     msg.write_string(key)
     msg.write_int32(self.EnvironmentDataTypes.FLOAT)
     msg.write_float32(value)
     super().queue_message_to_send(msg)
 def set_property(self, key: str, value: float) -> None:
     """
     Sets a property in the Unity Environment.
     :param key: The string identifier of the property.
     :param value: The float value of the property.
     """
     self._float_properties[key] = value
     msg = OutgoingMessage()
     msg.write_string(key)
     msg.write_float32(value)
     super().queue_message_to_send(msg)
Esempio n. 4
0
def test_message_string():
    val = "mlagents!"
    msg_out = OutgoingMessage()
    msg_out.write_string(val)

    msg_in = IncomingMessage(msg_out.buffer)
    read_val = msg_in.read_string()
    assert val == read_val

    # Test reading with defaults
    assert "" == msg_in.read_string()
    assert val == msg_in.read_string(default_value=val)
Esempio n. 5
0
def test_stats_channel():
    receiver = StatsSideChannel()
    message = OutgoingMessage()
    message.write_string("stats-1")
    message.write_float32(42.0)
    message.write_int32(1)  # corresponds to StatsAggregationMethod.MOST_RECENT

    receiver.on_message_received(IncomingMessage(message.buffer))

    stats = receiver.get_and_reset_stats()

    assert len(stats) == 1
    val, method = stats["stats-1"][0]
    assert val - 42.0 < 1e-8
    assert method == StatsAggregationMethod.MOST_RECENT
 def set_gaussian_sampler_parameters(self, key: str, mean: float,
                                     st_dev: float, seed: int) -> None:
     """
     Sets a gaussian environment parameter sampler.
     :param key: The string identifier of the parameter.
     :param mean: The mean of the sampling distribution.
     :param st_dev: The standard deviation of the sampling distribution.
     :param seed: The random seed to initialize the sampler.
     """
     msg = OutgoingMessage()
     msg.write_string(key)
     msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
     msg.write_int32(seed)
     msg.write_int32(self.SamplerTypes.GAUSSIAN)
     msg.write_float32(mean)
     msg.write_float32(st_dev)
     super().queue_message_to_send(msg)
 def set_uniform_sampler_parameters(self, key: str, min_value: float,
                                    max_value: float, seed: int) -> None:
     """
     Sets a uniform environment parameter sampler.
     :param key: The string identifier of the parameter.
     :param min_value: The minimum of the sampling distribution.
     :param max_value: The maximum of the sampling distribution.
     :param seed: The random seed to initialize the sampler.
     """
     msg = OutgoingMessage()
     msg.write_string(key)
     msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
     msg.write_int32(seed)
     msg.write_int32(self.SamplerTypes.UNIFORM)
     msg.write_float32(min_value)
     msg.write_float32(max_value)
     super().queue_message_to_send(msg)
 def set_multirangeuniform_sampler_parameters(self, key: str,
                                              intervals: List[Tuple[float,
                                                                    float]],
                                              seed: int) -> None:
     """
     Sets a multirangeuniform environment parameter sampler.
     :param key: The string identifier of the parameter.
     :param intervals: The lists of min and max that define each uniform distribution.
     :param seed: The random seed to initialize the sampler.
     """
     msg = OutgoingMessage()
     msg.write_string(key)
     msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
     msg.write_int32(seed)
     msg.write_int32(self.SamplerTypes.MULTIRANGEUNIFORM)
     flattened_intervals = [
         value for interval in intervals for value in interval
     ]
     msg.write_float32_list(flattened_intervals)
     super().queue_message_to_send(msg)