def convert_cifar(dataset_dir, validation_ratio, \ training_pattern, test_pattern, label_pattern, label_key, labelname_key): print('Read cifar dataset and convert it to .tfrecord format') # check paths if not tf.gfile.Exists(dataset_dir): raise PathError('dataset dir [%s] does not exists' % dataset_dir) output_dir = '%s/tfrecord' % dataset_dir if not tf.gfile.Exists(output_dir): tf.gfile.MakeDirs(output_dir) # write the labels file class_names = _read_labelfile(dataset_dir, label_pattern, labelname_key) label_filename = os.path.join(output_dir, 'labels.txt') _write_labels(class_names, label_filename) num_classes = len(class_names) # process the training and validation data images, labels = _collect_data(dataset_dir, training_pattern, label_key) training_images, training_labels, validation_images, validataion_labels \ = _shuffle_and_split_data(images, labels, validation_ratio) training_record_name = os.path.join(output_dir, 'train.tfrecord') _write_tfrecord(training_images, training_labels, training_record_name) validation_record_name = os.path.join(output_dir, 'validation.tfrecord') _write_tfrecord(validation_images, validataion_labels, validation_record_name) # process the test data test_images, test_labels = _collect_data(dataset_dir, test_pattern, label_key) test_record_name = os.path.join(output_dir, 'test.tfrecord') _write_tfrecord(test_images, test_labels, test_record_name) print('Finished converting the cifar%d dataset!' % num_classes)
def _read_labelfile(dataset_dir, label_pattern, labelname_key): label_files = fnmatch.filter(os.listdir(dataset_dir), label_pattern) if len(label_files) == 0: raise PathError('no label file in %s' % dataset_dir) label_filename = os.path.join(dataset_dir, label_files[0]) with tf.gfile.Open(label_filename, 'rb') as f: data = pickle.load(f, encoding='bytes') labels_bytes = data[labelname_key] labels_str = [label.decode('utf-8') for label in labels_bytes] return dict(zip(list(range(len(labels_str))), labels_str))
def convert_data(): if FLAGS.dataset == 'cifar10': dataset_dir = '/home/cideep/Work/tensorflow/datasets/cifar-10' convert_cifar10(dataset_dir, FLAGS.validation_ratio) elif FLAGS.dataset == 'cifar100': dataset_dir = '/home/cideep/Work/tensorflow/datasets/cifar-100' convert_cifar100(dataset_dir, FLAGS.validation_ratio) # elif FLAGS.dataset == 'voc2012': # convert_voc2012(FLAGS.input_dir, FLAGS.validation_ratio) else: raise PathError('%s: Not supported dataset' % FLAGS.input_dir)
def _collect_data(dataset_dir, file_pattern, label_key): src_files = fnmatch.filter(os.listdir(dataset_dir), file_pattern) if len(src_files) == 0: raise PathError('no data file such as %s in %s' % (file_pattern, dataset_dir)) total_images = np.array([]) total_labels = np.array([]) for filename in src_files: filename = os.path.join(dataset_dir, filename) print(' collect_data/srcfile name:', filename) with tf.gfile.Open(filename, 'rb') as f: data = pickle.load(f, encoding='bytes') images = data[b'data'] labels = np.asarray(data[label_key]) print(' collect_data/image_shape:', images.shape, 'label_shape:', labels.shape) if total_images.size == 0: total_images = images else: total_images = np.concatenate((total_images, images), axis=0) if total_labels.size == 0: total_labels = labels else: total_labels = np.concatenate((total_labels, labels), axis=0) if total_images.shape[ 0] == 0 or total_images.shape[0] != total_labels.shape[0]: raise PathError('invalid image or label size: %d != %d' \ % (total_images.shape[0], total_labels.shape[0])) print('collect_data/total_image_shape:', total_images.shape, 'total_label_shape:', total_labels.shape) return total_images, total_labels
def main(_): if not tf.gfile.Exists(FLAGS.dataset_path): raise PathError('dataset dir does not exists') image_dir = os.path.join(FLAGS.dataset_path, FLAGS.src_image_dir) image_list = list_images(image_dir) print('%d images are listed' % len(image_list)) # annotations = DataFrame(category, filename, xmin, xmax, ymin, ymax) src_annot_path = os.path.join(FLAGS.dataset_path, FLAGS.src_annot_dir) annotations = read_annotations(src_annot_path, image_list) dst_annot_path = os.path.join(FLAGS.dataset_path, FLAGS.dst_annot_dir) if not tf.gfile.Exists(dst_annot_path): tf.gfile.MakeDirs(dst_annot_path) dst_annot_name = os.path.join(dst_annot_path, '%s.csv' % FLAGS.dst_annot_dir) annotations.to_csv(dst_annot_name, sep='\t')