def _download_blob(self, blob_id, experiment_id): """Download the blob via rpc. Args: blob_id: Id of the blob. experiment_id: Id of the experiment that the blob belongs to. Returns: If the blob is downloaded successfully: The path of the downloaded blob file relative to the experiment directory. Else: `None`. """ # TODO(cais): Deduplicate with internal method perhaps. experiment_dir = _experiment_directory(self._outdir, experiment_id) request = export_service_pb2.StreamBlobDataRequest(blob_id=blob_id) blob_abspath = os.path.join( experiment_dir, _DIRNAME_BLOBS, _FILENAME_BLOBS_PREFIX + blob_id + _FILENAME_BLOBS_SUFFIX, ) with open(blob_abspath, "xb") as f: try: for response in self._api.StreamBlobData( request, metadata=grpc_util.version_metadata() ): # TODO(cais, soergel): validate the various response fields f.write(response.data) except grpc.RpcError as rpc_error: logger.error( "Omitting blob (id: %s) due to download failure: %s", blob_id, rpc_error, ) return None return os.path.relpath(blob_abspath, experiment_dir)
def test_e2e_success_case_with_blob_sequence_data(self): """Covers exporting of complete and incomplete blob sequences as well as rpc error during blob streaming. """ mock_api_client = self._create_mock_api_client() def stream_experiments(request, **kwargs): del request # unused self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) response = export_service_pb2.StreamExperimentsResponse() response.experiments.add(experiment_id="123") yield response response = export_service_pb2.StreamExperimentsResponse() response.experiments.add(experiment_id="456") yield response def stream_experiment_data(request, **kwargs): self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) tag = "__default_graph__" for run in ("train", "test"): response = export_service_pb2.StreamExperimentDataResponse() response.run_name = run response.tag_name = tag display_name = "%s:%s" % (request.experiment_id, tag) response.tag_metadata.CopyFrom( summary_pb2.SummaryMetadata( data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE)) for step in range(1): response.blob_sequences.steps.append(step) response.blob_sequences.wall_times.add(seconds=1571084520 + step, nanos=862939144) blob_sequence = blob_pb2.BlobSequence() if run == "train": # A finished blob sequence. blob = blob_pb2.Blob( blob_id="%s_blob" % run, state=blob_pb2.BlobState.BLOB_STATE_CURRENT, ) blob_sequence.entries.append( blob_pb2.BlobSequenceEntry(blob=blob)) # An unfinished blob sequence. blob = blob_pb2.Blob( state=blob_pb2.BlobState.BLOB_STATE_UNFINALIZED, ) blob_sequence.entries.append( blob_pb2.BlobSequenceEntry(blob=blob)) elif run == "test": blob_sequence.entries.append( # `blob` unspecified: a hole in the blob sequence. blob_pb2.BlobSequenceEntry()) response.blob_sequences.values.append(blob_sequence) yield response mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) mock_api_client.StreamExperimentData = mock.Mock( wraps=stream_experiment_data) mock_api_client.StreamBlobData.side_effect = [ iter([ export_service_pb2.StreamBlobDataResponse( data=b"4321", offset=0, final_chunk=False, ), export_service_pb2.StreamBlobDataResponse( data=b"8765", offset=4, final_chunk=True, ), ]), # Raise error from `StreamBlobData` to test the grpc-error # condition. test_util.grpc_error(grpc.StatusCode.INTERNAL, "Error for testing"), ] outdir = os.path.join(self.get_temp_dir(), "outdir") exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) start_time = 1571084846.25 start_time_pb = test_util.timestamp_pb(1571084846250000000) generator = exporter.export(read_time=start_time) expected_files = [] self.assertTrue(os.path.isdir(outdir)) self.assertCountEqual(expected_files, _outdir_files(outdir)) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # The first iteration should request the list of experiments and # data for one of them. self.assertEqual(next(generator), "123") expected_files.append(os.path.join("experiment_123", "metadata.json")) # scalars.json and tensors.json should exist and be empty. expected_files.append(os.path.join("experiment_123", "scalars.json")) expected_files.append(os.path.join("experiment_123", "tensors.json")) expected_files.append( os.path.join("experiment_123", "blob_sequences.json")) expected_files.append( os.path.join("experiment_123", "blobs", "blob_train_blob.bin")) # blobs/blob_test_blob.bin should not exist, because it contains # an unfinished blob. self.assertCountEqual(expected_files, _outdir_files(outdir)) # Check that the scalars and tensors data files are empty, because there # no scalars or tensors. with open(os.path.join(outdir, "experiment_123", "scalars.json")) as infile: self.assertEqual(infile.read(), "") with open(os.path.join(outdir, "experiment_123", "tensors.json")) as infile: self.assertEqual(infile.read(), "") # Check the blob_sequences.json file. with open(os.path.join(outdir, "experiment_123", "blob_sequences.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 2) datum = jsons[0] self.assertEqual(datum.pop("run"), "train") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) expected_summary_metadata = summary_pb2.SummaryMetadata( data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # The 1st blob is finished; the 2nd is unfinished. self.assertEqual(points.pop("blob_file_paths"), [["blobs/blob_train_blob.bin", None]]) self.assertEqual(points, {}) datum = jsons[1] self.assertEqual(datum.pop("run"), "test") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # `None` blob file path indicates an unfinished blob. self.assertEqual(points.pop("blob_file_paths"), [[None]]) self.assertEqual(points, {}) # Check the BLOB files. with open( os.path.join(outdir, "experiment_123", "blobs", "blob_train_blob.bin"), "rb", ) as f: self.assertEqual(f.read(), b"43218765") # Check call to StreamBlobData. expected_blob_data_request = export_service_pb2.StreamBlobDataRequest( blob_id="train_blob") mock_api_client.StreamBlobData.assert_called_once_with( expected_blob_data_request, metadata=grpc_util.version_metadata()) # Test the case where blob streaming errors out. self.assertEqual(next(generator), "456") # Check the blob_sequences.json file. with open(os.path.join(outdir, "experiment_456", "blob_sequences.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 2) datum = jsons[0] self.assertEqual(datum.pop("run"), "train") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # `None` represents the blob that experienced error during downloading # and hence is missing. self.assertEqual(points.pop("blob_file_paths"), [[None, None]]) self.assertEqual(points, {}) datum = jsons[1] self.assertEqual(datum.pop("run"), "test") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # `None` represents the blob that experienced error during downloading # and hence is missing. self.assertEqual(points.pop("blob_file_paths"), [[None]]) self.assertEqual(points, {})