예제 #1
0
def _get_mnist(directory):
  """Download all MNIST files to directory unless they are there."""
  for filename in [
      _MNIST_TRAIN_DATA_FILENAME, _MNIST_TRAIN_LABELS_FILENAME,
      _MNIST_TEST_DATA_FILENAME, _MNIST_TEST_LABELS_FILENAME
  ]:
    generator_utils.maybe_download(directory, filename, _MNIST_URL + filename)
예제 #2
0
def _maybe_download_corpus(tmp_dir):
    """Download and unpack the corpus.

  Args:
    tmp_dir: directory containing dataset.
  """
    corpus_url = (
        "http://www.statmt.org/lm-benchmark/"
        "1-billion-word-language-modeling-benchmark-r13output.tar.gz")
    corpus_filename = os.path.basename(corpus_url)
    corpus_filepath = os.path.join(tmp_dir, corpus_filename)
    if not os.path.exists(corpus_filepath):
        generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url)
        with tarfile.open(corpus_filepath, "r:gz") as corpus_tar:
            corpus_tar.extractall(tmp_dir)
예제 #3
0
    def generator(self, data_dir, tmp_dir, train):
        """Generate examples."""
        # Download and extract
        compressed_filename = os.path.basename(self.URL)
        download_path = generator_utils.maybe_download(tmp_dir,
                                                       compressed_filename,
                                                       self.URL)
        imdb_dir = os.path.join(tmp_dir, "aclImdb")
        if not tf.gfile.Exists(imdb_dir):
            with tarfile.open(download_path, "r:gz") as tar:
                tar.extractall(tmp_dir)

        # Generate vocab
        encoder = generator_utils.get_or_generate_vocab_inner(
            data_dir, self.vocab_file, self.targeted_vocab_size,
            self.doc_generator(imdb_dir, "train"))

        # Generate examples
        dataset = "train" if train else "test"
        for doc, label in self.doc_generator(imdb_dir,
                                             dataset,
                                             include_label=True):
            yield {
                "inputs": encoder.encode(doc) + [EOS],
                "targets": [int(label)],
            }
예제 #4
0
def _maybe_download_corpus(tmp_dir):
    """Download corpus if necessary.

  Args:
    tmp_dir: directory containing dataset.

  Returns:
    filepath of the downloaded corpus file.
  """
    corpus_url = ("https://dumps.wikimedia.org/enwiki/20170620/"
                  "enwiki-20170620-pages-articles-multistream.xml.bz2")
    corpus_filename = os.path.basename(corpus_url)
    corpus_filepath = os.path.join(tmp_dir, corpus_filename)
    if not tf.gfile.Exists(corpus_filepath):
        generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url)
    return corpus_filepath
예제 #5
0
def _get_mscoco(directory):
  """Download and extract MSCOCO datasets to directory unless it is there."""
  for url in _MSCOCO_URLS:
    filename = os.path.basename(url)
    download_url = os.path.join(_MSCOCO_ROOT_URL, url)
    path = generator_utils.maybe_download(directory, filename, download_url)
    unzip_dir = os.path.join(directory, filename.strip(".zip"))
    if not tf.gfile.Exists(unzip_dir):
      zipfile.ZipFile(path, "r").extractall(directory)
예제 #6
0
def _download_and_parse_dataset(tmp_dir, train):
  """Downloads and prepairs the dataset to be parsed by the data_generator."""
  file_path = generator_utils.maybe_download(tmp_dir, _SNLI_ZIP, _SNLI_URL)
  zip_ref = zipfile.ZipFile(file_path, 'r')
  zip_ref.extractall(tmp_dir)
  zip_ref.close()

  file_name = 'train' if train else 'dev'
  dataset_file_path = os.path.join(tmp_dir, _SNLI_DATA_PATH % file_name)
  _parse_dataset(dataset_file_path, tmp_dir, train)
예제 #7
0
def _original_vocab(tmp_dir):
    """Returns a set containing the original vocabulary.

  This is important for comparing with published results.

  Args:
    tmp_dir: directory containing dataset.

  Returns:
    a set of strings
  """
    vocab_url = ("http://download.tensorflow.org/models/LM_LSTM_CNN/"
                 "vocab-2016-09-10.txt")
    vocab_filename = os.path.basename(vocab_url + ".en")
    vocab_filepath = os.path.join(tmp_dir, vocab_filename)
    if not os.path.exists(vocab_filepath):
        generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url)
    return set([
        text_encoder.native_to_unicode(l.strip())
        for l in tf.gfile.Open(vocab_filepath)
    ])
예제 #8
0
    def testMaybeDownload(self):
        tmp_dir = self.get_temp_dir()
        (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
        tmp_file_name = os.path.basename(tmp_file_path)

        # Download Google index to the temporary file.http.
        res_path = generator_utils.maybe_download(tmp_dir,
                                                  tmp_file_name + ".http",
                                                  "http://google.com")
        self.assertEqual(res_path, tmp_file_path + ".http")

        # Clean up.
        os.remove(tmp_file_path + ".http")
        os.remove(tmp_file_path)
예제 #9
0
    def generate_data(self, data_dir, tmp_dir, task_id=-1):
        try:
            # Download source data if download_url specified
            h5_filepath = generator_utils.maybe_download(
                tmp_dir, self.h5_file, self.download_url)
        except NotImplementedError:
            # Otherwise, look for it locally
            h5_filepath = os.path.join(tmp_dir, self.h5_file)

        with h5py.File(h5_filepath, "r") as h5_file:
            num_train_examples = h5_file["train_in"].len()
            num_dev_examples = h5_file["valid_in"].len()
            num_test_examples = h5_file["test_in"].len()

        # Collect all_filepaths to later shuffle
        all_filepaths = []
        # Collect created shard processes to start and join
        processes = []

        datasets = [(self.training_filepaths, self.num_shards, "train",
                     num_train_examples),
                    (self.dev_filepaths, 10, "valid", num_dev_examples),
                    (self.test_filepaths, 10, "test", num_test_examples)]
        for fname_fn, nshards, key_prefix, num_examples in datasets:
            outfiles = fname_fn(data_dir, nshards, shuffled=False)
            all_filepaths.extend(outfiles)
            for start_idx, end_idx, outfile in generate_shard_args(
                    outfiles, num_examples):
                p = mp.Process(target=generate_dataset,
                               args=(h5_filepath, key_prefix, [outfile],
                                     self.chunk_size, start_idx, end_idx))
                processes.append(p)

        # 1 per training shard + 10 for dev + 10 for test
        assert len(processes) == self.num_shards + 20

        # Start and wait for processes in batches
        num_batches = int(
            math.ceil(float(len(processes)) / MAX_CONCURRENT_PROCESSES))
        for i in xrange(num_batches):
            start = i * MAX_CONCURRENT_PROCESSES
            end = start + MAX_CONCURRENT_PROCESSES
            current = processes[start:end]
            for p in current:
                p.start()
            for p in current:
                p.join()

        # Shuffle
        generator_utils.shuffle_dataset(all_filepaths)
예제 #10
0
 def generate_data(self, data_dir, tmp_dir, task_id=-1):
   list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/"
               "street/python/fsns_urls.txt")
   fsns_urls = generator_utils.maybe_download(tmp_dir, "fsns_urls.txt",
                                              list_url)
   fsns_files = [
       f.strip() for f in open(fsns_urls, "r") if f.startswith("http://")
   ]
   for url in fsns_files:
     if "/train/train" in url:
       generator_utils.maybe_download(
           data_dir, "image_fsns-train" + url[-len("-00100-of-00512"):], url)
     elif "/validation/validation" in url:
       generator_utils.maybe_download(
           data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url)
     elif "charset" in url:
       generator_utils.maybe_download(data_dir, "charset_size134.txt", url)
예제 #11
0
  def generator(self, data_dir, tmp_dir, train):
    filename = os.path.basename(PTB_URL)
    compressed_filepath = generator_utils.maybe_download(
        tmp_dir, filename, PTB_URL)
    ptb_files = []
    ptb_char_files = []
    with tarfile.open(compressed_filepath, "r:gz") as tgz:
      files = []
      # Selecting only relevant files.
      for m in tgz.getmembers():
        if "ptb" in m.name and ".txt" in m.name:
          if "char" in m.name:
            ptb_char_files += [m.name]
          else:
            ptb_files += [m.name]
          files += [m]

      tgz.extractall(tmp_dir, members=files)

    if self.is_character_level:
      files = ptb_char_files
    else:
      files = ptb_files

    train_file, valid_file = None, None
    for filename in files:
      if "train" in filename:
        train_file = os.path.join(tmp_dir, filename)
      elif "valid" in filename:
        valid_file = os.path.join(tmp_dir, filename)

    assert train_file, "Training file not found"
    assert valid_file, "Validation file not found"

    if self.is_character_level:
      encoder = text_encoder.ByteTextEncoder()
    else:
      encoder = _get_token_encoder(data_dir, self.vocab_file, train_file)

    if train:
      return self._generator(train_file, encoder)
    return self._generator(valid_file, encoder)
예제 #12
0
def _compile_data(tmp_dir, datasets, filename):
    """Concatenate all `datasets` and save to `filename`."""
    filename = os.path.join(tmp_dir, filename)
    with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_resfile:
        with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_resfile:
            for dataset in datasets:
                url = dataset[0]
                compressed_filename = os.path.basename(url)
                compressed_filepath = os.path.join(tmp_dir,
                                                   compressed_filename)
                generator_utils.maybe_download(tmp_dir, compressed_filename,
                                               url)

                if dataset[1][0] == 'tsv':
                    _, src_column, trg_column, glob_pattern = dataset[1]
                    filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
                    if not filenames:
                        mode = "r:gz" if compressed_filepath.endswith(
                            "gz") else "r"  # *.tgz *.tar.gz
                        with tarfile.open(compressed_filepath,
                                          mode) as corpus_tar:
                            corpus_tar.extractall(tmp_dir)
                        filenames = glob.glob(
                            os.path.join(tmp_dir, glob_pattern))
                    for tsv_filename in filenames:
                        if tsv_filename.endswith(".gz"):
                            new_filename = tsv_filename.strip(".gz")
                            try:
                                generator_utils.gunzip_file(
                                    tsv_filename, new_filename)
                            except PermissionError:
                                tsvdir = os.path.dirname(tsv_filename)
                                os.chmod(
                                    tsvdir,
                                    os.stat(tsvdir).st_mode | stat.S_IWRITE)
                                generator_utils.gunzip_file(
                                    tsv_filename, new_filename)
                            tsv_filename = new_filename
                        with tf.gfile.GFile(tsv_filename,
                                            mode="r") as tsv_file:
                            for line in tsv_file:
                                if line and "\t" in line:
                                    parts = line.split("\t")
                                    source, target = parts[src_column], parts[
                                        trg_column]
                                    lang1_resfile.write(source.strip() + "\n")
                                    lang2_resfile.write(target.strip() + "\n")
                else:
                    lang1_filename, lang2_filename = dataset[1]
                    lang1_filepath = os.path.join(tmp_dir, lang1_filename)
                    lang2_filepath = os.path.join(tmp_dir, lang2_filename)
                    is_sgm = (lang1_filename.endswith("sgm")
                              and lang2_filename.endswith("sgm"))

                    if not (os.path.exists(lang1_filepath)
                            and os.path.exists(lang2_filepath)):
                        # For .tar.gz and .tgz files, we read compressed.
                        mode = "r:gz" if compressed_filepath.endswith(
                            "gz") else "r"
                        with tarfile.open(compressed_filepath,
                                          mode) as corpus_tar:
                            corpus_tar.extractall(tmp_dir)
                    if lang1_filepath.endswith(".gz"):
                        new_filepath = lang1_filepath.strip(".gz")
                        generator_utils.gunzip_file(lang1_filepath,
                                                    new_filepath)
                        lang1_filepath = new_filepath
                    if lang2_filepath.endswith(".gz"):
                        new_filepath = lang2_filepath.strip(".gz")
                        generator_utils.gunzip_file(lang2_filepath,
                                                    new_filepath)
                        lang2_filepath = new_filepath
                    with tf.gfile.GFile(lang1_filepath,
                                        mode="r") as lang1_file:
                        with tf.gfile.GFile(lang2_filepath,
                                            mode="r") as lang2_file:
                            line1, line2 = lang1_file.readline(
                            ), lang2_file.readline()
                            while line1 or line2:
                                line1res = _preprocess_sgm(line1, is_sgm)
                                line2res = _preprocess_sgm(line2, is_sgm)
                                if line1res or line2res:
                                    lang1_resfile.write(line1res.strip() +
                                                        "\n")
                                    lang2_resfile.write(line2res.strip() +
                                                        "\n")
                                line1, line2 = lang1_file.readline(
                                ), lang2_file.readline()

    return filename
예제 #13
0
def _get_cifar10(directory):
  """Download and extract CIFAR to directory unless it is there."""
  filename = os.path.basename(_CIFAR10_URL)
  path = generator_utils.maybe_download(directory, filename, _CIFAR10_URL)
  tarfile.open(path, "r:gz").extractall(directory)