def run(config_map, tf_file_reader_class=tf.data.TFRecordDataset, file_reader=tf.python_io.tf_record_iterator): """Load model params, save config file and start trainer. Args: config_map: Dictionary mapping configuration name to Config object. tf_file_reader_class: The tf.data.Dataset class to use for reading files. file_reader: The Python reader to use for reading files. Raises: ValueError: if required flags are missing or invalid. """ if not FLAGS.run_dir: raise ValueError('Invalid run directory: %s' % FLAGS.run_dir) run_dir = os.path.expanduser(FLAGS.run_dir) train_dir = os.path.join(run_dir, 'train') if FLAGS.mode not in ['train', 'eval']: raise ValueError('Invalid mode: %s' % FLAGS.mode) if FLAGS.config not in config_map: raise ValueError('Invalid config: %s' % FLAGS.config) config = config_map[FLAGS.config] if FLAGS.hparams: config.hparams.parse(FLAGS.hparams) config_update_map = {} if FLAGS.examples_path: config_update_map['%s_examples_path' % FLAGS.mode] = FLAGS.examples_path config = configs.update_config(config, config_update_map) if FLAGS.num_sync_workers: config.hparams.batch_size //= FLAGS.num_sync_workers dataset = data.get_dataset(config, tf_file_reader_class=tf_file_reader_class, num_threads=FLAGS.num_data_threads, is_training=True) if FLAGS.mode == 'eval': num_batches = FLAGS.eval_num_batches or data.count_examples( config.eval_examples_path, config.note_sequence_converter, file_reader) // config.hparams.batch_size eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix) evaluate(train_dir, eval_dir, config=config, dataset=dataset, num_batches=num_batches, master=FLAGS.master) elif FLAGS.mode == 'train': train(train_dir, config=config, dataset=dataset, checkpoints_to_keep=FLAGS.checkpoints_to_keep, num_steps=FLAGS.num_steps, master=FLAGS.master, num_sync_workers=FLAGS.num_sync_workers, num_ps_tasks=FLAGS.num_ps_tasks, task=FLAGS.task)
def run(config_map, tf_file_reader=tf.data.TFRecordDataset, file_reader=tf.python_io.tf_record_iterator): """Load model params, save config file and start trainer. Args: config_map: Dictionary mapping configuration name to Config object. tf_file_reader: The tf.data.Dataset class to use for reading files. file_reader: The Python reader to use for reading files. Raises: ValueError: if required flags are missing or invalid. """ if not FLAGS.run_dir: raise ValueError('Invalid run directory: %s' % FLAGS.run_dir) run_dir = os.path.expanduser(FLAGS.run_dir) train_dir = os.path.join(run_dir, 'train') if FLAGS.mode not in ['train', 'eval']: raise ValueError('Invalid mode: %s' % FLAGS.mode) if FLAGS.config not in config_map: raise ValueError('Invalid config: %s' % FLAGS.config) config = config_map[FLAGS.config] if FLAGS.hparams: config.hparams.parse(FLAGS.hparams) config_update_map = {} if FLAGS.examples_path: config_update_map['%s_examples_path' % FLAGS.mode] = os.path.expanduser( FLAGS.examples_path) if FLAGS.tfds_name: if FLAGS.examples_path: raise ValueError( 'At most one of --examples_path and --tfds_name can be set.') config_update_map['tfds_name'] = FLAGS.tfds_name config_update_map['eval_examples_path'] = None config_update_map['train_examples_path'] = None config = configs.update_config(config, config_update_map) if FLAGS.num_sync_workers: config.hparams.batch_size //= FLAGS.num_sync_workers if FLAGS.mode == 'train': is_training = True elif FLAGS.mode == 'eval': is_training = False else: raise ValueError('Invalid mode: {}'.format(FLAGS.mode)) def dataset_fn(): return data.get_dataset( config, tf_file_reader=tf_file_reader, is_training=is_training, cache_dataset=FLAGS.cache_dataset) if is_training: train( train_dir, config=config, dataset_fn=dataset_fn, checkpoints_to_keep=FLAGS.checkpoints_to_keep, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, num_steps=FLAGS.num_steps, master=FLAGS.master, num_sync_workers=FLAGS.num_sync_workers, num_ps_tasks=FLAGS.num_ps_tasks, task=FLAGS.task) else: num_batches = FLAGS.eval_num_batches or data.count_examples( config.eval_examples_path, config.tfds_name, config.data_converter, file_reader) // config.hparams.batch_size eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix) evaluate( train_dir, eval_dir, config=config, dataset_fn=dataset_fn, num_batches=num_batches, master=FLAGS.master)
def run(config_map, tf_file_reader=tf.data.TFRecordDataset, file_reader=tf.python_io.tf_record_iterator): """Load model params, save config file and start trainer. Args: config_map: Dictionary mapping configuration name to Config object. tf_file_reader: The tf.data.Dataset class to use for reading files. file_reader: The Python reader to use for reading files. Raises: ValueError: if required flags are missing or invalid. """ if not FLAGS.run_dir: raise ValueError('Invalid run directory: %s' % FLAGS.run_dir) run_dir = os.path.expanduser(FLAGS.run_dir) train_dir = os.path.join(run_dir, 'train') if FLAGS.mode not in ['train', 'eval']: raise ValueError('Invalid mode: %s' % FLAGS.mode) if FLAGS.config not in config_map: raise ValueError('Invalid config: %s' % FLAGS.config) config = config_map[FLAGS.config] if FLAGS.hparams: config.hparams.parse(FLAGS.hparams) config_update_map = {} if FLAGS.examples_path: config_update_map['%s_examples_path' % FLAGS.mode] = os.path.expanduser( FLAGS.examples_path) config = configs.update_config(config, config_update_map) if FLAGS.num_sync_workers: config.hparams.batch_size //= FLAGS.num_sync_workers if FLAGS.mode == 'train': is_training = True elif FLAGS.mode == 'eval': is_training = False else: raise ValueError('Invalid mode: {}'.format(FLAGS.mode)) dataset = data.get_dataset( config, tf_file_reader=tf_file_reader, num_threads=FLAGS.num_data_threads, prefetch_size=FLAGS.prefetch_size, is_training=is_training) if is_training: train( train_dir, config=config, dataset=dataset, checkpoints_to_keep=FLAGS.checkpoints_to_keep, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, num_steps=FLAGS.num_steps, master=FLAGS.master, num_sync_workers=FLAGS.num_sync_workers, num_ps_tasks=FLAGS.num_ps_tasks, task=FLAGS.task) else: num_batches = FLAGS.eval_num_batches or data.count_examples( config.eval_examples_path, config.data_converter, file_reader) // config.hparams.batch_size eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix) evaluate( train_dir, eval_dir, config=config, dataset=dataset, num_batches=num_batches, master=FLAGS.master)