예제 #1
0
 def test_eq(self):
   x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
   x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
   x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum")
   self.assertEqual(x1, x2)
   self.assertNotEqual(x1, x3)
   self.assertNotEqual(x1, object())
예제 #2
0
 def test_hash(self):
   x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
   x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
   x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum")
   self.assertEqual(hash(x1), hash(x2))
   # The next check is technically not required by the `__hash__`
   # contract, but _should_ pass; failure on this assertion would at
   # least warrant some scrutiny.
   self.assertNotEqual(hash(x1), hash(x3))
예제 #3
0
 def list_blob_sequences(
     self, ctx, experiment_id, plugin_name, run_tag_filter=None
 ):
     with timing.log_latency("build request"):
         req = data_provider_pb2.ListBlobSequencesRequest()
         req.experiment_id = experiment_id
         req.plugin_filter.plugin_name = plugin_name
         _populate_rtf(run_tag_filter, req.run_tag_filter)
     with timing.log_latency("_stub.ListBlobSequences"):
         with _translate_grpc_error():
             res = self._stub.ListBlobSequences(req)
     with timing.log_latency("build result"):
         result = {}
         for run_entry in res.runs:
             tags = {}
             result[run_entry.run_name] = tags
             for tag_entry in run_entry.tags:
                 time_series = tag_entry.metadata
                 tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries(
                     max_step=time_series.max_step,
                     max_wall_time=time_series.max_wall_time,
                     max_length=time_series.max_length,
                     plugin_content=time_series.summary_metadata.plugin_data.content,
                     description=time_series.summary_metadata.summary_description,
                     display_name=time_series.summary_metadata.display_name,
                 )
         return result
예제 #4
0
    def list_blob_sequences(self,
                            experiment_id,
                            plugin_name,
                            run_tag_filter=None):
        self._validate_experiment_id(experiment_id)
        if run_tag_filter is None:
            run_tag_filter = provider.RunTagFilter(runs=None, tags=None)

        # TODO(davidsoergel, wchargin): consider images, etc.
        # Note this plugin_name can really just be 'graphs' for now; the
        # v2 cases are not handled yet.
        if plugin_name != graphs_metadata.PLUGIN_NAME:
            logger.warn("Directory has no blob data for plugin %r",
                        plugin_name)
            return {}

        result = collections.defaultdict(lambda: {})
        for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
            tag = None
            if not self._test_run_tag(run_tag_filter, run, tag):
                continue
            if not run_info[plugin_event_accumulator.GRAPH]:
                continue
            result[run][tag] = provider.BlobSequenceTimeSeries(
                max_step=0,
                max_wall_time=0,
                latest_max_index=0,  # Graphs are always one blob at a time
                plugin_content=None,
                description=None,
                display_name=None,
            )
        return result
예제 #5
0
 def list_blob_sequences(
     self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
 ):
     self._validate_context(ctx)
     self._validate_experiment_id(experiment_id)
     index = self._index(
         plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
     )
     result = {}
     for (run, tag_to_metadata) in index.items():
         result_for_run = {}
         result[run] = result_for_run
         for (tag, metadata) in tag_to_metadata.items():
             max_step = None
             max_wall_time = None
             max_length = None
             for event in self._multiplexer.Tensors(run, tag):
                 if max_step is None or max_step < event.step:
                     max_step = event.step
                 if max_wall_time is None or max_wall_time < event.wall_time:
                     max_wall_time = event.wall_time
                 length = _tensor_size(event.tensor_proto)
                 if max_length is None or length > max_length:
                     max_length = length
             result_for_run[tag] = provider.BlobSequenceTimeSeries(
                 max_step=max_step,
                 max_wall_time=max_wall_time,
                 max_length=max_length,
                 plugin_content=metadata.plugin_data.content,
                 description=metadata.summary_description,
                 display_name=metadata.display_name,
             )
     return result
예제 #6
0
 def test_repr(self):
   x = provider.BlobSequenceTimeSeries(
       max_step=77,
       max_wall_time=1234.5,
       latest_max_index=6,
       plugin_content=b"AB\xCD\xEF!\x00",
       description="test test",
       display_name="one two",
   )
   repr_ = repr(x)
   self.assertIn(repr(x.max_step), repr_)
   self.assertIn(repr(x.max_wall_time), repr_)
   self.assertIn(repr(x.latest_max_index), repr_)
   self.assertIn(repr(x.plugin_content), repr_)
   self.assertIn(repr(x.description), repr_)
   self.assertIn(repr(x.display_name), repr_)
예제 #7
0
 def list_blob_sequences(
     self, ctx, *, experiment_id, plugin_name, run_tag_filter=None
 ):
     self._validate_eid(experiment_id)
     run_name = "%s/test" % experiment_id
     tag_name = "input.%s" % plugin_name
     return {
         run_name: {
             tag_name: provider.BlobSequenceTimeSeries(
                 max_step=0,
                 max_wall_time=0.0,
                 max_length=2,
                 plugin_content=b"",
                 description="Greetings via %s" % self._name,
                 display_name="input",
             )
         }
     }
예제 #8
0
 def _blob_sequence_time_series(
     self,
     max_step,
     max_wall_time,
     max_length,
     plugin_content,
     description,
     display_name,
 ):
     # Helper to use explicit kwargs.
     return provider.BlobSequenceTimeSeries(
         max_step=max_step,
         max_wall_time=max_wall_time,
         max_length=max_length,
         plugin_content=plugin_content,
         description=description,
         display_name=display_name,
     )
예제 #9
0
    def list_blob_sequences(
        self, experiment_id, plugin_name, run_tag_filter=None
    ):
        self._validate_experiment_id(experiment_id)
        if run_tag_filter is None:
            run_tag_filter = provider.RunTagFilter(runs=None, tags=None)

        result = {}
        run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
        for (run, tag_to_content) in six.iteritems(run_tag_content):
            result_for_run = {}
            for tag in tag_to_content:
                if not self._test_run_tag(run_tag_filter, run, tag):
                    continue
                summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
                if (
                    summary_metadata.data_class
                    != summary_pb2.DATA_CLASS_BLOB_SEQUENCE
                ):
                    continue
                result[run] = result_for_run
                max_step = None
                max_wall_time = None
                max_length = None
                for event in self._multiplexer.Tensors(run, tag):
                    if max_step is None or max_step < event.step:
                        max_step = event.step
                    if max_wall_time is None or max_wall_time < event.wall_time:
                        max_wall_time = event.wall_time
                    length = _tensor_size(event.tensor_proto)
                    if max_length is None or length > max_length:
                        max_length = length
                result_for_run[tag] = provider.BlobSequenceTimeSeries(
                    max_step=max_step,
                    max_wall_time=max_wall_time,
                    max_length=max_length,
                    plugin_content=summary_metadata.plugin_data.content,
                    description=summary_metadata.summary_description,
                    display_name=summary_metadata.display_name,
                )
        return result
예제 #10
0
    def test_list_blob_sequences(self):
        res = data_provider_pb2.ListBlobSequencesResponse()
        run1 = res.runs.add(run_name="train")
        tag11 = run1.tags.add(tag_name="input_image")
        tag11.metadata.max_step = 7
        tag11.metadata.max_wall_time = 7.77
        tag11.metadata.max_length = 3
        tag11.metadata.summary_metadata.plugin_data.content = b"PNG"
        tag11.metadata.summary_metadata.display_name = "Input image"
        tag11.metadata.summary_metadata.summary_description = "img"
        self.stub.ListBlobSequences.return_value = res

        actual = self.provider.list_blob_sequences(
            self.ctx,
            experiment_id="123",
            plugin_name="images",
            run_tag_filter=provider.RunTagFilter(runs=["val", "train"]),
        )
        expected = {
            "train": {
                "input_image":
                provider.BlobSequenceTimeSeries(
                    max_step=7,
                    max_wall_time=7.77,
                    max_length=3,
                    plugin_content=b"PNG",
                    description="img",
                    display_name="Input image",
                ),
            },
        }
        self.assertEqual(actual, expected)

        req = data_provider_pb2.ListBlobSequencesRequest()
        req.experiment_id = "123"
        req.plugin_filter.plugin_name = "images"
        req.run_tag_filter.runs.names.extend(["train", "val"])  # sorted
        self.stub.ListBlobSequences.assert_called_once_with(req)