def test_actor_critic(lstm, shared): obs_size = 4 network_settings = NetworkSettings( memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True) obs_spec = create_observation_specs_with_shapes([(obs_size, )]) act_size = 2 mask = torch.ones([1, act_size * 2]) stream_names = [f"stream_name{n}" for n in range(4)] action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) if shared: actor = critic = SharedActorCritic(obs_spec, network_settings, action_spec, stream_names, network_settings) else: actor = SimpleActor(obs_spec, network_settings, action_spec) critic = ValueNetwork(stream_names, obs_spec, network_settings) if lstm: sample_obs = torch.ones( (1, network_settings.memory.sequence_length, obs_size)) memories = torch.ones( (1, network_settings.memory.sequence_length, actor.memory_size)) else: sample_obs = torch.ones((1, obs_size)) memories = torch.tensor([]) # memories isn't always set to None, the network should be able to # deal with that. # Test critic pass value_out, memories_out = critic.critic_pass([sample_obs], memories=memories) for stream in stream_names: if lstm: assert value_out[stream].shape == ( network_settings.memory.sequence_length, ) assert memories_out.shape == memories.shape else: assert value_out[stream].shape == (1, ) # Test get action stats and_value action, log_probs, entropies, mem_out = actor.get_action_and_stats( [sample_obs], memories=memories, masks=mask) if lstm: assert action.continuous_tensor.shape == (64, 2) else: assert action.continuous_tensor.shape == (1, 2) assert len(action.discrete_list) == 2 for _disc in action.discrete_list: if lstm: assert _disc.shape == (64, 1) else: assert _disc.shape == (1, 1) if mem_out is not None: assert mem_out.shape == memories.shape
def test_simple_actor(use_discrete): obs_size = 4 network_settings = NetworkSettings() obs_shapes = [(obs_size, )] act_size = [2] if use_discrete: masks = torch.ones((1, 1)) action_spec = ActionSpec.create_discrete(tuple(act_size)) else: masks = None action_spec = ActionSpec.create_continuous(act_size[0]) actor = SimpleActor(obs_shapes, network_settings, action_spec) # Test get_dist sample_obs = torch.ones((1, obs_size)) dists, _ = actor.get_dists([sample_obs], [], masks=masks) for dist in dists: if use_discrete: assert isinstance(dist, CategoricalDistInstance) else: assert isinstance(dist, GaussianDistInstance) # Test sample_actions actions = actor.sample_action(dists) for act in actions: if use_discrete: assert act.shape == (1, 1) else: assert act.shape == (1, act_size[0]) # Test forward actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward( [sample_obs], [], masks=masks) for act in actions: # This is different from above for ONNX export if use_discrete: assert act.shape == tuple(act_size) else: assert act.shape == (act_size[0], 1) assert mem_size == 0 assert is_cont == int(not use_discrete) assert act_size_vec == torch.tensor(act_size)
def test_simple_actor(action_type): obs_size = 4 network_settings = NetworkSettings() obs_shapes = [(obs_size, )] act_size = [2] masks = None if action_type == ActionType.CONTINUOUS else torch.ones( (1, 1)) actor = SimpleActor(obs_shapes, network_settings, action_type, act_size) # Test get_dist sample_obs = torch.ones((1, obs_size)) dists, _ = actor.get_dists([sample_obs], [], masks=masks) for dist in dists: if action_type == ActionType.CONTINUOUS: assert isinstance(dist, GaussianDistInstance) else: assert isinstance(dist, CategoricalDistInstance) # Test sample_actions actions = actor.sample_action(dists) for act in actions: if action_type == ActionType.CONTINUOUS: assert act.shape == (1, act_size[0]) else: assert act.shape == (1, 1) # Test forward actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward( [sample_obs], [], masks=masks) for act in actions: if action_type == ActionType.CONTINUOUS: assert act.shape == ( act_size[0], 1, ) # This is different from above for ONNX export else: assert act.shape == (1, 1) # TODO: Once export works properly. fix the shapes here. assert mem_size == 0 assert is_cont == int(action_type == ActionType.CONTINUOUS) assert act_size_vec == torch.tensor(act_size)
def __init__( self, seed: int, behavior_spec: BehaviorSpec, trainer_settings: TrainerSettings, tanh_squash: bool = False, separate_critic: bool = True, condition_sigma_on_obs: bool = True, ): """ Policy that uses a multilayer perceptron to map the observations to actions. Could also use a CNN to encode visual input prior to the MLP. Supports discrete and continuous actions, as well as recurrent networks. :param seed: Random seed. :param behavior_spec: Assigned BehaviorSpec object. :param trainer_settings: Defined training parameters. :param load: Whether a pre-trained model will be loaded or a new one created. :param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output. """ super().__init__( seed, behavior_spec, trainer_settings, tanh_squash, condition_sigma_on_obs ) self.global_step = ( GlobalSteps() ) # could be much simpler if TorchPolicy is nn.Module self.grads = None self.stats_name_to_update_name = { "Losses/Value Loss": "value_loss", "Losses/Policy Loss": "policy_loss", } if separate_critic: self.actor = SimpleActor( observation_specs=self.behavior_spec.observation_specs, network_settings=trainer_settings.network_settings, action_spec=behavior_spec.action_spec, conditional_sigma=self.condition_sigma_on_obs, tanh_squash=tanh_squash, ) self.shared_critic = False else: reward_signal_configs = trainer_settings.reward_signals reward_signal_names = [ key.value for key, _ in reward_signal_configs.items() ] self.actor = SharedActorCritic( observation_specs=self.behavior_spec.observation_specs, network_settings=trainer_settings.network_settings, action_spec=behavior_spec.action_spec, stream_names=reward_signal_names, conditional_sigma=self.condition_sigma_on_obs, tanh_squash=tanh_squash, ) self.shared_critic = True # Save the m_size needed for export self._export_m_size = self.m_size # m_size needed for training is determined by network, not trainer settings self.m_size = self.actor.memory_size self.actor.to(default_device()) self._clip_action = not tanh_squash