def main(_): domain_str = aFLAGS.domains global num_domains, layer_indexes, domain_weight, do_sent_pair num_domains = len(domain_str.split(",")) layer_indexes_str = aFLAGS.layer_indexes layer_indexes = [int(x) for x in layer_indexes_str.split(",")] domain_weight = aFLAGS.domain_weight do_sent_pair = aFLAGS.do_sent_pair app = Application() if "PAI" in tf.__version__: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
def main(_): app = Application() if FLAGS.mode == "train_and_evaluate_on_the_fly": if "PAI" in tf.__version__: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) if FLAGS.mode == "predict_on_the_fly": if "PAI" in tf.__version__: pred_reader = OdpsTableReader(input_glob=app.predict_input_fp, input_schema=app.input_schema, is_training=False, batch_size=app.predict_batch_size) pred_writer = OdpsTableWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema, slice_id=0, input_queue=None) else: pred_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=pred_reader, writer=pred_writer, checkpoint_path=app.predict_checkpoint_path)
def main(_): app = Application() if FLAGS.usePAI: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
def main(_): # load domain and class dict domains = aFLAGS.domains.split(",") classes = aFLAGS.classes.split(",") if aFLAGS.do_sent_pair: app = SentPairClsApplication() else: app = SentClsApplication() # create empty centroids domain_class_embeddings = dict() for domain_name in domains: for class_name in classes: key_name = domain_name + "\t" + class_name domain_class_embeddings[key_name] = list() # for training data if FLAGS.usePAI: predict_reader = OdpsTableReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) else: predict_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) # do inference for training data temp_output_data = list() for output in app.run_predict(reader=predict_reader, checkpoint_path=app.predict_checkpoint_path): current_size = len(output["pool_output"]) for i in range(current_size): if FLAGS.do_sent_pair: pool_output = output["pool_output"][i] text1 = output["text1"][i] text2 = output["text2"][i] domain = output["domain"][i] label = output["label"][i] key_name = domain.decode('utf-8') + "\t" + label.decode( 'utf-8') domain_class_embeddings[key_name].append(pool_output) temp_output_data.append( (text1, text2, domain, label, pool_output)) else: pool_output = output["pool_output"][i] text = output["text"][i] domain = output["domain"][i] label = output["label"][i] key_name = domain.decode('utf-8') + "\t" + label.decode( 'utf-8') domain_class_embeddings[key_name].append(pool_output) temp_output_data.append((text, domain, label, pool_output)) # compute centroids centroid_embeddings = dict() for key_name in domain_class_embeddings: domain_class_data_embeddings = np.array( domain_class_embeddings[key_name]) centroid_embeddings[key_name] = np.mean(domain_class_data_embeddings, axis=0) # output files for meta fine-tune if FLAGS.usePAI: #write odps tables records = [] if aFLAGS.do_sent_pair: for text1, text2, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) tup = (text1, text2, str(domains.index(domain)), label, np.around(weight, decimals=5)) records.append(tup) else: for text, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) tup = (text, str(domains.index(domain)), label, np.around(weight, decimals=5)) records.append(tup) with tf.python_io.TableWriter(FLAGS.outputs) as writer: if aFLAGS.do_sent_pair: indices = list(x for x in range(0, 5)) else: indices = list(x for x in range(0, 4)) writer.write(records, indices) else: #write to local file with open(FLAGS.outputs, 'w+') as f: if aFLAGS.do_sent_pair: for text1, text2, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) f.write(text1 + '\t' + text2 + '\t' + str(domains.index(domain)) + '\t' + label + '\t' + np.around(weight, decimals=5).astype('str') + '\n') else: for text, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) f.write(text + '\t' + str(domains.index(domain)) + '\t' + label + '\t' + np.around(weight, decimals=5).astype('str') + '\n')
def run(mode): if FLAGS.config is None: config_json = { "model_type": _APP_FLAGS.model_type, "vocab_size": _APP_FLAGS.vocab_size, "hidden_size": _APP_FLAGS.hidden_size, "intermediate_size": _APP_FLAGS.intermediate_size, "num_hidden_layers": _APP_FLAGS.num_hidden_layers, "max_position_embeddings": 512, "num_attention_heads": _APP_FLAGS.num_attention_heads, "type_vocab_size": 2 } if not tf.gfile.Exists(_APP_FLAGS.model_dir): tf.gfile.MkDir(_APP_FLAGS.model_dir) # Pretrain from scratch if _APP_FLAGS.pretrain_model_name_or_path is None: if not tf.gfile.Exists(_APP_FLAGS.model_dir + "/config.json"): with tf.gfile.GFile(_APP_FLAGS.model_dir + "/config.json", mode='w') as f: json.dump(config_json, f) shutil.copy2(_APP_FLAGS.vocab_fp, _APP_FLAGS.model_dir) if _APP_FLAGS.spm_model_fp is not None: shutil.copy2(_APP_FLAGS.spm_model_fp, _APP_FLAGS.model_dir) config = PretrainConfig() if _APP_FLAGS.do_multitaks_pretrain: app = PretrainMultitask(user_defined_config=config) else: app = Pretrain(user_defined_config=config) else: if _APP_FLAGS.do_multitaks_pretrain: app = PretrainMultitask() else: app = Pretrain() if "train" in mode: if _APP_FLAGS.data_reader == 'tfrecord': train_reader = BundleTFRecordReader( input_glob=app.train_input_fp, is_training=True, shuffle_buffer_size=1024, worker_hosts=FLAGS.worker_hosts, task_index=FLAGS.task_index, input_schema=app.input_schema, batch_size=app.train_batch_size) elif _APP_FLAGS.data_reader == 'odps': tf.logging.info("***********Reading Odps Table *************") worker_id = FLAGS.task_index num_workers = len(FLAGS.worker_hosts.split(",")) train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, shuffle_buffer_size=1024, input_schema=app.input_schema, slice_id=worker_id, slice_count=num_workers, batch_size=app.train_batch_size) if mode == "train_and_evaluate": if _APP_FLAGS.data_reader == 'tfrecord': eval_reader = BundleTFRecordReader(input_glob=app.eval_input_fp, is_training=False, shuffle_buffer_size=1024, worker_hosts=FLAGS.worker_hosts, task_index=FLAGS.task_index, input_schema=app.input_schema, batch_size=app.eval_batch_size) elif _APP_FLAGS.data_reader == 'odps': eval_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=False, shuffle_buffer_size=1024, input_schema=app.input_schema, slice_id=worker_id, slice_count=num_workers, batch_size=app.train_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) elif mode == "train": app.run_train(reader=train_reader) elif mode == "evaluate": if _APP_FLAGS.data_reader == 'tfrecord': eval_reader = BundleTFRecordReader(input_glob=app.eval_input_fp, is_training=False, shuffle_buffer_size=1024, worker_hosts=FLAGS.worker_hosts, task_index=FLAGS.task_index, input_schema=app.input_schema, batch_size=app.eval_batch_size) elif _APP_FLAGS.data_reader == 'odps': eval_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=False, shuffle_buffer_size=1024, input_schema=app.input_schema, slice_id=worker_id, slice_count=num_workers, batch_size=app.train_batch_size) ckpts = set() with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"), mode='r') as reader: for line in reader: line = line.strip() line = line.replace("oss://", "") ckpts.add( int( line.split(":")[1].strip().replace( "\"", "").split("/")[-1].replace("model.ckpt-", ""))) ckpts.remove(0) writer = tf.summary.FileWriter( os.path.join(app.config.model_dir, "eval_output")) for ckpt in sorted(ckpts): checkpoint_path = os.path.join(app.config.model_dir, "model.ckpt-" + str(ckpt)) tf.logging.info("checkpoint_path is {}".format(checkpoint_path)) ret_metrics = app.run_evaluate(reader=eval_reader, checkpoint_path=checkpoint_path) global_step = ret_metrics['global_step'] eval_masked_lm_accuracy = tf.Summary() eval_masked_lm_accuracy.value.add( tag='masked_lm_valid_accuracy', simple_value=ret_metrics['eval_masked_lm_accuracy']) eval_masked_lm_loss = tf.Summary() eval_masked_lm_loss.value.add( tag='masked_lm_valid_loss', simple_value=ret_metrics['eval_masked_lm_loss']) writer.add_summary(eval_masked_lm_accuracy, global_step) writer.add_summary(eval_masked_lm_loss, global_step) writer.close()