Ejemplo n.º 1
0
def convert_cifar(dataset_dir, validation_ratio, \
                  training_pattern, test_pattern, label_pattern, label_key, labelname_key):
    print('Read cifar dataset and convert it to .tfrecord format')
    # check paths
    if not tf.gfile.Exists(dataset_dir):
        raise PathError('dataset dir [%s] does not exists' % dataset_dir)
    output_dir = '%s/tfrecord' % dataset_dir
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MakeDirs(output_dir)

    # write the labels file
    class_names = _read_labelfile(dataset_dir, label_pattern, labelname_key)
    label_filename = os.path.join(output_dir, 'labels.txt')
    _write_labels(class_names, label_filename)
    num_classes = len(class_names)

    # process the training and validation data
    images, labels = _collect_data(dataset_dir, training_pattern, label_key)
    training_images, training_labels, validation_images, validataion_labels \
        = _shuffle_and_split_data(images, labels, validation_ratio)

    training_record_name = os.path.join(output_dir, 'train.tfrecord')
    _write_tfrecord(training_images, training_labels, training_record_name)

    validation_record_name = os.path.join(output_dir, 'validation.tfrecord')
    _write_tfrecord(validation_images, validataion_labels,
                    validation_record_name)

    # process the test data
    test_images, test_labels = _collect_data(dataset_dir, test_pattern,
                                             label_key)
    test_record_name = os.path.join(output_dir, 'test.tfrecord')
    _write_tfrecord(test_images, test_labels, test_record_name)

    print('Finished converting the cifar%d dataset!' % num_classes)
Ejemplo n.º 2
0
def _read_labelfile(dataset_dir, label_pattern, labelname_key):
    label_files = fnmatch.filter(os.listdir(dataset_dir), label_pattern)
    if len(label_files) == 0:
        raise PathError('no label file in %s' % dataset_dir)
    label_filename = os.path.join(dataset_dir, label_files[0])
    with tf.gfile.Open(label_filename, 'rb') as f:
        data = pickle.load(f, encoding='bytes')
        labels_bytes = data[labelname_key]
        labels_str = [label.decode('utf-8') for label in labels_bytes]
        return dict(zip(list(range(len(labels_str))), labels_str))
Ejemplo n.º 3
0
def convert_data():
  if FLAGS.dataset == 'cifar10':
    dataset_dir = '/home/cideep/Work/tensorflow/datasets/cifar-10'
    convert_cifar10(dataset_dir, FLAGS.validation_ratio)
  elif FLAGS.dataset == 'cifar100':
    dataset_dir = '/home/cideep/Work/tensorflow/datasets/cifar-100'
    convert_cifar100(dataset_dir, FLAGS.validation_ratio)
  # elif FLAGS.dataset == 'voc2012':
  #   convert_voc2012(FLAGS.input_dir, FLAGS.validation_ratio)
  else:
    raise PathError('%s: Not supported dataset' % FLAGS.input_dir)
Ejemplo n.º 4
0
def _collect_data(dataset_dir, file_pattern, label_key):
    src_files = fnmatch.filter(os.listdir(dataset_dir), file_pattern)
    if len(src_files) == 0:
        raise PathError('no data file such as %s in %s' %
                        (file_pattern, dataset_dir))

    total_images = np.array([])
    total_labels = np.array([])

    for filename in src_files:
        filename = os.path.join(dataset_dir, filename)
        print('   collect_data/srcfile name:', filename)
        with tf.gfile.Open(filename, 'rb') as f:
            data = pickle.load(f, encoding='bytes')
        images = data[b'data']
        labels = np.asarray(data[label_key])
        print('   collect_data/image_shape:', images.shape, 'label_shape:',
              labels.shape)

        if total_images.size == 0:
            total_images = images
        else:
            total_images = np.concatenate((total_images, images), axis=0)

        if total_labels.size == 0:
            total_labels = labels
        else:
            total_labels = np.concatenate((total_labels, labels), axis=0)

    if total_images.shape[
            0] == 0 or total_images.shape[0] != total_labels.shape[0]:
        raise PathError('invalid image or label size: %d != %d' \
                        % (total_images.shape[0], total_labels.shape[0]))
    print('collect_data/total_image_shape:', total_images.shape,
          'total_label_shape:', total_labels.shape)
    return total_images, total_labels
def main(_):
  if not tf.gfile.Exists(FLAGS.dataset_path):
    raise PathError('dataset dir does not exists')

  image_dir = os.path.join(FLAGS.dataset_path, FLAGS.src_image_dir)
  image_list = list_images(image_dir)
  print('%d images are listed' % len(image_list))

  # annotations = DataFrame(category, filename, xmin, xmax, ymin, ymax)
  src_annot_path = os.path.join(FLAGS.dataset_path, FLAGS.src_annot_dir)
  annotations = read_annotations(src_annot_path, image_list)

  dst_annot_path = os.path.join(FLAGS.dataset_path, FLAGS.dst_annot_dir)
  if not tf.gfile.Exists(dst_annot_path):
    tf.gfile.MakeDirs(dst_annot_path)
  dst_annot_name = os.path.join(dst_annot_path, '%s.csv' % FLAGS.dst_annot_dir)
  annotations.to_csv(dst_annot_name, sep='\t')