コード例 #1
0
def test_filter():
    # make an RGB observation smaller
    squeeze_filter = InputFilter()
    squeeze_filter.add_observation_filter('observation', 'squeeze',
                                          ObservationSqueezeFilter())
    squeeze_filter_with_axis = InputFilter()
    squeeze_filter_with_axis.add_observation_filter(
        'observation', 'squeeze', ObservationSqueezeFilter(2))

    observation = np.random.rand(20, 30, 1, 3)
    env_response = EnvResponse(next_state={'observation': observation},
                               reward=0,
                               game_over=False)

    result = squeeze_filter.filter(env_response)[0]
    result_with_axis = squeeze_filter_with_axis.filter(env_response)[0]
    unfiltered_observation_shape = env_response.next_state['observation'].shape
    filtered_observation_shape = result.next_state['observation'].shape
    filtered_observation_with_axis_shape = result_with_axis.next_state[
        'observation'].shape

    # make sure the original observation is unchanged
    assert unfiltered_observation_shape == observation.shape

    # make sure the filtering is done correctly
    assert filtered_observation_shape == (20, 30, 3)
    assert filtered_observation_with_axis_shape == (20, 30, 3)

    observation = np.random.rand(1, 30, 1, 3)
    env_response = EnvResponse(next_state={'observation': observation},
                               reward=0,
                               game_over=False)

    result = squeeze_filter.filter(env_response)[0]
    assert result.next_state['observation'].shape == (30, 3)
def test_filter():
    # make an RGB observation smaller
    env_response = EnvResponse(
        next_state={'observation': np.ones([20, 30, 3])},
        reward=0,
        game_over=False)
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale', ObservationRescaleSizeByFactorFilter(0.5))

    result = rescale_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.shape == (20, 30, 3)

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (10, 15, 3)

    # make a grayscale observation bigger
    env_response = EnvResponse(next_state={'observation': np.ones([20, 30])},
                               reward=0,
                               game_over=False)
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale', ObservationRescaleSizeByFactorFilter(2))
    result = rescale_filter.filter(env_response)[0]
    filtered_observation = result.next_state['observation']

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (40, 60)
    assert np.all(filtered_observation == np.ones([40, 60]))
def test_filter(env_response):
    crop_low = np.array([0, 5, 10])
    crop_high = np.array([5, 10, 20])
    crop_filter = InputFilter()
    crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high))

    result = crop_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (5, 5, 10)

    # validate the content of the filtered observation
    assert np.all(filtered_observation == unfiltered_observation[0:5, 5:10, 10:20])

    # crop with -1 on some axes
    crop_low = np.array([0, 0, 0])
    crop_high = np.array([5, -1, -1])
    crop_filter = InputFilter()
    crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high))

    result = crop_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (5, 20, 30)

    # validate the content of the filtered observation
    assert np.all(filtered_observation == unfiltered_observation[0:5, :, :])
def test_filter():
    # make an RGB observation smaller
    transition = EnvResponse(next_state={'observation': np.ones([20, 30, 3])},
                             reward=0,
                             game_over=False)
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale',
        ObservationRescaleToSizeFilter(
            ImageObservationSpace(np.array([10, 20, 3]), high=255)))

    result = rescale_filter.filter(transition)[0]
    unfiltered_observation = transition.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.shape == (20, 30, 3)

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (10, 20, 3)
    assert np.all(filtered_observation == np.ones([10, 20, 3]))

    # make a grayscale observation bigger
    transition = EnvResponse(next_state={'observation': np.ones([20, 30])},
                             reward=0,
                             game_over=False)
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale',
        ObservationRescaleToSizeFilter(
            ImageObservationSpace(np.array([40, 60]), high=255)))
    result = rescale_filter.filter(transition)[0]
    filtered_observation = result.next_state['observation']

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (40, 60)
    assert np.all(filtered_observation == np.ones([40, 60]))

    # rescale channels -> error
    # with pytest.raises(ValueError):
    #     InputFilter(
    #         observation_filters=OrderedDict([('rescale',
    #                                          ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]),
    #                                                                                               high=255)
    #                                                                        ))]))

    # TODO: validate input to filter
    # different number of axes -> error
    # env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
    # rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20]))
    #                                                 )
    # with pytest.raises(ValueError):
    #     result = rescale_filter.filter(transition)

    # channels first -> error
    with pytest.raises(ValueError):
        ObservationRescaleToSizeFilter(
            ImageObservationSpace(np.array([3, 10, 20]), high=255))
コード例 #5
0
def test_filter():
    rescale_filter = InputFilter(reward_filters=OrderedDict([('rescale', RewardRescaleFilter(1/10.))]))
    env_response = EnvResponse(next_state={'observation': np.zeros(10)}, reward=100, game_over=False)
    print(rescale_filter.observation_filters)
    result = rescale_filter.filter(env_response)[0]
    unfiltered_reward = env_response.reward
    filtered_reward = result.reward

    # validate that the reward was clipped correctly
    assert filtered_reward == 10

    # make sure the original reward is unchanged
    assert unfiltered_reward == 100

    # negative reward
    env_response = EnvResponse(next_state={'observation': np.zeros(10)}, reward=-50, game_over=False)
    result = rescale_filter.filter(env_response)[0]
    assert result.reward == -5
def test_filter():
    # Keep
    observation_space = VectorObservationSpace(
        3, measurements_names=['a', 'b', 'c'])
    env_response = EnvResponse(next_state={'observation': np.ones([3])},
                               reward=0,
                               game_over=False)
    reduction_filter = InputFilter()
    reduction_filter.add_observation_filter(
        'observation', 'reduce',
        ObservationReductionBySubPartsNameFilter(
            ["a"],
            ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep))

    reduction_filter.get_filtered_observation_space('observation',
                                                    observation_space)
    result = reduction_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.shape == (3, )

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (1, )

    # Discard
    reduction_filter = InputFilter()
    reduction_filter.add_observation_filter(
        'observation', 'reduce',
        ObservationReductionBySubPartsNameFilter(
            ["a"],
            ObservationReductionBySubPartsNameFilter.ReductionMethod.Discard))
    reduction_filter.get_filtered_observation_space('observation',
                                                    observation_space)
    result = reduction_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.shape == (3, )

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (2, )
コード例 #7
0
    def load_csv(self, csv_dataset: CsvDataset, input_filter: InputFilter) -> None:
        """
        Restore the replay buffer contents from a csv file.
        The csv file is assumed to include a list of transitions.
        :param csv_dataset: A construct which holds the dataset parameters
        :param input_filter: A filter used to filter the CSV data before feeding it to the memory.
        """
        self.assert_not_frozen()

        df = pd.read_csv(csv_dataset.filepath)
        if len(df) > self.max_size[1]:
            screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is "
                           "bigger than the max size of the replay buffer ({}). The excessive transitions will "
                           "not be stored.".format(len(df), self.max_size[1]))

        episode_ids = df['episode_id'].unique()
        progress_bar = ProgressBar(len(episode_ids))
        state_columns = [col for col in df.columns if col.startswith('state_feature')]

        for e_id in episode_ids:
            progress_bar.update(e_id)
            df_episode_transitions = df[df['episode_id'] == e_id]
            input_filter.reset()

            if len(df_episode_transitions) < 2:
                # we have to have at least 2 rows in each episode for creating a transition
                continue

            episode = Episode()
            transitions = []
            for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
                                                                     df_episode_transitions[1:].iterrows()):
                state = np.array([current_transition[col] for col in state_columns])
                next_state = np.array([next_transition[col] for col in state_columns])

                transitions.append(
                    Transition(state={'observation': state},
                               action=int(current_transition['action']), reward=current_transition['reward'],
                               next_state={'observation': next_state}, game_over=False,
                               info={'all_action_probabilities':
                                         ast.literal_eval(current_transition['all_action_probabilities'])}),
                    )

            transitions = input_filter.filter(transitions, deep_copy=False)
            for t in transitions:
                episode.insert(t)

            # Set the last transition to end the episode
            if csv_dataset.is_episodic:
                episode.get_last_transition().game_over = True

            self.store_episode(episode)

        # close the progress bar
        progress_bar.update(len(episode_ids))
        progress_bar.close()
コード例 #8
0
def test_filter_stacking():
    # test that filter stacking works fine by taking as input a transition with:
    # - an observation of shape 210x160,
    # - a reward of 100
    # filtering it by:
    # - rescaling the observation to 110x84
    # - cropping the observation to 84x84
    # - clipping the reward to 1
    # - stacking 4 observations to get 84x84x4

    env_response = EnvResponse({'observation': np.ones([210, 160])}, reward=100, game_over=False)

    filter1 = ObservationRescaleToSizeFilter(
        output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255),
    )

    filter2 = ObservationCropFilter(
        crop_low=np.array([16, 0]),
        crop_high=np.array([100, 84])
    )

    filter3 = RewardClippingFilter(
        clipping_low=-1,
        clipping_high=1
    )

    output_filter = ObservationStackingFilter(
        stack_size=4,
        stacking_axis=-1
    )

    input_filter = InputFilter(
        observation_filters={
            "observation": OrderedDict([
                ("filter1", filter1),
                ("filter2", filter2),
                ("output_filter", output_filter)
            ])},
        reward_filters=OrderedDict([
            ("filter3", filter3)
        ])
    )

    result = input_filter.filter(env_response)[0]
    observation = np.array(result.next_state['observation'])
    assert observation.shape == (84, 84, 4)
    assert np.all(observation == np.ones([84, 84, 4]))
    assert result.reward == 1
def test_filter():
    # make an RGB observation smaller
    uint8_filter = InputFilter()
    uint8_filter.add_observation_filter(
        'observation', 'to_uint8',
        ObservationToUInt8Filter(input_low=0, input_high=255))

    observation = np.random.rand(20, 30, 3) * 255.0
    env_response = EnvResponse(next_state={'observation': observation},
                               reward=0,
                               game_over=False)

    result = uint8_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.dtype == 'float64'

    # make sure the filtering is done correctly
    assert filtered_observation.dtype == 'uint8'
    assert np.all(filtered_observation == observation.astype('uint8'))