def test_eq(self): x1 = provider.BlobSequenceDatum( step=12, wall_time=0.25, values=("foo", "bar", "baz")) x2 = provider.BlobSequenceDatum( step=12, wall_time=0.25, values=("foo", "bar", "baz")) x3 = provider.BlobSequenceDatum(step=23, wall_time=3.25, values=("qux",)) self.assertEqual(x1, x2) self.assertNotEqual(x1, x3) self.assertNotEqual(x1, object())
def test_hash(self): x1 = provider.BlobSequenceDatum( step=12, wall_time=0.25, values=("foo", "bar", "baz")) x2 = provider.BlobSequenceDatum( step=12, wall_time=0.25, values=("foo", "bar", "baz")) x3 = provider.BlobSequenceDatum(step=23, wall_time=3.25, values=("qux",)) self.assertEqual(hash(x1), hash(x2)) # The next check is technically not required by the `__hash__` # contract, but _should_ pass; failure on this assertion would at # least warrant some scrutiny. self.assertNotEqual(hash(x1), hash(x3))
def test_read_blob_sequences(self): res = data_provider_pb2.ReadBlobSequencesResponse() run = res.runs.add(run_name="test") tag = run.tags.add(tag_name="input_image") tag.data.step.extend([0, 1]) tag.data.wall_time.extend([1234.0, 1235.0]) seq0 = tag.data.values.add() seq0.blob_refs.add(blob_key="step0img0") seq0.blob_refs.add(blob_key="step0img1") seq1 = tag.data.values.add() seq1.blob_refs.add(blob_key="step1img0") self.stub.ReadBlobSequences.return_value = res actual = self.provider.read_blob_sequences( self.ctx, experiment_id="123", plugin_name="images", run_tag_filter=provider.RunTagFilter(runs=["test", "nope"]), downsample=4, ) expected = { "test": { "input_image": [ provider.BlobSequenceDatum( step=0, wall_time=1234.0, values=( provider.BlobReference(blob_key="step0img0"), provider.BlobReference(blob_key="step0img1"), ), ), provider.BlobSequenceDatum( step=1, wall_time=1235.0, values=(provider.BlobReference( blob_key="step1img0"), ), ), ], }, } self.assertEqual(actual, expected) req = data_provider_pb2.ReadBlobSequencesRequest() req.experiment_id = "123" req.plugin_filter.plugin_name = "images" req.run_tag_filter.runs.names.extend(["nope", "test"]) # sorted req.downsample.num_points = 4 self.stub.ReadBlobSequences.assert_called_once_with(req)
def test_repr(self): x = provider.BlobSequenceDatum( step=123, wall_time=234.5, values=("foo", "bar", "baz")) repr_ = repr(x) self.assertIn(repr(x.step), repr_) self.assertIn(repr(x.wall_time), repr_) self.assertIn(repr(x.values), repr_)
def read_blob_sequences(self, ctx, *, experiment_id, plugin_name, downsample=None, run_tag_filter=None): self._validate_eid(experiment_id) if run_tag_filter is None: run_tag_filter = provider.RunTagFilter() rtf = run_tag_filter expected_run = "%s/test" % experiment_id expected_tag = "input.%s" % plugin_name if rtf.runs is not None and expected_run not in rtf.runs: return {} if rtf.tags is not None and expected_tag not in rtf.tags: return {} return { expected_run: { expected_tag: [ provider.BlobSequenceDatum( step=0, wall_time=0.0, values=[ self._make_blob_reference("experiment: %s" % experiment_id), self._make_blob_reference("name: %s" % self._name), ], ), ] } }
def _convert_blob_sequence_event(experiment_id, plugin_name, run, tag, event): """Helper for `read_blob_sequences`.""" num_blobs = _tensor_size(event.tensor_proto) values = tuple( provider.BlobReference( _encode_blob_key( experiment_id, plugin_name, run, tag, event.step, idx, ) ) for idx in range(num_blobs) ) return provider.BlobSequenceDatum( wall_time=event.wall_time, step=event.step, values=values, )
def read_blob_sequences(self, *args, experiment_id, **kwargs): (prefix, sub_eid, sub_provider) = self._parse_eid(experiment_id) result = sub_provider.read_blob_sequences(*args, experiment_id=sub_eid, **kwargs) for tag_to_data in result.values(): for (tag, old_data) in tag_to_data.items(): new_data = [ provider.BlobSequenceDatum( step=d.step, wall_time=d.wall_time, values=_convert_blob_references(prefix, d.values), ) for d in old_data ] tag_to_data[tag] = new_data return result
def read_blob_sequences(self, experiment_id, plugin_name, downsample=None, run_tag_filter=None): self._validate_experiment_id(experiment_id) # TODO(davidsoergel, wchargin): consider images, etc. # Note this plugin_name can really just be 'graphs' for now; the # v2 cases are not handled yet. if plugin_name != graphs_metadata.PLUGIN_NAME: logger.warn("Directory has no blob data for plugin %r", plugin_name) return {} result = collections.defaultdict( lambda: collections.defaultdict(lambda: [])) for (run, run_info) in six.iteritems(self._multiplexer.Runs()): tag = None if not self._test_run_tag(run_tag_filter, run, tag): continue if not run_info[plugin_event_accumulator.GRAPH]: continue time_series = result[run][tag] wall_time = 0.0 # dummy value for graph step = 0 # dummy value for graph index = 0 # dummy value for graph # In some situations these blobs may have directly accessible URLs. # But, for now, we assume they don't. graph_url = None graph_blob_key = _encode_blob_key(experiment_id, plugin_name, run, tag, step, index) blob_ref = provider.BlobReference(graph_blob_key, graph_url) datum = provider.BlobSequenceDatum( wall_time=wall_time, step=step, values=(blob_ref, ), ) time_series.append(datum) return result
def read_blob_sequences( self, ctx, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): with timing.log_latency("build request"): req = data_provider_pb2.ReadBlobSequencesRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) req.downsample.num_points = downsample with timing.log_latency("_stub.ReadBlobSequences"): with _translate_grpc_error(): res = self._stub.ReadBlobSequences(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: series = [] tags[tag_entry.tag_name] = series d = tag_entry.data for (step, wt, blob_sequence) in zip( d.step, d.wall_time, d.values ): values = [] for ref in blob_sequence.blob_refs: values.append( provider.BlobReference( blob_key=ref.blob_key, url=ref.url or None ) ) point = provider.BlobSequenceDatum( step=step, wall_time=wt, values=tuple(values) ) series.append(point) return result