コード例 #1
0
def run(config_map,
        tf_file_reader_class=tf.data.TFRecordDataset,
        file_reader=tf.python_io.tf_record_iterator):
    """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.
    tf_file_reader_class: The tf.data.Dataset class to use for reading files.
    file_reader: The Python reader to use for reading files.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
    if not FLAGS.run_dir:
        raise ValueError('Invalid run directory: %s' % FLAGS.run_dir)
    run_dir = os.path.expanduser(FLAGS.run_dir)
    train_dir = os.path.join(run_dir, 'train')

    if FLAGS.mode not in ['train', 'eval']:
        raise ValueError('Invalid mode: %s' % FLAGS.mode)

    if FLAGS.config not in config_map:
        raise ValueError('Invalid config: %s' % FLAGS.config)
    config = config_map[FLAGS.config]
    if FLAGS.hparams:
        config.hparams.parse(FLAGS.hparams)
    config_update_map = {}
    if FLAGS.examples_path:
        config_update_map['%s_examples_path' %
                          FLAGS.mode] = FLAGS.examples_path
    config = configs.update_config(config, config_update_map)
    if FLAGS.num_sync_workers:
        config.hparams.batch_size //= FLAGS.num_sync_workers

    dataset = data.get_dataset(config,
                               tf_file_reader_class=tf_file_reader_class,
                               num_threads=FLAGS.num_data_threads,
                               is_training=True)

    if FLAGS.mode == 'eval':
        num_batches = FLAGS.eval_num_batches or data.count_examples(
            config.eval_examples_path, config.note_sequence_converter,
            file_reader) // config.hparams.batch_size
        eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix)
        evaluate(train_dir,
                 eval_dir,
                 config=config,
                 dataset=dataset,
                 num_batches=num_batches,
                 master=FLAGS.master)
    elif FLAGS.mode == 'train':
        train(train_dir,
              config=config,
              dataset=dataset,
              checkpoints_to_keep=FLAGS.checkpoints_to_keep,
              num_steps=FLAGS.num_steps,
              master=FLAGS.master,
              num_sync_workers=FLAGS.num_sync_workers,
              num_ps_tasks=FLAGS.num_ps_tasks,
              task=FLAGS.task)
コード例 #2
0
ファイル: music_vae_train.py プロジェクト: vanton/magenta
def run(config_map,
        tf_file_reader=tf.data.TFRecordDataset,
        file_reader=tf.python_io.tf_record_iterator):
  """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.
    tf_file_reader: The tf.data.Dataset class to use for reading files.
    file_reader: The Python reader to use for reading files.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
  if not FLAGS.run_dir:
    raise ValueError('Invalid run directory: %s' % FLAGS.run_dir)
  run_dir = os.path.expanduser(FLAGS.run_dir)
  train_dir = os.path.join(run_dir, 'train')

  if FLAGS.mode not in ['train', 'eval']:
    raise ValueError('Invalid mode: %s' % FLAGS.mode)

  if FLAGS.config not in config_map:
    raise ValueError('Invalid config: %s' % FLAGS.config)
  config = config_map[FLAGS.config]
  if FLAGS.hparams:
    config.hparams.parse(FLAGS.hparams)
  config_update_map = {}
  if FLAGS.examples_path:
    config_update_map['%s_examples_path' % FLAGS.mode] = os.path.expanduser(
        FLAGS.examples_path)
  if FLAGS.tfds_name:
    if FLAGS.examples_path:
      raise ValueError(
          'At most one of --examples_path and --tfds_name can be set.')
    config_update_map['tfds_name'] = FLAGS.tfds_name
    config_update_map['eval_examples_path'] = None
    config_update_map['train_examples_path'] = None
  config = configs.update_config(config, config_update_map)
  if FLAGS.num_sync_workers:
    config.hparams.batch_size //= FLAGS.num_sync_workers

  if FLAGS.mode == 'train':
    is_training = True
  elif FLAGS.mode == 'eval':
    is_training = False
  else:
    raise ValueError('Invalid mode: {}'.format(FLAGS.mode))

  def dataset_fn():
    return data.get_dataset(
        config,
        tf_file_reader=tf_file_reader,
        is_training=is_training,
        cache_dataset=FLAGS.cache_dataset)

  if is_training:
    train(
        train_dir,
        config=config,
        dataset_fn=dataset_fn,
        checkpoints_to_keep=FLAGS.checkpoints_to_keep,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        num_steps=FLAGS.num_steps,
        master=FLAGS.master,
        num_sync_workers=FLAGS.num_sync_workers,
        num_ps_tasks=FLAGS.num_ps_tasks,
        task=FLAGS.task)
  else:
    num_batches = FLAGS.eval_num_batches or data.count_examples(
        config.eval_examples_path,
        config.tfds_name,
        config.data_converter,
        file_reader) // config.hparams.batch_size
    eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix)
    evaluate(
        train_dir,
        eval_dir,
        config=config,
        dataset_fn=dataset_fn,
        num_batches=num_batches,
        master=FLAGS.master)
コード例 #3
0
def run(config_map,
        tf_file_reader=tf.data.TFRecordDataset,
        file_reader=tf.python_io.tf_record_iterator):
  """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.
    tf_file_reader: The tf.data.Dataset class to use for reading files.
    file_reader: The Python reader to use for reading files.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
  if not FLAGS.run_dir:
    raise ValueError('Invalid run directory: %s' % FLAGS.run_dir)
  run_dir = os.path.expanduser(FLAGS.run_dir)
  train_dir = os.path.join(run_dir, 'train')

  if FLAGS.mode not in ['train', 'eval']:
    raise ValueError('Invalid mode: %s' % FLAGS.mode)

  if FLAGS.config not in config_map:
    raise ValueError('Invalid config: %s' % FLAGS.config)
  config = config_map[FLAGS.config]
  if FLAGS.hparams:
    config.hparams.parse(FLAGS.hparams)
  config_update_map = {}
  if FLAGS.examples_path:
    config_update_map['%s_examples_path' % FLAGS.mode] = os.path.expanduser(
        FLAGS.examples_path)
  config = configs.update_config(config, config_update_map)
  if FLAGS.num_sync_workers:
    config.hparams.batch_size //= FLAGS.num_sync_workers

  if FLAGS.mode == 'train':
    is_training = True
  elif FLAGS.mode == 'eval':
    is_training = False
  else:
    raise ValueError('Invalid mode: {}'.format(FLAGS.mode))

  dataset = data.get_dataset(
      config,
      tf_file_reader=tf_file_reader,
      num_threads=FLAGS.num_data_threads,
      prefetch_size=FLAGS.prefetch_size,
      is_training=is_training)

  if is_training:
    train(
        train_dir,
        config=config,
        dataset=dataset,
        checkpoints_to_keep=FLAGS.checkpoints_to_keep,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        num_steps=FLAGS.num_steps,
        master=FLAGS.master,
        num_sync_workers=FLAGS.num_sync_workers,
        num_ps_tasks=FLAGS.num_ps_tasks,
        task=FLAGS.task)
  else:
    num_batches = FLAGS.eval_num_batches or data.count_examples(
        config.eval_examples_path,
        config.data_converter,
        file_reader) // config.hparams.batch_size
    eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix)
    evaluate(
        train_dir,
        eval_dir,
        config=config,
        dataset=dataset,
        num_batches=num_batches,
        master=FLAGS.master)