예제 #1
0
    def setup(self):
        self.metric_list = MetricList(
            [Metric("metric_1"),
             Metric("metric_2"),
             Metric("metric_2")])

        self.logging_callback = LoggingCallback()
예제 #2
0
 def test_get_value_no_prefix_no_logs(self, mocker):
     """ should return N/A when nothing is found. """
     expected_value = 'N/A'
     logging_callback = LoggingCallback()
     metric = Metric('attribute')
     value = logging_callback._get_value(metric)
     check.equal(value, expected_value)
예제 #3
0
 def test_reset_attr_avg(self):
     """ should reset arguments correctly. """
     logging_callback = LoggingCallback()
     logging_callback.reward = 0
     logging_callback._reset_attr('reward', 'avg')
     check.equal(logging_callback.reward, 'N/A')
     check.equal(logging_callback.reward_seen, 0)
예제 #4
0
    def test_on_step_end(self, mocker):
        """ should update episode_metrics for the current agent and
            steps_cycle_metrics for all agents on_step_end. """
        mocker.patch(self.logging_callback_path +
                     '._update_metrics_all_agents')
        mocker.patch(self.logging_callback_path + '._update_metrics')

        logging_callback = LoggingCallback()
        logging_callback.n_agents = 1
        logging_callback.on_step_end(step=7)

        args, kwargs = logging_callback._update_metrics.call_args
        check.equal(args[1], 'episode')
        check.is_false(kwargs.get('reset'))
        check.is_none(kwargs.get('agent_id'))

        args, kwargs = logging_callback._update_metrics_all_agents.call_args
        check.equal(args[1], 'steps_cycle')
        check.is_false(kwargs.get('reset'))

        # Multi agents
        logging_callback = LoggingCallback()
        logging_callback.n_agents = 5
        logging_callback.on_step_end(step=7, logs={'agent_id': 3})

        args, kwargs = logging_callback._update_metrics.call_args
        check.equal(args[1], 'episode')
        check.is_false(kwargs.get('reset'))
        check.equal(kwargs.get('agent_id'), 3)

        args, kwargs = logging_callback._update_metrics_all_agents.call_args
        check.equal(args[1], 'steps_cycle')
        check.is_false(kwargs.get('reset'))
예제 #5
0
 def test_on_step_cycle_begin(self, mocker):
     """ should reset steps_cycle_metrics on step_cycle_begin. """
     mocker.patch(self.logging_callback_path +
                  '._update_metrics_all_agents')
     logging_callback = LoggingCallback()
     logging_callback.on_steps_cycle_begin(step=7)
     args, kwargs = logging_callback._update_metrics_all_agents.call_args
     check.equal(args[1], 'steps_cycle')
     check.is_true(kwargs.get('reset'))
예제 #6
0
 def test_on_episode_end(self, mocker):
     """ should update episodes_cycle_metrics on_episode_end from episode metrics. """
     mocker.patch(self.logging_callback_path +
                  '._update_metrics_all_agents')
     logging_callback = LoggingCallback()
     logging_callback.on_episode_end(episode=7)
     args, kwargs = logging_callback._update_metrics_all_agents.call_args
     check.equal(args[1], 'episodes_cycle')
     check.is_false(kwargs.get('reset'))
     check.equal(kwargs.get('source_prefix'), 'episode')
예제 #7
0
    def test_on_run_begin(self, mocker):
        """ should set n_agents on_run_begin. """
        class DummyPlayground():
            """DummyPlaygrounf"""
            def __init__(self, n_agents):
                self.agents = [None] * n_agents

        n_agents = 5
        mocker.patch(self.logging_callback_path + '', return_value=n_agents)
        logging_callback = LoggingCallback()
        logging_callback.playground = DummyPlayground(n_agents)
        check.is_none(logging_callback.n_agents)
        logging_callback.on_run_begin()
        check.equal(logging_callback.n_agents, n_agents)
예제 #8
0
def test_logging_episodes_operators_(eps_operator, cycle_operator,
                                     metric_name):

    print(eps_operator, cycle_operator, '\n')

    logging_callback = LoggingCallback(
        detailed_step_metrics=[],
        metrics=[('reward', {
            'episode': eps_operator,
            'episodes': cycle_operator
        }), ('loss', {
            'episode': eps_operator,
            'episodes': cycle_operator
        })],
    )

    def check_function(callbacks, logs):
        callback_dict = callbacks.callbacks[0].__dict__

        for position in ('episode', 'episodes_cycle'):
            if position == 'episode':
                expected = logs.get(f'{metric_name}_episode_{eps_operator}')
            elif position == 'episodes_cycle':
                expected = logs.get(
                    f'{metric_name}_{eps_operator}_cycle_{cycle_operator}')

            logged = callback_dict[f'{position}_{metric_name}']
            if expected is not None:
                print(position.capitalize(), metric_name, logged, expected)
                assert logged != 'N/A' and np.isclose(logged, expected), \
                    f'Logged {logged} instead of {expected}'

    playground = DummyPlayground()
    playground.run([logging_callback], eps_end_func=check_function, verbose=1)
    print()
예제 #9
0
    def test_get_value_no_prefix_logs(self, mocker):
        """ should return value in logs when no prefix is given. """
        expected_value = 123

        def _extract_metric_from_logs(self, metric_name, logs, agent_id):
            return logs[metric_name]

        mocker.patch(
            'learnrl.callbacks.logging_callback.LoggingCallback._extract_metric_from_logs',
            _extract_metric_from_logs)

        logging_callback = LoggingCallback()
        logs = {'attribute': 123}

        metric = Metric('attribute')
        value = logging_callback._get_value(metric, logs=logs)

        check.equal(value, expected_value)
예제 #10
0
    def test_get_value_with_prefix(self, mocker):
        """ should return attr value when prefix is given. """
        expected_value = 123

        def _get_attr_name(*args):
            return 'attribute'

        mocker.patch(
            'learnrl.callbacks.logging_callback.LoggingCallback._get_attr_name',
            _get_attr_name)

        logging_callback = LoggingCallback()
        logging_callback.attribute = expected_value

        metric = Metric('attribute')
        value = logging_callback._get_value(metric, prefix='prefix')

        check.equal(value, expected_value)
예제 #11
0
def test_logging_steps_operators_(cycle_operator, metric_name):

    print(cycle_operator, metric_name)

    logging_callback = LoggingCallback(metrics=[('reward', {
        'steps': cycle_operator
    }), ('loss', {
        'steps': cycle_operator
    })], )

    def check_function(callbacks, logs):
        callback_dict = callbacks.callbacks[0].__dict__
        expected = logs.get(f'{metric_name}_steps_{cycle_operator}')
        logged = callback_dict[f'steps_cycle_{metric_name}']
        if expected is not None:
            print(metric_name, logged, expected)
            assert logged != 'N/A' and np.isclose(logged, expected), \
                f'Logged {logged} instead of {expected}'

    playground = DummyPlayground()
    playground.run([logging_callback], eps_end_func=check_function, verbose=3)
    print()
예제 #12
0
class TestLoggingCallbackUpdateMetrics:
    """ LoggingCallback._update_metrics """
    @pytest.fixture(autouse=True)
    def setup(self):
        self.metric_list = MetricList(
            [Metric("metric_1"),
             Metric("metric_2"),
             Metric("metric_2")])

        self.logging_callback = LoggingCallback()

    def test_na_value(self, mocker):
        """ should not update attr if value is N/A. """
        logging_callback_path = 'learnrl.callbacks.logging_callback.LoggingCallback'
        mocker.patch(logging_callback_path + '._get_attr_name',
                     return_value="target_name")
        mocker.patch(logging_callback_path + '._reset_attr')
        mocker.patch(logging_callback_path + '._get_value', return_value='N/A')
        mocker.patch(logging_callback_path + '._update_attr')

        self.logging_callback._update_metrics(self.metric_list,
                                              'target_prefix',
                                              source_prefix='source_prefix',
                                              logs='logs',
                                              agent_id='agent_id',
                                              reset=False)

        for args, _ in self.logging_callback._get_attr_name.call_args_list:
            check.equal(args[0], 'target_prefix')
            check.equal(args[2], 'agent_id')

        for args, _ in self.logging_callback._get_value.call_args_list:
            check.equal(args[1], 'source_prefix')
            check.equal(args[2], 'agent_id')
            check.equal(args[3], 'logs')

        check.is_false(self.logging_callback._reset_attr.called)
        check.is_false(self.logging_callback._update_attr.called)

    def test_update(self, mocker):
        """ should update attr if value is not N/A. """
        logging_callback_path = 'learnrl.callbacks.logging_callback.LoggingCallback'
        mocker.patch(logging_callback_path + '._get_attr_name',
                     return_value="target_name")
        mocker.patch(logging_callback_path + '._reset_attr')
        mocker.patch(logging_callback_path + '._get_value',
                     return_value="value")
        mocker.patch(logging_callback_path + '._update_attr')

        self.logging_callback._update_metrics(self.metric_list,
                                              'target_prefix',
                                              source_prefix='source_prefix',
                                              logs='logs',
                                              agent_id='agent_id',
                                              reset=False)

        for args, _ in self.logging_callback._get_attr_name.call_args_list:
            check.equal(args[0], 'target_prefix')
            check.equal(args[2], 'agent_id')

        for args, _ in self.logging_callback._get_value.call_args_list:
            check.equal(args[1], 'source_prefix')
            check.equal(args[2], 'agent_id')
            check.equal(args[3], 'logs')

        check.is_false(self.logging_callback._reset_attr.called)

        for args, _ in self.logging_callback._update_attr.call_args_list:
            check.equal(args[0], 'target_name')
            check.equal(args[1], 'value')

    def test_reset(self, mocker):
        """ should reset attr if reset is True. """
        logging_callback_path = 'learnrl.callbacks.logging_callback.LoggingCallback'
        mocker.patch(logging_callback_path + '._get_attr_name',
                     return_value="target_name")
        mocker.patch(logging_callback_path + '._reset_attr')
        mocker.patch(logging_callback_path + '._get_value',
                     return_value="value")
        mocker.patch(logging_callback_path + '._update_attr')

        self.logging_callback._update_metrics(self.metric_list,
                                              'target_prefix',
                                              source_prefix='source_prefix',
                                              logs='logs',
                                              agent_id='agent_id',
                                              reset=True)

        for args, _ in self.logging_callback._get_attr_name.call_args_list:
            check.equal(args[0], 'target_prefix')
            check.equal(args[2], 'agent_id')

        for args, _ in self.logging_callback._get_value.call_args_list:
            check.equal(args[1], 'source_prefix')
            check.equal(args[2], 'agent_id')
            check.equal(args[3], 'logs')

        for args, _ in self.logging_callback._reset_attr.call_args_list:
            check.equal(args[0], 'target_name')

        check.is_false(self.logging_callback._update_attr.called)

    def test_all_agents(self, mocker):
        """ should update all agents indepentently. """
        logging_callback_path = 'learnrl.callbacks.logging_callback.LoggingCallback'
        mocker.patch(logging_callback_path + '._update_metrics')

        n_agents = 5
        self.logging_callback.n_agents = n_agents
        self.logging_callback._update_metrics_all_agents(
            self.metric_list, 'target_prefix')

        for i in range(n_agents):
            _, kwargs = self.logging_callback._update_metrics.call_args_list[i]
            check.equal(kwargs.get('agent_id'), i)
예제 #13
0
 def test_update_attr_raise(self):
     """ should raise ValueError if operator is unknowed. """
     with pytest.raises(ValueError, match=r"Unknowed operator.*"):
         logging_callback = LoggingCallback()
         logging_callback.reward = 'N/A'
         logging_callback._update_attr('reward', 1, 'x')
예제 #14
0
 def test_update_attr_last(self):
     """ should update arguments correctly with last operator. """
     logging_callback = LoggingCallback()
     logging_callback.reward = 2
     logging_callback._update_attr('reward', 1, 'last')
     check.equal(logging_callback.reward, 1)
예제 #15
0
    def test_update_attr_avg(self):
        """ should update arguments correctly with avg operator. """
        logging_callback = LoggingCallback()
        logging_callback.reward = 0
        logging_callback.reward_seen = 2
        logging_callback._update_attr('reward', 1, 'avg')
        check.equal(logging_callback.reward, 1 / 3)
        check.equal(logging_callback.reward_seen, 3)

        # With N/A
        logging_callback = LoggingCallback()
        logging_callback.reward = 'N/A'
        logging_callback.reward_seen = 0
        logging_callback._update_attr('reward', 1, 'avg')
        check.equal(logging_callback.reward, 1)
        check.equal(logging_callback.reward_seen, 1)
예제 #16
0
 def test_init(self):
     """ should instanciate correctly. """
     LoggingCallback()