예제 #1
0
파일: loaders.py 프로젝트: imito/odin
 def get_dataset(clazz, ext='', override=False):
   # ====== all path ====== #
   name = clazz.get_name(ext) + '.zip'
   path = base64.decodebytes(DataLoader.ORIGIN).decode() + name
   zip_path = clazz.get_zip_path(ext)
   out_path = clazz.get_ds_path(ext)
   # ====== check out_path ====== #
   if os.path.isfile(out_path):
     raise RuntimeError("Found a file at path: %s, we need a folder "
                        "to unzip downloaded files." % out_path)
   elif os.path.isdir(out_path):
     if override or len(os.listdir(out_path)) == 0:
       shutil.rmtree(out_path)
     else:
       return Dataset(out_path, read_only=True)
   # ====== download the file ====== #
   if os.path.exists(zip_path) and override:
     os.remove(zip_path)
   if not os.path.exists(zip_path):
     get_file(name, path, DataLoader.BASE_DIR)
   # ====== upzip dataset ====== #
   unzip_aes(in_path=zip_path, out_path=out_path)
   ds = Dataset(out_path, read_only=True)
   if os.path.exists(zip_path):
     os.remove(zip_path)
   return ds
예제 #2
0
 def get_dataset(clazz, ext='', override=False):
     # ====== all path ====== #
     name = clazz.get_name(ext) + '.zip'
     path = base64.decodebytes(DataLoader.ORIGIN).decode() + name
     zip_path = clazz.get_zip_path(ext)
     out_path = clazz.get_ds_path(ext)
     # ====== check out_path ====== #
     if os.path.isfile(out_path):
         raise RuntimeError("Found a file at path: %s, we need a folder "
                            "to unzip downloaded files." % out_path)
     elif os.path.isdir(out_path):
         if override or len(os.listdir(out_path)) == 0:
             shutil.rmtree(out_path)
         else:
             return Dataset(out_path, read_only=True)
     # ====== download the file ====== #
     if os.path.exists(zip_path) and override:
         os.remove(zip_path)
     if not os.path.exists(zip_path):
         get_file(name, path, DataLoader.BASE_DIR)
     # ====== upzip dataset ====== #
     unzip_aes(in_path=zip_path, out_path=out_path)
     ds = Dataset(out_path, read_only=True)
     if os.path.exists(zip_path):
         os.remove(zip_path)
     return ds
예제 #3
0
 def load(clazz):
     """ Return
 records: list of all path to recorded audio files
 metadata: numpy.ndarray
 """
     dat_path = get_datasetpath(name='FSDD', override=False)
     tmp_path = dat_path + '_tmp'
     zip_path = dat_path + '.zip'
     # ====== download zip dataset ====== #
     if not os.path.exists(dat_path) or \
     len(os.listdir(dat_path)) != 1501:
         if not os.path.exists(zip_path):
             get_file(fname='FSDD.zip',
                      origin=FSDD.LINK,
                      outdir=get_datasetpath())
         if os.path.exists(tmp_path):
             shutil.rmtree(tmp_path)
         unzip_folder(zip_path=zip_path, out_path=tmp_path, remove_zip=True)
         tmp_path = os.path.join(tmp_path, os.listdir(tmp_path)[0])
         # ====== get all records ====== #
         record_path = os.path.join(tmp_path, 'recordings')
         all_records = [
             os.path.join(record_path, i) for i in os.listdir(record_path)
         ]
         for f in all_records:
             name = os.path.basename(f)
             shutil.copy2(src=f, dst=os.path.join(dat_path, name))
         # ====== copy the metadata ====== #
         meta_path = os.path.join(tmp_path, 'metadata.py')
         import imp
         meta = imp.load_source('metadata', meta_path).metadata
         assert len(set(len(i)
                        for i in meta.values())) == 1, "Invalid metadata"
         rows = []
         for name, info in meta.items():
             info = sorted(info.items(), key=lambda x: x[0])
             header = ['name'] + [i[0] for i in info]
             rows.append([name] + [i[1] for i in info])
         with open(os.path.join(dat_path, 'metadata.csv'), 'w') as f:
             for r in [header] + rows:
                 f.write(','.join(r) + '\n')
     # ====== clean ====== #
     if os.path.exists(tmp_path):
         shutil.rmtree(tmp_path)
     # ====== return dataset ====== #
     all_files = [
         os.path.join(dat_path, i) for i in os.listdir(dat_path)
         if '.wav' in i
     ]
     meta = np.genfromtxt(os.path.join(dat_path, 'metadata.csv'),
                          dtype=str,
                          delimiter=',')
     return all_files, meta
예제 #4
0
def get_gene_id2name():
  r""" Return the mapping from gene identifier to gene symbol (i.e. name)
  for PBMC 8k data
  """
  from odin.utils import get_file
  from sisua.data.path import DOWNLOAD_DIR
  url = base64.decodebytes(
      b'aHR0cHM6Ly9haS1kYXRhc2V0cy5zMy5hbWF6b25hd3MuY29tL2dlbmVfaWQybmFtZS5wa2w=\n'
  )
  url = str(url, 'utf-8')
  get_file('gene_id2name.pkl', url, DOWNLOAD_DIR)
  with open(os.path.join(DOWNLOAD_DIR, 'gene_id2name.pkl'), 'rb') as f:
    return pickle.load(f)
예제 #5
0
def load_cifar100(path='https://s3.amazonaws.com/ai-datasets/cifar100.zip'):
    """
    path : str
        local path or url to hdf5 datafile
    """
    datapath = get_file('cifar100', path)
    return _load_data_from_path(datapath)
예제 #6
0
파일: loaders.py 프로젝트: imito/odin
def load_lre_list():
  """ The header include following column:
  * name: LDC2017E22/data/ara-acm/ar-20031215-034005_0-a.sph
  * lre: {'train17', 'eval15', 'train15', 'dev17', 'eval17'}
  * language: {'ara-arb', 'ara-ary', 'ara-apc', 'ara-arz', 'ara-acm',
               'eng-gbr', 'eng-usg', 'eng-sas',
               'fre-hat', 'fre-waf'
               'zho-wuu', 'zho-cdo', 'zho-cmn', 'zho-yue', 'zho-nan',
               'spa-lac', 'spa-eur', 'spa-car',
               'qsl-pol', 'qsl-rus',
               'por-brz'}
  * corpus: {'pcm', 'alaw', 'babel', 'ulaw', 'vast', 'mls14'}
  * duration: {'3', '30', '5', '15', '10', '20', '1000', '25'}

  Note
  ----
  Suggested namming scheme:
    `lre/lang/corpus/dur/base_name`
  """
  link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL2xyZV9saXN0LnR4dA==\n'
  link = str(base64.decodebytes(link), 'utf-8')
  path = get_file(fname=os.path.basename(link),
                  origin=link,
                  outdir=get_datasetpath(root='~'))
  return np.genfromtxt(fname=path, dtype=str, delimiter=' ',
                       skip_header=1)
예제 #7
0
def load_lre_list():
    """ The header include following column:
  * name: LDC2017E22/data/ara-acm/ar-20031215-034005_0-a.sph
  * lre: {'train17', 'eval15', 'train15', 'dev17', 'eval17'}
  * language: {'ara-arb', 'ara-ary', 'ara-apc', 'ara-arz', 'ara-acm',
               'eng-gbr', 'eng-usg', 'eng-sas',
               'fre-hat', 'fre-waf'
               'zho-wuu', 'zho-cdo', 'zho-cmn', 'zho-yue', 'zho-nan',
               'spa-lac', 'spa-eur', 'spa-car',
               'qsl-pol', 'qsl-rus',
               'por-brz'}
  * corpus: {'pcm', 'alaw', 'babel', 'ulaw', 'vast', 'mls14'}
  * duration: {'3', '30', '5', '15', '10', '20', '1000', '25'}

  Note
  ----
  Suggested namming scheme:
    `lre/lang/corpus/dur/base_name`
  """
    link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL2xyZV9saXN0LnR4dA==\n'
    link = str(base64.decodebytes(link), 'utf-8')
    path = get_file(fname=os.path.basename(link),
                    origin=link,
                    outdir=get_datasetpath(root='~'))
    return np.genfromtxt(fname=path, dtype=str, delimiter=' ', skip_header=1)
예제 #8
0
def load_mnist(path='https://s3.amazonaws.com/ai-datasets/MNIST.zip'):
    """
    path : str
        local path or url to hdf5 datafile
    """
    datapath = get_file('MNIST', path)
    return _load_data_from_path(datapath)
예제 #9
0
파일: loaders.py 프로젝트: imito/odin
 def load(clazz):
   """ Return
   records: list of all path to recorded audio files
   metadata: numpy.ndarray
   """
   dat_path = get_datasetpath(name='FSDD', override=False)
   tmp_path = dat_path + '_tmp'
   zip_path = dat_path + '.zip'
   # ====== download zip dataset ====== #
   if not os.path.exists(dat_path) or \
   len(os.listdir(dat_path)) != 1501:
     if not os.path.exists(zip_path):
       get_file(fname='FSDD.zip', origin=FSDD.LINK, outdir=get_datasetpath())
     if os.path.exists(tmp_path):
       shutil.rmtree(tmp_path)
     unzip_folder(zip_path=zip_path, out_path=tmp_path, remove_zip=True)
     tmp_path = os.path.join(tmp_path, os.listdir(tmp_path)[0])
     # ====== get all records ====== #
     record_path = os.path.join(tmp_path, 'recordings')
     all_records = [os.path.join(record_path, i)
                    for i in os.listdir(record_path)]
     for f in all_records:
       name = os.path.basename(f)
       shutil.copy2(src=f, dst=os.path.join(dat_path, name))
     # ====== copy the metadata ====== #
     meta_path = os.path.join(tmp_path, 'metadata.py')
     import imp
     meta = imp.load_source('metadata', meta_path).metadata
     assert len(set(len(i) for i in meta.values())) == 1, "Invalid metadata"
     rows = []
     for name, info in meta.items():
       info = sorted(info.items(), key=lambda x: x[0])
       header = ['name'] + [i[0] for i in info]
       rows.append([name] + [i[1] for i in info])
     with open(os.path.join(dat_path, 'metadata.csv'), 'w') as f:
       for r in [header] + rows:
         f.write(','.join(r) + '\n')
   # ====== clean ====== #
   if os.path.exists(tmp_path):
     shutil.rmtree(tmp_path)
   # ====== return dataset ====== #
   all_files = [os.path.join(dat_path, i)
                for i in os.listdir(dat_path)
                if '.wav' in i]
   meta = np.genfromtxt(os.path.join(dat_path, 'metadata.csv'),
                        dtype=str, delimiter=',')
   return all_files, meta
예제 #10
0
def load_mspec_test():
    """
    path : str
        local path or url to hdf5 datafile
    """
    path = 'https://s3.amazonaws.com/ai-datasets/mspec_test.zip'
    datapath = get_file('mspec_test', path)
    return _load_data_from_path(datapath)
예제 #11
0
def load_swb1_aligment(nb_senones=2304):
    support_nb_senones = (2304, )
    if nb_senones not in support_nb_senones:
        raise ValueError('We only support following number of senones: %s' %
                         support_nb_senones)
    fname = "swb1_%d" % nb_senones
    url = "https://s3.amazonaws.com/ai-datasets/" + fname
    alignment = get_file(fname, url)
    return MmapDict(alignment, read_only=True)
예제 #12
0
파일: loaders.py 프로젝트: imito/odin
def load_lre_sad():
  """
  key: 'LDC2017E23/data/eval/lre17_lqoyrygc.sph'
  value: [(1.99, 3.38), (8.78, 16.41)] (in second)
  """
  link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL2xyZV9zYWQ=\n'
  link = str(base64.decodebytes(link), 'utf-8')
  path = get_file(fname=os.path.basename(link),
                  origin=link,
                  outdir=get_datasetpath(root='~'))
  return MmapDict(path=path, read_only=True)
예제 #13
0
def load_lre_sad():
    """
  key: 'LDC2017E23/data/eval/lre17_lqoyrygc.sph'
  value: [(1.99, 3.38), (8.78, 16.41)] (in second)
  """
    link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL2xyZV9zYWQ=\n'
    link = str(base64.decodebytes(link), 'utf-8')
    path = get_file(fname=os.path.basename(link),
                    origin=link,
                    outdir=get_datasetpath(root='~'))
    return MmapDict(path=path, read_only=True)
예제 #14
0
파일: loaders.py 프로젝트: imito/odin
def load_voxceleb_list():
  link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL3ZveGNlbGViX2xpc3RzLnppcA==\n'
  link = str(base64.decodebytes(link), 'utf-8')
  ds_path = get_datasetpath(name='voxceleb_lists', root='~',
                            is_folder=False, override=False)
  if not os.path.exists(ds_path):
    path = get_file(fname=os.path.basename(link),
                    origin=link,
                    outdir=get_datasetpath(root='~'))
    unzip_folder(zip_path=path, out_path=os.path.dirname(path),
                 remove_zip=True)
  return Dataset(ds_path, read_only=True)
예제 #15
0
def load_glove(ndim=100):
    """ Automaticall load a MmapDict which contains the mapping
        (word -> [vector])
    where vector is the embedding vector with given `ndim`.
    """
    ndim = int(ndim)
    if ndim not in (50, 100, 200, 300):
        raise ValueError('Only support 50, 100, 200, 300 dimensions.')
    fname = 'glove.6B.%dd' % ndim
    link = 'https://s3.amazonaws.com/ai-datasets/%s' % fname
    embedding = get_file(fname, link)
    return MmapDict(embedding, read_only=True)
예제 #16
0
def load_glove(ndim=100):
    """ Automaticall load a MmapDict which contains the mapping
      (word -> [vector])
  where vector is the embedding vector with given `ndim`.
  """
    ndim = int(ndim)
    if ndim not in (50, 100, 200, 300):
        raise ValueError('Only support 50, 100, 200, 300 dimensions.')
    link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL2dsb3ZlLjZCLiVkZA==\n'
    link = str(base64.decodebytes(link) % ndim, 'utf-8')
    fname = os.path.basename(link)
    embedding = get_file(fname, link, outdir=get_datasetpath(root='~'))
    return MmapDict(embedding, read_only=True)
예제 #17
0
파일: loaders.py 프로젝트: imito/odin
def load_glove(ndim=100):
  """ Automaticall load a MmapDict which contains the mapping
      (word -> [vector])
  where vector is the embedding vector with given `ndim`.
  """
  ndim = int(ndim)
  if ndim not in (50, 100, 200, 300):
    raise ValueError('Only support 50, 100, 200, 300 dimensions.')
  link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL2dsb3ZlLjZCLiVkZA==\n'
  link = str(base64.decodebytes(link) % ndim, 'utf-8')
  fname = os.path.basename(link)
  embedding = get_file(fname, link, outdir=get_datasetpath(root='~'))
  return MmapDict(embedding, read_only=True)
예제 #18
0
파일: loaders.py 프로젝트: imito/odin
def load_sre_list():
  link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL1NSRV9GSUxFUy56aXA=\n'
  link = str(base64.decodebytes(link), 'utf-8')
  ds_path = get_datasetpath(name='SRE_FILES', root='~',
                            is_folder=False, override=False)
  if os.path.exists(ds_path) and len(os.listdir(ds_path)) != 24:
    shutil.rmtree(ds_path)
  if not os.path.exists(ds_path):
    path = get_file(fname=os.path.basename(link),
                    origin=link,
                    outdir=get_datasetpath(root='~'))
    unzip_folder(zip_path=path, out_path=ds_path, remove_zip=True)
  return Dataset(ds_path, read_only=True)
예제 #19
0
  def load_fsdd(self):
    r""" Free Spoken Digit Dataset
      A simple audio/speech dataset consisting of recordings of spoken digits
      in wav files at 8kHz. The recordings are trimmed so that they have near
      minimal silence at the beginnings and ends.

    Sample rate: 8,000

    Reference:
      Link: https://github.com/Jakobovski/free-spoken-digit-dataset
    """
    LINK = "https://github.com/Jakobovski/free-spoken-digit-dataset/archive/v1.0.8.zip"
    MD5 = "471b0df71a914629e2993300c1ccf33f"
    save_path = os.path.join(self.save_path, 'FSDD')
    if not os.path.exists(save_path):
      os.mkdir(save_path)
    # ====== download zip dataset ====== #
    if md5_checksum(''.join(sorted(os.listdir(save_path)))) != MD5:
      zip_path = get_file(fname='FSDD.zip',
                          origin=LINK,
                          outdir=save_path,
                          verbose=True)
      try:
        with ZipFile(zip_path, mode='r', compression=ZIP_DEFLATED) as zf:
          wav_files = [name for name in zf.namelist() if '.wav' == name[-4:]]
          for name in wav_files:
            data = zf.read(name)
            name = os.path.basename(name)
            with open(os.path.join(save_path, name), 'wb') as f:
              f.write(data)
      finally:
        os.remove(zip_path)
    # ====== get all records ====== #
    all_name = os.listdir(save_path)
    all_files = sorted([os.path.join(save_path, name) for name in all_name])
    all_speakers = list(set(i.split('_')[1] for i in all_name))
    # ====== splitting train, test ====== #
    rand = np.random.RandomState(seed=self.seed)
    rand.shuffle(all_speakers)
    train_spk = all_speakers[:-1]
    test_spk = all_speakers[-1:]
    train_files = [
        i for i in all_files if os.path.basename(i).split('_')[1] in train_spk
    ]
    test_files = [
        i for i in all_files if os.path.basename(i).split('_')[1] in test_spk
    ]
    rand.shuffle(train_files)
    rand.shuffle(test_files)
    return train_files, test_files
예제 #20
0
  def load_command(self):
    r""" Warden P. Speech Commands: A public dataset for single-word speech
      recognition, 2017. Available from
      http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz

    Sample rate: 16,000

    Example:
      ds = AudioFeatureLoader(sample_rate=16000,
                              frame_length=int(0.025 * 16000),
                              frame_step=int(0.005 * 16000))
      train, valid, test = ds.load_command()
      train = ds.create_dataset(train, max_length=40, return_path=True)
    """
    LINK = "http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz"
    MD5 = "a08eb256cea8cbb427c6c0035fffd881"
    save_path = os.path.join(self.save_path, 'speech_commands')
    if not os.path.exists(save_path):
      os.mkdir(save_path)
    audio_path = os.path.join(save_path, 'audio')
    audio_files = sorted(
        get_all_files(audio_path, filter_func=lambda x: '.wav' == x[-4:]))
    md5 = md5_checksum(''.join([os.path.basename(i) for i in audio_files]))
    # ====== Download and extract the data ====== #
    if md5 != MD5:
      zip_path = get_file(fname='speech_commands_v0.01.tar.gz',
                          origin=LINK,
                          outdir=save_path,
                          verbose=True)
      with tarfile.open(zip_path, 'r:gz') as tar:
        tar.extractall(audio_path)
    # ====== processing the audio file list ====== #
    audio_files = [i for i in audio_files if '_background_noise_' not in i]
    with open(os.path.join(audio_path, 'validation_list.txt'), 'r') as f:
      valid_list = {i.strip(): 1 for i in f}
    with open(os.path.join(audio_path, 'testing_list.txt'), 'r') as f:
      test_list = {i.strip(): 1 for i in f}
    train_files = []
    valid_files = []
    test_files = []
    for f in audio_files:
      name = '/'.join(f.split('/')[-2:])
      if name in valid_list:
        valid_files.append(f)
      elif name in test_list:
        test_files.append(f)
      else:
        train_files.append(f)
    return train_files, valid_files, test_files
예제 #21
0
 def load_parameters(clazz):
   # ====== all path ====== #
   name = clazz.__name__ + '.zip'
   path = os.path.join(base64.decodebytes(Model.ORIGIN).decode(), name)
   param_path = get_datasetpath(name=clazz.__name__, override=False)
   zip_path = os.path.join(Model.BASE_DIR, name)
   # ====== get params files ====== #
   if not os.path.exists(param_path) or \
   len(os.listdir(param_path)) == 0:
     get_file(name, origin=path, outdir=Model.BASE_DIR)
     zf = ZipFile(zip_path, mode='r', compression=ZIP_DEFLATED)
     zf.extractall(path=Model.BASE_DIR)
     zf.close()
     # check if proper unzipped
     if not os.path.exists(param_path) or \
     len(os.listdir(param_path)) == 0:
       raise RuntimeError("Zip file at path:%s is not proper unzipped, "
           "cannot find downloaded parameters at path: %s" %
           (zip_path, param_path))
     else:
       os.remove(zip_path)
   # ====== create and return the params dataset ====== #
   ds = Dataset(param_path, read_only=True)
   return ds
예제 #22
0
def load_voxceleb_list():
    link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL3ZveGNlbGViX2xpc3RzLnppcA==\n'
    link = str(base64.decodebytes(link), 'utf-8')
    ds_path = get_datasetpath(name='voxceleb_lists',
                              root='~',
                              is_folder=False,
                              override=False)
    if not os.path.exists(ds_path):
        path = get_file(fname=os.path.basename(link),
                        origin=link,
                        outdir=get_datasetpath(root='~'))
        unzip_folder(zip_path=path,
                     out_path=os.path.dirname(path),
                     remove_zip=True)
    return Dataset(ds_path, read_only=True)
예제 #23
0
def load_sre_list():
    link = b'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tL2FpLWRhdGFzZXRzL1NSRV9GSUxFUy56aXA=\n'
    link = str(base64.decodebytes(link), 'utf-8')
    ds_path = get_datasetpath(name='SRE_FILES',
                              root='~',
                              is_folder=False,
                              override=False)
    if os.path.exists(ds_path) and len(os.listdir(ds_path)) != 24:
        shutil.rmtree(ds_path)
    if not os.path.exists(ds_path):
        path = get_file(fname=os.path.basename(link),
                        origin=link,
                        outdir=get_datasetpath(root='~'))
        unzip_folder(zip_path=path, out_path=ds_path, remove_zip=True)
    return Dataset(ds_path, read_only=True)
예제 #24
0
def load_commands_wav():
    path = 'https://s3.amazonaws.com/ai-datasets/commands_wav.zip'
    datapath = get_file('commands_wav.zip', path)
    try:
        outpath = datapath.replace('.zip', '')
        if os.path.exists(outpath):
            shutil.rmtree(outpath)
        zf = ZipFile(datapath, mode='r', compression=ZIP_DEFLATED)
        zf.extractall(path=outpath + '/../')
        zf.close()
    except:
        # remove downloaded zip files
        os.remove(datapath)
        import traceback
        traceback.print_exc()
    return outpath
예제 #25
0
def load_imdb(nb_words=None, maxlen=None):
    """ The preprocessed imdb dataset with following configuraiton:
     - nb_words=88587
     - length=2494
     - NO skip for any top popular word
     - Word_IDX=1 for beginning of sequences
     - Word_IDX=2 for ignored word (OOV)
     - Other word start from 3
     - padding='pre' with value=0
    """
    path = 'https://s3.amazonaws.com/ai-datasets/imdb.zip'
    datapath = get_file('imdb', path)
    ds = _load_data_from_path(datapath)
    X_train, y_train, X_test, y_test = \
        ds['X_train'], ds['y_train'], ds['X_test'], ds['y_test']
    # create new data with new configuration
    if maxlen is not None or nb_words is not None:
        nb_words = max(min(88587, nb_words), 3)
        path = ds.path + '_tmp'
        if os.path.exists(path):
            shutil.rmtree(path)
        ds = Dataset(path)
        # preprocess data
        if maxlen is not None:
            # for X_train
            _X, _y = [], []
            for i, j in zip(X_train[:], y_train[:]):
                if i[-maxlen] == 0 or i[-maxlen] == 1:
                    _X.append([k if k < nb_words else 2 for k in i[-maxlen:]])
                    _y.append(j)
            X_train = np.array(_X, dtype=X_train.dtype)
            y_train = np.array(_y, dtype=y_train.dtype)
            # for X_test
            _X, _y = [], []
            for i, j in zip(X_test[:], y_test[:]):
                if i[-maxlen] == 0 or i[-maxlen] == 1:
                    _X.append([k if k < nb_words else 2 for k in i[-maxlen:]])
                    _y.append(j)
            X_test = np.array(_X, dtype=X_test.dtype)
            y_test = np.array(_y, dtype=y_test.dtype)
        ds['X_train'] = X_train
        ds['y_train'] = y_train
        ds['X_test'] = X_test
        ds['y_test'] = y_test
        ds.flush()
    return ds
예제 #26
0
def load_20newsgroup():
    link = 'https://s3.amazonaws.com/ai-datasets/news20'
    dataset = get_file('news20', link)
    return MmapDict(dataset, read_only=True)
예제 #27
0
def load_iris():
    path = "https://s3.amazonaws.com/ai-datasets/iris.zip"
    datapath = get_file('iris', path)
    return _load_data_from_path(datapath)
예제 #28
0
def load_digit_audio():
    path = 'https://s3.amazonaws.com/ai-datasets/digit.zip'
    name = 'digit'
    datapath = get_file(name, path)
    return _load_data_from_path(datapath)