def test_float_properties(): sender = FloatPropertiesChannel() receiver = FloatPropertiesChannel() sender.set_property("prop1", 1.0) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) val = receiver.get_property("prop1") assert val == 1.0 val = receiver.get_property("prop2") assert val is None sender.set_property("prop2", 2.0) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) val = receiver.get_property("prop1") assert val == 1.0 val = receiver.get_property("prop2") assert val == 2.0 assert len(receiver.list_properties()) == 2 assert "prop1" in receiver.list_properties() assert "prop2" in receiver.list_properties() val = sender.get_property("prop1") assert val == 1.0 assert receiver.get_property_dict_copy() == {"prop1": 1.0, "prop2": 2.0} assert receiver.get_property_dict_copy() == sender.get_property_dict_copy()
def test_int_channel(): sender = IntChannel() receiver = IntChannel() sender.send_int(5) sender.send_int(6) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) assert receiver.list_int[0] == 5 assert receiver.list_int[1] == 6
def test_raw_bytes(): guid = uuid.uuid4() sender = RawBytesChannel(guid) receiver = RawBytesChannel(guid) sender.send_raw_data(b"foo") sender.send_raw_data(b"bar") data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) messages = receiver.get_and_clear_received_messages() assert len(messages) == 2 assert messages[0].decode("ascii") == "foo" assert messages[1].decode("ascii") == "bar" messages = receiver.get_and_clear_received_messages() assert len(messages) == 0
def test_environment_parameters(): sender = EnvironmentParametersChannel() # We use a raw bytes channel to interpred the data receiver = RawBytesChannel(sender.channel_id) sender.set_float_parameter("param-1", 0.1) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) key = message.read_string() dtype = message.read_int32() value = message.read_float32() assert key == "param-1" assert dtype == EnvironmentParametersChannel.EnvironmentDataTypes.FLOAT assert value - 0.1 < 1e-8 sender.set_float_parameter("param-1", 0.1) sender.set_float_parameter("param-2", 0.1) sender.set_float_parameter("param-3", 0.1) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) assert len(receiver.get_and_clear_received_messages()) == 3 with pytest.raises(UnityCommunicationException): # try to send data to the EngineConfigurationChannel sender.set_float_parameter("param-1", 0.1) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([sender]).process_side_channel_message(data)
def test_engine_configuration(): sender = EngineConfigurationChannel() # We use a raw bytes channel to interpred the data receiver = RawBytesChannel(sender.channel_id) config = EngineConfig.default_config() sender.set_configuration(config) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) received_data = receiver.get_and_clear_received_messages() assert len(received_data) == 5 # 5 different messages one for each setting sent_time_scale = 4.5 sender.set_configuration_parameters(time_scale=sent_time_scale) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([receiver]).process_side_channel_message(data) message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) message.read_int32() time_scale = message.read_float32() assert time_scale == sent_time_scale with pytest.raises(UnitySideChannelException): sender.set_configuration_parameters(width=None, height=42) with pytest.raises(UnityCommunicationException): # try to send data to the EngineConfigurationChannel sender.set_configuration_parameters(time_scale=sent_time_scale) data = SideChannelManager([sender]).generate_side_channel_messages() SideChannelManager([sender]).process_side_channel_message(data)
def __init__( self, file_name: Optional[str] = None, worker_id: int = 0, base_port: Optional[int] = None, seed: int = 0, no_graphics: bool = False, timeout_wait: int = 60, additional_args: Optional[List[str]] = None, side_channels: Optional[List[SideChannel]] = None, log_folder: Optional[str] = None, ): """ Starts a new unity environment and establishes a connection with the environment. Notice: Currently communication between Unity and Python takes place over an open socket without authentication. Ensure that the network where training takes place is secure. :string file_name: Name of Unity environment binary. :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. If no environment is specified (i.e. file_name is None), the DEFAULT_EDITOR_PORT will be used. :int worker_id: Offset from base_port. Used for training multiple environments simultaneously. :bool no_graphics: Whether to run the Unity simulator in no-graphics mode :int timeout_wait: Time (in seconds) to wait for connection from environment. :list args: Addition Unity command line arguments :list side_channels: Additional side channel for no-rl communication with Unity :str log_folder: Optional folder to write the Unity Player log file into. Requires absolute path. """ atexit.register(self._close) self._additional_args = additional_args or [] self._no_graphics = no_graphics # If base port is not specified, use BASE_ENVIRONMENT_PORT if we have # an environment, otherwise DEFAULT_EDITOR_PORT if base_port is None: base_port = ( self.BASE_ENVIRONMENT_PORT if file_name else self.DEFAULT_EDITOR_PORT ) self._port = base_port + worker_id self._buffer_size = 12000 # If true, this means the environment was successfully loaded self._loaded = False # The process that is started. If None, no process was started self._proc1 = None self._timeout_wait: int = timeout_wait self._communicator = self._get_communicator(worker_id, base_port, timeout_wait) self._worker_id = worker_id self._side_channel_manager = SideChannelManager(side_channels) self._log_folder = log_folder # If the environment name is None, a new environment will not be launched # and the communicator will directly try to connect to an existing unity environment. # If the worker-id is not 0 and the environment name is None, an error is thrown if file_name is None and worker_id != 0: raise UnityEnvironmentException( "If the environment name is None, " "the worker-id must be 0 in order to connect with the Editor." ) if file_name is not None: try: self._proc1 = env_utils.launch_executable( file_name, self._executable_args() ) except UnityEnvironmentException: self._close(0) raise else: logger.info( f"Listening on port {self._port}. " f"Start training by pressing the Play button in the Unity Editor." ) self._loaded = True rl_init_parameters_in = UnityRLInitializationInputProto( seed=seed, communication_version=self.API_VERSION, package_version=mlagents_envs.__version__, capabilities=UnityEnvironment._get_capabilities_proto(), ) try: aca_output = self._send_academy_parameters(rl_init_parameters_in) aca_params = aca_output.rl_initialization_output except UnityTimeOutException: self._close(0) raise if not UnityEnvironment._check_communication_compatibility( aca_params.communication_version, UnityEnvironment.API_VERSION, aca_params.package_version, ): self._close(0) UnityEnvironment._raise_version_exception(aca_params.communication_version) UnityEnvironment._warn_csharp_base_capabilities( aca_params.capabilities, aca_params.package_version, UnityEnvironment.API_VERSION, ) self._env_state: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {} self._env_specs: Dict[str, BehaviorSpec] = {} self._env_actions: Dict[str, np.ndarray] = {} self._is_first_message = True self._update_behavior_specs(aca_output)