def test_eq(self): x1 = provider.BlobReference(url="foo", blob_key="baz") x2 = provider.BlobReference(url="foo", blob_key="baz") x3 = provider.BlobReference(url="foo", blob_key="qux") self.assertEqual(x1, x2) self.assertNotEqual(x1, x3) self.assertNotEqual(x1, object())
def test_hash(self): x1 = provider.BlobReference(url="foo", blob_key="baz") x2 = provider.BlobReference(url="foo", blob_key="baz") x3 = provider.BlobReference(url="foo", blob_key="qux") self.assertEqual(hash(x1), hash(x2)) # The next check is technically not required by the `__hash__` # contract, but _should_ pass; failure on this assertion would at # least warrant some scrutiny. self.assertNotEqual(hash(x1), hash(x3))
def test_read_blob_sequences(self): res = data_provider_pb2.ReadBlobSequencesResponse() run = res.runs.add(run_name="test") tag = run.tags.add(tag_name="input_image") tag.data.step.extend([0, 1]) tag.data.wall_time.extend([1234.0, 1235.0]) seq0 = tag.data.values.add() seq0.blob_refs.add(blob_key="step0img0") seq0.blob_refs.add(blob_key="step0img1") seq1 = tag.data.values.add() seq1.blob_refs.add(blob_key="step1img0") self.stub.ReadBlobSequences.return_value = res actual = self.provider.read_blob_sequences( self.ctx, experiment_id="123", plugin_name="images", run_tag_filter=provider.RunTagFilter(runs=["test", "nope"]), downsample=4, ) expected = { "test": { "input_image": [ provider.BlobSequenceDatum( step=0, wall_time=1234.0, values=( provider.BlobReference(blob_key="step0img0"), provider.BlobReference(blob_key="step0img1"), ), ), provider.BlobSequenceDatum( step=1, wall_time=1235.0, values=(provider.BlobReference( blob_key="step1img0"), ), ), ], }, } self.assertEqual(actual, expected) req = data_provider_pb2.ReadBlobSequencesRequest() req.experiment_id = "123" req.plugin_filter.plugin_name = "images" req.run_tag_filter.runs.names.extend(["nope", "test"]) # sorted req.downsample.num_points = 4 self.stub.ReadBlobSequences.assert_called_once_with(req)
def read_blob_sequences( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): del experiment_id, downsample # Unused. if plugin_name != PLUGIN_NAME: raise ValueError("Unsupported plugin_name: %s" % plugin_name) if run_tag_filter.runs is None: raise ValueError( "run_tag_filter.runs is expected to be specified, but is not." ) if run_tag_filter.tags is None: raise ValueError( "run_tag_filter.tags is expected to be specified, but is not." ) output = dict() existing_runs = self._multiplexer.Runs() for run in run_tag_filter.runs: if run not in existing_runs: continue output[run] = dict() for tag in run_tag_filter.tags: if ( tag.startswith(EXECUTION_DIGESTS_BLOB_TAG_PREFIX) or tag.startswith(SOURCE_FILE_BLOB_TAG_PREFIX) or tag in (SOURCE_FILE_LIST_BLOB_TAG,) ): output[run][tag] = [ provider.BlobReference(blob_key="%s.%s" % (tag, run)) ] return output
def _convert_blob_sequence_event(experiment_id, plugin_name, run, tag, event): """Helper for `read_blob_sequences`.""" num_blobs = _tensor_size(event.tensor_proto) values = tuple( provider.BlobReference( _encode_blob_key( experiment_id, plugin_name, run, tag, event.step, idx, ) ) for idx in range(num_blobs) ) return provider.BlobSequenceDatum( wall_time=event.wall_time, step=event.step, values=values, )
def read_blob_sequences(self, experiment_id, plugin_name, downsample=None, run_tag_filter=None): self._validate_experiment_id(experiment_id) # TODO(davidsoergel, wchargin): consider images, etc. # Note this plugin_name can really just be 'graphs' for now; the # v2 cases are not handled yet. if plugin_name != graphs_metadata.PLUGIN_NAME: logger.warn("Directory has no blob data for plugin %r", plugin_name) return {} result = collections.defaultdict( lambda: collections.defaultdict(lambda: [])) for (run, run_info) in six.iteritems(self._multiplexer.Runs()): tag = None if not self._test_run_tag(run_tag_filter, run, tag): continue if not run_info[plugin_event_accumulator.GRAPH]: continue time_series = result[run][tag] wall_time = 0.0 # dummy value for graph step = 0 # dummy value for graph index = 0 # dummy value for graph # In some situations these blobs may have directly accessible URLs. # But, for now, we assume they don't. graph_url = None graph_blob_key = _encode_blob_key(experiment_id, plugin_name, run, tag, step, index) blob_ref = provider.BlobReference(graph_blob_key, graph_url) datum = provider.BlobSequenceDatum( wall_time=wall_time, step=step, values=(blob_ref, ), ) time_series.append(datum) return result
def _convert_blob_references(prefix, references): """Encode all blob keys in a list of blob references. Args: prefix: The prefix of the sub-provider that generated the sub-key, or `None` if this was generated by the unprefixed provider. references: A list of `provider.BlobReference`s emitted by a sub-provider. Returns: A new list of `provider.BlobReference`s whose blob keys have been encoded per `_encode_blob_key`. """ return [ provider.BlobReference( blob_key=_encode_blob_key(prefix, r.blob_key), url=r.url, ) for r in references ]
def read_blob_sequences( self, ctx, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): with timing.log_latency("build request"): req = data_provider_pb2.ReadBlobSequencesRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) req.downsample.num_points = downsample with timing.log_latency("_stub.ReadBlobSequences"): with _translate_grpc_error(): res = self._stub.ReadBlobSequences(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: series = [] tags[tag_entry.tag_name] = series d = tag_entry.data for (step, wt, blob_sequence) in zip( d.step, d.wall_time, d.values ): values = [] for ref in blob_sequence.blob_refs: values.append( provider.BlobReference( blob_key=ref.blob_key, url=ref.url or None ) ) point = provider.BlobSequenceDatum( step=step, wall_time=wt, values=tuple(values) ) series.append(point) return result
def test_repr(self): x = provider.BlobReference(url="foo", blob_key="baz") repr_ = repr(x) self.assertIn(repr(x.url), repr_) self.assertIn(repr(x.blob_key), repr_)
def _make_blob_reference(self, text): key = base64.urlsafe_b64encode( ("%s:%s" % (self._name, text)).encode("utf-8")).decode("ascii") return provider.BlobReference(key)