def model_fn(features, labels, mode, params): #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) #### Get loss from inputs outputs = function_builder.get_qa_outputs(FLAGS, features, is_training) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('#params: {}'.format(num_params)) scaffold_fn = None #### Evaluation mode if mode == tf.estimator.ModeKeys.PREDICT: if FLAGS.init_checkpoint: tf.logging.info("init_checkpoint not being used in predict mode.") predictions = { "unique_ids": features["unique_ids"], "start_top_index": outputs["start_top_index"], "start_top_log_probs": outputs["start_top_log_probs"], "end_top_index": outputs["end_top_index"], "end_top_log_probs": outputs["end_top_log_probs"], "cls_logits": outputs["cls_logits"] } if FLAGS.use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: output_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=predictions) return output_spec ### Compute loss seq_length = tf.shape(features["input_ids"])[1] def compute_loss(log_probs, positions): one_hot_positions = tf.one_hot( positions, depth=seq_length, dtype=tf.float32) loss = - tf.reduce_sum(one_hot_positions * log_probs, axis=-1) loss = tf.reduce_mean(loss) return loss start_loss = compute_loss( outputs["start_log_probs"], features["start_positions"]) end_loss = compute_loss( outputs["end_log_probs"], features["end_positions"]) total_loss = (start_loss + end_loss) * 0.5 cls_logits = outputs["cls_logits"] is_impossible = tf.reshape(features["is_impossible"], [-1]) regression_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=is_impossible, logits=cls_logits) regression_loss = tf.reduce_mean(regression_loss) # note(zhiliny): by default multiply the loss by 0.5 so that the scale is # comparable to start_loss and end_loss total_loss += regression_loss * 0.5 #### Configuring the optimizer train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss) monitor_dict = {} monitor_dict["lr"] = learning_rate #### load pretrained models scaffold_fn = model_utils.init_from_checkpoint(FLAGS) #### Constucting training TPUEstimatorSpec with new cache. if FLAGS.use_tpu: host_call = function_builder.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean) train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) else: train_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op) return train_spec
def main(_): model_name = 'xlnet' kwargs = dict(training=False, logits=True) batch_size = FLAGS.batch_size max_seq_length = FLAGS.max_seq_length run_metadata = tf.RunMetadata() with tf.Session() as sess: if FLAGS.decompose: logger.info('running in decompose mode') max_first_length = FLAGS.max_first_length kwargs['fake_cache_first'] = FLAGS.cache_segment == 1 kwargs['fake_cache_second'] = FLAGS.cache_segment == 2 outputs = get_decomposed_qa_outputs(FLAGS, features, False) else: logger.info('running in normal mode') outputs = get_qa_outputs(FLAGS, features, False) inputs_dict, logits_ph = model.core_graph(config, **kwargs) sess.run(tf.global_variables_initializer()) opt_builder = tf.profiler.ProfileOptionBuilder # saver = tf.train.Saver() # saver.save(sess, 'data/sbert', write_meta_graph=False) if FLAGS.print_parameters: tf.profiler.profile( sess.graph, options=opt_builder.trainable_variables_parameter()) if not FLAGS.not_profile_flops: prof_options = opt_builder.float_operation() prof_options['hide_name_regexes'] = ['.*/Initializer/.*'] tfprof_node = tf.profiler.profile(sess.graph, options=prof_options) profile_metric(model_name, tfprof_node, metric='total_float_ops', metric_name='flops') if FLAGS.profile_memory: options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = run_metadata else: options = None run_metadata = None _ = sess.run([logits_ph], feed_dict=inputs_dict, options=options, run_metadata=run_metadata) if FLAGS.profile_memory: opts = opt_builder( opt_builder.time_and_memory()).build() tfprof_node = tf.profiler.profile( tf.get_default_graph(), run_meta=run_metadata, cmd='scope', options=opts) profile_metric(model_name, tfprof_node, metric='total_requested_bytes', metric_name='mem') if FLAGS.profile_time: # warm up two rounds logger.info("warm up for two rounds...") for _ in range(2): sess.run([logits_ph], feed_dict=inputs_dict, ) logger.info("start running 10 rounds...") start_time = time.time() # bench 10 rounds, take avg for _ in range(10): sess.run([logits_ph], feed_dict=inputs_dict, ) end_time = time.time() print('infer_time: {:.4f} s'.format((end_time - start_time) / 10))