Ejemplo n.º 1
0
 def predict_input_fn():
     source = tf.data.TFRecordDataset(list(test_source_files))
     target = tf.data.TFRecordDataset(list(test_target_files))
     dataset = dataset_factory(source, target, hparams)
     batched = dataset.prepare_and_zip().group_by_batch(
         batch_size=1).move_mel_to_source()
     return batched.dataset
Ejemplo n.º 2
0
    def eval_input_fn():
        source_and_target_files = list(zip(eval_source_files, eval_target_files))
        shuffle(source_and_target_files)
        source = tf.data.TFRecordDataset([s for s, _ in source_and_target_files])
        target = tf.data.TFRecordDataset([t for _, t in source_and_target_files])

        dataset = dataset_factory(source, target, hparams)
        zipped = dataset.prepare_and_zip()
        dataset = zipped.filter_by_max_output_length().repeat().group_by_batch(batch_size=1)
        return dataset.dataset
Ejemplo n.º 3
0
    def _config_ds_iterator(self):
        with tf.device('/cpu:0'):
            # split dataset by worker
            data_sources = dataset_utils.get_tfrecord_files(
                self._num_workers, FLAGS.file_pattern)

            # select the preprocessing func
            dataset_obj_initializer = dataset_factory(
                FLAGS.dataset_name).get_dataset_obj()
            self._dataset_obj = dataset_obj_initializer(
                data_sources, None, None)
            self._ds_iterator = self._dataset_obj.get_dataset()
Ejemplo n.º 4
0
def _get_dataset(tfrecords_dir='tfrecords/', preprocessing_type='caffe', dataset_class='pascalvoc', coco_year='2017',
                pascalvoc_year='2007', pascalvoc_class='trainval', num_pascalvoc_tfrecords=2,
                data_root_path=None):
    if dataset_class == 'pascalvoc':
        file_pattern = 'pascalvoc_{}_{}_%02d.tfrecord'.format(pascalvoc_year, pascalvoc_class)
        file_names = [os.path.join(tfrecords_dir, file_pattern % i) for i in range(num_pascalvoc_tfrecords)]
        # print(file_names)
        dataset_configs = { 'tf_record_list': file_names,
                            'min_size': configs['image_min_size'],
                            'max_size': configs['image_max_size'],
                            'preprocessing_type': configs['preprocessing_type'],
                            'caffe_pixel_means': configs['bgr_pixel_means'],
                            'data_argumentation': True
                            }
        # def dataset_factory(dataset_class, mode, configs):
        dataset = dataset_factory('pascalvoc', 'trainval', dataset_configs)
        return dataset