def initialize_from_bucket(self): """Initialize DatasetInfo from GCS bucket info files.""" # In order to support Colab, we use the HTTP GCS API to access the metadata # files. They are copied locally and then loaded. tmp_dir = tempfile.mkdtemp("tfds") data_files = gcs_utils.gcs_dataset_info_files(self.full_name) if not data_files: return logging.info("Loading info from GCS for %s", self.full_name) for fname in data_files: out_fname = os.path.join(tmp_dir, os.path.basename(fname)) gcs_utils.download_gcs_file(fname, out_fname) self.read_from_directory(tmp_dir)
def initialize_from_bucket(self): """Initialize DatasetInfo from GCS bucket info files.""" # In order to support Colab, we use the HTTP GCS API to access the metadata # files. They are copied locally and then loaded. tmp_dir = tempfile.mkdtemp("tfds") data_files = gcs_utils.gcs_dataset_info_files(self.full_name) if not data_files: return logging.info("Load pre-computed DatasetInfo (eg: splits, num examples,...) " "from GCS: %s", self.full_name) for fname in data_files: out_fname = os.path.join(tmp_dir, os.path.basename(fname)) tf.io.gfile.copy(gcs_utils.gcs_path(fname), out_fname) self.read_from_directory(tmp_dir)
def test_download_dataset(self): files = [ 'dataset_info/mnist/2.0.0/dataset_info.json', 'dataset_info/mnist/2.0.0/image.image.json', ] with self.gcs_access(): self.assertCountEqual( gcs_utils.gcs_dataset_info_files('mnist/2.0.0'), files, ) with tempfile.TemporaryDirectory() as f: gcs_utils.download_gcs_dataset('mnist/2.0.0', f) self.assertCountEqual(os.listdir(f), [ 'mnist-test.tfrecord-00000-of-00001', 'mnist-train.tfrecord-00000-of-00001', 'dataset_info.json', 'image.image.json', ])
def test_download_dataset(self): files = [ 'gs://tfds-data/dataset_info/mnist/2.0.0/dataset_info.json', 'gs://tfds-data/dataset_info/mnist/2.0.0/image.image.json', ] with self.gcs_access(): self.assertCountEqual( gcs_utils.gcs_dataset_info_files('mnist/2.0.0'), [tfds.core.as_path(f) for f in files], ) with tempfile.TemporaryDirectory() as tmp_dir: gcs_utils.download_gcs_dataset( dataset_name='mnist/2.0.0', local_dataset_dir=tmp_dir) self.assertCountEqual( os.listdir(tmp_dir), [ 'mnist-test.tfrecord-00000-of-00001', 'mnist-train.tfrecord-00000-of-00001', 'dataset_info.json', 'image.image.json', ])