Exemple #1
0
 def testGeneratedShardedFilenamesCommaWithoutShard(self):
   filenames = data.generate_sharded_filenames('/foo/bar,/baz/qux')
   self.assertEqual(
       [
           '/foo/bar',
           '/baz/qux',
       ],
       filenames)
Exemple #2
0
 def testGeneratedShardedFilenamesCommaWithShard(self):
     filenames = data.generate_sharded_filenames('/foo/bar@3,/baz/qux@2')
     self.assertEqual([
         '/foo/bar-00000-of-00003',
         '/foo/bar-00001-of-00003',
         '/foo/bar-00002-of-00003',
         '/baz/qux-00000-of-00002',
         '/baz/qux-00001-of-00002',
     ], filenames)
def pipeline(config_map, dataset_config_map, preprocess_example_fn,
             input_tensors_to_example_fn):
    """Pipeline for dataset creation."""
    tf.flags.mark_flags_as_required(['output_directory'])

    pipeline_options = beam.options.pipeline_options.PipelineOptions(
        FLAGS.pipeline_options.split(','))

    config = config_map[FLAGS.config]
    hparams = config.hparams
    hparams.parse(FLAGS.hparams)

    datasets = dataset_config_map[FLAGS.dataset_config]

    if tf.gfile.Exists(FLAGS.output_directory):
        raise ValueError('Output directory %s already exists!' %
                         FLAGS.output_directory)
    tf.gfile.MakeDirs(FLAGS.output_directory)
    with tf.gfile.Open(os.path.join(FLAGS.output_directory, 'config.txt'),
                       'w') as f:
        f.write('\n\n'.join([
            'min_length: {}'.format(FLAGS.min_length),
            'max_length: {}'.format(FLAGS.max_length),
            'sample_rate: {}'.format(FLAGS.sample_rate),
            'preprocess_examples: {}'.format(FLAGS.preprocess_examples),
            'preprocess_train_example_multiplier: {}'.format(
                FLAGS.preprocess_train_example_multiplier),
            'config: {}'.format(FLAGS.config),
            'hparams: {}'.format(hparams.to_json(sort_keys=True)),
            'dataset_config: {}'.format(FLAGS.dataset_config),
            'datasets: {}'.format(datasets),
        ]))

    with beam.Pipeline(options=pipeline_options) as p:
        for dataset in datasets:
            if isinstance(dataset.path, (list, tuple)):
                # If dataset.path is a list, then it's a list of sources to mix together
                # to form new examples. First, do the mixing, then pass the results to
                # the rest of the pipeline.
                id_exs = []
                sourceid_to_exids = []
                for source_id, stem_path in enumerate(dataset.path):
                    if dataset.num_mixes is None:
                        raise ValueError(
                            'If path is not a list, num_mixes must not be None: {}'
                            .format(dataset))
                    stem_p = p | 'tfrecord_list_%s_%d' % (
                        dataset.name, source_id) >> (beam.Create(
                            data.generate_sharded_filenames(stem_path)))

                    # Note that we do not specify a coder when reading here.
                    # This is so that the hashing in key_example below can work directly
                    # on the serialized version instead of having to re-serialize it.
                    # Also, deserializing with a coder and then re-serializing does not
                    # always generate the same hash for the same example (likely due to
                    # the map fields in tf.train.Example). This is important when reading
                    # the same dataset multiple times to mix it with itself.
                    stem_p |= 'read_tfrecord_%s_%d' % (
                        dataset.name, source_id) >> (
                            beam.io.tfrecordio.ReadAllFromTFRecord())
                    stem_p |= 'shuffle_stems_%s_%d' % (
                        dataset.name, source_id) >> (beam.Reshuffle())

                    # Key all examples with a hash.
                    def key_example(ex):
                        return (hashlib.sha256(ex).hexdigest(), ex)

                    stem_p |= 'add_id_key_%s_%d' % (
                        dataset.name, source_id) >> (beam.Map(key_example))
                    id_exs.append(stem_p)

                    # Create a list of source_id to example id.
                    def sourceid_to_exid(id_ex, source_id):
                        return (source_id, id_ex[0])

                    sourceid_to_exids.append(
                        stem_p | 'key_%s_%d' % (dataset.name, source_id) >>
                        (beam.Map(sourceid_to_exid, source_id=source_id)))

                # ('example_hash', serialized_example)
                id_exs = (
                    id_exs
                    | 'id_exs_flatten_%s' % dataset.name >> beam.Flatten()
                    | 'id_exs_distinct_%s' % dataset.name >> beam.Distinct())

                # ('source_id, 'example_hash')
                sourceid_to_exids = (sourceid_to_exids
                                     | 'sourceid_to_exids_flatten_%s' %
                                     dataset.name >> beam.Flatten())

                # Pass the list of source id to example IDs to generate_mixes,
                # which will create mixes by selecting random IDs from each source
                # (with replacement). This is represented as a list of example IDs
                # to Mix IDs.
                # Note: beam.Create([0]) is just a single dummy value to allow the
                # sourceid_to_exids to be passed in as a python list so we can do the
                # sampling with numpy.
                exid_to_mixids = (
                    p
                    | 'create_dummy_%s' % dataset.name >> beam.Create([0])
                    | 'generate_mixes_%s' % dataset.name >> beam.Map(
                        create_dataset_lib.generate_mixes,
                        num_mixes=dataset.num_mixes,
                        sourceid_to_exids=beam.pvalue.AsList(
                            sourceid_to_exids)))

                # Create a list of (Mix ID, Full Example proto). Note: Examples may be
                # present in more than one mix. Then, group by Mix ID.
                def mixid_to_exs(id_ex, exid_to_mixids):
                    exid, ex = id_ex
                    for mixid in exid_to_mixids[exid]:
                        yield mixid, ex

                mixid_exs = (
                    id_exs
                    | 'mixid_to_exs_%s' % dataset.name >> beam.FlatMap(
                        mixid_to_exs,
                        exid_to_mixids=beam.pvalue.AsSingleton(exid_to_mixids))
                    | 'group_by_key_%s' % dataset.name >> beam.GroupByKey())
                # Take these groups of Examples, mix their audio and sequences to return
                # a single new Example. Then, carry on with the rest of the pipeline
                # like normal.
                split_p = (mixid_exs
                           | 'mix_examples_%s' % dataset.name >> beam.Map(
                               mix_examples, FLAGS.sample_rate,
                               FLAGS.load_audio_with_librosa))
            else:
                if dataset.num_mixes is not None:
                    raise ValueError(
                        'If path is not a list, num_mixes must be None: {}'.
                        format(dataset))
                split_p = p | 'tfrecord_list_%s' % dataset.name >> beam.Create(
                    data.generate_sharded_filenames(dataset.path))
                split_p |= 'read_tfrecord_%s' % dataset.name >> (
                    beam.io.tfrecordio.ReadAllFromTFRecord(
                        coder=beam.coders.ProtoCoder(tf.train.Example)))
            split_p |= 'shuffle_input_%s' % dataset.name >> beam.Reshuffle()
            split_p |= 'split_wav_%s' % dataset.name >> beam.FlatMap(
                split_wav,
                min_length=FLAGS.min_length,
                max_length=FLAGS.max_length,
                sample_rate=FLAGS.sample_rate,
                debug_output_directory=FLAGS.output_directory,
                split_example=dataset.process_for_training,
                load_audio_with_librosa=FLAGS.load_audio_with_librosa)
            if FLAGS.preprocess_examples:
                if dataset.process_for_training:
                    mul_name = 'preprocess_multiply_%dx_%s' % (
                        FLAGS.preprocess_train_example_multiplier,
                        dataset.name)
                    split_p |= mul_name >> beam.FlatMap(
                        multiply_example,
                        FLAGS.preprocess_train_example_multiplier)
                split_p |= 'preprocess_%s' % dataset.name >> beam.Map(
                    preprocess_data, preprocess_example_fn,
                    input_tensors_to_example_fn, hparams,
                    dataset.process_for_training)
            split_p |= 'shuffle_output_%s' % dataset.name >> beam.Reshuffle()
            split_p |= 'write_%s' % dataset.name >> beam.io.WriteToTFRecord(
                os.path.join(FLAGS.output_directory,
                             '%s.tfrecord' % dataset.name),
                coder=beam.coders.ProtoCoder(tf.train.Example))