def histogram_plot(value: List[np.ndarray], **kwargs) -> None: """Plots histograms of all observations (default behaviour). :param value: A list of observation arrays. :param kwargs: Additional plotting relevant arguments. """ unused(kwargs) flat_values = np.stack(value).flatten() fig = plt.figure(figsize=(7, 5)) plt.hist(flat_values) return fig
def forward(self, input_dict: Dict[str, Any], state: List, seq_lens: torch.Tensor) -> Tuple[Any, List]: """Perform the forward pass through the network :param input_dict: Dictionary of input tensors, including "obs", "obs_flat", "prev_action", "prev_reward", "is_training" :param state: List of state tensors with sizes matching those returned by get_initial_state + the batch dimension :param seq_lens: 1d tensor holding input sequence lengths :return: A tuple of network output and state, where the network output tensor is of size [BATCH, num_outputs] """ unused(state) unused(seq_lens) return self.policy_forward(input_dict, self._policy)
def create_binary_plot(value: Union[List[Tuple[np.ndarray, int]], List[int], List[float]], **kwargs) -> plt.Figure: """ Checks the type of value and calls the correct plotting function accordingly. :param value: Output of an reducer function :param kwargs: Additional plotting relevant arguments. :return: plt.figure that contains a bar plot """ unused(kwargs) if isinstance(value[0], tuple): # in this case, we have the discrete action events and need the relative bar plot for plotting fig = create_multi_binary_relative_bar_plot(value) else: raise NotImplementedError( 'plotting for this data type not implemented yet') return fig
def create_violin_distribution(value: List[np.ndarray], **kwargs) -> plt.Figure: """ Creates simple matplotlib violin plot of value. :param value: output of an event (expected to be a list of numpy vectors) :param kwargs: Additional plotting relevant arguments. :return: plt.figure that contains a bar plot """ unused(kwargs) # extract array value_array = np.stack(value) fig_size = (min(14, max(7, value_array.shape[1] // 2)), 7) fig = plt.figure(figsize=fig_size) plt.violinplot(value_array, showmeans=True) plt.grid(True) return fig
def __init__(self, observation_space: spaces.Dict, action_space: spaces.Dict, model_config: Dict, maze_model_composer_config: ConfigType, spaces_config_dump_file: str, state_dict_dump_file: str): unused(state_dict_dump_file) assert isinstance(action_space, spaces.Dict), f'The given action_space should be gym.spaces.Dict, but we' \ f' got {type(action_space)}' assert isinstance(observation_space, spaces.Dict), f'The original observation space has to be a gym Dict, ' \ f'but we got {type(observation_space)}' assert model_config.get( 'vf_share_layers' ) is False, 'vf_share_layer not implemented for maze models' assert maze_model_composer_config.get('shared_embedding_keys', None) is None, 'Shared embedding with maze ' \ 'models is not supported for ' \ 'rllib trainers' # Initialize model composer self.model_composer = Factory(BaseModelComposer).instantiate( maze_model_composer_config, action_spaces_dict={0: action_space}, observation_spaces_dict={0: observation_space}, agent_counts_dict={0: 1}) # Obtain action order from distribution mapper (this has to be in the same order as the attribute # self.action_heads in the MazeRLlibActionDistribution self.action_keys = list( self.model_composer.distribution_mapper.action_space.spaces.keys()) # Initialize space config, and dump it to file SpacesConfig(self.model_composer.action_spaces_dict, self.model_composer.observation_spaces_dict, self.model_composer.agent_counts_dict).save( spaces_config_dump_file) # Assert that only one network is used for policy assert len(self.model_composer.policy.networks) == 1
def __init__(self, obs_space: Any, action_space: spaces.Space, num_outputs: int, model_config: Dict, name: str, maze_model_composer_config: ConfigType, spaces_config_dump_file: str, state_dict_dump_file: str, dueling: bool = True, num_atoms: int = 1, **kwargs): unused(num_outputs) org_obs_space = obs_space.original_space assert isinstance(action_space, spaces.Discrete), f'Only discrete spaces supported but got {action_space}' num_outputs = action_space.n org_action_space = spaces.Dict({'action': action_space}) assert dueling is True, 'Only dueling==True is supported at this point' assert num_atoms == 1, 'Only num_atoms == 1 is suported at this point' DQNTorchModel.__init__(self, obs_space=obs_space, action_space=action_space, model_config=model_config, num_outputs=num_outputs, name=name + '_maze_wrapper', dueling=dueling, num_atoms=num_atoms, **kwargs) import random import numpy as np random.seed(42) np.random.seed(42) torch.manual_seed(42) MazeRLlibBaseModel.__init__(self, observation_space=org_obs_space, action_space=org_action_space, model_config=model_config, maze_model_composer_config=maze_model_composer_config, spaces_config_dump_file=spaces_config_dump_file, state_dict_dump_file=state_dict_dump_file) self._advantage_module: nn.Module = list(self.model_composer.policy.networks.values())[0] self.advantage_module = nn.Identity() # Assert that at most one network is used for critic assert self.model_composer.critic is not None and len(self.model_composer.critic.networks) == 1 self._value_module: nn.Module = list(self.model_composer.critic.networks.values())[0] # Init class values self._value_module_input = None self._model_maze_input: Optional[Dict[str, torch.Tensor]] = None