def test_eq(self): x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum") self.assertEqual(x1, x2) self.assertNotEqual(x1, x3) self.assertNotEqual(x1, object())
def test_hash(self): x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum") self.assertEqual(hash(x1), hash(x2)) # The next check is technically not required by the `__hash__` # contract, but _should_ pass; failure on this assertion would at # least warrant some scrutiny. self.assertNotEqual(hash(x1), hash(x3))
def list_blob_sequences( self, ctx, experiment_id, plugin_name, run_tag_filter=None ): with timing.log_latency("build request"): req = data_provider_pb2.ListBlobSequencesRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) with timing.log_latency("_stub.ListBlobSequences"): with _translate_grpc_error(): res = self._stub.ListBlobSequences(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries( max_step=time_series.max_step, max_wall_time=time_series.max_wall_time, max_length=time_series.max_length, plugin_content=time_series.summary_metadata.plugin_data.content, description=time_series.summary_metadata.summary_description, display_name=time_series.summary_metadata.display_name, ) return result
def list_blob_sequences(self, experiment_id, plugin_name, run_tag_filter=None): self._validate_experiment_id(experiment_id) if run_tag_filter is None: run_tag_filter = provider.RunTagFilter(runs=None, tags=None) # TODO(davidsoergel, wchargin): consider images, etc. # Note this plugin_name can really just be 'graphs' for now; the # v2 cases are not handled yet. if plugin_name != graphs_metadata.PLUGIN_NAME: logger.warn("Directory has no blob data for plugin %r", plugin_name) return {} result = collections.defaultdict(lambda: {}) for (run, run_info) in six.iteritems(self._multiplexer.Runs()): tag = None if not self._test_run_tag(run_tag_filter, run, tag): continue if not run_info[plugin_event_accumulator.GRAPH]: continue result[run][tag] = provider.BlobSequenceTimeSeries( max_step=0, max_wall_time=0, latest_max_index=0, # Graphs are always one blob at a time plugin_content=None, description=None, display_name=None, ) return result
def list_blob_sequences( self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE ) result = {} for (run, tag_to_metadata) in index.items(): result_for_run = {} result[run] = result_for_run for (tag, metadata) in tag_to_metadata.items(): max_step = None max_wall_time = None max_length = None for event in self._multiplexer.Tensors(run, tag): if max_step is None or max_step < event.step: max_step = event.step if max_wall_time is None or max_wall_time < event.wall_time: max_wall_time = event.wall_time length = _tensor_size(event.tensor_proto) if max_length is None or length > max_length: max_length = length result_for_run[tag] = provider.BlobSequenceTimeSeries( max_step=max_step, max_wall_time=max_wall_time, max_length=max_length, plugin_content=metadata.plugin_data.content, description=metadata.summary_description, display_name=metadata.display_name, ) return result
def test_repr(self): x = provider.BlobSequenceTimeSeries( max_step=77, max_wall_time=1234.5, latest_max_index=6, plugin_content=b"AB\xCD\xEF!\x00", description="test test", display_name="one two", ) repr_ = repr(x) self.assertIn(repr(x.max_step), repr_) self.assertIn(repr(x.max_wall_time), repr_) self.assertIn(repr(x.latest_max_index), repr_) self.assertIn(repr(x.plugin_content), repr_) self.assertIn(repr(x.description), repr_) self.assertIn(repr(x.display_name), repr_)
def list_blob_sequences( self, ctx, *, experiment_id, plugin_name, run_tag_filter=None ): self._validate_eid(experiment_id) run_name = "%s/test" % experiment_id tag_name = "input.%s" % plugin_name return { run_name: { tag_name: provider.BlobSequenceTimeSeries( max_step=0, max_wall_time=0.0, max_length=2, plugin_content=b"", description="Greetings via %s" % self._name, display_name="input", ) } }
def _blob_sequence_time_series( self, max_step, max_wall_time, max_length, plugin_content, description, display_name, ): # Helper to use explicit kwargs. return provider.BlobSequenceTimeSeries( max_step=max_step, max_wall_time=max_wall_time, max_length=max_length, plugin_content=plugin_content, description=description, display_name=display_name, )
def list_blob_sequences( self, experiment_id, plugin_name, run_tag_filter=None ): self._validate_experiment_id(experiment_id) if run_tag_filter is None: run_tag_filter = provider.RunTagFilter(runs=None, tags=None) result = {} run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) for (run, tag_to_content) in six.iteritems(run_tag_content): result_for_run = {} for tag in tag_to_content: if not self._test_run_tag(run_tag_filter, run, tag): continue summary_metadata = self._multiplexer.SummaryMetadata(run, tag) if ( summary_metadata.data_class != summary_pb2.DATA_CLASS_BLOB_SEQUENCE ): continue result[run] = result_for_run max_step = None max_wall_time = None max_length = None for event in self._multiplexer.Tensors(run, tag): if max_step is None or max_step < event.step: max_step = event.step if max_wall_time is None or max_wall_time < event.wall_time: max_wall_time = event.wall_time length = _tensor_size(event.tensor_proto) if max_length is None or length > max_length: max_length = length result_for_run[tag] = provider.BlobSequenceTimeSeries( max_step=max_step, max_wall_time=max_wall_time, max_length=max_length, plugin_content=summary_metadata.plugin_data.content, description=summary_metadata.summary_description, display_name=summary_metadata.display_name, ) return result
def test_list_blob_sequences(self): res = data_provider_pb2.ListBlobSequencesResponse() run1 = res.runs.add(run_name="train") tag11 = run1.tags.add(tag_name="input_image") tag11.metadata.max_step = 7 tag11.metadata.max_wall_time = 7.77 tag11.metadata.max_length = 3 tag11.metadata.summary_metadata.plugin_data.content = b"PNG" tag11.metadata.summary_metadata.display_name = "Input image" tag11.metadata.summary_metadata.summary_description = "img" self.stub.ListBlobSequences.return_value = res actual = self.provider.list_blob_sequences( self.ctx, experiment_id="123", plugin_name="images", run_tag_filter=provider.RunTagFilter(runs=["val", "train"]), ) expected = { "train": { "input_image": provider.BlobSequenceTimeSeries( max_step=7, max_wall_time=7.77, max_length=3, plugin_content=b"PNG", description="img", display_name="Input image", ), }, } self.assertEqual(actual, expected) req = data_provider_pb2.ListBlobSequencesRequest() req.experiment_id = "123" req.plugin_filter.plugin_name = "images" req.run_tag_filter.runs.names.extend(["train", "val"]) # sorted self.stub.ListBlobSequences.assert_called_once_with(req)