def _get_data_prep_params_from_flags():
    """Get parameters for data prep pipeline from flags."""
    if not FLAGS.output_filename:
        raise ValueError('Must provide output filename.')
    if not FLAGS.comma_escape_char:
        raise ValueError('`FLAGS.comma_escape_char` must be provided.')

    run_data_prep = True
    if FLAGS.train_input_glob:  # Explicitly pass globs.
        if not FLAGS.validation_input_glob:
            raise ValueError(
                'If using globs, must supply `validation_input_glob.`')
        if not FLAGS.test_input_glob:
            raise ValueError('If using globs, must supply `test_input_glob.`')
        input_filenames_list, output_filenames = [], []
        for input_glob, name in [(FLAGS.train_input_glob, 'train'),
                                 (FLAGS.validation_input_glob, 'validation'),
                                 (FLAGS.test_input_glob, 'test')]:
            FLAGS.input_glob = input_glob
            cur_inputs, cur_outputs, prep_params = utils.get_beam_params_from_flags(
            )
            if len(cur_outputs) != 1:
                raise ValueError(f'`cur_outputs` too long: {cur_outputs}')
            cur_outputs = f'{cur_outputs[0]}.{name}'

            input_filenames_list.extend(cur_inputs)
            output_filenames.append(cur_outputs)
    else:  # Get params from a TFDS dataset.
        if not FLAGS.tfds_dataset:
            raise ValueError(
                'Must supply TFDS dataset name if not globs provided.')
        input_filenames_list, output_filenames, prep_params = utils.get_beam_params_from_flags(
        )
    if len(output_filenames) != 3:
        raise ValueError(
            f'Data prep output must be 3 files: {output_filenames}')

    try:
        # Check that inputs and flags are formatted correctly.
        utils.validate_inputs(input_filenames_list, output_filenames,
                              prep_params['embedding_modules'],
                              prep_params['embedding_names'],
                              prep_params['module_output_keys'])
    except ValueError:
        if FLAGS.skip_existing_error:
            # Check if there are any files left after filtering. Return the expected
            # locations, though, and remove.
            _, output_filenames_filtered = _remove_existing_outputs(
                input_filenames_list, output_filenames)
            if not output_filenames_filtered:
                run_data_prep = False
        else:
            raise

    return prep_params, input_filenames_list, output_filenames, run_data_prep
Example #2
0
    def test_read_flags_and_create_pipeline(self, data_prep_behavior):
        """Test that the read-from-flags and pipeline creation are synced."""
        FLAGS.input_glob = os.path.join(absltest.get_default_test_srcdir(),
                                        TEST_DIR, '*')
        FLAGS.output_filename = os.path.join(
            absltest.get_default_test_tmpdir(),
            f'{data_prep_behavior}.tfrecord')
        FLAGS.data_prep_behavior = data_prep_behavior
        FLAGS.embedding_names = ['em1', 'em2']
        FLAGS.embedding_modules = ['dummy_mod_loc']
        FLAGS.module_output_keys = ['k1', 'k2']
        FLAGS.sample_rate = 5
        FLAGS.audio_key = 'audio_key'
        FLAGS.label_key = 'label_key'
        input_filenames_list, output_filenames, beam_params = audio_to_embeddings_beam_utils.get_beam_params_from_flags(
        )
        # Use the defaults, unless we are using TFLite models.
        self.assertNotIn('module_call_fn', beam_params)
        self.assertNotIn('setup_fn', beam_params)

        # Check that the arguments run through.
        audio_to_embeddings_beam_utils.data_prep_pipeline(
            root=beam.Pipeline(),
            input_filenames_or_glob=input_filenames_list[0],
            output_filename=output_filenames[0],
            data_prep_behavior=FLAGS.data_prep_behavior,
            beam_params=beam_params,
            suffix='s')
Example #3
0
def main(_):

    input_filenames_list, output_filenames, beam_params = utils.get_beam_params_from_flags(
    )
    # Check that inputs and flags are formatted correctly.
    utils.validate_inputs(input_filenames_list=input_filenames_list,
                          output_filenames=output_filenames,
                          embedding_modules=beam_params['embedding_modules'],
                          embedding_names=beam_params['embedding_names'],
                          module_output_keys=beam_params['module_output_keys'])
    logging.info('main: input_filenames_list: %s', input_filenames_list)
    logging.info('main: output_filenames: %s', output_filenames)
    logging.info('main: beam_params: %s', beam_params)

    # If you have custom beam options, add them here.
    beam_options = None

    logging.info('Starting to create flume pipeline...')
    with beam.Pipeline(beam_options) as root:
        for i, (input_filenames_or_glob, output_filename) in enumerate(
                zip(input_filenames_list, output_filenames)):
            utils.data_prep_pipeline(
                root=root,
                input_filenames_or_glob=input_filenames_or_glob,
                output_filename=output_filename,
                data_prep_behavior=FLAGS.data_prep_behavior,
                beam_params=beam_params,
                suffix=str(i))
def main(unused_argv):

  # Data prep setup.
  run_data_prep = True
  if FLAGS.train_input_glob:
    assert FLAGS.validation_input_glob
    assert FLAGS.test_input_glob
    input_filenames_list, output_filenames = [], []
    for input_glob in [
        FLAGS.train_input_glob, FLAGS.validation_input_glob,
        FLAGS.test_input_glob,
    ]:
      FLAGS.input_glob = input_glob
      cur_inputs, cur_outputs, beam_params = data_prep_utils.get_beam_params_from_flags(
      )
      input_filenames_list.extend(cur_inputs)
      output_filenames.extend(cur_outputs)
  else:
    input_filenames_list, output_filenames, beam_params = data_prep_utils.get_beam_params_from_flags(
    )
  assert input_filenames_list, input_filenames_list
  assert output_filenames, output_filenames
  try:
    # Check that inputs and flags are formatted correctly.
    data_prep_utils.validate_inputs(
        input_filenames_list, output_filenames,
        beam_params['embedding_modules'], beam_params['embedding_names'],
        beam_params['module_output_keys'])
  except ValueError:
    if FLAGS.skip_existing_error:
      run_data_prep = False
    else:
      raise
  logging.info('beam_params: %s', beam_params)

  # Generate sklearn eval experiment parameters based on data prep flags.
  if len(output_filenames) != 3:
    raise ValueError(f'Data prep output must be 3 files: {output_filenames}')
  # Make them globs.
  train_glob, eval_glob, test_glob = [f'{x}*' for x in output_filenames]
  sklearn_results_output_file = FLAGS.results_output_file
  exp_params = sklearn_utils.experiment_params(
      embedding_list=beam_params['embedding_names'],
      speaker_id_name=FLAGS.speaker_id_key,
      label_name=FLAGS.label_key,
      label_list=FLAGS.label_list,
      train_glob=train_glob,
      eval_glob=eval_glob,
      test_glob=test_glob,
      save_model_dir=None,
      save_predictions_dir=None,
      eval_metric=FLAGS.eval_metric,
  )
  logging.info('exp_params: %s', exp_params)

  # Make and run beam pipeline.
  beam_options = None

  if run_data_prep:
    logging.info('Data prep on: %s, %s...', input_filenames_list,
                 output_filenames)
    with beam.Pipeline(beam_options) as root:
      for i, (input_filenames_or_glob, output_filename) in enumerate(
          zip(input_filenames_list, output_filenames)):
        data_prep_utils.make_beam_pipeline(
            root,
            input_filenames=input_filenames_or_glob,
            output_filename=output_filename,
            suffix=str(i),
            **beam_params)

  # Check that previous beam pipeline wrote outputs.
  sklearn_utils.validate_flags(train_glob, eval_glob, test_glob,
                               sklearn_results_output_file)
  logging.info('Eval sklearn...')
  with beam.Pipeline(beam_options) as root:
    _ = (
        root
        | 'MakeCollection' >> beam.Create(exp_params)
        | 'CalcScores' >> beam.Map(
            lambda d: (d, sklearn_utils.train_and_get_score(**d)))
        | 'FormatText' >> beam.Map(sklearn_utils.format_text_line)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'WriteOutput' >> beam.io.WriteToText(
            sklearn_results_output_file, num_shards=1))