コード例 #1
0
    def test_read_scalars(self):
        res = data_provider_pb2.ReadScalarsResponse()
        run = res.runs.add(run_name="test")
        tag = run.tags.add(tag_name="accuracy")
        tag.data.step.extend([0, 1, 2, 4])
        tag.data.wall_time.extend([1234.0, 1235.0, 1236.0, 1237.0])
        tag.data.value.extend([0.25, 0.50, 0.75, 1.00])
        self.stub.ReadScalars.return_value = res

        actual = self.provider.read_scalars(
            self.ctx,
            experiment_id="123",
            plugin_name="scalars",
            run_tag_filter=provider.RunTagFilter(runs=["test", "nope"]),
            downsample=4,
        )
        expected = {
            "test": {
                "accuracy": [
                    provider.ScalarDatum(step=0, wall_time=1234.0, value=0.25),
                    provider.ScalarDatum(step=1, wall_time=1235.0, value=0.50),
                    provider.ScalarDatum(step=2, wall_time=1236.0, value=0.75),
                    provider.ScalarDatum(step=4, wall_time=1237.0, value=1.00),
                ],
            },
        }
        self.assertEqual(actual, expected)

        req = data_provider_pb2.ReadScalarsRequest()
        req.experiment_id = "123"
        req.plugin_filter.plugin_name = "scalars"
        req.run_tag_filter.runs.names.extend(["nope", "test"])  # sorted
        req.downsample.num_points = 4
        self.stub.ReadScalars.assert_called_once_with(req)
コード例 #2
0
 def test_eq(self):
   x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
   x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
   x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5)
   self.assertEqual(x1, x2)
   self.assertNotEqual(x1, x3)
   self.assertNotEqual(x1, object())
コード例 #3
0
 def read_scalars(self,
                  ctx,
                  *,
                  experiment_id,
                  plugin_name,
                  downsample=None,
                  run_tag_filter=None):
     self._validate_eid(experiment_id)
     if run_tag_filter is None:
         run_tag_filter = provider.RunTagFilter()
     rtf = run_tag_filter
     expected_run = "%s/train" % experiment_id
     expected_tag = "loss.%s" % plugin_name
     if rtf.runs is not None and expected_run not in rtf.runs:
         return {}
     if rtf.tags is not None and expected_tag not in rtf.tags:
         return {}
     return {
         expected_run: {
             expected_tag: [
                 provider.ScalarDatum(step=0,
                                      wall_time=0.0,
                                      value=float(len(plugin_name))),
                 provider.ScalarDatum(step=1,
                                      wall_time=0.5,
                                      value=float(len(experiment_id))),
             ]
         }
     }
コード例 #4
0
 def test_hash(self):
   x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
   x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
   x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5)
   self.assertEqual(hash(x1), hash(x2))
   # The next check is technically not required by the `__hash__`
   # contract, but _should_ pass; failure on this assertion would at
   # least warrant some scrutiny.
   self.assertNotEqual(hash(x1), hash(x3))
コード例 #5
0
 def read_scalars(
     self,
     ctx,
     *,
     experiment_id,
     plugin_name,
     downsample=None,
     run_tag_filter=None,
 ):
     with timing.log_latency("build request"):
         req = data_provider_pb2.ReadScalarsRequest()
         req.experiment_id = experiment_id
         req.plugin_filter.plugin_name = plugin_name
         _populate_rtf(run_tag_filter, req.run_tag_filter)
         req.downsample.num_points = downsample
     with timing.log_latency("_stub.ReadScalars"):
         with _translate_grpc_error():
             res = self._stub.ReadScalars(req)
     with timing.log_latency("build result"):
         result = {}
         for run_entry in res.runs:
             tags = {}
             result[run_entry.run_name] = tags
             for tag_entry in run_entry.tags:
                 series = []
                 tags[tag_entry.tag_name] = series
                 d = tag_entry.data
                 for (step, wt, value) in zip(d.step, d.wall_time, d.value):
                     point = provider.ScalarDatum(
                         step=step,
                         wall_time=wt,
                         value=value,
                     )
                     series.append(point)
         return result
コード例 #6
0
def _convert_scalar_event(event):
    """Helper for `read_scalars`."""
    return provider.ScalarDatum(
        step=event.step,
        wall_time=event.wall_time,
        value=tensor_util.make_ndarray(event.tensor_proto).item(),
    )
コード例 #7
0
 def test_repr(self):
   x = provider.ScalarDatum(step=123, wall_time=234.5, value=-0.125)
   repr_ = repr(x)
   self.assertIn(repr(x.step), repr_)
   self.assertIn(repr(x.wall_time), repr_)
   self.assertIn(repr(x.value), repr_)
コード例 #8
0
 def _convert_scalar_event(self, event):
     return provider.ScalarDatum(
         step=event.step,
         wall_time=event.wall_time,
         value=tensor_util.make_ndarray(event.tensor_proto).item(),
     )