def get_action_dist(action_space: gym.Space, config: ModelConfigDict, dist_type: Optional[Union[ str, Type[ActionDistribution]]] = None, framework: str = "tf", **kwargs) -> (type, int): """Returns a distribution class and size for the given action space. Args: action_space (Space): Action space of the target gym env. config (Optional[dict]): Optional model config. dist_type (Optional[Union[str, Type[ActionDistribution]]]): Identifier of the action distribution (str) interpreted as a hint or the actual ActionDistribution class to use. framework (str): One of "tf2", "tf", "tfe", "torch", or "jax". kwargs (dict): Optional kwargs to pass on to the Distribution's constructor. Returns: Tuple: - dist_class (ActionDistribution): Python class of the distribution. - dist_dim (int): The size of the input vector to the distribution. """ dist_cls = None config = config or MODEL_DEFAULTS # Custom distribution given. if config.get("custom_action_dist"): custom_action_config = config.copy() action_dist_name = custom_action_config.pop("custom_action_dist") logger.debug( "Using custom action distribution {}".format(action_dist_name)) dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) return ModelCatalog._get_multi_action_distribution( dist_cls, action_space, custom_action_config, framework) # Dist_type is given directly as a class. elif (type(dist_type) is type and issubclass(dist_type, ActionDistribution) and dist_type not in (MultiActionDistribution, TorchMultiActionDistribution)): dist_cls = dist_type # Box space -> DiagGaussian OR Deterministic. elif isinstance(action_space, Box): if action_space.dtype.name.startswith("int"): low_ = np.min(action_space.low) high_ = np.max(action_space.high) dist_cls = (TorchMultiCategorical if framework == "torch" else MultiCategorical) num_cats = int(np.product(action_space.shape)) return ( partial( dist_cls, input_lens=[high_ - low_ + 1 for _ in range(num_cats)], action_space=action_space, ), num_cats * (high_ - low_ + 1), ) else: if len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space has multiple dimensions " "{}. ".format(action_space.shape) + "Consider reshaping this into a single dimension, " "using a custom action distribution, " "using a Tuple action space, or the multi-agent API.") # TODO(sven): Check for bounds and return SquashedNormal, etc.. if dist_type is None: return ( partial( TorchDiagGaussian if framework == "torch" else DiagGaussian, action_space=action_space, ), DiagGaussian.required_model_output_shape( action_space, config), ) elif dist_type == "deterministic": dist_cls = (TorchDeterministic if framework == "torch" else Deterministic) # Discrete Space -> Categorical. elif isinstance(action_space, Discrete): dist_cls = (TorchCategorical if framework == "torch" else JAXCategorical if framework == "jax" else Categorical) # Tuple/Dict Spaces -> MultiAction. elif (dist_type in ( MultiActionDistribution, TorchMultiActionDistribution, ) or isinstance(action_space, (Tuple, Dict))): return ModelCatalog._get_multi_action_distribution( (MultiActionDistribution if framework == "tf" else TorchMultiActionDistribution), action_space, config, framework, ) # Simplex -> Dirichlet. elif isinstance(action_space, Simplex): if framework == "torch": # TODO(sven): implement raise NotImplementedError( "Simplex action spaces not supported for torch.") dist_cls = Dirichlet # MultiDiscrete -> MultiCategorical. elif isinstance(action_space, MultiDiscrete): dist_cls = (TorchMultiCategorical if framework == "torch" else MultiCategorical) return partial(dist_cls, input_lens=action_space.nvec), int( sum(action_space.nvec)) # Unknown type -> Error. else: raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) return dist_cls, dist_cls.required_model_output_shape( action_space, config)
def visualize_adversaries(rllib_config, checkpoint, grid_size, num_rollouts, outdir): env, agent, multiagent, use_lstm, policy_agent_mapping, state_init, action_init = \ instantiate_rollout(rllib_config, checkpoint) # figure out how many adversaries you have and initialize their grids num_adversaries = env.num_adversaries adversary_grid_dict = {} kl_grid = np.zeros((num_adversaries, num_adversaries)) for i in range(num_adversaries): adversary_str = 'adversary' + str(i) # each adversary grid is a map of agent action versus observation dimension adversary_grid = np.zeros((grid_size - 1, grid_size - 1, env.observation_space.low.shape[0], env.adv_action_space.low.shape[0])).astype(int) strength_grid = np.linspace(env.adv_action_space.low, env.adv_action_space.high, grid_size).T obs_grid = np.linspace(env.observation_space.low, env.observation_space.high, grid_size).T adversary_grid_dict[adversary_str] = {'grid': adversary_grid, 'action_bins': strength_grid, 'obs_bins': obs_grid, 'action_list': []} total_steps = 0 # env.should_render = True # actually do the rollout for r_itr in range(num_rollouts): print('On iteration {}'.format(r_itr)) mapping_cache = {} # in case policy_agent_mapping is stochastic agent_states = DefaultMapping( lambda agent_id: state_init[mapping_cache[agent_id]]) prev_actions = DefaultMapping( lambda agent_id: action_init[mapping_cache[agent_id]]) obs = env.reset() prev_rewards = collections.defaultdict(lambda: 0.) done = False reward_total = 0.0 step_num = 0 while not done: multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} obs = multi_obs['agent'] * env.obs_norm if isinstance(env.adv_observation_space, dict): multi_obs = {'adversary{}'.format(i): {'obs': obs, 'is_active': np.array([1])} for i in range(env.num_adversaries)} else: multi_obs = {'adversary{}'.format(i): obs for i in range(env.num_adversaries)} multi_obs.update({'agent': obs}) action_dict = {} action_dist_dict = {} for agent_id, a_obs in multi_obs.items(): if a_obs is not None: policy_id = mapping_cache.setdefault( agent_id, policy_agent_mapping(agent_id)) policy = agent.get_policy(policy_id) p_use_lstm = use_lstm[policy_id] if p_use_lstm: prev_action = _flatten_action(prev_actions[agent_id]) a_action, p_state, _ = agent.compute_action( a_obs, state=agent_states[agent_id], prev_action=prev_action, prev_reward=prev_rewards[agent_id], policy_id=policy_id) agent_states[agent_id] = p_state if isinstance(a_obs, dict): flat_obs = np.concatenate([val for val in a_obs.values()])[np.newaxis, :] else: flat_obs = _flatten_action(a_obs)[np.newaxis, :] logits, _ = policy.model.from_batch({"obs": flat_obs, "prev_action": prev_action}) else: if isinstance(a_obs, dict): flat_obs = np.concatenate([val for val in a_obs.values()])[np.newaxis, :] else: flat_obs = _flatten_action(a_obs)[np.newaxis, :] logits, _ = policy.model.from_batch({"obs": flat_obs}) prev_action = _flatten_action(prev_actions[agent_id]) flat_action = _flatten_action(a_obs) a_action = agent.compute_action( flat_action, prev_action=prev_action, prev_reward=prev_rewards[agent_id], policy_id=policy_id) # handle the tuple case if len(a_action) > 1: if isinstance(a_action[0], np.ndarray): a_action[0] = a_action[0].flatten() action_dict[agent_id] = a_action action_dist_dict[agent_id] = DiagGaussian(logits, None) prev_action = _flatten_action(a_action) # tuple actions prev_actions[agent_id] = prev_action # Now store the agent action in the corresponding grid if agent_id != 'agent': action_bins = adversary_grid_dict[agent_id]['action_bins'] obs_bins = adversary_grid_dict[agent_id]['obs_bins'] heat_map = adversary_grid_dict[agent_id]['grid'] for action_loop_index, action in enumerate(a_action): adversary_grid_dict[agent_id]['action_list'].append(a_action[0]) action_index = np.digitize(action, action_bins[action_loop_index, :]) - 1 # digitize will set the right edge of the box to the wrong value if action_index == heat_map.shape[0]: action_index -= 1 for obs_loop_index, obs_elem in enumerate(obs): obs_index = np.digitize(obs_elem, obs_bins[obs_loop_index, :]) - 1 if obs_index == heat_map.shape[1]: obs_index -= 1 heat_map[action_index, obs_index, obs_loop_index, action_loop_index] += 1 for agent_id in multi_obs.keys(): if agent_id != 'agent': # Now iterate through the agents and compute the kl_diff curr_id = int(agent_id.split('adversary')[1]) your_action_dist = action_dist_dict[agent_id] # mean, log_std = np.split(your_logits[0], 2) for i in range(num_adversaries): # KL diff of something with itself is zero if i == curr_id: pass # otherwise just compute the kl difference between the agents else: other_action_dist = action_dist_dict['adversary{}'.format(i)] # other_mean, other_log_std = np.split(other_logits.numpy()[0], 2) kl_diff = your_action_dist.kl(other_action_dist) kl_grid[curr_id, i] += kl_diff action = action_dict action = action if multiagent else action[_DUMMY_AGENT_ID] # we turn the adversaries off so you only send in the pendulum keys new_dict = {} new_dict.update({'agent': action['agent']}) next_obs, reward, done, info = env.step(new_dict) if isinstance(done, dict): done = done['__all__'] step_num += 1 if multiagent: for agent_id, r in reward.items(): prev_rewards[agent_id] = r else: prev_rewards[_DUMMY_AGENT_ID] = reward # we only want the robot reward, not the adversary reward reward_total += info['agent']['agent_reward'] obs = next_obs total_steps += step_num file_path = os.path.dirname(os.path.abspath(__file__)) output_file_path = os.path.join(file_path, outdir) if not os.path.exists(output_file_path): try: os.makedirs(os.path.dirname(output_file_path)) except OSError as exc: if exc.errno != errno.EEXIST: raise # Plot the heatmap of the actions for adversary, adv_dict in adversary_grid_dict.items(): heat_map = adv_dict['grid'] action_bins = adv_dict['action_bins'] obs_bins = adv_dict['obs_bins'] action_list = adv_dict['action_list'] plt.figure() sns.distplot(action_list) output_str = '{}/{}'.format(outdir, adversary + 'action_histogram.png') plt.savefig(output_str) # x_label, y_label = env.transform_adversary_actions(bins) # ax = sns.heatmap(heat_map, annot=True, fmt="d") xtitles = ['x', 'xdot', 'theta', 'thetadot'] ytitles = ['ax', 'ay'] for obs_idx in range(heat_map.shape[-2]): for a_idx in range(heat_map.shape[-1]): plt.figure() # increasing the row index implies moving down on the y axis sns.heatmap(heat_map[:, :, obs_idx, a_idx], yticklabels=np.round(action_bins[0], 1), xticklabels=np.round(obs_bins[i], 1)) plt.ylabel(ytitles[a_idx]) plt.xlabel(xtitles[obs_idx]) output_str = '{}/{}'.format(outdir, adversary + 'action_heatmap_{}_{}.png'.format(xtitles[obs_idx], ytitles[a_idx])) plt.savefig(output_str) # Plot the kl difference between agents plt.figure() sns.heatmap(kl_grid / total_steps) output_str = '{}/{}'.format(outdir, 'kl_heatmap.png') plt.savefig(output_str)