def test_eq(self): x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") self.assertEqual(x1, x2) self.assertNotEqual(x1, x3) self.assertNotEqual(x1, object())
def test_list_scalars(self): res = data_provider_pb2.ListScalarsResponse() run1 = res.runs.add(run_name="val") tag11 = run1.tags.add(tag_name="accuracy") tag11.metadata.max_step = 7 tag11.metadata.max_wall_time = 7.77 tag11.metadata.summary_metadata.plugin_data.content = b"magic" tag11.metadata.summary_metadata.display_name = "Accuracy" tag11.metadata.summary_metadata.summary_description = "hey" tag12 = run1.tags.add(tag_name="xent") tag12.metadata.max_step = 8 tag12.metadata.max_wall_time = 8.88 run2 = res.runs.add(run_name="test") tag21 = run2.tags.add(tag_name="accuracy") tag21.metadata.max_step = 9 tag21.metadata.max_wall_time = 9.99 self.stub.ListScalars.return_value = res actual = self.provider.list_scalars( self.ctx, experiment_id="123", plugin_name="scalars", run_tag_filter=provider.RunTagFilter(tags=["xent", "accuracy"]), ) expected = { "val": { "accuracy": provider.ScalarTimeSeries( max_step=7, max_wall_time=7.77, plugin_content=b"magic", description="hey", display_name="Accuracy", ), "xent": provider.ScalarTimeSeries( max_step=8, max_wall_time=8.88, plugin_content=b"", description="", display_name="", ), }, "test": { "accuracy": provider.ScalarTimeSeries( max_step=9, max_wall_time=9.99, plugin_content=b"", description="", display_name="", ), }, } self.assertEqual(actual, expected) req = data_provider_pb2.ListScalarsRequest() req.experiment_id = "123" req.plugin_filter.plugin_name = "scalars" req.run_tag_filter.tags.names.extend(["accuracy", "xent"]) # sorted self.stub.ListScalars.assert_called_once_with(req)
def test_hash(self): x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") self.assertEqual(hash(x1), hash(x2)) # The next check is technically not required by the `__hash__` # contract, but _should_ pass; failure on this assertion would at # least warrant some scrutiny. self.assertNotEqual(hash(x1), hash(x3))
def list_scalars( self, ctx, *, experiment_id, plugin_name, run_tag_filter=None ): with timing.log_latency("build request"): req = data_provider_pb2.ListScalarsRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) with timing.log_latency("_stub.ListScalars"): with _translate_grpc_error(): res = self._stub.ListScalars(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.ScalarTimeSeries( max_step=time_series.max_step, max_wall_time=time_series.max_wall_time, plugin_content=time_series.summary_metadata.plugin_data.content, description=time_series.summary_metadata.summary_description, display_name=time_series.summary_metadata.display_name, ) return result
def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): del experiment_id # ignored for now run_tag_content = self._multiplexer.PluginRunToTagToContent( plugin_name) result = {} if run_tag_filter is None: run_tag_filter = provider.RunTagFilter(runs=None, tags=None) for (run, tag_to_content) in six.iteritems(run_tag_content): result_for_run = {} for tag in tag_to_content: if not self._test_run_tag(run_tag_filter, run, tag): continue result[run] = result_for_run max_step = None max_wall_time = None for event in self._multiplexer.Tensors(run, tag): if max_step is None or max_step < event.step: max_step = event.step if max_wall_time is None or max_wall_time < event.wall_time: max_wall_time = event.wall_time summary_metadata = self._multiplexer.SummaryMetadata(run, tag) result_for_run[tag] = provider.ScalarTimeSeries( max_step=max_step, max_wall_time=max_wall_time, plugin_content=summary_metadata.plugin_data.content, description=summary_metadata.summary_description, display_name=summary_metadata.display_name, ) return result
def _scalar_time_series(self, max_step, max_wall_time, plugin_content, description, display_name): # Helper to use explicit kwargs. return provider.ScalarTimeSeries( max_step=max_step, max_wall_time=max_wall_time, plugin_content=plugin_content, description=description, display_name=display_name, )
def test_repr(self): x = provider.ScalarTimeSeries( max_step=77, max_wall_time=1234.5, plugin_content=b"AB\xCD\xEF!\x00", description="test test", display_name="one two", ) repr_ = repr(x) self.assertIn(repr(x.max_step), repr_) self.assertIn(repr(x.max_wall_time), repr_) self.assertIn(repr(x.plugin_content), repr_) self.assertIn(repr(x.description), repr_) self.assertIn(repr(x.display_name), repr_)
def list_scalars( self, ctx, *, experiment_id, plugin_name, run_tag_filter=None ): self._validate_eid(experiment_id) run_name = "%s/train" % experiment_id tag_name = "loss.%s" % plugin_name return { run_name: { tag_name: provider.ScalarTimeSeries( max_step=2, max_wall_time=0.5, plugin_content=b"", description="Hello from %s" % self._name, display_name="loss", ) } }