Esempio n. 1
0
 def step(self, action):
     action = np.clip(action, self.action_space.low, self.action_space.high)
     if self._flatten_actions:
         action = spaces.unflatten(self.env.action_space, action)
     obs, reward, done, info = self.env.step(action)
     if self._flatten_obs:
         obs = spaces.flatten(self.env.observation_space, obs)
     return obs, reward, done, info
Esempio n. 2
0
    def test_flatten_unflatten(self, observation_space, ordered_values):
        """
        test flatten and unflatten functions directly
        """
        original = observation_space.sample()

        flattened = flatten(observation_space, original)
        unflattened = unflatten(observation_space, flattened)

        self._check_observations(original, flattened, unflattened,
                                 ordered_values)
Esempio n. 3
0
    def test_flattened_environment(self, observation_space, ordered_values):
        """
        make sure that flattened observations occur in the order expected
        """
        env = FakeEnvironment(observation_space=observation_space)
        wrapped_env = FlattenObservation(env)
        flattened = wrapped_env.reset()

        unflattened = unflatten(env.observation_space, flattened)
        original = env.observation

        self._check_observations(original, flattened, unflattened,
                                 ordered_values)
Esempio n. 4
0
    def get_qtable(self, values_fmt='{:.2g}'):
        # Format states
        if hasattr(self.env, 'format_state'):
            unflatten_f = lambda x: spaces.unflatten(self.env.observation_space, x)
            states = map(self.env.format_state, map(unflatten_f, self.q_table.keys()))
        else:
            states = ['state {}'.format(i) for i in range(len(self.q_table))]

        # Format actions
        actions = map(self.env.format_action, range(self.env.action_space.n))

        # Create, format and render DataFrame
        df = pd.DataFrame(self.q_table.values(), list(states), list(actions))
        df = df.applymap(values_fmt.format)

        return df
Esempio n. 5
0
 def transform(self, attr: AttributationLike) -> AttributationLike:
     obs_space = self.obs_space
     if self.obs_image_channel_dim is not None:
         attr = np.sum(attr, axis=self.obs_image_channel_dim)
         obs_space = remove_channel_dim_from_image_space(obs_space)
     attr = flatten(self.obs_space, attr)
     if self.mode == AttributationNormalizationMode.ALL:
         scaling_factor = self._calculate_safe_scaling_factor(np.abs(attr))
     elif self.mode == AttributationNormalizationMode.POSITIVE:
         attr = (attr > 0) * attr
         scaling_factor = self._calculate_safe_scaling_factor(attr)
     elif self.mode == AttributationNormalizationMode.NEGATIVE:
         attr = (attr < 0) * attr
         scaling_factor = -self._calculate_safe_scaling_factor(np.abs(attr))
     elif self.mode == AttributationNormalizationMode.ABSOLUTE_VALUE:
         attr = np.abs(attr)
         scaling_factor = self._calculate_safe_scaling_factor(attr)
     else:
         raise EnumValueNotFound(self.mode, AttributationNormalizationMode)
     attr_norm = self._scale(attr, scaling_factor)
     return unflatten(obs_space, attr_norm)
Esempio n. 6
0
    def get_observation(self, env: EnvType, task: SubTaskType, *args: Any,
                        **kwargs: Any) -> Any:
        # If the task is completed, we needn't (perhaps can't) find the expert
        # action from the (current) terminal state.
        if task.is_done():
            return self._zeroed_observation

        action, expert_was_successful = task.query_expert(**self.expert_args)

        if isinstance(action, int):
            assert isinstance(self.action_space, gym.spaces.Discrete)
            unflattened_action = action
        else:
            # Assume we receive a gym-flattened numpy action
            unflattened_action = gyms.unflatten(self.action_space, action)

        unflattened_torch = su.torch_point(
            self.unflattened_observation_space,
            (unflattened_action, expert_was_successful),
        )

        flattened_torch = su.flatten(self.unflattened_observation_space,
                                     unflattened_torch)
        return flattened_torch.cpu().numpy()
 def action(self, action):
     if self.CONTINUOUS == False:
         return {key: value for key, value in zip(self.labels, action)}
     else:
         return unflatten(self.env.action_space, action)
Esempio n. 8
0
    def inspect_memory(self, top_n=10, max_col=80):
        # Functions to encode/decode states
        encode_state = lambda s: tuple(
            spaces.flatten(self.env.observation_space, s))
        decode_state = lambda s: spaces.unflatten(self.env.observation_space, s
                                                  )

        # Function to create barchart from counter
        def count_barchart(counter, ax, xlabel=None, normalize=True):
            # Sort and extract key, counts
            sorted_tuples = counter.most_common()
            sorted_keys = [key for key, count in sorted_tuples]
            sorted_counts = [count for key, count in sorted_tuples]

            # Normalize counts
            if normalize:
                total = sum(counters['reward'].values())
                sorted_counts = [c / total for c in sorted_counts]

            # Plotting
            x_indexes = range(len(sorted_counts))
            ax.bar(x_indexes, sorted_counts)
            ax.set_xticks(x_indexes)
            ax.set_xticklabels(sorted_keys)
            ax.set_ylabel('proportion')
            if xlabel is not None:
                ax.set_xlabel(xlabel)
            ax.set_title('Replay Memory')

        # Function to print top states from counter
        def top_states(counter):
            for i, (state, count) in enumerate(counter.most_common(top_n), 1):
                state_label = str(decode_state(state))
                state_label = state_label.replace('\n', ' ')
                state_label = state_label[:max_col] + '..' if len(
                    state_label) > max_col else state_label
                print('{:>2}) Count: {} state: {}'.format(
                    i, count, state_label))

        # Count statistics
        counters = defaultdict(Counter)
        for state, action, reward, next_state, done in self.memory:
            counters['state'][encode_state(state)] += 1
            counters['action'][action] += 1
            counters['reward'][reward] += 1
            counters['next_state'][encode_state(next_state)] += 1
            counters['done'][done] += 1

        # Plot reward/action
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 4))
        count_barchart(counters['reward'], ax1, 'rewards')
        count_barchart(counters['action'], ax2, 'actions')
        plt.plot()
        plt.show()

        # Print top states
        print('Top state:')
        top_states(counters['state'])
        print()

        print('Top next_state:')
        top_states(counters['next_state'])
        print()

        # Done signal
        print('Proportion of done: {:.2f}%'.format(
            100 * counters['done'][True] / sum(counters['done'].values())))