Ejemplo n.º 1
0
    def initialize_datasets(self):

        # Define prefetch count
        #self.prefetch_count = 2
        #self.prefetch_count = 4
        #self.prefetch_count = 8
        self.prefetch_count = 50
        #self.prefetch_count = 100
        
        # Specify which transformations to use for data augmentation
        self.transformations = get_transformations(self.rotate, self.flip)

        # Define iterators for training datasets
        self.training_datasets = [self.make_dataset(transformation=t) for t in self.transformations]

        # Define iterators for validation datasets
        self.validation_datasets = [self.make_dataset(training=False,transformation=t) for t in self.transformations]
        
        # Create early stopping batch from validation dataset
        if self.use_hires:
            filenames = 'hires_validation-*.tfrecords'
        else:
            filenames = 'validation-*.tfrecords'
        efiles = tf.data.Dataset.list_files(os.path.join(self.data_dir, filenames))
        self.edataset = tf.data.TFRecordDataset(efiles)
        self.edataset = self.edataset.map(lambda x: _parse_data(x,res=self.default_res))
        self.edataset = self.edataset.apply(tf.contrib.data.shuffle_and_repeat(self.stopping_size))
        self.edataset = self.edataset.batch(self.stopping_size)
        self.edataset = self.edataset.make_one_shot_iterator()
Ejemplo n.º 2
0
    def make_dataset(self, training=True, transformation=None):
        #filenames = 'hires_training-*.tfrecords' if training else 'hires_validation-*.tfrecords'

        if training:
            if self.data_files == 0:
                if self.use_hires:
                    filenames = 'hires_training-*.tfrecords'
                else:
                    filenames = 'training-*.tfrecords'
                files = tf.data.Dataset.list_files(os.path.join(self.data_dir, filenames), shuffle=True)
            else:
                if self.use_hires:
                    filenames = ['hires_training-' + str(n) + '.tfrecords' for n in range(0,self.data_files)]
                else:
                    filenames = ['training-' + str(n) + '.tfrecords' for n in range(0,self.data_files)]
                filenames = [os.path.join(self.data_dir, f) for f in filenames]
                files = tf.data.Dataset.from_tensor_slices(filenames)
        else:
            if self.use_hires:
                filenames = 'hires_validation-*.tfrecords'
            else:
                filenames = 'validation-*.tfrecords'
            files = tf.data.Dataset.list_files(os.path.join(self.data_dir, filenames), shuffle=True)

        def tfrecord_dataset(filename):
            buffer_size = 4 * 1024 * 1024
            return tf.data.TFRecordDataset(filename, buffer_size=buffer_size) 
            
        dataset = files.apply(tf.contrib.data.parallel_interleave(
            tfrecord_dataset, cycle_length=8, sloppy=True)) # cycle_length = number of input datasets to interleave from in parallel
        dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(10000)) # buffer_size here is just for 'randomness' of shuffling
        #dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(15000)) # buffer_size here is just for 'randomness' of shuffling
        dataset = dataset.apply(
            tf.contrib.data.map_and_batch(lambda x: _parse_data(x,res=self.default_res,transformation=transformation),
                                          self.batch_size, num_parallel_batches=self.prefetch_count))
        #if self.use_gpu:
        #    dataset = dataset.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=self.prefetch_count))
        #else:
        #    dataset = dataset.prefetch(self.prefetch_count)
        dataset = dataset.prefetch(self.prefetch_count) # prefetch is defined in terms of 'elements' of current dataset (i.e. batches)
        dataset = dataset.make_one_shot_iterator()
        return dataset