Ejemplo n.º 1
0
 def test_eq(self):
     x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
     x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
     x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
     self.assertEqual(x1, x2)
     self.assertNotEqual(x1, x3)
     self.assertNotEqual(x1, object())
Ejemplo n.º 2
0
    def test_list_tensors(self):
        res = data_provider_pb2.ListTensorsResponse()
        run1 = res.runs.add(run_name="val")
        tag11 = run1.tags.add(tag_name="weights")
        tag11.metadata.max_step = 7
        tag11.metadata.max_wall_time = 7.77
        tag11.metadata.summary_metadata.plugin_data.content = b"magic"
        tag11.metadata.summary_metadata.summary_description = "hey"
        tag12 = run1.tags.add(tag_name="other")
        tag12.metadata.max_step = 8
        tag12.metadata.max_wall_time = 8.88
        run2 = res.runs.add(run_name="test")
        tag21 = run2.tags.add(tag_name="weights")
        tag21.metadata.max_step = 9
        tag21.metadata.max_wall_time = 9.99
        self.stub.ListTensors.return_value = res

        actual = self.provider.list_tensors(
            self.ctx,
            experiment_id="123",
            plugin_name="histograms",
            run_tag_filter=provider.RunTagFilter(tags=["weights", "other"]),
        )
        expected = {
            "val": {
                "weights":
                provider.TensorTimeSeries(
                    max_step=7,
                    max_wall_time=7.77,
                    plugin_content=b"magic",
                    description="hey",
                    display_name="",
                ),
                "other":
                provider.TensorTimeSeries(
                    max_step=8,
                    max_wall_time=8.88,
                    plugin_content=b"",
                    description="",
                    display_name="",
                ),
            },
            "test": {
                "weights":
                provider.TensorTimeSeries(
                    max_step=9,
                    max_wall_time=9.99,
                    plugin_content=b"",
                    description="",
                    display_name="",
                ),
            },
        }
        self.assertEqual(actual, expected)

        req = data_provider_pb2.ListTensorsRequest()
        req.experiment_id = "123"
        req.plugin_filter.plugin_name = "histograms"
        req.run_tag_filter.tags.names.extend(["other", "weights"])  # sorted
        self.stub.ListTensors.assert_called_once_with(req)
Ejemplo n.º 3
0
 def test_hash(self):
     x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
     x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
     x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
     self.assertEqual(hash(x1), hash(x2))
     # The next check is technically not required by the `__hash__`
     # contract, but _should_ pass; failure on this assertion would at
     # least warrant some scrutiny.
     self.assertNotEqual(hash(x1), hash(x3))
Ejemplo n.º 4
0
 def list_tensors(self,
                  ctx,
                  *,
                  experiment_id,
                  plugin_name,
                  run_tag_filter=None):
     with timing.log_latency("build request"):
         req = data_provider_pb2.ListTensorsRequest()
         req.experiment_id = experiment_id
         req.plugin_filter.plugin_name = plugin_name
         _populate_rtf(run_tag_filter, req.run_tag_filter)
     with timing.log_latency("_stub.ListTensors"):
         with _translate_grpc_error():
             res = self._stub.ListTensors(req)
     with timing.log_latency("build result"):
         result = {}
         for run_entry in res.runs:
             tags = {}
             result[run_entry.run_name] = tags
             for tag_entry in run_entry.tags:
                 time_series = tag_entry.metadata
                 tags[tag_entry.tag_name] = provider.TensorTimeSeries(
                     max_step=time_series.max_step,
                     max_wall_time=time_series.max_wall_time,
                     plugin_content=time_series.summary_metadata.
                     plugin_data.content,
                     description=time_series.summary_metadata.
                     summary_description,
                     display_name=time_series.summary_metadata.display_name,
                 )
         return result
Ejemplo n.º 5
0
 def _tensor_time_series(self, max_step, max_wall_time, plugin_content,
                         description, display_name):
     # Helper to use explicit kwargs.
     return provider.TensorTimeSeries(
         max_step=max_step,
         max_wall_time=max_wall_time,
         plugin_content=plugin_content,
         description=description,
         display_name=display_name,
     )
Ejemplo n.º 6
0
 def test_repr(self):
     x = provider.TensorTimeSeries(
         max_step=77,
         max_wall_time=1234.5,
         plugin_content=b"AB\xCD\xEF!\x00",
         description="test test",
         display_name="one two",
     )
     repr_ = repr(x)
     self.assertIn(repr(x.max_step), repr_)
     self.assertIn(repr(x.max_wall_time), repr_)
     self.assertIn(repr(x.plugin_content), repr_)
     self.assertIn(repr(x.description), repr_)
     self.assertIn(repr(x.display_name), repr_)