コード例 #1
0
def get_experiment_fn(dataset_dir):
    """
    Returns the experiment function given a dataset_dir
    """
    # Get the number of classes from the label file
    _, num_classes = dataset_utils.read_label_file(dataset_dir)

    def experiment_fn(run_config, hparams):
        """
        This is a method passed to tf.contrib.learn.learn_runner that will
        return an instance of an Experiment.
        """

        train_input_fn = functools.partial(input_fn,
                                           dataset_dir=dataset_dir,
                                           split_name='train',
                                           is_training=True)

        eval_input_fn = functools.partial(input_fn,
                                          dataset_dir=dataset_dir,
                                          split_name='validation',
                                          is_training=False)

        classifier = tf.estimator.Estimator(model_fn=get_model_fn(num_classes),
                                            config=run_config)

        return tf.contrib.learn.Experiment(
            classifier,
            train_input_fn=train_input_fn,
            eval_input_fn=eval_input_fn,
            train_steps=None,  # Train forever
            eval_steps=VALIDATION_STEPS)

    return experiment_fn
コード例 #2
0
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading ImageNet.

  Args:
    split_name: A train/test split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/test split.
  """
  if split_name not in _SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature(
          (), tf.string, default_value='jpeg'),
      'image/class/label': tf.FixedLenFeature(
          [], dtype=tf.int64, default_value=-1),
      'image/class/text': tf.FixedLenFeature(
          [], dtype=tf.string, default_value=''),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
      'label_text': slim.tfexample_decoder.Tensor('image/class/text'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=_SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)
コード例 #3
0
def main(_):
    print("dataset dir", FLAGS.dataset_dir)
    class_names = dataset_utils.read_label_file(FLAGS.dataset_dir)
    filename = os.path.join(FLAGS.dataset_dir, '%s.tfrecord' % FLAGS.split)

    splits_to_sizes, num_classes, image_shape, items_to_descriptions \
        = dataset_factory.dataset_config(FLAGS.dataset_name)

    output_data = np.zeros((splits_to_sizes[FLAGS.split], num_classes + 1))
    print(output_data.shape)

    with tf.Graph().as_default():
        network_fn = nets_factory.get_network_fn(FLAGS.model,
                                                 num_classes=num_classes,
                                                 is_training=False)
        processed_images, raw_images, labels = create_inputs( \
            filename, FLAGS.batch_size, network_fn.default_image_size)
        probabilities = pass_network(processed_images, FLAGS.model, network_fn)

        sess = tf.Session()
        init_ops(sess)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        plt.figure()
        np.set_printoptions(precision=3, suppress=True)

        try:
            step = 0
            while not coord.should_stop():
                [np_images, np_labels, np_probabilities
                 ] = sess.run([raw_images, labels, probabilities])

                # print_results(np_images, np_labels, np_probabilities, class_names)
                output_data = stack_results(output_data, np_labels,
                                            np_probabilities, step,
                                            FLAGS.batch_size)
                step += 1
                if step % 100 == 0:
                    print("step:", step)

        except tf.errors.OutOfRangeError:
            print('Done evaluating for %d steps.' % step)
        except my_exceptions.GeneralError as ge:
            print('Done evaluating for %d steps.' % step, ge)
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()

        # Wait for threads to finish.
        coord.join(threads)
        sess.close()

    save_probabilities(output_data, FLAGS.output_dir, FLAGS.split)
コード例 #4
0
    def getSplit (self, split_name, nValidations = None):
        if not nValidations:
            if not self.nValidations:
                raise ValueError("nValidations is not known");
            else:
                nValidations = self.nValidations
                
        if not self.nTrains:
            photo_filenames, _ = self._get_filenames_and_classes()
            self.nTrains = len(photo_filenames) - self.nValidations
            
        if split_name not in self.SPLITS_NAMES:
            raise ValueError('split name %s was not recognized.' % split_name)
        file_pattern = self._FILE_PATTERN
        file_pattern = os.path.join(self.datasetDir, file_pattern % split_name)

        reader = tf.TFRecordReader

        keys_to_features = {
              'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
              'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
              'image/class/label': tf.FixedLenFeature(
                  [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
          }

        items_to_handlers = {
          'image': slim.tfexample_decoder.Image(),
          'label': slim.tfexample_decoder.Tensor('image/class/label'),
        }

        decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

        labels_to_names = None
        if dataset_utils.has_labels(self.datasetDir):
            labels_to_names = dataset_utils.read_label_file(self.datasetDir)
            self.labels_to_names = labels_to_names
            self.nClasses = len(labels_to_names)
        else:
            raise ValueError("Can't find label file in the data directory: " + self.datasetDir);
            
        return slim.dataset.Dataset(
          data_sources=file_pattern,
          reader=reader,
          decoder=decoder,
          num_samples=self.nTrains, 
          items_to_descriptions=self._ITEMS_TO_DESCRIPTIONS,
          num_classes=len(labels_to_names),
          labels_to_names=labels_to_names)
コード例 #5
0
ファイル: dataset_factory.py プロジェクト: hutt94/paper_code
def get_dataset(split_name, dataset_dir):
    """Given a dataset name and a split_name returns a Dataset.

    Args:
        split_name: A train/test split name.
        dataset_dir: The directory where the dataset files are stored.
        file_pattern: The file pattern to use for matching the dataset source files.

    Returns:
        A `Dataset` class.
    """

    file_pattern = os.path.join(dataset_dir, split_name + '.tfrecord')
    reader = tf.TFRecordReader

    keys_to_features = {
        'image/encoded':
        tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
        tf.FixedLenFeature((), tf.string, default_value='png'),
        'image/class/label':
        tf.FixedLenFeature([],
                           tf.int64,
                           default_value=tf.zeros([], dtype=tf.int64))
    }

    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(),
        'label': slim.tfexample_decoder.Tensor('image/class/label')
    }

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)

    return slim.dataset.Dataset(data_sources=file_pattern,
                                reader=reader,
                                decoder=decoder,
                                num_samples=get_sample_num(split_name),
                                items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
                                num_classes=_NUM_CLASSES,
                                labels_to_names=labels_to_names)
コード例 #6
0
ファイル: evaluate.py プロジェクト: rkrishnan2012/WCCI_Paper
def evaluate(model_dir, dataset_dir):
    """
    Begins evaluating the entire architecture.
    """
    # Session configuration.
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        intra_op_parallelism_threads=0,  # Autocompute how many threads to run
        gpu_options=tf.GPUOptions(force_gpu_compatible=True))

    config = tf.contrib.learn.RunConfig(session_config=sess_config,
                                        model_dir=model_dir)

    eval_input_fn = functools.partial(input_fn,
                                      dataset_dir=dataset_dir,
                                      split_name='validation',
                                      is_training=False)

    # Get the number of classes from the label file
    labels_to_class_names, num_classes = read_label_file(dataset_dir)

    classifier = tf.estimator.Estimator(model_fn=get_model_fn(num_classes),
                                        config=config)

    # .predict() returns an iterator of dicts;
    y = classifier.predict(input_fn=eval_input_fn)

    num_food_image = {}

    for pred in y:
        predicted_class = labels_to_class_names[int(pred['classes'])]
        food_dir = '../Validations/%s/%s' % (os.path.basename(model_dir),
                                             predicted_class)

        if not os.path.exists(food_dir):
            os.makedirs(food_dir)

        file_name = os.path.join(
            food_dir, '%s.png' % num_food_image.get(predicted_class, 1))

        num_food_image[predicted_class] = num_food_image.get(
            predicted_class, 1) + 1

        scipy.misc.imsave(file_name, pred['features'])
コード例 #7
0
def get_dataset(dataset_name, split_name, dataset_dir):

    splits_to_sizes, num_classes, example_shape, items_to_descriptions \
      = dataset_config(dataset_name)

    if split_name not in splits_to_sizes:
        raise ValueError('split name %s was not recognized.' % split_name)

    data_file = os.path.join(dataset_dir,
                             FILE_PATTERN % (dataset_name, split_name))

    # Allowing None in the signature so that dataset_factory can use the default.
    reader = tf.TFRecordReader

    keys_to_features = {
        'image/encoded':
        tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
        tf.FixedLenFeature((), tf.string, default_value='raw'),
        'image/class/label':
        tf.FixedLenFeature([],
                           tf.int64,
                           default_value=tf.zeros([], dtype=tf.int64)),
    }

    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(shape=example_shape),
        'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)

    return slim.dataset.Dataset(data_sources=data_file,
                                reader=reader,
                                decoder=decoder,
                                num_samples=splits_to_sizes[split_name],
                                items_to_descriptions=items_to_descriptions,
                                num_classes=num_classes,
                                labels_to_names=labels_to_names)
コード例 #8
0
def main(_):
    class_names = dataset_utils.read_label_file(DATASET_PATH_PATTERN %
                                                FLAGS.dataset_name)

    with tf.Graph().as_default():
        processed_images, raw_images, labels = inputs(
            batch_size=FLAGS.batch_size)

        probabilities = pass_network(processed_images, FLAGS.model_name)

        sess = tf.Session()
        init_fn = slim.assign_from_checkpoint_fn(
            FLAGS.checkpoint_path,
            slim.get_model_variables(name_to_name_scope[FLAGS.model_name]))
        init_fn(sess)
        init_local = tf.local_variables_initializer()
        sess.run(init_local)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        plt.figure()
        np.set_printoptions(precision=3, suppress=True)

        try:
            step = 0
            while not coord.should_stop():
                step += 1
                [np_images, np_labels, np_probabilities
                 ] = sess.run([raw_images, labels, probabilities])

                print_results(np_images, np_labels, np_probabilities,
                              class_names)

        except tf.errors.OutOfRangeError:
            print('Done evaluating for %d steps.' % step)
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()

        # Wait for threads to finish.
        coord.join(threads)
        sess.close()
コード例 #9
0
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
    if split_name not in SPLITS_TO_SIZES:
        raise ValueError('split name %s was not recognized.' % split_name)

    if not file_pattern:
        file_pattern = _FILE_PATTERN
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader

    keys_to_features = {
        'image/encoded':
        tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
        tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label':
        tf.FixedLenFeature([],
                           tf.int64,
                           default_value=tf.zeros([], dtype=tf.int64)),
    }

    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(),
        'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)

    return slim.dataset.Dataset(data_sources=file_pattern,
                                reader=reader,
                                decoder=decoder,
                                num_samples=SPLITS_TO_SIZES[split_name],
                                items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
                                num_classes=_NUM_CLASSES,
                                labels_to_names=labels_to_names)
コード例 #10
0
ファイル: pascalvoc_common.py プロジェクト: Zehaos/misc
def get_split(split_name, dataset_dir, file_pattern, reader,
              split_to_sizes, items_to_descriptions, num_classes):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.

    Args:
      split_name: A train/test split name.
      dataset_dir: The base directory of the dataset sources.
      file_pattern: The file pattern to use when matching the dataset sources.
        It is assumed that the pattern contains a '%s' string so that the split
        name can be inserted.
      reader: The TensorFlow reader type.

    Returns:
      A `Dataset` namedtuple.

    Raises:
        ValueError: if `split_name` is not a valid train/test split.
    """
    if split_name not in split_to_sizes:
        raise ValueError('split name %s was not recognized.' % split_name)
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader
    # Features in Pascal VOC TFRecords.
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)
    # else:
    #     labels_to_names = create_readable_names_for_imagenet_labels()
    #     dataset_utils.write_label_file(labels_to_names, dataset_dir)

    return slim.dataset.Dataset(
            data_sources=file_pattern,
            reader=reader,
            decoder=decoder,
            num_samples=split_to_sizes[split_name],
            items_to_descriptions=items_to_descriptions,
            num_classes=num_classes,
            labels_to_names=labels_to_names)
コード例 #11
0
origin_image_path = '/comvol/nfs/datasets/medicine/NIH-CXR8/images/images'
image_label_list = 'data/list/pneumonia/image_label.txt'

#State the labels file and read it
# labels_file = 'data/list/labels.txt'
# labels = open(labels_file, 'r')

#Create a dictionary to refer each label to their string name
# labels_to_name = {}
# for line in labels:
#     label, string_name = line.split(':')
#     string_name = string_name[:-1] #Remove newline
#     labels_to_name[int(label)] = string_name

labels_to_name = read_label_file(origin_image_path, image_label_list)

#Create the file pattern of your TFRecord files so that it could be recognized later on
file_pattern = 'chest14_%s_*.tfrecord'

#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
items_to_descriptions = {
    'image':
    'A chest image that is used in binary classfication',
    'label':
    'A label that is as such -- 0: no certain illness, 1:have certain illness'
}

#================= TRAINING INFORMATION ==================
#State the number of epochs to train
# num_epochs = 100
コード例 #12
0
ファイル: imagenet.py プロジェクト: MinhaoTang/tpu-demos
def get_split(split_name, dataset_dir, labels_dir=None, file_pattern=None):
    """Retrieves a InputData object with the parameters for reading ImageNet data.

  Args:
    split_name: A train/test split name.
    dataset_dir: The base directory of the dataset sources.
    labels_dir: The folder where the labels file is located, and where it will
      be eventually written if missing.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.

  Returns:
    An `InputData` object.

  Raises:
    ValueError: if `split_name` is not a valid train/test split.
  """
    if split_name not in _SPLITS_TO_SIZES:
        raise ValueError('split name %s was not recognized.' % split_name)
    if not labels_dir:
        labels_dir = dataset_dir
    if not file_pattern:
        file_pattern = _FILE_PATTERN
    file_pattern = file_pattern % split_name
    files = []
    # Allow for filename expansion w/out using Glob().
    # Example: 'train-[0,1023,05d]-of-01024' to generate:
    #   train-00000-of-01024
    #   train-00001-of-01024
    #   ...
    #   train-01023-of-01024
    m = re.match(r'(.*)\[(\d+),(\d+),([a-zA-Z0-9]+)\](.*)', file_pattern)
    if m:
        format_string = '%' + m.group(4)
        for n in range(int(m.group(2)), int(m.group(3)) + 1):
            seqstr = format_string % n
            files.append(
                os.path.join(dataset_dir,
                             m.group(1) + seqstr + m.group(5)))
    else:
        path = os.path.join(dataset_dir, file_pattern)
        # If the file_pattern ends with '.list', then the file is supposed to be a
        # file which lists the input files one per line.
        if path.endswith('.list'):
            with gfile.Open(path, 'r') as list_file:
                for fpath in list_file:
                    fpath = fpath.strip()
                    if fpath:
                        files.append(fpath)
        elif path.find('*') < 0:
            # If the path does not contain any glob pattern, assume it is a single
            # input file. Detection for glob patters might be more complex, but all
            # the examples seen so far, uses '*' only.
            files.append(path)
        else:
            # Otherwise we assume it is a glob-able path.
            files = gfile.Glob(path)
    keys_to_features = {
        'image/encoded':
        tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
        tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label':
        tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
        'image/class/text':
        tf.FixedLenFeature([], dtype=tf.string, default_value=''),
        'image/object/bbox/xmin':
        tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin':
        tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax':
        tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax':
        tf.VarLenFeature(dtype=tf.float32),
        'image/object/class/label':
        tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {
        'image':
        slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'label':
        slim.tfexample_decoder.Tensor('image/class/label'),
        'label_text':
        slim.tfexample_decoder.Tensor('image/class/text'),
        'object/bbox':
        slim.tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
                                           'image/object/bbox/'),
        'object/label':
        slim.tfexample_decoder.Tensor('image/object/class/label'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)
    labels_to_names = None
    if dataset_utils.has_labels(labels_dir):
        labels_to_names = dataset_utils.read_label_file(labels_dir)
    else:
        labels_to_names = create_readable_names_for_imagenet_labels()
        dataset_utils.write_label_file(labels_to_names, labels_dir)
    return InputData(data_sources=files,
                     decoder=decoder,
                     num_samples=_SPLITS_TO_SIZES[split_name],
                     items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
                     num_classes=_NUM_CLASSES,
                     labels_to_names=labels_to_names)
コード例 #13
0
def main():

    #==============================================================CHECKS==========================================================================
    #Check if there is a tfrecord_filename entered
    if not FLAGS.tfrecord_filename:
        raise ValueError(
            'tfrecord_filename is empty. Please state a tfrecord_filename argument.'
        )

    #Check if there is a dataset directory entered
    if not FLAGS.dataset_dir:
        raise ValueError(
            'dataset_dir is empty. Please state a dataset_dir argument.')

    #If the TFRecord files already exist in the directory, then exit without creating the files again
    if _dataset_exists(dataset_dir=FLAGS.dataset_dir,
                       _NUM_SHARDS=FLAGS.num_shards,
                       output_filename=FLAGS.tfrecord_filename):
        print 'Dataset files already exist. Exiting without re-creating them.'
        return None
    #==============================================================END OF CHECKS===================================================================

    #Get a list of photo_filenames like ['123.jpg', '456.jpg'...] and a list of sorted class names from parsing the subdirectories.
    # photo_filenames, labels = _get_image_label(read_label_file(FLAGS.dataset_dir, 'data/list/binary_effusion.txt'))
    train_image, train_label = _get_image_label(
        read_label_file(FLAGS.dataset_dir, FLAGS.train_list))
    logging.debug("train_image: %s, train_label: %s", train_image[:10],
                  train_label[:10])
    val_image, val_label = _get_image_label(
        read_label_file(FLAGS.dataset_dir, FLAGS.val_list))
    logging.debug("val_image: %s, val_label: %s", val_image[:10],
                  val_label[:10])

    #Refer each of the class name to a specific integer number for predictions later
    # class_names_to_ids = dict(zip(class_names, range(len(class_names))))

    #Find the number of validation examples we need
    # num_validation = int(FLAGS.validation_size * len(photo_filenames))

    # Divide the training datasets into train and test:
    random.seed(FLAGS.random_seed)
    # random.shuffle(photo_filenames)
    # training_filenames = photo_filenames[num_validation:]
    # train_label = labels[num_validation:]
    # validation_filenames = photo_filenames[:num_validation]
    # val_label = labels[:num_validation]

    # First, convert the training and validation sets.
    _convert_dataset('train',
                     train_image,
                     train_label,
                     dataset_dir=FLAGS.dataset_dir,
                     write_dir=FLAGS.write_dir,
                     tfrecord_filename=FLAGS.tfrecord_filename,
                     _NUM_SHARDS=FLAGS.num_shards)
    _convert_dataset('validation',
                     val_image,
                     val_label,
                     dataset_dir=FLAGS.dataset_dir,
                     write_dir=FLAGS.write_dir,
                     tfrecord_filename=FLAGS.tfrecord_filename,
                     _NUM_SHARDS=FLAGS.num_shards)

    # Finally, write the labels file:
    # labels_to_class_names = dict(zip(range(len(class_names)), class_names))
    # write_label_file(labels_to_class_names, FLAGS.dataset_dir)

    print '\nFinished converting the %s dataset!' % (FLAGS.tfrecord_filename)
コード例 #14
0
ファイル: flowers.py プロジェクト: Ryota7101/cat_class
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)
コード例 #15
0
ファイル: eval.py プロジェクト: NoicFank/ChestRayXNet
def get_split(split_name, dataset_dir, file_pattern, file_pattern_for_counting):
    '''
    Obtains the split - training or validation - to create a Dataset class for feeding the examples into a queue later on. This function will
    set up the decoder and dataset information all into one Dataset class so that you can avoid the brute work later on.
    Your file_pattern is very important in locating the files later.

    INPUTS:
    - split_name(str): 'train' or 'validation'. Used to get the correct data split of tfrecord files
    - dataset_dir(str): the dataset directory where the tfrecord files are located
    - file_pattern(str): the file name structure of the tfrecord files in order to get the correct data
    - file_pattern_for_counting(str): the string name to identify your tfrecord files for counting

    OUTPUTS:
    - dataset (Dataset): A Dataset class object where we can read its various components for easier batch creation later.
    '''

    #First check whether the split_name is train or validation
    if split_name not in ['train', 'validation']:
        raise ValueError('The split_name %s is not recognized. Please input either train or validation as the split_name' % (split_name))

    #Create the full path for a general file_pattern to locate the tfrecord_files
    file_pattern_path = os.path.join(dataset_dir, file_pattern % (split_name))

    #Count the total number of examples in all of these shard
    num_samples = 0
    file_pattern_for_counting = file_pattern_for_counting + '_' + split_name
    tfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir) if file.startswith(file_pattern_for_counting)]
    for tfrecord_file in tfrecords_to_count:
        for record in tf.python_io.tf_record_iterator(tfrecord_file):
            num_samples += 1

    #Create a reader, which must be a TFRecord reader in this case
    reader = tf.TFRecordReader

    #Create the keys_to_features dictionary for the decoder
    keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
    }

    #Create the items_to_handlers dictionary for the decoder.
    items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),
    'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }

    #Start to create the decoder
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

    #Create the labels_to_name file
    labels_to_name_dict = read_label_file(image_set_dir, 'data/list/pneumonia/image_label.txt')

    items_to_descriptions = {
        'image': 'A chest image that is used in binary classfication',
        'label': 'A label that is as such -- 0: no certain illness, 1:have certain illness'
    }


    #Actually create the dataset
    dataset = slim.dataset.Dataset(
        data_sources = file_pattern_path,
        decoder = decoder,
        reader = reader,
        num_readers = 4,
        num_samples = num_samples,
        num_classes = 2,
        labels_to_name = labels_to_name_dict,
        items_to_descriptions = items_to_descriptions)

    return dataset