def _get_mnist(directory): """Download all MNIST files to directory unless they are there.""" for filename in [ _MNIST_TRAIN_DATA_FILENAME, _MNIST_TRAIN_LABELS_FILENAME, _MNIST_TEST_DATA_FILENAME, _MNIST_TEST_LABELS_FILENAME ]: generator_utils.maybe_download(directory, filename, _MNIST_URL + filename)
def _maybe_download_corpus(tmp_dir): """Download and unpack the corpus. Args: tmp_dir: directory containing dataset. """ corpus_url = ( "http://www.statmt.org/lm-benchmark/" "1-billion-word-language-modeling-benchmark-r13output.tar.gz") corpus_filename = os.path.basename(corpus_url) corpus_filepath = os.path.join(tmp_dir, corpus_filename) if not os.path.exists(corpus_filepath): generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url) with tarfile.open(corpus_filepath, "r:gz") as corpus_tar: corpus_tar.extractall(tmp_dir)
def generator(self, data_dir, tmp_dir, train): """Generate examples.""" # Download and extract compressed_filename = os.path.basename(self.URL) download_path = generator_utils.maybe_download(tmp_dir, compressed_filename, self.URL) imdb_dir = os.path.join(tmp_dir, "aclImdb") if not tf.gfile.Exists(imdb_dir): with tarfile.open(download_path, "r:gz") as tar: tar.extractall(tmp_dir) # Generate vocab encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, self.doc_generator(imdb_dir, "train")) # Generate examples dataset = "train" if train else "test" for doc, label in self.doc_generator(imdb_dir, dataset, include_label=True): yield { "inputs": encoder.encode(doc) + [EOS], "targets": [int(label)], }
def _maybe_download_corpus(tmp_dir): """Download corpus if necessary. Args: tmp_dir: directory containing dataset. Returns: filepath of the downloaded corpus file. """ corpus_url = ("https://dumps.wikimedia.org/enwiki/20170620/" "enwiki-20170620-pages-articles-multistream.xml.bz2") corpus_filename = os.path.basename(corpus_url) corpus_filepath = os.path.join(tmp_dir, corpus_filename) if not tf.gfile.Exists(corpus_filepath): generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url) return corpus_filepath
def _get_mscoco(directory): """Download and extract MSCOCO datasets to directory unless it is there.""" for url in _MSCOCO_URLS: filename = os.path.basename(url) download_url = os.path.join(_MSCOCO_ROOT_URL, url) path = generator_utils.maybe_download(directory, filename, download_url) unzip_dir = os.path.join(directory, filename.strip(".zip")) if not tf.gfile.Exists(unzip_dir): zipfile.ZipFile(path, "r").extractall(directory)
def _download_and_parse_dataset(tmp_dir, train): """Downloads and prepairs the dataset to be parsed by the data_generator.""" file_path = generator_utils.maybe_download(tmp_dir, _SNLI_ZIP, _SNLI_URL) zip_ref = zipfile.ZipFile(file_path, 'r') zip_ref.extractall(tmp_dir) zip_ref.close() file_name = 'train' if train else 'dev' dataset_file_path = os.path.join(tmp_dir, _SNLI_DATA_PATH % file_name) _parse_dataset(dataset_file_path, tmp_dir, train)
def _original_vocab(tmp_dir): """Returns a set containing the original vocabulary. This is important for comparing with published results. Args: tmp_dir: directory containing dataset. Returns: a set of strings """ vocab_url = ("http://download.tensorflow.org/models/LM_LSTM_CNN/" "vocab-2016-09-10.txt") vocab_filename = os.path.basename(vocab_url + ".en") vocab_filepath = os.path.join(tmp_dir, vocab_filename) if not os.path.exists(vocab_filepath): generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url) return set([ text_encoder.native_to_unicode(l.strip()) for l in tf.gfile.Open(vocab_filepath) ])
def testMaybeDownload(self): tmp_dir = self.get_temp_dir() (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir) tmp_file_name = os.path.basename(tmp_file_path) # Download Google index to the temporary file.http. res_path = generator_utils.maybe_download(tmp_dir, tmp_file_name + ".http", "http://google.com") self.assertEqual(res_path, tmp_file_path + ".http") # Clean up. os.remove(tmp_file_path + ".http") os.remove(tmp_file_path)
def generate_data(self, data_dir, tmp_dir, task_id=-1): try: # Download source data if download_url specified h5_filepath = generator_utils.maybe_download( tmp_dir, self.h5_file, self.download_url) except NotImplementedError: # Otherwise, look for it locally h5_filepath = os.path.join(tmp_dir, self.h5_file) with h5py.File(h5_filepath, "r") as h5_file: num_train_examples = h5_file["train_in"].len() num_dev_examples = h5_file["valid_in"].len() num_test_examples = h5_file["test_in"].len() # Collect all_filepaths to later shuffle all_filepaths = [] # Collect created shard processes to start and join processes = [] datasets = [(self.training_filepaths, self.num_shards, "train", num_train_examples), (self.dev_filepaths, 10, "valid", num_dev_examples), (self.test_filepaths, 10, "test", num_test_examples)] for fname_fn, nshards, key_prefix, num_examples in datasets: outfiles = fname_fn(data_dir, nshards, shuffled=False) all_filepaths.extend(outfiles) for start_idx, end_idx, outfile in generate_shard_args( outfiles, num_examples): p = mp.Process(target=generate_dataset, args=(h5_filepath, key_prefix, [outfile], self.chunk_size, start_idx, end_idx)) processes.append(p) # 1 per training shard + 10 for dev + 10 for test assert len(processes) == self.num_shards + 20 # Start and wait for processes in batches num_batches = int( math.ceil(float(len(processes)) / MAX_CONCURRENT_PROCESSES)) for i in xrange(num_batches): start = i * MAX_CONCURRENT_PROCESSES end = start + MAX_CONCURRENT_PROCESSES current = processes[start:end] for p in current: p.start() for p in current: p.join() # Shuffle generator_utils.shuffle_dataset(all_filepaths)
def generate_data(self, data_dir, tmp_dir, task_id=-1): list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/" "street/python/fsns_urls.txt") fsns_urls = generator_utils.maybe_download(tmp_dir, "fsns_urls.txt", list_url) fsns_files = [ f.strip() for f in open(fsns_urls, "r") if f.startswith("http://") ] for url in fsns_files: if "/train/train" in url: generator_utils.maybe_download( data_dir, "image_fsns-train" + url[-len("-00100-of-00512"):], url) elif "/validation/validation" in url: generator_utils.maybe_download( data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url) elif "charset" in url: generator_utils.maybe_download(data_dir, "charset_size134.txt", url)
def generator(self, data_dir, tmp_dir, train): filename = os.path.basename(PTB_URL) compressed_filepath = generator_utils.maybe_download( tmp_dir, filename, PTB_URL) ptb_files = [] ptb_char_files = [] with tarfile.open(compressed_filepath, "r:gz") as tgz: files = [] # Selecting only relevant files. for m in tgz.getmembers(): if "ptb" in m.name and ".txt" in m.name: if "char" in m.name: ptb_char_files += [m.name] else: ptb_files += [m.name] files += [m] tgz.extractall(tmp_dir, members=files) if self.is_character_level: files = ptb_char_files else: files = ptb_files train_file, valid_file = None, None for filename in files: if "train" in filename: train_file = os.path.join(tmp_dir, filename) elif "valid" in filename: valid_file = os.path.join(tmp_dir, filename) assert train_file, "Training file not found" assert valid_file, "Validation file not found" if self.is_character_level: encoder = text_encoder.ByteTextEncoder() else: encoder = _get_token_encoder(data_dir, self.vocab_file, train_file) if train: return self._generator(train_file, encoder) return self._generator(valid_file, encoder)
def _compile_data(tmp_dir, datasets, filename): """Concatenate all `datasets` and save to `filename`.""" filename = os.path.join(tmp_dir, filename) with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_resfile: with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_resfile: for dataset in datasets: url = dataset[0] compressed_filename = os.path.basename(url) compressed_filepath = os.path.join(tmp_dir, compressed_filename) generator_utils.maybe_download(tmp_dir, compressed_filename, url) if dataset[1][0] == 'tsv': _, src_column, trg_column, glob_pattern = dataset[1] filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) if not filenames: mode = "r:gz" if compressed_filepath.endswith( "gz") else "r" # *.tgz *.tar.gz with tarfile.open(compressed_filepath, mode) as corpus_tar: corpus_tar.extractall(tmp_dir) filenames = glob.glob( os.path.join(tmp_dir, glob_pattern)) for tsv_filename in filenames: if tsv_filename.endswith(".gz"): new_filename = tsv_filename.strip(".gz") try: generator_utils.gunzip_file( tsv_filename, new_filename) except PermissionError: tsvdir = os.path.dirname(tsv_filename) os.chmod( tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE) generator_utils.gunzip_file( tsv_filename, new_filename) tsv_filename = new_filename with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file: for line in tsv_file: if line and "\t" in line: parts = line.split("\t") source, target = parts[src_column], parts[ trg_column] lang1_resfile.write(source.strip() + "\n") lang2_resfile.write(target.strip() + "\n") else: lang1_filename, lang2_filename = dataset[1] lang1_filepath = os.path.join(tmp_dir, lang1_filename) lang2_filepath = os.path.join(tmp_dir, lang2_filename) is_sgm = (lang1_filename.endswith("sgm") and lang2_filename.endswith("sgm")) if not (os.path.exists(lang1_filepath) and os.path.exists(lang2_filepath)): # For .tar.gz and .tgz files, we read compressed. mode = "r:gz" if compressed_filepath.endswith( "gz") else "r" with tarfile.open(compressed_filepath, mode) as corpus_tar: corpus_tar.extractall(tmp_dir) if lang1_filepath.endswith(".gz"): new_filepath = lang1_filepath.strip(".gz") generator_utils.gunzip_file(lang1_filepath, new_filepath) lang1_filepath = new_filepath if lang2_filepath.endswith(".gz"): new_filepath = lang2_filepath.strip(".gz") generator_utils.gunzip_file(lang2_filepath, new_filepath) lang2_filepath = new_filepath with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file: with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file: line1, line2 = lang1_file.readline( ), lang2_file.readline() while line1 or line2: line1res = _preprocess_sgm(line1, is_sgm) line2res = _preprocess_sgm(line2, is_sgm) if line1res or line2res: lang1_resfile.write(line1res.strip() + "\n") lang2_resfile.write(line2res.strip() + "\n") line1, line2 = lang1_file.readline( ), lang2_file.readline() return filename
def _get_cifar10(directory): """Download and extract CIFAR to directory unless it is there.""" filename = os.path.basename(_CIFAR10_URL) path = generator_utils.maybe_download(directory, filename, _CIFAR10_URL) tarfile.open(path, "r:gz").extractall(directory)