Esempio n. 1
0
    def test_e2e_success_case(self):
        mock_api_client = self._create_mock_api_client()
        mock_api_client.StreamExperiments.return_value = iter(
            [_make_experiments_response(["789"])])

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="123")
            response.experiments.add(experiment_id="456")
            yield response

            response = export_service_pb2.StreamExperimentsResponse()
            experiment = response.experiments.add()
            experiment.experiment_id = "789"
            experiment.name = "bert"
            experiment.description = "ernie"
            util.set_timestamp(experiment.create_time, 981173106)
            util.set_timestamp(experiment.update_time, 1015218367)
            yield response

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
            for run in ("train", "test"):
                for tag in ("accuracy", "loss"):
                    response = export_service_pb2.StreamExperimentDataResponse(
                    )
                    response.run_name = run
                    response.tag_name = tag
                    display_name = "%s:%s" % (request.experiment_id, tag)
                    response.tag_metadata.CopyFrom(
                        test_util.scalar_metadata(display_name))
                    for step in range(10):
                        response.points.steps.append(step)
                        response.points.values.append(2.0 * step)
                        response.points.wall_times.add(seconds=1571084520 +
                                                       step,
                                                       nanos=862939144)
                    yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        def outdir_files():
            # Recursively list `outdir`.
            result = []
            for (dirpath, dirnames, filenames) in os.walk(outdir):
                for filename in filenames:
                    fullpath = os.path.join(dirpath, filename)
                    result.append(os.path.relpath(fullpath, outdir))
            return result

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, outdir_files())
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append(os.path.join("experiment_123", "metadata.json"))
        expected_files.append(os.path.join("experiment_123", "scalars.json"))
        self.assertCountEqual(expected_files, outdir_files())

        expected_eids_request = export_service_pb2.StreamExperimentsRequest()
        expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
        expected_eids_request.limit = 2**63 - 1
        expected_eids_request.experiments_mask.create_time = True
        expected_eids_request.experiments_mask.update_time = True
        expected_eids_request.experiments_mask.name = True
        expected_eids_request.experiments_mask.description = True
        mock_api_client.StreamExperiments.assert_called_once_with(
            expected_eids_request, metadata=grpc_util.version_metadata())

        expected_data_request = export_service_pb2.StreamExperimentDataRequest(
        )
        expected_data_request.experiment_id = "123"
        expected_data_request.read_timestamp.CopyFrom(start_time_pb)
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The next iteration should just request data for the next experiment.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "456")

        expected_files.append(os.path.join("experiment_456", "metadata.json"))
        expected_files.append(os.path.join("experiment_456", "scalars.json"))
        self.assertCountEqual(expected_files, outdir_files())
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "456"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # Again, request data for the next experiment; this experiment ID
        # was in the second response batch in the list of IDs.
        expected_files.append(os.path.join("experiment_789", "metadata.json"))
        expected_files.append(os.path.join("experiment_789", "scalars.json"))
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "789")

        self.assertCountEqual(expected_files, outdir_files())
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "789"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The final continuation shouldn't need to send any RPCs.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(list(generator), [])

        self.assertCountEqual(expected_files, outdir_files())
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # Spot-check one of the scalar data files.
        with open(os.path.join(outdir, "experiment_456",
                               "scalars.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 4)
        datum = jsons[2]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "accuracy")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata("456:accuracy")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        expected_steps = [x for x in range(10)]
        expected_values = [2.0 * x for x in range(10)]
        expected_wall_times = [1571084520.862939144 + x for x in range(10)]
        self.assertEqual(points.pop("steps"), expected_steps)
        self.assertEqual(points.pop("values"), expected_values)
        self.assertEqual(points.pop("wall_times"), expected_wall_times)
        self.assertEqual(points, {})
        self.assertEqual(datum, {})

        # Spot-check one of the metadata files.
        with open(os.path.join(outdir, "experiment_789",
                               "metadata.json")) as infile:
            metadata = json.load(infile)
        self.assertEqual(
            metadata,
            {
                "name": "bert",
                "description": "ernie",
                "create_time": "2001-02-03T04:05:06Z",
                "update_time": "2002-03-04T05:06:07Z",
            },
        )
Esempio n. 2
0
    def test_e2e_success_case(self):
        mock_api_client = self._create_mock_api_client()
        mock_api_client.StreamExperiments.return_value = iter([
            export_service_pb2.StreamExperimentsResponse(
                experiment_ids=["789"]),
        ])

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
            yield export_service_pb2.StreamExperimentsResponse(
                experiment_ids=["123", "456"])
            yield export_service_pb2.StreamExperimentsResponse(
                experiment_ids=["789"])

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
            for run in ("train", "test"):
                for tag in ("accuracy", "loss"):
                    response = export_service_pb2.StreamExperimentDataResponse(
                    )
                    response.run_name = run
                    response.tag_name = tag
                    display_name = "%s:%s" % (request.experiment_id, tag)
                    response.tag_metadata.CopyFrom(
                        test_util.scalar_metadata(display_name))
                    for step in range(10):
                        response.points.steps.append(step)
                        response.points.values.append(2.0 * step)
                        response.points.wall_times.add(seconds=1571084520 +
                                                       step,
                                                       nanos=862939144)
                    yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append("scalars_123.json")
        self.assertCountEqual(expected_files, os.listdir(outdir))

        expected_eids_request = export_service_pb2.StreamExperimentsRequest()
        expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
        expected_eids_request.limit = 2**63 - 1
        mock_api_client.StreamExperiments.assert_called_once_with(
            expected_eids_request, metadata=grpc_util.version_metadata())

        expected_data_request = export_service_pb2.StreamExperimentDataRequest(
        )
        expected_data_request.experiment_id = "123"
        expected_data_request.read_timestamp.CopyFrom(start_time_pb)
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The next iteration should just request data for the next experiment.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "456")

        expected_files.append("scalars_456.json")
        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "456"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # Again, request data for the next experiment; this experiment ID
        # was in the second response batch in the list of IDs.
        expected_files.append("scalars_789.json")
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "789")

        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "789"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The final continuation shouldn't need to send any RPCs.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(list(generator), [])

        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # Spot-check one of the files.
        with open(os.path.join(outdir, "scalars_456.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 4)
        datum = jsons[2]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "accuracy")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata("456:accuracy")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        expected_steps = [x for x in range(10)]
        expected_values = [2.0 * x for x in range(10)]
        expected_wall_times = [1571084520.862939144 + x for x in range(10)]
        self.assertEqual(points.pop("steps"), expected_steps)
        self.assertEqual(points.pop("values"), expected_values)
        self.assertEqual(points.pop("wall_times"), expected_wall_times)
        self.assertEqual(points, {})
        self.assertEqual(datum, {})
Esempio n. 3
0
    def test_e2e_success_case_with_blob_sequence_data(self):
        """Covers exporting of complete and incomplete blob sequences

        as well as rpc error during blob streaming.
        """
        mock_api_client = self._create_mock_api_client()

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="123")
            yield response
            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="456")
            yield response

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            tag = "__default_graph__"
            for run in ("train", "test"):
                response = export_service_pb2.StreamExperimentDataResponse()
                response.run_name = run
                response.tag_name = tag
                display_name = "%s:%s" % (request.experiment_id, tag)
                response.tag_metadata.CopyFrom(
                    summary_pb2.SummaryMetadata(
                        data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE))
                for step in range(1):
                    response.blob_sequences.steps.append(step)
                    response.blob_sequences.wall_times.add(seconds=1571084520 +
                                                           step,
                                                           nanos=862939144)
                    blob_sequence = blob_pb2.BlobSequence()
                    if run == "train":
                        # A finished blob sequence.
                        blob = blob_pb2.Blob(
                            blob_id="%s_blob" % run,
                            state=blob_pb2.BlobState.BLOB_STATE_CURRENT,
                        )
                        blob_sequence.entries.append(
                            blob_pb2.BlobSequenceEntry(blob=blob))
                        # An unfinished blob sequence.
                        blob = blob_pb2.Blob(
                            state=blob_pb2.BlobState.BLOB_STATE_UNFINALIZED, )
                        blob_sequence.entries.append(
                            blob_pb2.BlobSequenceEntry(blob=blob))
                    elif run == "test":
                        blob_sequence.entries.append(
                            # `blob` unspecified: a hole in the blob sequence.
                            blob_pb2.BlobSequenceEntry())
                    response.blob_sequences.values.append(blob_sequence)
                yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)
        mock_api_client.StreamBlobData.side_effect = [
            iter([
                export_service_pb2.StreamBlobDataResponse(
                    data=b"4321",
                    offset=0,
                    final_chunk=False,
                ),
                export_service_pb2.StreamBlobDataResponse(
                    data=b"8765",
                    offset=4,
                    final_chunk=True,
                ),
            ]),
            # Raise error from `StreamBlobData` to test the grpc-error
            # condition.
            test_util.grpc_error(grpc.StatusCode.INTERNAL,
                                 "Error for testing"),
        ]

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append(os.path.join("experiment_123", "metadata.json"))
        # scalars.json should exist and be empty.
        expected_files.append(os.path.join("experiment_123", "scalars.json"))
        expected_files.append(
            os.path.join("experiment_123", "blob_sequences.json"))
        expected_files.append(
            os.path.join("experiment_123", "blobs", "blob_train_blob.bin"))
        # blobs/blob_test_blob.bin should not exist, because it contains
        # an unfinished blob.
        self.assertCountEqual(expected_files, _outdir_files(outdir))

        # Check that the scalars data file is empty, because there no scalars.
        with open(os.path.join(outdir, "experiment_123",
                               "scalars.json")) as infile:
            self.assertEqual(infile.read(), "")

        # Check the blob_sequences.json file.
        with open(os.path.join(outdir, "experiment_123",
                               "blob_sequences.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 2)

        datum = jsons[0]
        self.assertEqual(datum.pop("run"), "train")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = summary_pb2.SummaryMetadata(
            data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE)
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # The 1st blob is finished; the 2nd is unfinished.
        self.assertEqual(points.pop("blob_file_paths"),
                         [["blobs/blob_train_blob.bin", None]])
        self.assertEqual(points, {})

        datum = jsons[1]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # `None` blob file path indicates an unfinished blob.
        self.assertEqual(points.pop("blob_file_paths"), [[None]])
        self.assertEqual(points, {})

        # Check the BLOB files.
        with open(
                os.path.join(outdir, "experiment_123", "blobs",
                             "blob_train_blob.bin"),
                "rb",
        ) as f:
            self.assertEqual(f.read(), b"43218765")

        # Check call to StreamBlobData.
        expected_blob_data_request = export_service_pb2.StreamBlobDataRequest(
            blob_id="train_blob")
        mock_api_client.StreamBlobData.assert_called_once_with(
            expected_blob_data_request, metadata=grpc_util.version_metadata())

        # Test the case where blob streaming errors out.
        self.assertEqual(next(generator), "456")
        # Check the blob_sequences.json file.
        with open(os.path.join(outdir, "experiment_456",
                               "blob_sequences.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 2)

        datum = jsons[0]
        self.assertEqual(datum.pop("run"), "train")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # `None` represents the blob that experienced error during downloading
        # and hence is missing.
        self.assertEqual(points.pop("blob_file_paths"), [[None, None]])
        self.assertEqual(points, {})

        datum = jsons[1]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # `None` represents the blob that experienced error during downloading
        # and hence is missing.
        self.assertEqual(points.pop("blob_file_paths"), [[None]])
        self.assertEqual(points, {})
Esempio n. 4
0
    def test_e2e_success_case_with_only_tensors_data(self):
        mock_api_client = self._create_mock_api_client()

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="123")
            yield response

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
            for run in ("train_1", "train_2"):
                for tag in ("dense_1/kernel", "dense_1/bias", "text/test"):
                    response = export_service_pb2.StreamExperimentDataResponse(
                    )
                    response.run_name = run
                    response.tag_name = tag
                    display_name = "%s:%s" % (request.experiment_id, tag)
                    response.tag_metadata.CopyFrom(
                        test_util.scalar_metadata(display_name))
                    for step in range(2):
                        response.tensors.steps.append(step)
                        response.tensors.wall_times.add(
                            seconds=1571084520 + step,
                            nanos=862939144 if run == "train_1" else 962939144,
                        )
                        if tag != "text/test":
                            response.tensors.values.append(
                                tensor_util.make_tensor_proto(
                                    np.ones([3, 2]) * step))
                        else:
                            response.tensors.values.append(
                                tensor_util.make_tensor_proto(
                                    np.full([3], "a" * (step + 1))))
                    yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append(os.path.join("experiment_123", "metadata.json"))
        # scalars.json should exist and be empty.
        expected_files.append(os.path.join("experiment_123", "scalars.json"))
        expected_files.append(os.path.join("experiment_123", "tensors.json"))
        # blob_sequences.json should exist and be empty.
        expected_files.append(
            os.path.join("experiment_123", "blob_sequences.json"))
        expected_files.append(
            os.path.join("experiment_123", "tensors", "1571084520.862939.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.862939_1.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.862939_2.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors", "1571084520.962939.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.962939_1.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.962939_2.npz"))
        self.assertCountEqual(expected_files, _outdir_files(outdir))

        # Check that the scalars and blob_sequences data files are empty, because
        # there are no scalars or blob sequences.
        with open(os.path.join(outdir, "experiment_123",
                               "scalars.json")) as infile:
            self.assertEqual(infile.read(), "")
        with open(os.path.join(outdir, "experiment_123",
                               "blob_sequences.json")) as infile:
            self.assertEqual(infile.read(), "")

        expected_eids_request = export_service_pb2.StreamExperimentsRequest()
        expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
        expected_eids_request.limit = 2**63 - 1
        expected_eids_request.experiments_mask.create_time = True
        expected_eids_request.experiments_mask.update_time = True
        expected_eids_request.experiments_mask.name = True
        expected_eids_request.experiments_mask.description = True
        mock_api_client.StreamExperiments.assert_called_once_with(
            expected_eids_request, metadata=grpc_util.version_metadata())

        expected_data_request = export_service_pb2.StreamExperimentDataRequest(
        )
        expected_data_request.experiment_id = "123"
        expected_data_request.read_timestamp.CopyFrom(start_time_pb)
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The final StreamExperiments continuation shouldn't need to send any
        # RPCs.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(list(generator), [])

        # Check tensor data.
        with open(os.path.join(outdir, "experiment_123",
                               "tensors.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 6)

        datum = jsons[0]
        self.assertEqual(datum.pop("run"), "train_1")
        self.assertEqual(datum.pop("tag"), "dense_1/kernel")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata(
            "123:dense_1/kernel")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(points.pop("steps"), [0, 1])
        self.assertEqual(
            points.pop("tensors_file_path"),
            os.path.join("tensors", "1571084520.862939.npz"),
        )
        self.assertEqual(datum, {})

        datum = jsons[4]
        self.assertEqual(datum.pop("run"), "train_2")
        self.assertEqual(datum.pop("tag"), "dense_1/bias")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata(
            "123:dense_1/bias")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(points.pop("steps"), [0, 1])
        self.assertEqual(
            points.pop("tensors_file_path"),
            os.path.join("tensors", "1571084520.962939_1.npz"),
        )
        self.assertEqual(datum, {})

        # Load and check the tensor data from the save .npz files.
        for filename in (
                "1571084520.862939.npz",
                "1571084520.862939_1.npz",
                "1571084520.962939.npz",
                "1571084520.962939_1.npz",
        ):
            tensors = np.load(
                os.path.join(outdir, "experiment_123", "tensors", filename))
            tensors = [tensors[key] for key in tensors.keys()]
            self.assertLen(tensors, 2)
            np.testing.assert_array_equal(tensors[0], 0 * np.ones([3, 2]))
            np.testing.assert_array_equal(tensors[1], 1 * np.ones([3, 2]))

        for filename in (
                "1571084520.862939_2.npz",
                "1571084520.962939_2.npz",
        ):
            tensors = np.load(
                os.path.join(outdir, "experiment_123", "tensors", filename))
            tensors = [tensors[key] for key in tensors.keys()]
            self.assertLen(tensors, 2)
            np.testing.assert_array_equal(tensors[0],
                                          np.array(["a", "a", "a"], "|S"))
            np.testing.assert_array_equal(tensors[1],
                                          np.array(["aa", "aa", "aa"], "|S"))