def test_eq(self): x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") self.assertEqual(x1, x2) self.assertNotEqual(x1, x3) self.assertNotEqual(x1, object())
def test_list_tensors(self): res = data_provider_pb2.ListTensorsResponse() run1 = res.runs.add(run_name="val") tag11 = run1.tags.add(tag_name="weights") tag11.metadata.max_step = 7 tag11.metadata.max_wall_time = 7.77 tag11.metadata.summary_metadata.plugin_data.content = b"magic" tag11.metadata.summary_metadata.summary_description = "hey" tag12 = run1.tags.add(tag_name="other") tag12.metadata.max_step = 8 tag12.metadata.max_wall_time = 8.88 run2 = res.runs.add(run_name="test") tag21 = run2.tags.add(tag_name="weights") tag21.metadata.max_step = 9 tag21.metadata.max_wall_time = 9.99 self.stub.ListTensors.return_value = res actual = self.provider.list_tensors( self.ctx, experiment_id="123", plugin_name="histograms", run_tag_filter=provider.RunTagFilter(tags=["weights", "other"]), ) expected = { "val": { "weights": provider.TensorTimeSeries( max_step=7, max_wall_time=7.77, plugin_content=b"magic", description="hey", display_name="", ), "other": provider.TensorTimeSeries( max_step=8, max_wall_time=8.88, plugin_content=b"", description="", display_name="", ), }, "test": { "weights": provider.TensorTimeSeries( max_step=9, max_wall_time=9.99, plugin_content=b"", description="", display_name="", ), }, } self.assertEqual(actual, expected) req = data_provider_pb2.ListTensorsRequest() req.experiment_id = "123" req.plugin_filter.plugin_name = "histograms" req.run_tag_filter.tags.names.extend(["other", "weights"]) # sorted self.stub.ListTensors.assert_called_once_with(req)
def test_hash(self): x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") self.assertEqual(hash(x1), hash(x2)) # The next check is technically not required by the `__hash__` # contract, but _should_ pass; failure on this assertion would at # least warrant some scrutiny. self.assertNotEqual(hash(x1), hash(x3))
def list_tensors(self, ctx, *, experiment_id, plugin_name, run_tag_filter=None): with timing.log_latency("build request"): req = data_provider_pb2.ListTensorsRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) with timing.log_latency("_stub.ListTensors"): with _translate_grpc_error(): res = self._stub.ListTensors(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.TensorTimeSeries( max_step=time_series.max_step, max_wall_time=time_series.max_wall_time, plugin_content=time_series.summary_metadata. plugin_data.content, description=time_series.summary_metadata. summary_description, display_name=time_series.summary_metadata.display_name, ) return result
def _tensor_time_series(self, max_step, max_wall_time, plugin_content, description, display_name): # Helper to use explicit kwargs. return provider.TensorTimeSeries( max_step=max_step, max_wall_time=max_wall_time, plugin_content=plugin_content, description=description, display_name=display_name, )
def test_repr(self): x = provider.TensorTimeSeries( max_step=77, max_wall_time=1234.5, plugin_content=b"AB\xCD\xEF!\x00", description="test test", display_name="one two", ) repr_ = repr(x) self.assertIn(repr(x.max_step), repr_) self.assertIn(repr(x.max_wall_time), repr_) self.assertIn(repr(x.plugin_content), repr_) self.assertIn(repr(x.description), repr_) self.assertIn(repr(x.display_name), repr_)