예제 #1
0
 def _request_scalar_data(self, experiment_id, read_time):
     """Yields JSON-serializable blocks of scalar data."""
     request = export_service_pb2.StreamExperimentDataRequest()
     request.experiment_id = experiment_id
     util.set_timestamp(request.read_timestamp, read_time)
     # No special error handling as we don't expect any errors from these
     # calls: all experiments should exist (read consistency timestamp)
     # and be owned by the calling user (only queried for own experiment
     # IDs). Any non-transient errors would be internal, and we have no
     # way to efficiently resume from transient errors because the server
     # does not support pagination.
     stream = self._api.StreamExperimentData(
         request, metadata=grpc_util.version_metadata())
     for response in stream:
         metadata = base64.b64encode(
             response.tag_metadata.SerializeToString()).decode("ascii")
         wall_times = [
             t.ToNanoseconds() / 1e9 for t in response.points.wall_times
         ]
         yield {
             u"run": response.run_name,
             u"tag": response.tag_name,
             u"summary_metadata": metadata,
             u"points": {
                 u"steps": list(response.points.steps),
                 u"wall_times": wall_times,
                 u"values": list(response.points.values),
             },
         }
예제 #2
0
    def _request_json_data(self, experiment_id, read_time):
        """Given experiment id, generates JSON data and destination file name.

        The JSON data describes the run, tag, metadata, in addition to
          - Actual data in the case of scalars
          - Pointer to binary files in the case of blob sequences.

        For the case of blob sequences, this method has the side effect of
          downloading the contents of the blobs and writing them to files in
          a subdirectory of the experiment directory.

        Args:
          experiment_id: The id of the experiment to request data for.
          read_time: A fixed timestamp from which to export data, as float
            seconds since epoch (like `time.time()`). Optional; defaults to the
            current time.

        Yields:
          (JSON-serializable data, destination file name) tuples.
        """
        request = export_service_pb2.StreamExperimentDataRequest()
        request.experiment_id = experiment_id
        util.set_timestamp(request.read_timestamp, read_time)
        # No special error handling as we don't expect any errors from these
        # calls: all experiments should exist (read consistency timestamp)
        # and be owned by the calling user (only queried for own experiment
        # IDs). Any non-transient errors would be internal, and we have no
        # way to efficiently resume from transient errors because the server
        # does not support pagination.
        stream = self._api.StreamExperimentData(
            request, metadata=grpc_util.version_metadata()
        )
        for response in stream:
            metadata = base64.b64encode(
                response.tag_metadata.SerializeToString()
            ).decode("ascii")
            json_data = {
                u"run": response.run_name,
                u"tag": response.tag_name,
                u"summary_metadata": metadata,
            }
            filename = None
            if response.HasField("points"):
                json_data[u"points"] = self._process_scalar_points(
                    response.points
                )
                filename = _FILENAME_SCALARS
            elif response.HasField("tensors"):
                json_data[u"points"] = self._process_tensor_points(
                    response.tensors, experiment_id
                )
                filename = _FILENAME_TENSORS
            elif response.HasField("blob_sequences"):
                json_data[u"points"] = self._process_blob_sequence_points(
                    response.blob_sequences, experiment_id
                )
                filename = _FILENAME_BLOB_SEQUENCES
            if filename:
                yield json_data, filename
예제 #3
0
    def get_scalars(
        self,
        runs_filter=None,
        tags_filter=None,
        pivot=False,
        include_wall_time=False,
    ):
        # NOTE(#3650): Import pandas early in this method, so if the
        # Python environment does not have pandas installed, an error can be
        # raised early, before any rpc call is made.
        pandas = import_pandas()
        if runs_filter is not None:
            raise NotImplementedError(
                "runs_filter support for get_scalars() is not implemented yet."
            )
        if tags_filter is not None:
            raise NotImplementedError(
                "tags_filter support for get_scalars() is not implemented yet."
            )

        request = export_service_pb2.StreamExperimentDataRequest()
        request.experiment_id = self._experiment_id
        read_time = time.time()
        util.set_timestamp(request.read_timestamp, read_time)
        # TODO(cais, wchargin): Use another rpc to check for staleness and avoid
        # a new StreamExperimentData rpc request if data is not stale.
        stream = self._api_client.StreamExperimentData(
            request, metadata=grpc_util.version_metadata()
        )

        runs = []
        tags = []
        steps = []
        wall_times = []
        values = []
        for response in stream:
            # TODO(cais, wchargin): Display progress bar during data loading.
            num_values = len(response.points.values)
            runs.extend([response.run_name] * num_values)
            tags.extend([response.tag_name] * num_values)
            steps.extend(list(response.points.steps))
            wall_times.extend(
                [t.ToNanoseconds() / 1e9 for t in response.points.wall_times]
            )
            values.extend(list(response.points.values))

        data = {
            "run": runs,
            "tag": tags,
            "step": steps,
            "value": values,
        }
        if include_wall_time:
            data["wall_time"] = wall_times
        dataframe = pandas.DataFrame(data)
        if pivot:
            dataframe = self._pivot_dataframe(dataframe)
        return dataframe
예제 #4
0
    def get_scalars(self, runs_filter=None, tags_filter=None, pivot=None):
        if runs_filter is not None:
            raise NotImplementedError(
                "runs_filter support for get_scalars() is not implemented yet."
            )
        if tags_filter is not None:
            raise NotImplementedError(
                "tags_filter support for get_scalars() is not implemented yet."
            )
        pivot = True if pivot is None else pivot

        request = export_service_pb2.StreamExperimentDataRequest()
        request.experiment_id = self._experiment_id
        read_time = time.time()
        util.set_timestamp(request.read_timestamp, read_time)
        # TODO(cais, wchargin): Use another rpc to check for staleness and avoid
        # a new StreamExperimentData rpc request if data is not stale.
        stream = self._api_client.StreamExperimentData(
            request, metadata=grpc_util.version_metadata())

        runs = []
        tags = []
        steps = []
        wall_times = []
        values = []
        for response in stream:
            # TODO(cais, wchargin): Display progress bar during data loading.
            num_values = len(response.points.values)
            runs.extend([response.run_name] * num_values)
            tags.extend([response.tag_name] * num_values)
            steps.extend(list(response.points.steps))
            wall_times.extend(
                [t.ToNanoseconds() / 1e9 for t in response.points.wall_times])
            values.extend(list(response.points.values))

        dataframe = pandas.DataFrame({
            "run": runs,
            "tag": tags,
            "step": steps,
            "wall_time": wall_times,
            "value": values,
        })
        if pivot:
            dataframe = self._pivot_dataframe(dataframe)
        return dataframe
예제 #5
0
    def test_e2e_success_case_with_only_scalar_data(self):
        mock_api_client = self._create_mock_api_client()
        mock_api_client.StreamExperiments.return_value = iter(
            [_make_experiments_response(["789"])])

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="123")
            response.experiments.add(experiment_id="456")
            yield response

            response = export_service_pb2.StreamExperimentsResponse()
            experiment = response.experiments.add()
            experiment.experiment_id = "789"
            experiment.name = "bert"
            experiment.description = "ernie"
            util.set_timestamp(experiment.create_time, 981173106)
            util.set_timestamp(experiment.update_time, 1015218367)
            yield response

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
            for run in ("train", "test"):
                for tag in ("accuracy", "loss"):
                    response = export_service_pb2.StreamExperimentDataResponse(
                    )
                    response.run_name = run
                    response.tag_name = tag
                    display_name = "%s:%s" % (request.experiment_id, tag)
                    response.tag_metadata.CopyFrom(
                        test_util.scalar_metadata(display_name))
                    for step in range(10):
                        response.points.steps.append(step)
                        response.points.values.append(2.0 * step)
                        response.points.wall_times.add(seconds=1571084520 +
                                                       step,
                                                       nanos=862939144)
                    yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append(os.path.join("experiment_123", "metadata.json"))
        expected_files.append(os.path.join("experiment_123", "scalars.json"))
        expected_files.append(os.path.join("experiment_123", "tensors.json"))
        # blob_sequences.json should exist and be empty.
        expected_files.append(
            os.path.join("experiment_123", "blob_sequences.json"))
        self.assertCountEqual(expected_files, _outdir_files(outdir))

        # Check that the tensors and blob_sequences data files are empty, because
        # there are no tensors or blob sequences.
        with open(os.path.join(outdir, "experiment_123",
                               "tensors.json")) as infile:
            self.assertEqual(infile.read(), "")
        with open(os.path.join(outdir, "experiment_123",
                               "blob_sequences.json")) as infile:
            self.assertEqual(infile.read(), "")

        expected_eids_request = export_service_pb2.StreamExperimentsRequest()
        expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
        expected_eids_request.limit = 2**63 - 1
        expected_eids_request.experiments_mask.create_time = True
        expected_eids_request.experiments_mask.update_time = True
        expected_eids_request.experiments_mask.name = True
        expected_eids_request.experiments_mask.description = True
        mock_api_client.StreamExperiments.assert_called_once_with(
            expected_eids_request, metadata=grpc_util.version_metadata())

        expected_data_request = export_service_pb2.StreamExperimentDataRequest(
        )
        expected_data_request.experiment_id = "123"
        expected_data_request.read_timestamp.CopyFrom(start_time_pb)
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The next iteration should just request data for the next experiment.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "456")

        expected_files.append(os.path.join("experiment_456", "metadata.json"))
        expected_files.append(os.path.join("experiment_456", "scalars.json"))
        expected_files.append(os.path.join("experiment_456", "tensors.json"))
        # blob_sequences.json should exist and be empty.
        expected_files.append(
            os.path.join("experiment_456", "blob_sequences.json"))
        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "456"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # Again, request data for the next experiment; this experiment ID
        # was in the second response batch in the list of IDs.
        expected_files.append(os.path.join("experiment_789", "metadata.json"))
        expected_files.append(os.path.join("experiment_789", "scalars.json"))
        expected_files.append(os.path.join("experiment_789", "tensors.json"))
        # blob_sequences.json should exist and be empty.
        expected_files.append(
            os.path.join("experiment_789", "blob_sequences.json"))
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "789")

        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "789"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The final continuation shouldn't need to send any RPCs.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(list(generator), [])

        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # Spot-check one of the scalar data files.
        with open(os.path.join(outdir, "experiment_456",
                               "scalars.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 4)
        datum = jsons[2]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "accuracy")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata("456:accuracy")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        expected_steps = [x for x in range(10)]
        expected_values = [2.0 * x for x in range(10)]
        expected_wall_times = [1571084520.862939144 + x for x in range(10)]
        self.assertEqual(points.pop("steps"), expected_steps)
        self.assertEqual(points.pop("values"), expected_values)
        self.assertEqual(points.pop("wall_times"), expected_wall_times)
        self.assertEqual(points, {})
        self.assertEqual(datum, {})

        # Check that one of the blob_sequences data file is empty, because there
        # no blob sequences in this experiment.
        with open(os.path.join(outdir, "experiment_456",
                               "blob_sequences.json")) as infile:
            self.assertEqual(infile.read(), "")

        # Spot-check one of the metadata files.
        with open(os.path.join(outdir, "experiment_789",
                               "metadata.json")) as infile:
            metadata = json.load(infile)
        self.assertEqual(
            metadata,
            {
                "name": "bert",
                "description": "ernie",
                "create_time": "2001-02-03T04:05:06Z",
                "update_time": "2002-03-04T05:06:07Z",
            },
        )
예제 #6
0
    def test_e2e_success_case_with_only_tensors_data(self):
        mock_api_client = self._create_mock_api_client()

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="123")
            yield response

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
            for run in ("train_1", "train_2"):
                for tag in ("dense_1/kernel", "dense_1/bias", "text/test"):
                    response = export_service_pb2.StreamExperimentDataResponse(
                    )
                    response.run_name = run
                    response.tag_name = tag
                    display_name = "%s:%s" % (request.experiment_id, tag)
                    response.tag_metadata.CopyFrom(
                        test_util.scalar_metadata(display_name))
                    for step in range(2):
                        response.tensors.steps.append(step)
                        response.tensors.wall_times.add(
                            seconds=1571084520 + step,
                            nanos=862939144 if run == "train_1" else 962939144,
                        )
                        if tag != "text/test":
                            response.tensors.values.append(
                                tensor_util.make_tensor_proto(
                                    np.ones([3, 2]) * step))
                        else:
                            response.tensors.values.append(
                                tensor_util.make_tensor_proto(
                                    np.full([3], "a" * (step + 1))))
                    yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append(os.path.join("experiment_123", "metadata.json"))
        # scalars.json should exist and be empty.
        expected_files.append(os.path.join("experiment_123", "scalars.json"))
        expected_files.append(os.path.join("experiment_123", "tensors.json"))
        # blob_sequences.json should exist and be empty.
        expected_files.append(
            os.path.join("experiment_123", "blob_sequences.json"))
        expected_files.append(
            os.path.join("experiment_123", "tensors", "1571084520.862939.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.862939_1.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.862939_2.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors", "1571084520.962939.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.962939_1.npz"))
        expected_files.append(
            os.path.join("experiment_123", "tensors",
                         "1571084520.962939_2.npz"))
        self.assertCountEqual(expected_files, _outdir_files(outdir))

        # Check that the scalars and blob_sequences data files are empty, because
        # there are no scalars or blob sequences.
        with open(os.path.join(outdir, "experiment_123",
                               "scalars.json")) as infile:
            self.assertEqual(infile.read(), "")
        with open(os.path.join(outdir, "experiment_123",
                               "blob_sequences.json")) as infile:
            self.assertEqual(infile.read(), "")

        expected_eids_request = export_service_pb2.StreamExperimentsRequest()
        expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
        expected_eids_request.limit = 2**63 - 1
        expected_eids_request.experiments_mask.create_time = True
        expected_eids_request.experiments_mask.update_time = True
        expected_eids_request.experiments_mask.name = True
        expected_eids_request.experiments_mask.description = True
        mock_api_client.StreamExperiments.assert_called_once_with(
            expected_eids_request, metadata=grpc_util.version_metadata())

        expected_data_request = export_service_pb2.StreamExperimentDataRequest(
        )
        expected_data_request.experiment_id = "123"
        expected_data_request.read_timestamp.CopyFrom(start_time_pb)
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request, metadata=grpc_util.version_metadata())

        # The final StreamExperiments continuation shouldn't need to send any
        # RPCs.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(list(generator), [])

        # Check tensor data.
        with open(os.path.join(outdir, "experiment_123",
                               "tensors.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 6)

        datum = jsons[0]
        self.assertEqual(datum.pop("run"), "train_1")
        self.assertEqual(datum.pop("tag"), "dense_1/kernel")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata(
            "123:dense_1/kernel")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(points.pop("steps"), [0, 1])
        self.assertEqual(
            points.pop("tensors_file_path"),
            os.path.join("tensors", "1571084520.862939.npz"),
        )
        self.assertEqual(datum, {})

        datum = jsons[4]
        self.assertEqual(datum.pop("run"), "train_2")
        self.assertEqual(datum.pop("tag"), "dense_1/bias")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata(
            "123:dense_1/bias")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(points.pop("steps"), [0, 1])
        self.assertEqual(
            points.pop("tensors_file_path"),
            os.path.join("tensors", "1571084520.962939_1.npz"),
        )
        self.assertEqual(datum, {})

        # Load and check the tensor data from the save .npz files.
        for filename in (
                "1571084520.862939.npz",
                "1571084520.862939_1.npz",
                "1571084520.962939.npz",
                "1571084520.962939_1.npz",
        ):
            tensors = np.load(
                os.path.join(outdir, "experiment_123", "tensors", filename))
            tensors = [tensors[key] for key in tensors.keys()]
            self.assertLen(tensors, 2)
            np.testing.assert_array_equal(tensors[0], 0 * np.ones([3, 2]))
            np.testing.assert_array_equal(tensors[1], 1 * np.ones([3, 2]))

        for filename in (
                "1571084520.862939_2.npz",
                "1571084520.962939_2.npz",
        ):
            tensors = np.load(
                os.path.join(outdir, "experiment_123", "tensors", filename))
            tensors = [tensors[key] for key in tensors.keys()]
            self.assertLen(tensors, 2)
            np.testing.assert_array_equal(tensors[0],
                                          np.array(["a", "a", "a"], "|S"))
            np.testing.assert_array_equal(tensors[1],
                                          np.array(["aa", "aa", "aa"], "|S"))
예제 #7
0
    def test_e2e_success_case(self):
        mock_api_client = self._create_mock_api_client()
        mock_api_client.StreamExperiments.return_value = iter([
            export_service_pb2.StreamExperimentsResponse(
                experiment_ids=["789"]),
        ])

        def stream_experiments(request):
            del request  # unused
            yield export_service_pb2.StreamExperimentsResponse(
                experiment_ids=["123", "456"])
            yield export_service_pb2.StreamExperimentsResponse(
                experiment_ids=["789"])

        def stream_experiment_data(request):
            for run in ("train", "test"):
                for tag in ("accuracy", "loss"):
                    response = export_service_pb2.StreamExperimentDataResponse(
                    )
                    response.run_name = run
                    response.tag_name = tag
                    display_name = "%s:%s" % (request.experiment_id, tag)
                    response.tag_metadata.CopyFrom(
                        test_util.scalar_metadata(display_name))
                    for step in range(10):
                        response.points.steps.append(step)
                        response.points.values.append(2.0 * step)
                        response.points.wall_times.add(seconds=1571084520 +
                                                       step,
                                                       nanos=862939144)
                    yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append("scalars_123.json")
        self.assertCountEqual(expected_files, os.listdir(outdir))

        expected_eids_request = export_service_pb2.StreamExperimentsRequest()
        expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
        expected_eids_request.limit = 2**63 - 1
        mock_api_client.StreamExperiments.assert_called_once_with(
            expected_eids_request)

        expected_data_request = export_service_pb2.StreamExperimentDataRequest(
        )
        expected_data_request.experiment_id = "123"
        expected_data_request.read_timestamp.CopyFrom(start_time_pb)
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request)

        # The next iteration should just request data for the next experiment.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "456")

        expected_files.append("scalars_456.json")
        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "456"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request)

        # Again, request data for the next experiment; this experiment ID
        # was in the second response batch in the list of IDs.
        expected_files.append("scalars_789.json")
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(next(generator), "789")

        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        expected_data_request.experiment_id = "789"
        mock_api_client.StreamExperimentData.assert_called_once_with(
            expected_data_request)

        # The final continuation shouldn't need to send any RPCs.
        mock_api_client.StreamExperiments.reset_mock()
        mock_api_client.StreamExperimentData.reset_mock()
        self.assertEqual(list(generator), [])

        self.assertCountEqual(expected_files, os.listdir(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # Spot-check one of the files.
        with open(os.path.join(outdir, "scalars_456.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 4)
        datum = jsons[2]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "accuracy")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = test_util.scalar_metadata("456:accuracy")
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        expected_steps = [x for x in range(10)]
        expected_values = [2.0 * x for x in range(10)]
        expected_wall_times = [1571084520.862939144 + x for x in range(10)]
        self.assertEqual(points.pop("steps"), expected_steps)
        self.assertEqual(points.pop("values"), expected_values)
        self.assertEqual(points.pop("wall_times"), expected_wall_times)
        self.assertEqual(points, {})
        self.assertEqual(datum, {})