def test_e2e_success_case(self): mock_api_client = self._create_mock_api_client() mock_api_client.StreamExperiments.return_value = iter( [_make_experiments_response(["789"])]) def stream_experiments(request, **kwargs): del request # unused self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) response = export_service_pb2.StreamExperimentsResponse() response.experiments.add(experiment_id="123") response.experiments.add(experiment_id="456") yield response response = export_service_pb2.StreamExperimentsResponse() experiment = response.experiments.add() experiment.experiment_id = "789" experiment.name = "bert" experiment.description = "ernie" util.set_timestamp(experiment.create_time, 981173106) util.set_timestamp(experiment.update_time, 1015218367) yield response def stream_experiment_data(request, **kwargs): self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) for run in ("train", "test"): for tag in ("accuracy", "loss"): response = export_service_pb2.StreamExperimentDataResponse( ) response.run_name = run response.tag_name = tag display_name = "%s:%s" % (request.experiment_id, tag) response.tag_metadata.CopyFrom( test_util.scalar_metadata(display_name)) for step in range(10): response.points.steps.append(step) response.points.values.append(2.0 * step) response.points.wall_times.add(seconds=1571084520 + step, nanos=862939144) yield response mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) mock_api_client.StreamExperimentData = mock.Mock( wraps=stream_experiment_data) outdir = os.path.join(self.get_temp_dir(), "outdir") exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) start_time = 1571084846.25 start_time_pb = test_util.timestamp_pb(1571084846250000000) def outdir_files(): # Recursively list `outdir`. result = [] for (dirpath, dirnames, filenames) in os.walk(outdir): for filename in filenames: fullpath = os.path.join(dirpath, filename) result.append(os.path.relpath(fullpath, outdir)) return result generator = exporter.export(read_time=start_time) expected_files = [] self.assertTrue(os.path.isdir(outdir)) self.assertCountEqual(expected_files, outdir_files()) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # The first iteration should request the list of experiments and # data for one of them. self.assertEqual(next(generator), "123") expected_files.append(os.path.join("experiment_123", "metadata.json")) expected_files.append(os.path.join("experiment_123", "scalars.json")) self.assertCountEqual(expected_files, outdir_files()) expected_eids_request = export_service_pb2.StreamExperimentsRequest() expected_eids_request.read_timestamp.CopyFrom(start_time_pb) expected_eids_request.limit = 2**63 - 1 expected_eids_request.experiments_mask.create_time = True expected_eids_request.experiments_mask.update_time = True expected_eids_request.experiments_mask.name = True expected_eids_request.experiments_mask.description = True mock_api_client.StreamExperiments.assert_called_once_with( expected_eids_request, metadata=grpc_util.version_metadata()) expected_data_request = export_service_pb2.StreamExperimentDataRequest( ) expected_data_request.experiment_id = "123" expected_data_request.read_timestamp.CopyFrom(start_time_pb) mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # The next iteration should just request data for the next experiment. mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(next(generator), "456") expected_files.append(os.path.join("experiment_456", "metadata.json")) expected_files.append(os.path.join("experiment_456", "scalars.json")) self.assertCountEqual(expected_files, outdir_files()) mock_api_client.StreamExperiments.assert_not_called() expected_data_request.experiment_id = "456" mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # Again, request data for the next experiment; this experiment ID # was in the second response batch in the list of IDs. expected_files.append(os.path.join("experiment_789", "metadata.json")) expected_files.append(os.path.join("experiment_789", "scalars.json")) mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(next(generator), "789") self.assertCountEqual(expected_files, outdir_files()) mock_api_client.StreamExperiments.assert_not_called() expected_data_request.experiment_id = "789" mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # The final continuation shouldn't need to send any RPCs. mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(list(generator), []) self.assertCountEqual(expected_files, outdir_files()) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # Spot-check one of the scalar data files. with open(os.path.join(outdir, "experiment_456", "scalars.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 4) datum = jsons[2] self.assertEqual(datum.pop("run"), "test") self.assertEqual(datum.pop("tag"), "accuracy") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) expected_summary_metadata = test_util.scalar_metadata("456:accuracy") self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") expected_steps = [x for x in range(10)] expected_values = [2.0 * x for x in range(10)] expected_wall_times = [1571084520.862939144 + x for x in range(10)] self.assertEqual(points.pop("steps"), expected_steps) self.assertEqual(points.pop("values"), expected_values) self.assertEqual(points.pop("wall_times"), expected_wall_times) self.assertEqual(points, {}) self.assertEqual(datum, {}) # Spot-check one of the metadata files. with open(os.path.join(outdir, "experiment_789", "metadata.json")) as infile: metadata = json.load(infile) self.assertEqual( metadata, { "name": "bert", "description": "ernie", "create_time": "2001-02-03T04:05:06Z", "update_time": "2002-03-04T05:06:07Z", }, )
def test_e2e_success_case(self): mock_api_client = self._create_mock_api_client() mock_api_client.StreamExperiments.return_value = iter([ export_service_pb2.StreamExperimentsResponse( experiment_ids=["789"]), ]) def stream_experiments(request, **kwargs): del request # unused self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) yield export_service_pb2.StreamExperimentsResponse( experiment_ids=["123", "456"]) yield export_service_pb2.StreamExperimentsResponse( experiment_ids=["789"]) def stream_experiment_data(request, **kwargs): self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) for run in ("train", "test"): for tag in ("accuracy", "loss"): response = export_service_pb2.StreamExperimentDataResponse( ) response.run_name = run response.tag_name = tag display_name = "%s:%s" % (request.experiment_id, tag) response.tag_metadata.CopyFrom( test_util.scalar_metadata(display_name)) for step in range(10): response.points.steps.append(step) response.points.values.append(2.0 * step) response.points.wall_times.add(seconds=1571084520 + step, nanos=862939144) yield response mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) mock_api_client.StreamExperimentData = mock.Mock( wraps=stream_experiment_data) outdir = os.path.join(self.get_temp_dir(), "outdir") exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) start_time = 1571084846.25 start_time_pb = test_util.timestamp_pb(1571084846250000000) generator = exporter.export(read_time=start_time) expected_files = [] self.assertTrue(os.path.isdir(outdir)) self.assertCountEqual(expected_files, os.listdir(outdir)) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # The first iteration should request the list of experiments and # data for one of them. self.assertEqual(next(generator), "123") expected_files.append("scalars_123.json") self.assertCountEqual(expected_files, os.listdir(outdir)) expected_eids_request = export_service_pb2.StreamExperimentsRequest() expected_eids_request.read_timestamp.CopyFrom(start_time_pb) expected_eids_request.limit = 2**63 - 1 mock_api_client.StreamExperiments.assert_called_once_with( expected_eids_request, metadata=grpc_util.version_metadata()) expected_data_request = export_service_pb2.StreamExperimentDataRequest( ) expected_data_request.experiment_id = "123" expected_data_request.read_timestamp.CopyFrom(start_time_pb) mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # The next iteration should just request data for the next experiment. mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(next(generator), "456") expected_files.append("scalars_456.json") self.assertCountEqual(expected_files, os.listdir(outdir)) mock_api_client.StreamExperiments.assert_not_called() expected_data_request.experiment_id = "456" mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # Again, request data for the next experiment; this experiment ID # was in the second response batch in the list of IDs. expected_files.append("scalars_789.json") mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(next(generator), "789") self.assertCountEqual(expected_files, os.listdir(outdir)) mock_api_client.StreamExperiments.assert_not_called() expected_data_request.experiment_id = "789" mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # The final continuation shouldn't need to send any RPCs. mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(list(generator), []) self.assertCountEqual(expected_files, os.listdir(outdir)) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # Spot-check one of the files. with open(os.path.join(outdir, "scalars_456.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 4) datum = jsons[2] self.assertEqual(datum.pop("run"), "test") self.assertEqual(datum.pop("tag"), "accuracy") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) expected_summary_metadata = test_util.scalar_metadata("456:accuracy") self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") expected_steps = [x for x in range(10)] expected_values = [2.0 * x for x in range(10)] expected_wall_times = [1571084520.862939144 + x for x in range(10)] self.assertEqual(points.pop("steps"), expected_steps) self.assertEqual(points.pop("values"), expected_values) self.assertEqual(points.pop("wall_times"), expected_wall_times) self.assertEqual(points, {}) self.assertEqual(datum, {})
def test_e2e_success_case_with_blob_sequence_data(self): """Covers exporting of complete and incomplete blob sequences as well as rpc error during blob streaming. """ mock_api_client = self._create_mock_api_client() def stream_experiments(request, **kwargs): del request # unused self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) response = export_service_pb2.StreamExperimentsResponse() response.experiments.add(experiment_id="123") yield response response = export_service_pb2.StreamExperimentsResponse() response.experiments.add(experiment_id="456") yield response def stream_experiment_data(request, **kwargs): self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) tag = "__default_graph__" for run in ("train", "test"): response = export_service_pb2.StreamExperimentDataResponse() response.run_name = run response.tag_name = tag display_name = "%s:%s" % (request.experiment_id, tag) response.tag_metadata.CopyFrom( summary_pb2.SummaryMetadata( data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE)) for step in range(1): response.blob_sequences.steps.append(step) response.blob_sequences.wall_times.add(seconds=1571084520 + step, nanos=862939144) blob_sequence = blob_pb2.BlobSequence() if run == "train": # A finished blob sequence. blob = blob_pb2.Blob( blob_id="%s_blob" % run, state=blob_pb2.BlobState.BLOB_STATE_CURRENT, ) blob_sequence.entries.append( blob_pb2.BlobSequenceEntry(blob=blob)) # An unfinished blob sequence. blob = blob_pb2.Blob( state=blob_pb2.BlobState.BLOB_STATE_UNFINALIZED, ) blob_sequence.entries.append( blob_pb2.BlobSequenceEntry(blob=blob)) elif run == "test": blob_sequence.entries.append( # `blob` unspecified: a hole in the blob sequence. blob_pb2.BlobSequenceEntry()) response.blob_sequences.values.append(blob_sequence) yield response mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) mock_api_client.StreamExperimentData = mock.Mock( wraps=stream_experiment_data) mock_api_client.StreamBlobData.side_effect = [ iter([ export_service_pb2.StreamBlobDataResponse( data=b"4321", offset=0, final_chunk=False, ), export_service_pb2.StreamBlobDataResponse( data=b"8765", offset=4, final_chunk=True, ), ]), # Raise error from `StreamBlobData` to test the grpc-error # condition. test_util.grpc_error(grpc.StatusCode.INTERNAL, "Error for testing"), ] outdir = os.path.join(self.get_temp_dir(), "outdir") exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) start_time = 1571084846.25 start_time_pb = test_util.timestamp_pb(1571084846250000000) generator = exporter.export(read_time=start_time) expected_files = [] self.assertTrue(os.path.isdir(outdir)) self.assertCountEqual(expected_files, _outdir_files(outdir)) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # The first iteration should request the list of experiments and # data for one of them. self.assertEqual(next(generator), "123") expected_files.append(os.path.join("experiment_123", "metadata.json")) # scalars.json should exist and be empty. expected_files.append(os.path.join("experiment_123", "scalars.json")) expected_files.append( os.path.join("experiment_123", "blob_sequences.json")) expected_files.append( os.path.join("experiment_123", "blobs", "blob_train_blob.bin")) # blobs/blob_test_blob.bin should not exist, because it contains # an unfinished blob. self.assertCountEqual(expected_files, _outdir_files(outdir)) # Check that the scalars data file is empty, because there no scalars. with open(os.path.join(outdir, "experiment_123", "scalars.json")) as infile: self.assertEqual(infile.read(), "") # Check the blob_sequences.json file. with open(os.path.join(outdir, "experiment_123", "blob_sequences.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 2) datum = jsons[0] self.assertEqual(datum.pop("run"), "train") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) expected_summary_metadata = summary_pb2.SummaryMetadata( data_class=summary_pb2.DATA_CLASS_BLOB_SEQUENCE) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # The 1st blob is finished; the 2nd is unfinished. self.assertEqual(points.pop("blob_file_paths"), [["blobs/blob_train_blob.bin", None]]) self.assertEqual(points, {}) datum = jsons[1] self.assertEqual(datum.pop("run"), "test") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # `None` blob file path indicates an unfinished blob. self.assertEqual(points.pop("blob_file_paths"), [[None]]) self.assertEqual(points, {}) # Check the BLOB files. with open( os.path.join(outdir, "experiment_123", "blobs", "blob_train_blob.bin"), "rb", ) as f: self.assertEqual(f.read(), b"43218765") # Check call to StreamBlobData. expected_blob_data_request = export_service_pb2.StreamBlobDataRequest( blob_id="train_blob") mock_api_client.StreamBlobData.assert_called_once_with( expected_blob_data_request, metadata=grpc_util.version_metadata()) # Test the case where blob streaming errors out. self.assertEqual(next(generator), "456") # Check the blob_sequences.json file. with open(os.path.join(outdir, "experiment_456", "blob_sequences.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 2) datum = jsons[0] self.assertEqual(datum.pop("run"), "train") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # `None` represents the blob that experienced error during downloading # and hence is missing. self.assertEqual(points.pop("blob_file_paths"), [[None, None]]) self.assertEqual(points, {}) datum = jsons[1] self.assertEqual(datum.pop("run"), "test") self.assertEqual(datum.pop("tag"), "__default_graph__") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(datum, {}) self.assertEqual(points.pop("steps"), [0]) self.assertEqual(points.pop("wall_times"), [1571084520.862939144]) # `None` represents the blob that experienced error during downloading # and hence is missing. self.assertEqual(points.pop("blob_file_paths"), [[None]]) self.assertEqual(points, {})
def test_e2e_success_case_with_only_tensors_data(self): mock_api_client = self._create_mock_api_client() def stream_experiments(request, **kwargs): del request # unused self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) response = export_service_pb2.StreamExperimentsResponse() response.experiments.add(experiment_id="123") yield response def stream_experiment_data(request, **kwargs): self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) for run in ("train_1", "train_2"): for tag in ("dense_1/kernel", "dense_1/bias", "text/test"): response = export_service_pb2.StreamExperimentDataResponse( ) response.run_name = run response.tag_name = tag display_name = "%s:%s" % (request.experiment_id, tag) response.tag_metadata.CopyFrom( test_util.scalar_metadata(display_name)) for step in range(2): response.tensors.steps.append(step) response.tensors.wall_times.add( seconds=1571084520 + step, nanos=862939144 if run == "train_1" else 962939144, ) if tag != "text/test": response.tensors.values.append( tensor_util.make_tensor_proto( np.ones([3, 2]) * step)) else: response.tensors.values.append( tensor_util.make_tensor_proto( np.full([3], "a" * (step + 1)))) yield response mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) mock_api_client.StreamExperimentData = mock.Mock( wraps=stream_experiment_data) outdir = os.path.join(self.get_temp_dir(), "outdir") exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) start_time = 1571084846.25 start_time_pb = test_util.timestamp_pb(1571084846250000000) generator = exporter.export(read_time=start_time) expected_files = [] self.assertTrue(os.path.isdir(outdir)) self.assertCountEqual(expected_files, _outdir_files(outdir)) mock_api_client.StreamExperiments.assert_not_called() mock_api_client.StreamExperimentData.assert_not_called() # The first iteration should request the list of experiments and # data for one of them. self.assertEqual(next(generator), "123") expected_files.append(os.path.join("experiment_123", "metadata.json")) # scalars.json should exist and be empty. expected_files.append(os.path.join("experiment_123", "scalars.json")) expected_files.append(os.path.join("experiment_123", "tensors.json")) # blob_sequences.json should exist and be empty. expected_files.append( os.path.join("experiment_123", "blob_sequences.json")) expected_files.append( os.path.join("experiment_123", "tensors", "1571084520.862939.npz")) expected_files.append( os.path.join("experiment_123", "tensors", "1571084520.862939_1.npz")) expected_files.append( os.path.join("experiment_123", "tensors", "1571084520.862939_2.npz")) expected_files.append( os.path.join("experiment_123", "tensors", "1571084520.962939.npz")) expected_files.append( os.path.join("experiment_123", "tensors", "1571084520.962939_1.npz")) expected_files.append( os.path.join("experiment_123", "tensors", "1571084520.962939_2.npz")) self.assertCountEqual(expected_files, _outdir_files(outdir)) # Check that the scalars and blob_sequences data files are empty, because # there are no scalars or blob sequences. with open(os.path.join(outdir, "experiment_123", "scalars.json")) as infile: self.assertEqual(infile.read(), "") with open(os.path.join(outdir, "experiment_123", "blob_sequences.json")) as infile: self.assertEqual(infile.read(), "") expected_eids_request = export_service_pb2.StreamExperimentsRequest() expected_eids_request.read_timestamp.CopyFrom(start_time_pb) expected_eids_request.limit = 2**63 - 1 expected_eids_request.experiments_mask.create_time = True expected_eids_request.experiments_mask.update_time = True expected_eids_request.experiments_mask.name = True expected_eids_request.experiments_mask.description = True mock_api_client.StreamExperiments.assert_called_once_with( expected_eids_request, metadata=grpc_util.version_metadata()) expected_data_request = export_service_pb2.StreamExperimentDataRequest( ) expected_data_request.experiment_id = "123" expected_data_request.read_timestamp.CopyFrom(start_time_pb) mock_api_client.StreamExperimentData.assert_called_once_with( expected_data_request, metadata=grpc_util.version_metadata()) # The final StreamExperiments continuation shouldn't need to send any # RPCs. mock_api_client.StreamExperiments.reset_mock() mock_api_client.StreamExperimentData.reset_mock() self.assertEqual(list(generator), []) # Check tensor data. with open(os.path.join(outdir, "experiment_123", "tensors.json")) as infile: jsons = [json.loads(line) for line in infile] self.assertLen(jsons, 6) datum = jsons[0] self.assertEqual(datum.pop("run"), "train_1") self.assertEqual(datum.pop("tag"), "dense_1/kernel") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) expected_summary_metadata = test_util.scalar_metadata( "123:dense_1/kernel") self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(points.pop("steps"), [0, 1]) self.assertEqual( points.pop("tensors_file_path"), os.path.join("tensors", "1571084520.862939.npz"), ) self.assertEqual(datum, {}) datum = jsons[4] self.assertEqual(datum.pop("run"), "train_2") self.assertEqual(datum.pop("tag"), "dense_1/bias") summary_metadata = summary_pb2.SummaryMetadata.FromString( base64.b64decode(datum.pop("summary_metadata"))) expected_summary_metadata = test_util.scalar_metadata( "123:dense_1/bias") self.assertEqual(summary_metadata, expected_summary_metadata) points = datum.pop("points") self.assertEqual(points.pop("steps"), [0, 1]) self.assertEqual( points.pop("tensors_file_path"), os.path.join("tensors", "1571084520.962939_1.npz"), ) self.assertEqual(datum, {}) # Load and check the tensor data from the save .npz files. for filename in ( "1571084520.862939.npz", "1571084520.862939_1.npz", "1571084520.962939.npz", "1571084520.962939_1.npz", ): tensors = np.load( os.path.join(outdir, "experiment_123", "tensors", filename)) tensors = [tensors[key] for key in tensors.keys()] self.assertLen(tensors, 2) np.testing.assert_array_equal(tensors[0], 0 * np.ones([3, 2])) np.testing.assert_array_equal(tensors[1], 1 * np.ones([3, 2])) for filename in ( "1571084520.862939_2.npz", "1571084520.962939_2.npz", ): tensors = np.load( os.path.join(outdir, "experiment_123", "tensors", filename)) tensors = [tensors[key] for key in tensors.keys()] self.assertLen(tensors, 2) np.testing.assert_array_equal(tensors[0], np.array(["a", "a", "a"], "|S")) np.testing.assert_array_equal(tensors[1], np.array(["aa", "aa", "aa"], "|S"))