Beispiel #1
0
 def _get_distribution_strategy(self, use_ds=True):
     """Gets the distribution strategy."""
     if self.tpu:
         return distribution_utils.get_distribution_strategy(
             distribution_strategy='tpu', tpu_address=self.tpu)
     else:
         return distribution_utils.get_distribution_strategy(
             distribution_strategy='mirrored' if use_ds else 'off',
             num_gpus=self.num_gpus)
Beispiel #2
0
  def _run_bert_classifier(self, callbacks=None, use_ds=True):
    """Starts BERT classification task."""
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
      input_meta_data = json.loads(reader.read().decode('utf-8'))

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
    if self.num_steps_per_epoch:
      steps_per_epoch = self.num_steps_per_epoch
    else:
      train_data_size = input_meta_data['train_data_size']
      steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
    warmup_steps = int(epochs * steps_per_epoch * 0.1)
    eval_steps = int(
        math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
    if self.tpu:
      strategy = distribution_utils.get_distribution_strategy(
          distribution_strategy='tpu', tpu_address=self.tpu)
    else:
      strategy = distribution_utils.get_distribution_strategy(
          distribution_strategy='mirrored' if use_ds else 'off',
          num_gpus=self.num_gpus)

    steps_per_loop = 1

    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier.get_dataset_fn(
        FLAGS.train_data_path,
        max_seq_length,
        FLAGS.train_batch_size,
        is_training=True)
    eval_input_fn = run_classifier.get_dataset_fn(
        FLAGS.eval_data_path,
        max_seq_length,
        FLAGS.eval_batch_size,
        is_training=False)
    run_classifier.run_bert_classifier(
        strategy,
        bert_config,
        input_meta_data,
        FLAGS.model_dir,
        epochs,
        steps_per_epoch,
        steps_per_loop,
        eval_steps,
        warmup_steps,
        FLAGS.learning_rate,
        FLAGS.init_checkpoint,
        train_input_fn,
        eval_input_fn,
        custom_callbacks=callbacks)
Beispiel #3
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = get_dataset_fn(FLAGS.train_data_path,
                                    max_seq_length,
                                    FLAGS.train_batch_size,
                                    is_training=True)
    eval_input_fn = get_dataset_fn(FLAGS.eval_data_path,
                                   max_seq_length,
                                   FLAGS.eval_batch_size,
                                   is_training=False)

    run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn)
Beispiel #4
0
 def __init__(self, strategy_type=None, strategy_config=None):
     num_workers = distribution_utils.configure_cluster(
         strategy_config.worker_hosts, strategy_config.task_index)
     self._strategy = distribution_utils.get_distribution_strategy(
         distribution_strategy=strategy_type,
         num_gpus=strategy_config.num_gpus,
         num_workers=num_workers,
         all_reduce_alg=strategy_config.all_reduce_alg,
         num_packs=strategy_config.num_packs,
         tpu_address=strategy_config.tpu)
Beispiel #5
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'
    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    if strategy:
        print('***** Number of cores used : ', strategy.num_replicas_in_sync)

    run_bert_pretrain(strategy)
Beispiel #6
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if FLAGS.mode == 'export_only':
        export_squad(FLAGS.model_export_path, input_meta_data)
        return

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    if FLAGS.mode in ('train', 'train_and_predict'):
        train_squad(strategy, input_meta_data)
    if FLAGS.mode in ('predict', 'train_and_predict'):
        predict_squad(strategy, input_meta_data)
Beispiel #7
0
def run_mnist(flags_obj):
    """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    model_helpers.apply_clean(flags_obj)
    model_function = model_fn

    session_config = tf.compat.v1.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        all_reduce_alg=flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(model_fn=model_function,
                                              model_dir=flags_obj.model_dir,
                                              config=run_config,
                                              params={
                                                  'data_format': data_format,
                                              })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags_obj.data_dir).batch(
            flags_obj.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    # Train and evaluate model.
    for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
        mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    # Export the model
    if flags_obj.export_dir is not None:
        image = tf.compat.v1.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image':
            image,
        })
        mnist_classifier.export_savedmodel(flags_obj.export_dir,
                                           input_fn,
                                           strip_default_attrs=True)