def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug): graph = tf.Graph() with graph.as_default(): import json config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.use_one_hot_embeddings = True config.scope = "bert" config.dropout_prob = 0.1 config.label_type = "single_label" print(config, "==model config==") if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size/worker_count) epoch = FLAGS.epoch init_lr = 2e-5 label_dict = json.load(open(FLAGS.label_id)) num_train_steps = int( train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 190 num_eval_steps = 100 num_train_steps = 200 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({"init_lr":init_lr/worker_count, "num_train_steps":num_train_steps, "num_warmup_steps":num_warmup_steps, "worker_count":worker_count, "opt_type":FLAGS.opt_type, "is_chief":is_chief, "train_op":"adam"}) model_io_config = Bunch({"fix_lm":False}) num_classes = FLAGS.num_classes checkpoint_dir = checkpoint_dir #if task_index == 0 else None print("==checkpoint_dir==", checkpoint_dir, is_chief) model_train_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=None, load_pretrained=True, opt_config=opt_config, model_io_config=model_io_config, exclude_scope="", not_storage_params=[], target="", output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, task_index=task_index) model_eval_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=True, load_pretrained=True, opt_config=opt_config, model_io_config=model_io_config, exclude_scope="", not_storage_params=[], target="", output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, task_index=task_index) print("==succeeded in building model==") def eval_metric_fn(features, eval_op_dict): logits = eval_op_dict["logits"] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["label_ids"], tf.int32) ) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return {"accuracy":accuracy, "loss":eval_op_dict["loss"], "pred_label":pred_label, "label_ids":features["label_ids"]} def train_metric_fn(features, train_op_dict): logits = train_op_dict["logits"] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["label_ids"], tf.int32) ) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return {"accuracy":accuracy, "loss":train_op_dict["loss"], "train_op":train_op_dict["train_op"]} name_to_features = { "input_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "input_mask": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "segment_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "label_ids": tf.FixedLenFeature([], tf.int64), } def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example params = Bunch({}) params.epoch = epoch params.batch_size = FLAGS.batch_size print("==train_file==", train_file, params) train_features = tf_data_utils.train_input_fn(train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = tf_data_utils.eval_input_fn(dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) train_op_dict = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN) eval_op_dict = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL) eval_dict = eval_metric_fn(eval_features, eval_op_dict["eval"]) train_dict = train_metric_fn(train_features, train_op_dict["train"]) print("==succeeded in building data and model==") print(train_op_dict) def eval_fn(eval_dict, sess): i = 0 total_accuracy = 0 eval_total_dict = {} while True: try: eval_result = sess.run(eval_dict) for key in eval_result: if key not in eval_total_dict: if key in ["pred_label", "label_ids"]: eval_total_dict[key] = [] eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] = 0.0 eval_total_dict[key] += eval_result[key] else: if key in ["pred_label", "label_ids"]: eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] += eval_result[key] i += 1 if np.mod(i, num_eval_steps) == 0: break except tf.errors.OutOfRangeError: print("End of dataset") break label_id = eval_total_dict["label_ids"] pred_label = eval_total_dict["pred_label"] result = classification_report(label_id, pred_label, target_names=list(label_dict["label2id"].keys())) print(result, task_index) eval_total_dict["classification_report"] = result return eval_total_dict def train_fn(train_op_dict, sess): i = 0 cnt = 0 loss_dict = {} monitoring_train = [] monitoring_eval = [] while True: try: [train_result] = sess.run([train_op_dict]) for key in train_result: if key == "train_op": continue else: if np.isnan(train_result[key]): print(train_loss, "get nan loss") break else: if key in loss_dict: loss_dict[key] += train_result[key] else: loss_dict[key] = train_result[key] i += 1 cnt += 1 if np.mod(i, num_storage_steps) == 0: string = "" for key in loss_dict: tmp = key + " " + str(loss_dict[key]/cnt) + "\t" string += tmp print(string) monitoring_train.append(loss_dict) eval_finial_dict = eval_fn(eval_dict, sess) monitoring_eval.append(eval_finial_dict) for key in loss_dict: loss_dict[key] = 0.0 cnt = 0 if is_debug == "0": if i == num_train_steps: break except tf.errors.OutOfRangeError: print("==Succeeded in training model==") break print("===========begin to train============") # sess_config = tf.ConfigProto(allow_soft_placement=False, # log_device_placement=False) # # sess_config.gpu_options.visible_device_list = str(task_index) # print(sess_config.gpu_options.visible_device_list, task_index, "==============") print("start training") hooks = [] hooks.extend(train_op_dict["hooks"]) if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": print("==create monitored training session==", FLAGS.opt_type, is_chief) sess = tf.train.MonitoredTrainingSession(master=target, is_chief=is_chief, config=sess_config, hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) elif FLAGS.opt_type == "pai_soar" and pai: sess = tf.train.MonitoredTrainingSession(master=target, is_chief=is_chief, config=sess_config, hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) elif FLAGS.opt_type == "hvd" and hvd: sess_config.gpu_options.allow_growth = True sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir, hooks=hooks, config=sess_config, save_checkpoint_steps=num_storage_steps) else: print("==single sess==") sess = tf.train.MonitoredTrainingSession(config=sess_config, hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) print("==begin to train and eval==") train_fn(train_dict, sess) # for i in range(10): # l = sess.run(train_features) # print(l, task_index) if task_index == 0: print("===========begin to eval============") eval_finial_dict = eval_fn(eval_dict, sess)
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug): graph = tf.Graph() with graph.as_default(): import json config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.use_one_hot_embeddings = True config.scope = "bert" config.dropout_prob = 0.1 config.label_type = "single_label" if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch init_lr = 2e-5 label_dict = json.load(open(FLAGS.label_id)) num_train_steps = int(train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 2 num_eval_steps = 10 num_train_steps = 10 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}". format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({ "init_lr": init_lr / worker_count, "num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "worker_count": worker_count, "opt_type": FLAGS.opt_type, "is_chief": is_chief, "train_op": "adam" }) model_io_config = Bunch({"fix_lm": False}) model_io_fn = model_io.ModelIO(model_io_config) num_classes = FLAGS.num_classes checkpoint_dir = checkpoint_dir if task_index == 0 else None model_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=None, load_pretrained=True, model_io_config=model_io_config, opt_config=opt_config, model_io_fn=model_io_fn, exclude_scope="", not_storage_params=[], target="", output_type="estimator", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps) name_to_features = { "input_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "input_mask": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "segment_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "label_ids": tf.FixedLenFeature([], tf.int64), } def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example params = Bunch({}) params.epoch = FLAGS.epoch params.batch_size = FLAGS.batch_size train_features = lambda: tf_data_utils.train_input_fn( train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.eval_input_fn( dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) print("===========begin to train============") sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) train_hooks = [] eval_hooks = [] if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": print("==no need for hook==") elif FLAGS.opt_type == "pai_soar" and pai: print("no need for hook") elif FLAGS.opt_type == "hvd" and hvd: sess_config.gpu_options.allow_growth = True sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) print("==no need fo hook==") else: print("==no need for hooks==") run_config = tf.estimator.RunConfig( model_dir=checkpoint_dir, save_checkpoints_steps=num_storage_steps, session_config=sess_config) model_estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) train_spec = tf.estimator.TrainSpec(input_fn=train_features, max_steps=num_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=eval_features, steps=num_eval_steps) tf.estimator.train_and_evaluate(model_estimator, train_spec, eval_spec)