def get_features(split, total_batch_size, num_gpus, data_dir, num_targets, dataset, validate=False): """Reads the input data and distributes it over num_gpus GPUs. Each tower of data has 1/FLAGS.num_gpus of the total_batch_size. Args: split: 'train' or 'test', split of the data to read. total_batch_size: total number of data entries over all towers. num_gpus: Number of GPUs to distribute the data on. data_dir: Directory containing the input data. num_targets: Number of objects present in the image. dataset: The name of the dataset, either norb or mnist. validate: If set, subset training data into training and test. Returns: A list of batched feature dictionaries. Raises: ValueError: If dataset is not mnist or norb. """ batch_size = total_batch_size // max(1, num_gpus) features = [] for i in range(num_gpus): with tf.device('/gpu:%d' % i): if dataset == 'mnist': features.append( mnist_input_record.inputs( data_dir=data_dir, batch_size=batch_size, split=split, num_targets=num_targets, validate=validate, )) elif dataset == 'norb': features.append( norb_input_record.inputs( data_dir=data_dir, batch_size=batch_size, split=split, )) elif dataset == 'cifar10': data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') features.append( cifar10_input.inputs(split=split, data_dir=data_dir, batch_size=batch_size)) else: raise ValueError( 'Unexpected dataset {!r}, must be mnist, norb, or cifar10.' .format(dataset)) return features
def get_features(split, total_batch_size, num_gpus, data_dir, num_targets, dataset, validate=False): """Reads the input data and distributes it over num_gpus GPUs. Each tower of data has 1/FLAGS.num_gpus of the total_batch_size. Args: split: 'train' or 'test', split of the data to read. total_batch_size: total number of data entries over all towers. num_gpus: Number of GPUs to distribute the data on. data_dir: Directory containing the input data. num_targets: Number of objects present in the image. dataset: The name of the dataset, either norb or mnist. validate: If set, subset training data into training and test. Returns: A list of batched feature dictionaries. Raises: ValueError: If dataset is not mnist or norb. """ batch_size = total_batch_size // max(1, num_gpus) features = [] for i in range(num_gpus): with tf.device('/gpu:%d' % i): if dataset == 'mnist': features.append( mnist_input_record.inputs( data_dir=data_dir, batch_size=batch_size, split=split, num_targets=num_targets, validate=validate, )) elif dataset == 'norb': features.append( norb_input_record.inputs( data_dir=data_dir, batch_size=batch_size, split=split, )) elif dataset == 'cifar10': data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') features.append( cifar10_input.inputs( split=split, data_dir=data_dir, batch_size=batch_size)) else: raise ValueError( 'Unexpected dataset {!r}, must be mnist, norb, or cifar10.'.format( dataset)) return features
def get_features(split, total_batch_size, num_gpus, data_dir, num_targets, dataset, validate=False, evaluate=False, seed=None, shuffled=False, shift=2, pad=0, eval_shard=None): """Reads the input data and distributes it over num_gpus GPUs. Each tower of data has 1/FLAGS.num_gpus of the total_batch_size. Args: split: 'train' or 'test', split of the data to read. total_batch_size: total number of data entries over all towers. num_gpus: Number of GPUs to distribute the data on. data_dir: Directory containing the input data. num_targets: Number of objects present in the image. dataset: The name of the dataset, either norb or mnist. validate: If set, subset training data into training and test. evaluate: If set, prepare features for test time (e.g. no shuffling). seed: If set, specify the seed for shuffling training batches. shuffled: If set, use the shuffled version of .tfrecords dataset. Returns: A list of batched feature dictionaries. Raises: ValueError: If dataset is not mnist or norb. """ batch_size = total_batch_size // max(1, num_gpus) features = [] for i in range(num_gpus): with tf.device('/gpu:%d' % i): if dataset == 'mnist': features.append( mnist_input_record.inputs(data_dir=data_dir, batch_size=batch_size, split=split, shift=shift, pad=pad, shuffled=shuffled, num_targets=num_targets, validate=validate, evaluate=evaluate, seed=seed)) elif dataset == 'norb': features.append( norb_input_record.inputs( data_dir=data_dir, batch_size=batch_size, split=split, )) elif dataset == 'cifar10': data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') features.append( cifar10_input.inputs(split=split, data_dir=data_dir, batch_size=batch_size)) elif dataset == 'affnist': features.append( affnist_input_record.inputs(data_dir=data_dir, batch_size=batch_size, split=split, shift=shift, validate=validate, evaluate=evaluate, seed=seed, eval_shard=eval_shard)) else: raise ValueError('Unexpected dataset {!r}, must be mnist, ' 'norb, cifar10, or affnist.'.format(dataset)) return features