class SimpleActor(nn.Module, Actor): def __init__( self, observation_specs: List[ObservationSpec], network_settings: NetworkSettings, action_spec: ActionSpec, conditional_sigma: bool = False, tanh_squash: bool = False, ): super().__init__() self.action_spec = action_spec self.version_number = torch.nn.Parameter(torch.Tensor([2.0]), requires_grad=False) self.is_continuous_int_deprecated = torch.nn.Parameter( torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False) self.continuous_act_size_vector = torch.nn.Parameter( torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False) # TODO: export list of branch sizes instead of sum self.discrete_act_size_vector = torch.nn.Parameter(torch.Tensor( [sum(self.action_spec.discrete_branches)]), requires_grad=False) self.act_size_vector_deprecated = torch.nn.Parameter( torch.Tensor([ self.action_spec.continuous_size + sum(self.action_spec.discrete_branches) ]), requires_grad=False, ) self.network_body = NetworkBody(observation_specs, network_settings) if network_settings.memory is not None: self.encoding_size = network_settings.memory.memory_size // 2 else: self.encoding_size = network_settings.hidden_units self.memory_size_vector = torch.nn.Parameter(torch.Tensor( [int(self.network_body.memory_size)]), requires_grad=False) self.action_model = ActionModel( self.encoding_size, action_spec, conditional_sigma=conditional_sigma, tanh_squash=tanh_squash, ) @property def memory_size(self) -> int: return self.network_body.memory_size def update_normalization(self, buffer: AgentBuffer) -> None: self.network_body.update_normalization(buffer) def get_action_and_stats( self, inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: encoding, memories = self.network_body(inputs, memories=memories, sequence_length=sequence_length) action, log_probs, entropies = self.action_model(encoding, masks) return action, log_probs, entropies, memories def get_stats( self, inputs: List[torch.Tensor], actions: AgentAction, masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[ActionLogProbs, torch.Tensor]: encoding, actor_mem_outs = self.network_body( inputs, memories=memories, sequence_length=sequence_length) log_probs, entropies = self.action_model.evaluate( encoding, masks, actions) return log_probs, entropies def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], var_len_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, ) -> Tuple[Union[int, torch.Tensor], ...]: """ Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. At this moment, torch.onnx.export() doesn't accept None as tensor to be exported, so the size of return tuple varies with action spec. """ # This code will convert the vec and vis obs into a list of inputs for the network concatenated_vec_obs = vec_inputs[0] inputs = [] start = 0 end = 0 vis_index = 0 var_len_index = 0 for i, enc in enumerate(self.network_body.processors): if isinstance(enc, VectorInput): # This is a vec_obs vec_size = self.network_body.embedding_sizes[i] end = start + vec_size inputs.append(concatenated_vec_obs[:, start:end]) start = end elif isinstance(enc, EntityEmbedding): inputs.append(var_len_inputs[var_len_index]) var_len_index += 1 else: # visual input inputs.append(vis_inputs[vis_index]) vis_index += 1 # End of code to convert the vec and vis obs into a list of inputs for the network encoding, memories_out = self.network_body(inputs, memories=memories, sequence_length=1) ( cont_action_out, disc_action_out, action_out_deprecated, ) = self.action_model.get_action_out(encoding, masks) export_out = [self.version_number, self.memory_size_vector] if self.action_spec.continuous_size > 0: export_out += [cont_action_out, self.continuous_act_size_vector] if self.action_spec.discrete_size > 0: export_out += [disc_action_out, self.discrete_act_size_vector] # Only export deprecated nodes with non-hybrid action spec if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0: export_out += [ action_out_deprecated, self.is_continuous_int_deprecated, self.act_size_vector_deprecated, ] return tuple(export_out)
class SimpleActor(nn.Module, Actor): MODEL_EXPORT_VERSION = 3 # Corresponds to ModelApiVersion.MLAgents2_0 def __init__( self, observation_specs: List[ObservationSpec], network_settings: NetworkSettings, action_spec: ActionSpec, conditional_sigma: bool = False, tanh_squash: bool = False, ): super().__init__() self.action_spec = action_spec self.version_number = torch.nn.Parameter(torch.Tensor( [self.MODEL_EXPORT_VERSION]), requires_grad=False) self.is_continuous_int_deprecated = torch.nn.Parameter( torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False) self.continuous_act_size_vector = torch.nn.Parameter( torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False) self.discrete_act_size_vector = torch.nn.Parameter(torch.Tensor( [self.action_spec.discrete_branches]), requires_grad=False) self.act_size_vector_deprecated = torch.nn.Parameter( torch.Tensor([ self.action_spec.continuous_size + sum(self.action_spec.discrete_branches) ]), requires_grad=False, ) self.network_body = NetworkBody(observation_specs, network_settings) if network_settings.memory is not None: self.encoding_size = network_settings.memory.memory_size // 2 else: self.encoding_size = network_settings.hidden_units self.memory_size_vector = torch.nn.Parameter(torch.Tensor( [int(self.network_body.memory_size)]), requires_grad=False) self.action_model = ActionModel( self.encoding_size, action_spec, conditional_sigma=conditional_sigma, tanh_squash=tanh_squash, ) @property def memory_size(self) -> int: return self.network_body.memory_size def update_normalization(self, buffer: AgentBuffer) -> None: self.network_body.update_normalization(buffer) def get_action_and_stats( self, inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: encoding, memories = self.network_body(inputs, memories=memories, sequence_length=sequence_length) action, log_probs, entropies = self.action_model(encoding, masks) return action, log_probs, entropies, memories def get_stats( self, inputs: List[torch.Tensor], actions: AgentAction, masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[ActionLogProbs, torch.Tensor]: encoding, actor_mem_outs = self.network_body( inputs, memories=memories, sequence_length=sequence_length) log_probs, entropies = self.action_model.evaluate( encoding, masks, actions) return log_probs, entropies def forward( self, inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, ) -> Tuple[Union[int, torch.Tensor], ...]: """ Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. At this moment, torch.onnx.export() doesn't accept None as tensor to be exported, so the size of return tuple varies with action spec. """ encoding, memories_out = self.network_body(inputs, memories=memories, sequence_length=1) ( cont_action_out, disc_action_out, action_out_deprecated, ) = self.action_model.get_action_out(encoding, masks) export_out = [self.version_number, self.memory_size_vector] if self.action_spec.continuous_size > 0: export_out += [cont_action_out, self.continuous_act_size_vector] if self.action_spec.discrete_size > 0: export_out += [disc_action_out, self.discrete_act_size_vector] if self.network_body.memory_size > 0: export_out += [memories_out] return tuple(export_out)