def main(_): tf.logging.set_verbosity(tf.logging.INFO) processors = { "dstc2_clean": Dstc2Processor, "woz2": Woz2Processor, "sim-m": SimMProcessor, "sim-r": SimRProcessor, } tokenization.validate_case_matches_checkpoint( do_lower_case=True, init_checkpoint=FLAGS.init_checkpoint) if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: raise ValueError( "At least one of `do_train`, `do_eval` or `do_predict' must be True." ) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.gfile.MakeDirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() slot_list = processor.slot_list class_types = processor.class_types num_class_labels = len(class_types) if task_name in ['woz2', 'dstc2_clean']: num_class_labels -= 1 tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=None, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) train_examples = None num_train_steps = None num_warmup_steps = None if FLAGS.do_train: train_examples = processor.get_train_examples(FLAGS.data_dir) num_train_steps = int( len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) model_fn = model_fn_builder(bert_config=bert_config, slot_list=slot_list, num_class_labels=num_class_labels, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_tpu) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.contrib.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.predict_batch_size) if FLAGS.do_train: train_file = os.path.join(FLAGS.output_dir, "train.tf_record") file_based_convert_examples_to_features(train_examples, slot_list, class_types, FLAGS.max_seq_length, tokenizer, train_file) tf.logging.info("***** Running training *****") tf.logging.info(" Num examples = %d", len(train_examples)) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) train_input_fn = file_based_input_fn_builder( input_file=train_file, seq_length=FLAGS.max_seq_length, is_training=True, drop_remainder=True, slot_list=slot_list) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) if FLAGS.do_eval: if FLAGS.eval_set == 'dev': eval_examples = processor.get_dev_examples(FLAGS.data_dir) else: eval_examples = processor.get_test_examples(FLAGS.data_dir) num_actual_eval_examples = len(eval_examples) if FLAGS.use_tpu: # TPU requires a fixed batch size for all batches, therefore the number # of examples must be a multiple of the batch size, or else examples # will get dropped. So we pad with fake examples which are ignored # later on. These do NOT count towards the metric (all tf.metrics # support a per-instance weight, and these get a weight of 0.0). while len(eval_examples) % FLAGS.eval_batch_size != 0: eval_examples.append(run_classifier.PaddingInputExample()) eval_file = os.path.join(FLAGS.output_dir, "eval.%s.tf_record" % FLAGS.eval_set) file_based_convert_examples_to_features(eval_examples, slot_list, class_types, FLAGS.max_seq_length, tokenizer, eval_file) tf.logging.info("***** Running evaluation *****") tf.logging.info(" Num examples = %d (%d actual, %d padding)", len(eval_examples), num_actual_eval_examples, len(eval_examples) - num_actual_eval_examples) tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) # This tells the estimator to run through the entire set. eval_steps = None # However, if running eval on the TPU, you will need to specify the # number of steps. if FLAGS.use_tpu: assert len(eval_examples) % FLAGS.eval_batch_size == 0 eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) eval_drop_remainder = True if FLAGS.use_tpu else False eval_input_fn = file_based_input_fn_builder( input_file=eval_file, seq_length=FLAGS.max_seq_length, is_training=False, drop_remainder=eval_drop_remainder, slot_list=slot_list) output_eval_file = os.path.join(FLAGS.output_dir, "eval_res.%s.json" % FLAGS.eval_set) if tf.gfile.Exists(output_eval_file): with tf.gfile.GFile(output_eval_file) as f: eval_result = json.load(f) else: eval_result = [] ckpt_nums = [ num.strip() for num in FLAGS.eval_ckpt.split(',') if num.strip() != "" ] for ckpt_num in ckpt_nums: result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=os.path.join( FLAGS.output_dir, "model.ckpt-%s" % ckpt_num)) result_dict = {k: float(v) for k, v in result.items()} eval_result.append(result_dict) tf.logging.info("***** Eval results for %s set *****", FLAGS.eval_set) for key in sorted(result.keys()): tf.logging.info("%s = %s", key, str(result[key])) if len(eval_result) > 0: with tf.gfile.GFile(output_eval_file, "w") as f: json.dump(eval_result, f, indent=2) if FLAGS.do_predict: if FLAGS.eval_set == 'dev': predict_examples = processor.get_dev_examples(FLAGS.data_dir) else: predict_examples = processor.get_test_examples(FLAGS.data_dir) num_actual_predict_examples = len(predict_examples) if FLAGS.use_tpu: # TPU requires a fixed batch size for all batches, therefore the number # of examples must be a multiple of the batch size, or else examples # will get dropped. So we pad with fake examples which are ignored # later on. while len(predict_examples) % FLAGS.predict_batch_size != 0: predict_examples.append(run_classifier.PaddingInputExample()) predict_file = os.path.join(FLAGS.output_dir, "pred.%s.tf_record" % FLAGS.eval_set) file_based_convert_examples_to_features(predict_examples, slot_list, class_types, FLAGS.max_seq_length, tokenizer, predict_file) tf.logging.info("***** Running prediction *****") tf.logging.info(" Num examples = %d (%d actual, %d padding)", len(predict_examples), num_actual_predict_examples, len(predict_examples) - num_actual_predict_examples) tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) predict_drop_remainder = True if FLAGS.use_tpu else False predict_input_fn = file_based_input_fn_builder( input_file=predict_file, seq_length=FLAGS.max_seq_length, is_training=False, drop_remainder=predict_drop_remainder, slot_list=slot_list) ckpt_nums = [num for num in FLAGS.eval_ckpt.split(',') if num != ""] for ckpt_num in ckpt_nums: result = estimator.predict(input_fn=predict_input_fn, checkpoint_path=os.path.join( FLAGS.output_dir, "model.ckpt-%s" % ckpt_num)) output_predict_file = os.path.join( FLAGS.output_dir, "pred_res.%s.%08d.json" % (FLAGS.eval_set, int(ckpt_num))) with tf.gfile.GFile(output_predict_file, "w") as f: num_written_ex = 0 tf.logging.info("***** Predict results for %s set *****", FLAGS.eval_set) list_prediction = [] for (i, prediction) in enumerate(result): # Str feature is encoded as bytes, which is not JSON serializable. # Hence convert to str. prediction["guid"] = prediction["guid"].decode( "utf-8").split("-") for slot in slot_list: start_pd = prediction['start_prediction_%s' % slot] start_gt = prediction['start_pos_%s' % slot] end_pd = prediction['start_prediction_%s' % slot] end_gt = prediction['end_pos_%s' % slot] # TF uses int64, which is not JSON serializable. # Hence convert to int. prediction['class_prediction_%s' % slot] = int( prediction['class_prediction_%s' % slot]) prediction['class_label_id_%s' % slot] = int( prediction['class_label_id_%s' % slot]) prediction['start_prediction_%s' % slot] = int(start_pd) prediction['start_pos_%s' % slot] = int(start_gt) prediction['end_prediction_%s' % slot] = int(end_pd) prediction['end_pos_%s' % slot] = int(end_gt) prediction["input_ids_%s" % slot] = list( map(int, prediction["input_ids_%s" % slot].tolist())) input_tokens = tokenizer.convert_ids_to_tokens( prediction["input_ids_%s" % slot]) prediction["slot_prediction_%s" % slot] = ' '.join( input_tokens[start_pd:end_pd + 1]) prediction["slot_groundtruth_%s" % slot] = ' '.join( input_tokens[start_gt:end_gt + 1]) list_prediction.append(prediction) if i >= num_actual_predict_examples: break num_written_ex += 1 json.dump(list_prediction, f, indent=2) assert num_written_ex == num_actual_predict_examples
("How can i recover old gmail account ?", "How can i delete my old gmail account ?"), ("How can i recover old gmail account ?", "How can i access my old gmail account ?")] # In[ ]: print("******* Predictions on Custom Data ********") # create `InputExample` for custom examples predict_examples = processor.get_predict_examples(sent_pairs) num_predict_examples = len(predict_examples) # For TPU, We will append `PaddingExample` for maintaining batch size if USE_TPU: while (len(predict_examples) % EVAL_BATCH_SIZE != 0): predict_examples.append(run_classifier.PaddingInputExample()) # Converting to features predict_features = run_classifier.convert_examples_to_features( predict_examples, label_list, MAX_SEQ_LENGTH, tokenizer) print(' Num examples = {}'.format(num_predict_examples)) print(' Batch size = {}'.format(PREDICT_BATCH_SIZE)) # Input function for prediction predict_input_fn = run_classifier.input_fn_builder(predict_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False) result = list(estimator.predict(input_fn=predict_input_fn)) print(result)