def load_data():
    train_filename = base.maybe_download('notMNIST_large.tar.gz', dataset_dir,
                                         source_url)
    test_filename = base.maybe_download('notMNIST_small.tar.gz', dataset_dir,
                                        source_url)
    train_folders = maybe_extract(train_filename)
    test_folders = maybe_extract(test_filename)
예제 #2
0
파일: mnist.py 프로젝트: scottdu/BigDL
def read_data_sets(train_dir, data_type="train"):
    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

    if data_type == "train":
        local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                         SOURCE_URL + TRAIN_IMAGES)
        with open(local_file, 'rb') as f:
            train_images = extract_images(f)

        local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                         SOURCE_URL + TRAIN_LABELS)
        with open(local_file, 'rb') as f:
            train_labels = extract_labels(f)
        return train_images, train_labels

    else:
        local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                         SOURCE_URL + TEST_IMAGES)
        with open(local_file, 'rb') as f:
            test_images = extract_images(f)

        local_file = base.maybe_download(TEST_LABELS, train_dir,
                                         SOURCE_URL + TEST_LABELS)
        with open(local_file, 'rb') as f:
            test_labels = extract_labels(f)
        return test_images, test_labels
예제 #3
0
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32):
  if fake_data:

    def fake():
      return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
  VALIDATION_SIZE = 5000

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   SOURCE_URL + TRAIN_IMAGES)
  train_images = extract_images(local_file)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   SOURCE_URL + TRAIN_LABELS)
  train_labels = extract_labels(local_file, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   SOURCE_URL + TEST_IMAGES)
  test_images = extract_images(local_file)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   SOURCE_URL + TEST_LABELS)
  test_labels = extract_labels(local_file, one_hot=one_hot)

  validation_images = train_images[:VALIDATION_SIZE]
  validation_labels = train_labels[:VALIDATION_SIZE]
  train_images = train_images[VALIDATION_SIZE:]
  train_labels = train_labels[VALIDATION_SIZE:]

  train = DataSet(train_images, train_labels, dtype=dtype)
  validation = DataSet(validation_images, validation_labels, dtype=dtype)
  test = DataSet(test_images, test_labels, dtype=dtype)

  return base.Datasets(train=train, validation=validation, test=test)
예제 #4
0
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32):
    if fake_data:

        def fake():
            return DataSet([], [],
                           fake_data=True,
                           one_hot=one_hot,
                           dtype=dtype)

        train = fake()
        validation = fake()
        test = fake()
        return base.Datasets(train=train, validation=validation, test=test)

    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
    VALIDATION_SIZE = 5000

    local_file = base.maybe_download(TRAIN_IMAGES, train_dir)
    train_images = extract_images(local_file)

    local_file = base.maybe_download(TRAIN_LABELS, train_dir)
    train_labels = extract_labels(local_file, one_hot=one_hot)

    local_file = base.maybe_download(TEST_IMAGES, train_dir)
    test_images = extract_images(local_file)

    local_file = base.maybe_download(TEST_LABELS, train_dir)
    test_labels = extract_labels(local_file, one_hot=one_hot)

    validation_images = train_images[:VALIDATION_SIZE]
    validation_labels = train_labels[:VALIDATION_SIZE]
    train_images = train_images[VALIDATION_SIZE:]
    train_labels = train_labels[VALIDATION_SIZE:]

    train = DataSet(train_images, train_labels, dtype=dtype)
    validation = DataSet(validation_images, validation_labels, dtype=dtype)
    test = DataSet(test_images, test_labels, dtype=dtype)

    return base.Datasets(train=train, validation=validation, test=test)
예제 #5
0
def download_news20(dest_dir):
    file_name = "news20.tar.gz"
    file_abs_path = base.maybe_download(file_name, dest_dir, NEWS20_URL)
    tar = tarfile.open(file_abs_path, "r:gz")
    extracted_to = os.path.join(dest_dir, "20_newsgroup")
    print("Extracting %s to %s" % (file_abs_path, extracted_to))
    tar.extractall(dest_dir)
    tar.close()
    return extracted_to
예제 #6
0
def download_glove_w2v(dest_dir):
    file_name = "glove.6B.zip"
    file_abs_path = base.maybe_download(file_name, dest_dir, GLOVE_URL)
    import zipfile
    zip_ref = zipfile.ZipFile(file_abs_path, 'r')
    extracted_to = os.path.join(dest_dir, "glove.6B")
    print("Extracting %s to %s" % (file_abs_path, extracted_to))
    zip_ref.extractall(extracted_to)
    zip_ref.close()
    return extracted_to
예제 #7
0
def download_news20(dest_dir):
    file_name = "20news-19997.tar.gz"
    file_abs_path = base.maybe_download(file_name, dest_dir, NEWS20_URL)
    tar = tarfile.open(file_abs_path, "r:gz")
    extracted_to = os.path.join(dest_dir, "20_newsgroups")
    if not os.path.exists(extracted_to):
        print("Extracting %s to %s" % (file_abs_path, extracted_to))
        tar.extractall(dest_dir)
        tar.close()
    return extracted_to
예제 #8
0
def download_glove_w2v(dest_dir):
    file_name = "glove.6B.zip"
    file_abs_path = base.maybe_download(file_name, dest_dir, GLOVE_URL)
    import zipfile
    zip_ref = zipfile.ZipFile(file_abs_path, 'r')
    extracted_to = os.path.join(dest_dir, "glove.6B")
    if not os.path.exists(extracted_to):
        print("Extracting %s to %s" % (file_abs_path, extracted_to))
        zip_ref.extractall(extracted_to)
        zip_ref.close()
    return extracted_to
예제 #9
0
def read_data_sets(train_dir, data_type="train"):
    """
    Parse or download mnist data if train_dir is empty.
    :param train_dir: The directory storing the mnist data
    :param data_type: Reading training set or testing set.
           It can be either "train" or "test"
    :return: (ndarray, ndarray) representing (features, labels)
    """
    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

    if data_type == "train":
        local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                         SOURCE_URL + TRAIN_IMAGES)
        with open(local_file, 'rb') as f:
            train_images = extract_images(f)

        local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                         SOURCE_URL + TRAIN_LABELS)
        with open(local_file, 'rb') as f:
            train_labels = extract_labels(f)
        return train_images, train_labels

    else:
        local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                         SOURCE_URL + TEST_IMAGES)
        with open(local_file, 'rb') as f:
            test_images = extract_images(f)

        local_file = base.maybe_download(TEST_LABELS, train_dir,
                                         SOURCE_URL + TEST_LABELS)
        with open(local_file, 'rb') as f:
            test_labels = extract_labels(f)
        return test_images, test_labels
예제 #10
0
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=tf.float32,
                   reshape=True,
                   validation_size=5000,
                   seed=None,
                   source_url=DEFAULT_SOURCE_URL):
  if fake_data:

    def fake():
      return DataSet(
          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  if not source_url:  # empty string check
    source_url = DEFAULT_SOURCE_URL

  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   source_url + TRAIN_IMAGES)
  with tf.gfile.Open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   source_url + TRAIN_LABELS)
  with tf.gfile.Open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   source_url + TEST_IMAGES)
  with tf.gfile.Open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   source_url + TEST_LABELS)
  with tf.gfile.Open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError('Validation size should be between 0 and {}. Received: {}.'
                     .format(len(train_images), validation_size))

  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = DataSet(train_images, train_labels, **options)
  validation = DataSet(validation_images, validation_labels, **options)
  test = DataSet(test_images, test_labels, **options)

  return base.Datasets(train=train, validation=validation, test=test)
예제 #11
0
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):
    if fake_data:

        def fake():
            return DataSet([], [],
                           fake_data=True,
                           one_hot=one_hot,
                           dtype=dtype)

        train = fake()
        validation = fake()
        test = fake()
        return base.Datasets(train=train, validation=validation, test=test)

    TRAIN_IMAGES = 'stk_train_img.bin.gz'
    TRAIN_LABELS = 'stk_train_lbl.bin.gz'
    TEST_IMAGES = 'stk_test_img.bin.gz'
    TEST_LABELS = 'stk_test_lbl.bin.gz'

    local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                     SOURCE_URL + TRAIN_IMAGES)
    with open(local_file, 'rb') as f:
        train_images = extract_images(f)

    local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                     SOURCE_URL + TRAIN_LABELS)
    with open(local_file, 'rb') as f:
        train_labels = extract_labels(f, one_hot=one_hot)

    local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                     SOURCE_URL + TEST_IMAGES)
    with open(local_file, 'rb') as f:
        test_images = extract_images(f)

    local_file = base.maybe_download(TEST_LABELS, train_dir,
                                     SOURCE_URL + TEST_LABELS)
    with open(local_file, 'rb') as f:
        test_labels = extract_labels(f, one_hot=one_hot)

    if not 0 <= validation_size <= len(train_images):
        raise ValueError(
            'Validation size should be between 0 and {}. Received: {}.'.format(
                len(train_images), validation_size))

    validation_images = train_images[:validation_size]
    validation_labels = train_labels[:validation_size]
    train_images = train_images[validation_size:]
    train_labels = train_labels[validation_size:]

    train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
    validation = DataSet(validation_images,
                         validation_labels,
                         dtype=dtype,
                         reshape=reshape)
    test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
    return base.Datasets(train=train, validation=validation, test=test)
예제 #12
0
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):
  if fake_data:
    def fake():
      return DataSet([], [], [], fake_data=True, one_hot=one_hot, dtype=dtype)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   SOURCE_URL + TRAIN_IMAGES)
  with open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   SOURCE_URL + TRAIN_LABELS)
  with open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  TRAIN_FILEPATH='/Users/billvarcho/Documents/Research/MNIST/train/DATA/'
  VALIDATION_FILEPATH='/Users/billvarcho/Documents/Research/MNIST/validation/DATA/'
  

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   SOURCE_URL + TEST_IMAGES)
  with open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   SOURCE_URL + TEST_LABELS)
  with open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  TEST_FILEPATH='/Users/billvarcho/Documents/Research/MNIST/test/DATA/'
  test_barcodes = extract_barcodes(TEST_FILEPATH)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError(
        'Validation size should be between 0 and {}. Received: {}.'
        .format(len(train_images), validation_size))


  # print(train_images.shape)
  # print(train_barcodes.shape)
  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  validation_barcodes = extract_barcodes(VALIDATION_FILEPATH)
  # TODO get validation barcodes
  # validation_barcodes = train_barcodes[:validation_size]

  # print(train_images.shape)
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]
  train_barcodes = extract_barcodes(TRAIN_FILEPATH)
  # train_barcodes = train_barcodes


  train = DataSet(train_images, train_labels, train_barcodes, dtype=dtype, reshape=reshape)


  validation = DataSet(validation_images,
                       validation_labels,
                       validation_barcodes,
                       dtype=dtype,
                       reshape=reshape)
  test = DataSet(test_images, test_labels, test_barcodes, dtype=dtype, reshape=reshape)

  # print(train_images.num_examples)
  # print(validation.num_examples)
  # print(test.num_examples)

  return base.Datasets(train=train, validation=validation, test=test)