コード例 #1
0
 def list_blob_sequences(
     self, ctx, experiment_id, plugin_name, run_tag_filter=None
 ):
     with timing.log_latency("build request"):
         req = data_provider_pb2.ListBlobSequencesRequest()
         req.experiment_id = experiment_id
         req.plugin_filter.plugin_name = plugin_name
         _populate_rtf(run_tag_filter, req.run_tag_filter)
     with timing.log_latency("_stub.ListBlobSequences"):
         with _translate_grpc_error():
             res = self._stub.ListBlobSequences(req)
     with timing.log_latency("build result"):
         result = {}
         for run_entry in res.runs:
             tags = {}
             result[run_entry.run_name] = tags
             for tag_entry in run_entry.tags:
                 time_series = tag_entry.metadata
                 tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries(
                     max_step=time_series.max_step,
                     max_wall_time=time_series.max_wall_time,
                     max_length=time_series.max_length,
                     plugin_content=time_series.summary_metadata.plugin_data.content,
                     description=time_series.summary_metadata.summary_description,
                     display_name=time_series.summary_metadata.display_name,
                 )
         return result
コード例 #2
0
    def test_list_blob_sequences(self):
        res = data_provider_pb2.ListBlobSequencesResponse()
        run1 = res.runs.add(run_name="train")
        tag11 = run1.tags.add(tag_name="input_image")
        tag11.metadata.max_step = 7
        tag11.metadata.max_wall_time = 7.77
        tag11.metadata.max_length = 3
        tag11.metadata.summary_metadata.plugin_data.content = b"PNG"
        tag11.metadata.summary_metadata.display_name = "Input image"
        tag11.metadata.summary_metadata.summary_description = "img"
        self.stub.ListBlobSequences.return_value = res

        actual = self.provider.list_blob_sequences(
            self.ctx,
            experiment_id="123",
            plugin_name="images",
            run_tag_filter=provider.RunTagFilter(runs=["val", "train"]),
        )
        expected = {
            "train": {
                "input_image":
                provider.BlobSequenceTimeSeries(
                    max_step=7,
                    max_wall_time=7.77,
                    max_length=3,
                    plugin_content=b"PNG",
                    description="img",
                    display_name="Input image",
                ),
            },
        }
        self.assertEqual(actual, expected)

        req = data_provider_pb2.ListBlobSequencesRequest()
        req.experiment_id = "123"
        req.plugin_filter.plugin_name = "images"
        req.run_tag_filter.runs.names.extend(["train", "val"])  # sorted
        self.stub.ListBlobSequences.assert_called_once_with(req)