示例#1
0
def get_features(split,
                 total_batch_size,
                 num_gpus,
                 data_dir,
                 num_targets,
                 dataset,
                 validate=False):
    """Reads the input data and distributes it over num_gpus GPUs.

  Each tower of data has 1/FLAGS.num_gpus of the total_batch_size.

  Args:
    split: 'train' or 'test', split of the data to read.
    total_batch_size: total number of data entries over all towers.
    num_gpus: Number of GPUs to distribute the data on.
    data_dir: Directory containing the input data.
    num_targets: Number of objects present in the image.
    dataset: The name of the dataset, either norb or mnist.
    validate: If set, subset training data into training and test.

  Returns:
    A list of batched feature dictionaries.

  Raises:
    ValueError: If dataset is not mnist or norb.
  """

    batch_size = total_batch_size // max(1, num_gpus)
    features = []
    for i in range(num_gpus):
        with tf.device('/gpu:%d' % i):
            if dataset == 'mnist':
                features.append(
                    mnist_input_record.inputs(
                        data_dir=data_dir,
                        batch_size=batch_size,
                        split=split,
                        num_targets=num_targets,
                        validate=validate,
                    ))
            elif dataset == 'norb':
                features.append(
                    norb_input_record.inputs(
                        data_dir=data_dir,
                        batch_size=batch_size,
                        split=split,
                    ))
            elif dataset == 'cifar10':
                data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
                features.append(
                    cifar10_input.inputs(split=split,
                                         data_dir=data_dir,
                                         batch_size=batch_size))
            else:
                raise ValueError(
                    'Unexpected dataset {!r}, must be mnist, norb, or cifar10.'
                    .format(dataset))
    return features
示例#2
0
def get_features(split, total_batch_size, num_gpus, data_dir, num_targets,
                 dataset, validate=False):
  """Reads the input data and distributes it over num_gpus GPUs.

  Each tower of data has 1/FLAGS.num_gpus of the total_batch_size.

  Args:
    split: 'train' or 'test', split of the data to read.
    total_batch_size: total number of data entries over all towers.
    num_gpus: Number of GPUs to distribute the data on.
    data_dir: Directory containing the input data.
    num_targets: Number of objects present in the image.
    dataset: The name of the dataset, either norb or mnist.
    validate: If set, subset training data into training and test.

  Returns:
    A list of batched feature dictionaries.

  Raises:
    ValueError: If dataset is not mnist or norb.
  """

  batch_size = total_batch_size // max(1, num_gpus)
  features = []
  for i in range(num_gpus):
    with tf.device('/gpu:%d' % i):
      if dataset == 'mnist':
        features.append(
            mnist_input_record.inputs(
                data_dir=data_dir,
                batch_size=batch_size,
                split=split,
                num_targets=num_targets,
                validate=validate,
            ))
      elif dataset == 'norb':
        features.append(
            norb_input_record.inputs(
                data_dir=data_dir, batch_size=batch_size, split=split,
            ))
      elif dataset == 'cifar10':
        data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
        features.append(
            cifar10_input.inputs(
                split=split, data_dir=data_dir, batch_size=batch_size))
      else:
        raise ValueError(
            'Unexpected dataset {!r}, must be mnist, norb, or cifar10.'.format(
                dataset))
  return features
def get_features(split,
                 total_batch_size,
                 num_gpus,
                 data_dir,
                 num_targets,
                 dataset,
                 validate=False,
                 evaluate=False,
                 seed=None,
                 shuffled=False,
                 shift=2,
                 pad=0,
                 eval_shard=None):
    """Reads the input data and distributes it over num_gpus GPUs.

    Each tower of data has 1/FLAGS.num_gpus of the total_batch_size.

    Args:
      split: 'train' or 'test', split of the data to read.
      total_batch_size: total number of data entries over all towers.
      num_gpus: Number of GPUs to distribute the data on.
      data_dir: Directory containing the input data.
      num_targets: Number of objects present in the image.
      dataset: The name of the dataset, either norb or mnist.
      validate: If set, subset training data into training and test.
      evaluate: If set, prepare features for test time (e.g. no shuffling).
      seed: If set, specify the seed for shuffling training batches.
      shuffled: If set, use the shuffled version of .tfrecords dataset.

    Returns:
      A list of batched feature dictionaries.

    Raises:
      ValueError: If dataset is not mnist or norb.
    """

    batch_size = total_batch_size // max(1, num_gpus)
    features = []
    for i in range(num_gpus):
        with tf.device('/gpu:%d' % i):
            if dataset == 'mnist':
                features.append(
                    mnist_input_record.inputs(data_dir=data_dir,
                                              batch_size=batch_size,
                                              split=split,
                                              shift=shift,
                                              pad=pad,
                                              shuffled=shuffled,
                                              num_targets=num_targets,
                                              validate=validate,
                                              evaluate=evaluate,
                                              seed=seed))
            elif dataset == 'norb':
                features.append(
                    norb_input_record.inputs(
                        data_dir=data_dir,
                        batch_size=batch_size,
                        split=split,
                    ))
            elif dataset == 'cifar10':
                data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
                features.append(
                    cifar10_input.inputs(split=split,
                                         data_dir=data_dir,
                                         batch_size=batch_size))
            elif dataset == 'affnist':
                features.append(
                    affnist_input_record.inputs(data_dir=data_dir,
                                                batch_size=batch_size,
                                                split=split,
                                                shift=shift,
                                                validate=validate,
                                                evaluate=evaluate,
                                                seed=seed,
                                                eval_shard=eval_shard))
            else:
                raise ValueError('Unexpected dataset {!r}, must be mnist, '
                                 'norb, cifar10, or affnist.'.format(dataset))
    return features