def decode_Beam(FLAGS): # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. #if FLAGS.mode == 'decode': # FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode #if FLAGS.single_pass and FLAGS.mode != 'decode': # raise Exception("The single_pass flag should only be True in decode mode") vocab_in, vocab_out = data.load_dict_data(FLAGS) FLAGS_batcher = config.retype_FLAGS() FLAGS_decode = FLAGS_batcher._asdict() FLAGS_decode["max_dec_steps"] = 1 FLAGS_decode["mode"] = "decode" FLAGS_decode = config.generate_nametuple(FLAGS_decode) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries batcher = Batcher(FLAGS.data_path, vocab_in,vocab_out, FLAGS_batcher, data_file=FLAGS.test_name) model = SummarizationModel(FLAGS_decode, vocab_in,vocab_out,batcher) decoder = BeamSearchDecoder(model, batcher, vocab_out) decoder.decode()
def create_train_eval_model(FLAGS): Classify_model = model_pools["tagging_model"] bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.gfile.MakeDirs(FLAGS.output_dir) # load custom processer from task name task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() train_batcher = Batcher(processor, FLAGS) # create trainning model Bert_model = Classify_model(bert_config, train_batcher, FLAGS) Bert_model.build_graph() Bert_model.create_or_load_recent_model() FLAGS_eval = FLAGS._asdict() FLAGS_eval["mode"] = "dev" FLAGS_eval = config.generate_nametuple(FLAGS_eval) validate_batcher =Batcher(processor, FLAGS_eval) validate_model = Classify_model(bert_config, validate_batcher, FLAGS_eval) validate_model.build_graph() validate_model.create_or_load_recent_model() return Bert_model,validate_model
def create_training_model(FLAGS,vocab_in, vocab_out = None): batcher_train = Batcher(FLAGS.data_path, vocab_in,vocab_out, FLAGS, data_file=FLAGS.train_name) train_model = SummarizationModel(FLAGS, vocab_in,vocab_out,batcher_train) logging.info("Building graph...") train_model.build_graph() # Create dev model # I can't deepCopy tf.flags, so I change flags into nametuple. # Find another way in the future FLAGS_eval = FLAGS._asdict() FLAGS_eval["mode"] = "eval" FLAGS_eval = config.generate_nametuple(FLAGS_eval) #variable_scope.get_variable_scope().reuse_variables() batcher_dev = Batcher(FLAGS.data_path, vocab_in,vocab_out, FLAGS, data_file=FLAGS.dev_name) dev_model = SummarizationModel(FLAGS_eval, vocab_in,vocab_out,batcher_dev) dev_model.build_graph() train_model.create_or_load_recent_model() return train_model,dev_model
def create_decode_model(FLAGS, vocab_in,vocab_out): batcher = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.qq_name) import eval FLAGS_decode = config.retype_FLAGS()._asdict() FLAGS_decode["max_dec_steps"] = 1 FLAGS_decode["mode"] = "decode" FLAGS_decode = config.generate_nametuple(FLAGS_decode) model = SummarizationModel(FLAGS_decode, vocab_in, vocab_out, batcher) #model.graph.as_default() decoder = eval.EvalDecoder(model, batcher, vocab_out) return decoder
def main(_): FLAGS = config.retype_FLAGS() Classify_model = model_pools["tagging_classify_model"] bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) task_name1 = "ner" task_name2 = "qicm" FLAGS_tagging = FLAGS._asdict() FLAGS_tagging["train_batch_size"] = int(FLAGS_tagging["train_batch_size"]/2) FLAGS_tagging["mode"] = "dev" FLAGS_tagging = config.generate_nametuple(FLAGS_tagging) FLAGS_classify = FLAGS._asdict() FLAGS_classify["train_batch_size"] = int(FLAGS_classify["train_batch_size"]/2) FLAGS_classify["mode"] = "dev" #pad to equal FLAGS_classify["train_batch_size"]+= FLAGS.train_batch_size - FLAGS_tagging.train_batch_size - FLAGS_classify["train_batch_size"] FLAGS_classify["train_file"] = FLAGS_classify["train_file_multi"] FLAGS_classify["dev_file"] = FLAGS_classify["dev_file_multi"] FLAGS_classify["test_file"] = FLAGS_classify["test_file_multi"] FLAGS_classify = config.generate_nametuple(FLAGS_classify) processor_tagging = processors[task_name1]() processor_classify = processors[task_name2]() tagging_batcher = Batcher(processor_tagging, FLAGS_tagging) classify_batcher = Batcher(processor_classify, FLAGS_classify) # create trainning model Bert_model = Classify_model(bert_config, tagging_batcher,classify_batcher, FLAGS) for step in range(0, Bert_model.num_train_steps): tagging_batch = Bert_model.tagging_batcher.next_batch() classify_batch = Bert_model.classify_batcher.next_batch() batch = Bert_model.classify_batcher.merge_multi_task(tagging_batch, classify_batch) results = Bert_model.run_dev_step(batch)
def decode_multi(FLAGS): vocab_in, vocab_out = data.load_dict_data(FLAGS) batcher = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.test_name,shuffle=False) import eval FLAGS_decode = config.retype_FLAGS()._asdict() FLAGS_decode["max_dec_steps"] = 1 FLAGS_decode = config.generate_nametuple(FLAGS_decode) model = SummarizationModel(FLAGS_decode, vocab_in, vocab_out, batcher) decoder = eval.EvalDecoder(model, batcher, vocab_out) time_start = time.time() decoder.pair_wise_decode() time_end = time.time() print(time_end - time_start)
def create_train_eval_model(FLAGS): Classify_model = model_pools["tagging_classify_model"] bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.gfile.MakeDirs(FLAGS.output_dir) # load custom processer from task name task_name1 = "ner" task_name2 = "qicm" FLAGS_tagging = FLAGS._asdict() FLAGS_tagging["train_batch_size"] = int(FLAGS_tagging["train_batch_size"]/2) FLAGS_tagging = config.generate_nametuple(FLAGS_tagging) FLAGS_classify = FLAGS._asdict() FLAGS_classify["train_batch_size"] = int(FLAGS_classify["train_batch_size"]/2) #pad to equal FLAGS_classify["train_batch_size"]+= FLAGS.train_batch_size - FLAGS_tagging.train_batch_size - FLAGS_classify["train_batch_size"] FLAGS_classify["train_file"] = FLAGS_classify["train_file_multi"] FLAGS_classify["dev_file"] = FLAGS_classify["dev_file_multi"] FLAGS_classify["test_file"] = FLAGS_classify["test_file_multi"] FLAGS_classify = config.generate_nametuple(FLAGS_classify) processor_tagging = processors[task_name1]() processor_classify = processors[task_name2]() tagging_batcher = Batcher(processor_tagging, FLAGS_tagging) classify_batcher = Batcher(processor_classify, FLAGS_classify) # create trainning model Bert_model = Classify_model(bert_config, tagging_batcher,classify_batcher, FLAGS) Bert_model.build_graph() Bert_model.create_or_load_recent_model() FLAGS_tagging = FLAGS._asdict() FLAGS_tagging["train_batch_size"] = int(FLAGS_tagging["train_batch_size"] / 2) FLAGS_tagging["mode"] = "dev" FLAGS_tagging = config.generate_nametuple(FLAGS_tagging) FLAGS_classify = FLAGS._asdict() FLAGS_classify["train_batch_size"] = int(FLAGS_classify["train_batch_size"] / 2) FLAGS_classify["mode"] = "dev" FLAGS_classify["train_file"] = FLAGS_classify["train_file_multi"] FLAGS_classify["dev_file"] = FLAGS_classify["dev_file_multi"] FLAGS_classify["test_file"] = FLAGS_classify["test_file_multi"] FLAGS_classify = config.generate_nametuple(FLAGS_classify) FLAGS_eval = FLAGS._asdict() FLAGS_eval["mode"] = "dev" FLAGS_eval = config.generate_nametuple(FLAGS_eval) processor_tagging = processors[task_name1]() processor_classify = processors[task_name2]() tagging_batcher = Batcher(processor_tagging, FLAGS_tagging) classify_batcher = Batcher(processor_classify, FLAGS_classify) validate_model = Classify_model(bert_config, tagging_batcher,classify_batcher, FLAGS_eval) validate_model.build_graph() validate_model.create_or_load_recent_model() return Bert_model,validate_model