def setup(self): self.metric_list = MetricList( [Metric("metric_1"), Metric("metric_2"), Metric("metric_2")]) self.logging_callback = LoggingCallback()
def test_add_metriclist(self): """ should concatenate correclty two MetricList. """ metriclist = MetricList(self.metrics) exploration = Metric('exploration~exp.last') decay = Metric('decay.last') metriclist += MetricList([exploration, decay]) check.equal(metriclist.metric_names, self.metrics + [exploration, decay]) expected_metric_names = [metric.name for metric in self.metrics ] + ['exploration', 'decay'] check.equal(metriclist.metric_names, expected_metric_names)
def test_add_metric_codes_list(self): """ should add correclty a new list of metric codes. """ metriclist = MetricList(self.metrics) exploration = Metric('exploration~exp.last') decay = Metric('decay.last') metriclist += ['exploration~exp.last', 'decay.last'] check.equal(metriclist.metric_names, self.metrics + [exploration, decay]) expected_metric_names = [metric.name for metric in self.metrics ] + ['exploration', 'decay'] check.equal(metriclist.metric_names, expected_metric_names)
def test_get_value_no_prefix_no_logs(self, mocker): """ should return N/A when nothing is found. """ expected_value = 'N/A' logging_callback = LoggingCallback() metric = Metric('attribute') value = logging_callback._get_value(metric) check.equal(value, expected_value)
def test_init_full(self): """ should instanciate correctly with a full metric code. """ metric = Metric('reward~rwd.sum') check.equal(metric.name, 'reward', f'Metric name should be reward and not {metric.name}') check.equal(metric.surname, 'rwd', f'Metric surname should be rwd and not {metric.surname}') check.equal(metric.operator, 'sum', f'Metric opertor should be sum and not {metric.operator}')
def test_init_no_op(self): """ should instanciate correctly without an operator. """ metric = Metric('reward~rwd') check.equal(metric.name, 'reward', f'Metric name should be reward and not {metric.name}') check.equal(metric.surname, 'rwd', f'Metric surname should be rwd and not {metric.surname}') check.equal(metric.operator, 'avg', f'Metric opertor should be avg and not {metric.operator}')
def test_add_metric_code(self): """ should add correclty a new metric code. """ metriclist = MetricList(self.metrics) metric = Metric('exploration~exp.last') metriclist += 'exploration~exp.last' check.equal(metriclist.metric_names, self.metrics + [metric]) expected_metric_names = [metric.name for metric in self.metrics] + ['exploration'] check.equal(metriclist.metric_names, expected_metric_names)
def test_init_no_surname(self): """ should instanciate correctly without a surname. """ metric = Metric('reward.sum') check.equal(metric.name, 'reward', f'Metric name should be reward and not {metric.name}') check.equal( metric.surname, 'reward', f'Metric surname should be reward and not {metric.surname}') check.equal(metric.operator, 'sum', f'Metric opertor should be sum and not {metric.operator}')
def test_get_value_no_prefix_logs(self, mocker): """ should return value in logs when no prefix is given. """ expected_value = 123 def _extract_metric_from_logs(self, metric_name, logs, agent_id): return logs[metric_name] mocker.patch( 'learnrl.callbacks.logging_callback.LoggingCallback._extract_metric_from_logs', _extract_metric_from_logs) logging_callback = LoggingCallback() logs = {'attribute': 123} metric = Metric('attribute') value = logging_callback._get_value(metric, logs=logs) check.equal(value, expected_value)
def test_get_value_with_prefix(self, mocker): """ should return attr value when prefix is given. """ expected_value = 123 def _get_attr_name(*args): return 'attribute' mocker.patch( 'learnrl.callbacks.logging_callback.LoggingCallback._get_attr_name', _get_attr_name) logging_callback = LoggingCallback() logging_callback.attribute = expected_value metric = Metric('attribute') value = logging_callback._get_value(metric, prefix='prefix') check.equal(value, expected_value)
def test_repr(self): """ should be represented as full metric code with repr. """ metric = Metric('reward~rwd.sum') check.is_true(repr(metric) == 'reward~rwd.sum')
def setup_metrics(self): """ Setup metrics for tests """ self.metric_codes = ['reward~rwd.sum', 'loss_1', 'loss_2.sum'] self.metrics = [Metric(code) for code in self.metric_codes]
def test_str(self): """ should be represented as name with str. """ metric = Metric('reward~rwd.sum') check.is_true(str(metric) == 'reward')
def test_equal_metric(self): """ should be equal to a metric with same name. """ metric_sum = Metric('reward~rwd.sum') metric_avg = Metric('reward~R.avg') check.is_true(metric_sum == metric_avg)
def test_equal_str(self): """ should be equal to a string equal to its name only. """ metric = Metric('reward~rwd.sum') check.is_true(metric == 'reward') check.is_false(metric == 'rewards') check.is_false(metric == 'rwd')
def test_without_agent(self): """ should name attrs correctly. """ metric = Metric('reward~rwd') name = self.get_attr_name('prefix', metric, agent_id=None) expected_name = 'prefix_reward' check.equal(name, expected_name)
def test_with_specific_agent(self): """ should name attrs correctly with specific agent. """ metric = Metric('reward~rwd') name = self.get_attr_name('prefix', metric, agent_id=2) expected_name = 'prefix_agent2_reward' check.equal(name, expected_name)