コード例 #1
0
ファイル: uploader_test.py プロジェクト: doc22940/tensorboard
    def test_unauthorized(self):
        mock_client = _create_mock_client()
        error = test_util.grpc_error(grpc.StatusCode.PERMISSION_DENIED, "nope")
        mock_client.DeleteExperiment.side_effect = error

        with self.assertRaises(uploader_lib.PermissionDeniedError):
            uploader_lib.delete_experiment(mock_client, "123")
コード例 #2
0
ファイル: uploader_test.py プロジェクト: doc22940/tensorboard
    def test_not_found(self):
        mock_client = _create_mock_client()
        error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
        mock_client.DeleteExperiment.side_effect = error

        with self.assertRaises(uploader_lib.ExperimentNotFoundError):
            uploader_lib.delete_experiment(mock_client, "123")
コード例 #3
0
ファイル: uploader_test.py プロジェクト: doc22940/tensorboard
    def test_internal_error(self):
        mock_client = _create_mock_client()
        error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "travesty")
        mock_client.DeleteExperiment.side_effect = error

        with self.assertRaises(grpc.RpcError) as cm:
            uploader_lib.delete_experiment(mock_client, "123")
        msg = str(cm.exception)
        self.assertIn("travesty", msg)
コード例 #4
0
 def test_upload_propagates_experiment_deletion(self):
     logdir = self.get_temp_dir()
     with tb_test_util.FileWriter(logdir) as writer:
         writer.add_test_summary("foo")
     mock_client = _create_mock_client()
     uploader = _create_uploader(mock_client, logdir)
     uploader.create_experiment()
     error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
     mock_client.WriteScalar.side_effect = error
     with self.assertRaises(uploader_lib.ExperimentNotFoundError):
         uploader._upload_once()
コード例 #5
0
 def test_upload_swallows_rpc_failure(self):
     logdir = self.get_temp_dir()
     with tb_test_util.FileWriter(logdir) as writer:
         writer.add_test_summary("foo")
     mock_client = _create_mock_client()
     uploader = _create_uploader(mock_client, logdir)
     uploader.create_experiment()
     error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "Failure")
     mock_client.WriteScalar.side_effect = error
     uploader._upload_once()
     mock_client.WriteScalar.assert_called_once()
コード例 #6
0
    def test_invalid_argument(self):
        mock_client = _create_mock_client()
        error = test_util.grpc_error(
            grpc.StatusCode.INVALID_ARGUMENT, "too many"
        )
        mock_client.UpdateExperiment.side_effect = error

        with self.assertRaises(uploader_lib.InvalidArgumentError) as cm:
            uploader_lib.update_experiment_metadata(mock_client, "123", name="")
        msg = str(cm.exception)
        self.assertIn("too many", msg)
コード例 #7
0
    def test_upload_server_error(self):
        mock_client = _create_mock_client()
        mock_rate_limiter = mock.create_autospec(util.RateLimiter)
        mock_blob_rate_limiter = mock.create_autospec(util.RateLimiter)
        uploader = _create_uploader(
            mock_client,
            "/logs/foo",
            rpc_rate_limiter=mock_rate_limiter,
            blob_rpc_rate_limiter=mock_blob_rate_limiter,
            allowed_plugins=[
                scalars_metadata.PLUGIN_NAME,
                graphs_metadata.PLUGIN_NAME,
            ],
        )
        uploader.create_experiment()

        # Of course a real Event stream will never produce the same Event twice,
        # but is this test context it's fine to reuse this one.
        graph_event = event_pb2.Event(
            graph_def=_create_example_graph_bytes(950))

        mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader)
        mock_logdir_loader.get_run_events.side_effect = [
            {
                "run 1": [graph_event],
            },
            {
                "run 1": [graph_event],
            },
            AbortUploadError,
        ]

        mock_client.WriteBlob.side_effect = [
            [write_service_pb2.WriteBlobResponse()],
            test_util.grpc_error(grpc.StatusCode.INTERNAL, "nope"),
        ]

        # This demonstrates that the INTERNAL error is NOT handled, so the
        # uploader will die if this happens.
        with mock.patch.object(uploader, "_logdir_loader",
                               mock_logdir_loader), self.assertRaises(
                                   grpc.RpcError):
            uploader.start_uploading()
        self.assertEqual(1, mock_client.CreateExperiment.call_count)
        self.assertEqual(2, mock_client.WriteBlob.call_count)
        self.assertEqual(0, mock_rate_limiter.tick.call_count)
        self.assertEqual(2, mock_blob_rate_limiter.tick.call_count)
コード例 #8
0
    def test_upload_same_graph_twice(self):
        mock_client = _create_mock_client()
        mock_rate_limiter = mock.create_autospec(util.RateLimiter)
        mock_blob_rate_limiter = mock.create_autospec(util.RateLimiter)
        uploader = _create_uploader(
            mock_client,
            "/logs/foo",
            rpc_rate_limiter=mock_rate_limiter,
            blob_rpc_rate_limiter=mock_blob_rate_limiter,
            allowed_plugins=[
                scalars_metadata.PLUGIN_NAME,
                graphs_metadata.PLUGIN_NAME,
            ],
        )
        uploader.create_experiment()

        graph_event = event_pb2.Event(
            graph_def=_create_example_graph_bytes(950))

        mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader)
        mock_logdir_loader.get_run_events.side_effect = [
            {
                "run 1": [graph_event],
            },
            {
                "run 1": [graph_event],
            },
            AbortUploadError,
        ]

        mock_client.WriteBlob.side_effect = [
            [write_service_pb2.WriteBlobResponse()],
            test_util.grpc_error(grpc.StatusCode.ALREADY_EXISTS, "nope"),
        ]

        # This demonstrates that the ALREADY_EXISTS error is handled gracefully.
        with mock.patch.object(
                uploader, "_logdir_loader",
                mock_logdir_loader), self.assertRaises(AbortUploadError):
            uploader.start_uploading()
        self.assertEqual(1, mock_client.CreateExperiment.call_count)
        self.assertEqual(2, mock_client.WriteBlob.call_count)
        self.assertEqual(0, mock_rate_limiter.tick.call_count)
        self.assertEqual(2, mock_blob_rate_limiter.tick.call_count)
コード例 #9
0
 def stream_experiment_data(request, **kwargs):
     del request  # unused
     raise test_util.grpc_error(grpc.StatusCode.INTERNAL,
                                "details string")
コード例 #10
0
 def stream_experiment_data(request, **kwargs):
     raise test_util.grpc_error(grpc.StatusCode.CANCELLED,
                                "details string")
コード例 #11
0
    def test_e2e_success_case_with_blob_sequence_data(self):
        """Covers exporting of complete and incomplete blob sequences

        as well as rpc error during blob streaming.
        """
        mock_api_client = self._create_mock_api_client()

        def stream_experiments(request, **kwargs):
            del request  # unused
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="123")
            yield response
            response = export_service_pb2.StreamExperimentsResponse()
            response.experiments.add(experiment_id="456")
            yield response

        def stream_experiment_data(request, **kwargs):
            self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())

            tag = "__default_graph__"
            for run in ("train", "test"):
                response = export_service_pb2.StreamExperimentDataResponse()
                response.run_name = run
                response.tag_name = tag
                display_name = "%s:%s" % (request.experiment_id, tag)
                response.tag_metadata.CopyFrom(
                    summary_pb2.SummaryMetadata(
                        data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE))
                for step in range(1):
                    response.blob_sequences.steps.append(step)
                    response.blob_sequences.wall_times.add(seconds=1571084520 +
                                                           step,
                                                           nanos=862939144)
                    blob_sequence = blob_pb2.BlobSequence()
                    if run == "train":
                        # A finished blob sequence.
                        blob = blob_pb2.Blob(
                            blob_id="%s_blob" % run,
                            state=blob_pb2.BlobState.BLOB_STATE_CURRENT,
                        )
                        blob_sequence.entries.append(
                            blob_pb2.BlobSequenceEntry(blob=blob))
                        # An unfinished blob sequence.
                        blob = blob_pb2.Blob(
                            state=blob_pb2.BlobState.BLOB_STATE_UNFINALIZED, )
                        blob_sequence.entries.append(
                            blob_pb2.BlobSequenceEntry(blob=blob))
                    elif run == "test":
                        blob_sequence.entries.append(
                            # `blob` unspecified: a hole in the blob sequence.
                            blob_pb2.BlobSequenceEntry())
                    response.blob_sequences.values.append(blob_sequence)
                yield response

        mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
        mock_api_client.StreamExperimentData = mock.Mock(
            wraps=stream_experiment_data)
        mock_api_client.StreamBlobData.side_effect = [
            iter([
                export_service_pb2.StreamBlobDataResponse(
                    data=b"4321",
                    offset=0,
                    final_chunk=False,
                ),
                export_service_pb2.StreamBlobDataResponse(
                    data=b"8765",
                    offset=4,
                    final_chunk=True,
                ),
            ]),
            # Raise error from `StreamBlobData` to test the grpc-error
            # condition.
            test_util.grpc_error(grpc.StatusCode.INTERNAL,
                                 "Error for testing"),
        ]

        outdir = os.path.join(self.get_temp_dir(), "outdir")
        exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
        start_time = 1571084846.25
        start_time_pb = test_util.timestamp_pb(1571084846250000000)

        generator = exporter.export(read_time=start_time)
        expected_files = []
        self.assertTrue(os.path.isdir(outdir))
        self.assertCountEqual(expected_files, _outdir_files(outdir))
        mock_api_client.StreamExperiments.assert_not_called()
        mock_api_client.StreamExperimentData.assert_not_called()

        # The first iteration should request the list of experiments and
        # data for one of them.
        self.assertEqual(next(generator), "123")
        expected_files.append(os.path.join("experiment_123", "metadata.json"))
        # scalars.json and tensors.json should exist and be empty.
        expected_files.append(os.path.join("experiment_123", "scalars.json"))
        expected_files.append(os.path.join("experiment_123", "tensors.json"))
        expected_files.append(
            os.path.join("experiment_123", "blob_sequences.json"))
        expected_files.append(
            os.path.join("experiment_123", "blobs", "blob_train_blob.bin"))
        # blobs/blob_test_blob.bin should not exist, because it contains
        # an unfinished blob.
        self.assertCountEqual(expected_files, _outdir_files(outdir))

        # Check that the scalars and tensors data files are empty, because there
        # no scalars or tensors.
        with open(os.path.join(outdir, "experiment_123",
                               "scalars.json")) as infile:
            self.assertEqual(infile.read(), "")
        with open(os.path.join(outdir, "experiment_123",
                               "tensors.json")) as infile:
            self.assertEqual(infile.read(), "")

        # Check the blob_sequences.json file.
        with open(os.path.join(outdir, "experiment_123",
                               "blob_sequences.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 2)

        datum = jsons[0]
        self.assertEqual(datum.pop("run"), "train")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        expected_summary_metadata = summary_pb2.SummaryMetadata(
            data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE)
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # The 1st blob is finished; the 2nd is unfinished.
        self.assertEqual(points.pop("blob_file_paths"),
                         [["blobs/blob_train_blob.bin", None]])
        self.assertEqual(points, {})

        datum = jsons[1]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # `None` blob file path indicates an unfinished blob.
        self.assertEqual(points.pop("blob_file_paths"), [[None]])
        self.assertEqual(points, {})

        # Check the BLOB files.
        with open(
                os.path.join(outdir, "experiment_123", "blobs",
                             "blob_train_blob.bin"),
                "rb",
        ) as f:
            self.assertEqual(f.read(), b"43218765")

        # Check call to StreamBlobData.
        expected_blob_data_request = export_service_pb2.StreamBlobDataRequest(
            blob_id="train_blob")
        mock_api_client.StreamBlobData.assert_called_once_with(
            expected_blob_data_request, metadata=grpc_util.version_metadata())

        # Test the case where blob streaming errors out.
        self.assertEqual(next(generator), "456")
        # Check the blob_sequences.json file.
        with open(os.path.join(outdir, "experiment_456",
                               "blob_sequences.json")) as infile:
            jsons = [json.loads(line) for line in infile]
        self.assertLen(jsons, 2)

        datum = jsons[0]
        self.assertEqual(datum.pop("run"), "train")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # `None` represents the blob that experienced error during downloading
        # and hence is missing.
        self.assertEqual(points.pop("blob_file_paths"), [[None, None]])
        self.assertEqual(points, {})

        datum = jsons[1]
        self.assertEqual(datum.pop("run"), "test")
        self.assertEqual(datum.pop("tag"), "__default_graph__")
        summary_metadata = summary_pb2.SummaryMetadata.FromString(
            base64.b64decode(datum.pop("summary_metadata")))
        self.assertEqual(summary_metadata, expected_summary_metadata)
        points = datum.pop("points")
        self.assertEqual(datum, {})
        self.assertEqual(points.pop("steps"), [0])
        self.assertEqual(points.pop("wall_times"), [1571084520.862939144])
        # `None` represents the blob that experienced error during downloading
        # and hence is missing.
        self.assertEqual(points.pop("blob_file_paths"), [[None]])
        self.assertEqual(points, {})