Esempio n. 1
0
  def model_fn(features, labels, mode, params):
    #### Training or Evaluation
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    #### Get loss from inputs
    outputs = function_builder.get_qa_outputs(FLAGS, features, is_training)

    #### Check model parameters
    num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('#params: {}'.format(num_params))

    scaffold_fn = None

    #### Evaluation mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      if FLAGS.init_checkpoint:
        tf.logging.info("init_checkpoint not being used in predict mode.")

      predictions = {
          "unique_ids": features["unique_ids"],
          "start_top_index": outputs["start_top_index"],
          "start_top_log_probs": outputs["start_top_log_probs"],
          "end_top_index": outputs["end_top_index"],
          "end_top_log_probs": outputs["end_top_log_probs"],
          "cls_logits": outputs["cls_logits"]
      }

      if FLAGS.use_tpu:
        output_spec = tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
      else:
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode, predictions=predictions)
      return output_spec

    ### Compute loss
    seq_length = tf.shape(features["input_ids"])[1]
    def compute_loss(log_probs, positions):
      one_hot_positions = tf.one_hot(
          positions, depth=seq_length, dtype=tf.float32)

      loss = - tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
      loss = tf.reduce_mean(loss)
      return loss

    start_loss = compute_loss(
        outputs["start_log_probs"], features["start_positions"])
    end_loss = compute_loss(
        outputs["end_log_probs"], features["end_positions"])

    total_loss = (start_loss + end_loss) * 0.5

    cls_logits = outputs["cls_logits"]
    is_impossible = tf.reshape(features["is_impossible"], [-1])
    regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=is_impossible, logits=cls_logits)
    regression_loss = tf.reduce_mean(regression_loss)

    # note(zhiliny): by default multiply the loss by 0.5 so that the scale is
    # comparable to start_loss and end_loss
    total_loss += regression_loss * 0.5

    #### Configuring the optimizer
    train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss)

    monitor_dict = {}
    monitor_dict["lr"] = learning_rate

    #### load pretrained models
    scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

    #### Constucting training TPUEstimatorSpec with new cache.
    if FLAGS.use_tpu:
      host_call = function_builder.construct_scalar_host_call(
          monitor_dict=monitor_dict,
          model_dir=FLAGS.model_dir,
          prefix="train/",
          reduce_fn=tf.reduce_mean)

      train_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op, host_call=host_call,
          scaffold_fn=scaffold_fn)
    else:
      train_spec = tf.estimator.EstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op)

    return train_spec
Esempio n. 2
0
def main(_):

    model_name = 'xlnet'
    kwargs = dict(training=False, logits=True)
    batch_size = FLAGS.batch_size
    max_seq_length = FLAGS.max_seq_length

    run_metadata = tf.RunMetadata()
    with tf.Session() as sess:
        if FLAGS.decompose:
            logger.info('running in decompose mode')
            max_first_length = FLAGS.max_first_length
            kwargs['fake_cache_first'] = FLAGS.cache_segment == 1
            kwargs['fake_cache_second'] = FLAGS.cache_segment == 2
            outputs = get_decomposed_qa_outputs(FLAGS, features, False)
        else:
            logger.info('running in normal mode')
            outputs = get_qa_outputs(FLAGS, features, False)

        inputs_dict, logits_ph = model.core_graph(config, **kwargs)

        sess.run(tf.global_variables_initializer())
        opt_builder = tf.profiler.ProfileOptionBuilder
        # saver = tf.train.Saver()
        # saver.save(sess, 'data/sbert', write_meta_graph=False)
        if FLAGS.print_parameters:
            tf.profiler.profile(
                sess.graph,
                options=opt_builder.trainable_variables_parameter())

        if not FLAGS.not_profile_flops:
            prof_options = opt_builder.float_operation()
            prof_options['hide_name_regexes'] = ['.*/Initializer/.*']
            tfprof_node = tf.profiler.profile(sess.graph, options=prof_options)
            profile_metric(model_name, tfprof_node, metric='total_float_ops',
                           metric_name='flops')

        if FLAGS.profile_memory:
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = run_metadata
        else:
            options = None
            run_metadata = None
        _ = sess.run([logits_ph], feed_dict=inputs_dict,
                     options=options,
                     run_metadata=run_metadata)

        if FLAGS.profile_memory:
            opts = opt_builder(
                opt_builder.time_and_memory()).build()

            tfprof_node = tf.profiler.profile(
                tf.get_default_graph(),
                run_meta=run_metadata,
                cmd='scope',
                options=opts)

            profile_metric(model_name, tfprof_node,
                           metric='total_requested_bytes', metric_name='mem')

        if FLAGS.profile_time:
            # warm up two rounds
            logger.info("warm up for two rounds...")

            for _ in range(2):
                sess.run([logits_ph], feed_dict=inputs_dict, )

            logger.info("start running 10 rounds...")
            start_time = time.time()
            # bench 10 rounds, take avg
            for _ in range(10):
                sess.run([logits_ph], feed_dict=inputs_dict, )
            end_time = time.time()
            print('infer_time: {:.4f} s'.format((end_time - start_time) / 10))