def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug, **kargs): graph = tf.Graph() with graph.as_default(): import json # config = model_config_parser(FLAGS) if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch else: train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch multi_task_config = Bunch( json.load(tf.gfile.Open(FLAGS.multi_task_config))) num_train_steps = int(train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 190 num_eval_steps = 100 num_train_steps = 200 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}". format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({ "init_lr": FLAGS.init_lr, "num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "worker_count": worker_count, "opt_type": FLAGS.opt_type, "is_chief": is_chief, "train_op": kargs.get("train_op", "adam"), "decay": kargs.get("decay", "no"), "warmup": kargs.get("warmup", "no"), "grad_clip": kargs.get("grad_clip", "global_norm"), "clip_norm": kargs.get("clip_norm", 1.0) }) anneal_config = Bunch({ "initial_value": 1.0, "num_train_steps": num_train_steps }) model_io_config = Bunch({"fix_lm": False}) if FLAGS.opt_type == "hvd" and hvd: checkpoint_dir = checkpoint_dir if task_index == 0 else None else: checkpoint_dir = checkpoint_dir print("==checkpoint_dir==", checkpoint_dir, is_chief) model_config_dict = {} num_labels_dict = {} init_checkpoint_dict = {} load_pretrained_dict = {} exclude_scope_dict = {} not_storage_params_dict = {} target_dict = {} task_type_dict = {} model_type_lst = [] label_dict = {} for task_type in FLAGS.multi_task_type.split(","): print("==task type==", task_type) model_config_dict[task_type] = model_config_parser( Bunch(multi_task_config[task_type])) num_labels_dict[task_type] = multi_task_config[task_type][ "num_labels"] init_checkpoint_dict[task_type] = os.path.join( FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"]) load_pretrained_dict[task_type] = multi_task_config[task_type][ "load_pretrained"] exclude_scope_dict[task_type] = multi_task_config[task_type][ "exclude_scope"] not_storage_params_dict[task_type] = multi_task_config[task_type][ "not_storage_params"] target_dict[task_type] = multi_task_config[task_type]["target"] task_type_dict[task_type] = multi_task_config[task_type][ "task_type"] label_dict[task_type] = json.load( open( os.path.join(FLAGS.buckets, multi_task_config[task_type]["label_id"]))) model_train_fn = multitask_model_fn( model_config_dict, num_labels_dict, task_type_dict, init_checkpoint_dict, load_pretrained_dict=load_pretrained_dict, opt_config=opt_config, model_io_config=model_io_config, exclude_scope_dict=exclude_scope_dict, not_storage_params_dict=not_storage_params_dict, target_dict=target_dict, output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, anneal_config=anneal_config, task_layer_reuse=None, model_type_lst=model_type_lst, **kargs) eval_model_fn = {} for task_type in FLAGS.multi_task_type.split(","): eval_task_type_dict = {} model_config_dict[task_type] = model_config_parser( Bunch(multi_task_config[task_type])) num_labels_dict[task_type] = multi_task_config[task_type][ "num_labels"] init_checkpoint_dict[task_type] = os.path.join( FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"]) load_pretrained_dict[task_type] = multi_task_config[task_type][ "load_pretrained"] exclude_scope_dict[task_type] = multi_task_config[task_type][ "exclude_scope"] not_storage_params_dict[task_type] = multi_task_config[task_type][ "not_storage_params"] target_dict[task_type] = multi_task_config[task_type]["target"] eval_task_type_dict[task_type] = multi_task_config[task_type][ "task_type"] eval_model_fn[task_type] = multitask_model_fn( model_config_dict, num_labels_dict, eval_task_type_dict, init_checkpoint_dict, load_pretrained_dict=load_pretrained_dict, opt_config=opt_config, model_io_config=model_io_config, exclude_scope_dict=exclude_scope_dict, not_storage_params_dict=not_storage_params_dict, target_dict=target_dict, output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, anneal_config=anneal_config, task_layer_reuse=True, model_type_lst=model_type_lst, multi_task_config=multi_task_config, **kargs) print("==succeeded in building model==") def eval_metric_fn(features, eval_op_dict, task_type): logits = eval_op_dict["logits"][task_type] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["{}_label_ids".format(task_type)], tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return { "accuracy": accuracy, "loss": eval_op_dict["loss"][task_type], "pred_label": pred_label, "label_ids": features["{}_label_ids".format(task_type)] } def train_metric_fn(features, train_op_dict): return train_op_dict name_to_features = data_interface(FLAGS, multi_task_config, FLAGS.multi_task_type.split(",")) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def _decode_batch_record(record, name_to_features): example = tf.parse_example(record, name_to_features) return example params = Bunch({}) params.epoch = epoch params.batch_size = FLAGS.batch_size if kargs.get("parse_type", "parse_single") == "parse_single": train_file_lst = [ multi_task_config[task_type]["train_result_file"] for task_type in FLAGS.multi_task_type.split(",") ] print(train_file_lst) train_features = tf_data_utils.train_input_fn( train_file_lst, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features_dict = {} for task_type in FLAGS.multi_task_type.split(","): name_to_features = data_interface( FLAGS, {task_type: multi_task_config[task_type]}) eval_features_dict[task_type] = tf_data_utils.eval_input_fn( multi_task_config[task_type]["dev_result_file"], _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch": train_file_lst = [ multi_task_config[task_type]["train_result_file"] for task_type in FLAGS.multi_task_type.split(",") ] train_file_path_lst = [ os.path.join(FLAGS.buckets, train_file) for train_file in train_file_lst ] train_features = tf_data_utils.train_batch_input_fn( train_file_path_lst, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features_dict = {} for task_type in FLAGS.multi_task_type.split(","): name_to_features = data_interface( FLAGS, {task_type: multi_task_config[task_type]}, [task_type_dict]) dev_file_path = os.path.join( FLAGS.buckets, multi_task_config[task_type]["dev_result_file"]) eval_features_dict[ task_type] = tf_data_utils.eval_batch_input_fn( dev_file_path, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) train_op_dict = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN) train_dict = train_metric_fn(train_features, train_op_dict["train"]) eval_dict = {} for task_type in eval_features_dict: eval_features = eval_features_dict[task_type] eval_op_dict = eval_model_fn[task_type](eval_features, [], tf.estimator.ModeKeys.EVAL) eval_dict_tmp = eval_metric_fn(eval_features, eval_op_dict["eval"], task_type) eval_dict[task_type] = eval_dict_tmp print("==succeeded in building data and model==") print(train_op_dict) def task_eval(eval_dict, sess, eval_total_dict): eval_result = sess.run(eval_dict) for key in eval_result: if key not in eval_total_dict: if key in ["pred_label", "label_ids"]: eval_total_dict[key] = [] eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] = 0.0 eval_total_dict[key] += eval_result[key] else: if key in ["pred_label", "label_ids"]: eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] += eval_result[key] def task_metric(eval_dict, label_dict): label_id = eval_dict["label_ids"] pred_label = eval_dict["pred_label"] label_dict_id = sorted(list(label_dict["id2label"].keys())) print(len(label_id), len(pred_label), len(set(label_id))) accuracy = accuracy_score(label_id, pred_label) print("==accuracy==", accuracy) if len(label_dict["id2label"]) < 10: result = classification_report(label_id, pred_label, target_names=[ label_dict["id2label"][key] for key in label_dict_id ], digits=4) print(result, task_index) eval_total_dict["classification_report"] = result print("==classification report==") def eval_fn(eval_dict, sess): i = 0 total_accuracy = 0 eval_total_dict = {} for task_type in eval_dict: eval_total_dict[task_type] = {} while True: try: for task_type in eval_dict: task_eval(eval_dict[task_type], sess, eval_total_dict[task_type]) i += 1 if np.mod(i, num_eval_steps) == 0: break except tf.errors.OutOfRangeError: print("End of dataset") break for task_type in eval_total_dict: task_metric(eval_total_dict[task_type], label_dict[task_type]) return eval_total_dict def train_fn(train_op_dict, sess): i = 0 cnt = 0 loss_dict = {} monitoring_train = [] monitoring_eval = [] while True: try: [train_result] = sess.run([train_op_dict]) for key in train_result: if key == "train_op": continue else: if key == "loss": for task_type in train_result[key]: loss_dict[task_type][ "loss"] += train_result[key][task_type] else: try: if np.isnan(train_result[key]): print(train_loss, "get nan loss") break else: if key in loss_dict: loss_dict[key] += train_result[key] else: loss_dict[key] = train_result[key] except: continue i += 1 cnt += 1 if np.mod(i, num_storage_steps) == 0: string = "" for key in loss_dict: tmp = key + " " + str(loss_dict[key] / cnt) + "\t" string += tmp print(string) monitoring_train.append(loss_dict) eval_finial_dict = eval_fn(eval_dict, sess) monitoring_eval.append(eval_finial_dict) for key in loss_dict: loss_dict[key] = 0.0 cnt = 0 if is_debug == "0": if i == num_train_steps: break except tf.errors.OutOfRangeError: print("==Succeeded in training model==") break return {"eval": monitoring_eval, "train": monitoring_train} print("start training") hooks = [] hooks.extend(train_op_dict["hooks"]) if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) print("==create monitored training session==", FLAGS.opt_type, is_chief) sess = tf.train.MonitoredTrainingSession( master=target, is_chief=is_chief, config=kargs.get("sess_config", sess_config), hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) elif FLAGS.opt_type == "pai_soar" and pai: sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess = tf.train.MonitoredTrainingSession( master=target, is_chief=is_chief, config=kargs.get("sess_config", sess_config), hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) elif FLAGS.opt_type == "hvd" and hvd: sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess_config.gpu_options.allow_growth = False sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, hooks=hooks, config=sess_config, save_checkpoint_steps=num_storage_steps) else: print("==single sess==") sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess = tf.train.MonitoredTrainingSession( config=sess_config, hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) print("==begin to train and eval==") monitoring_info = train_fn(train_dict, sess) if task_index == 0: start_time = time.time() print("===========begin to eval============") eval_finial_dict = eval_fn(eval_dict, sess) end_time = time.time() print("==total forward time==", end_time - start_time)
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug, **kargs): graph = tf.Graph() with graph.as_default(): import json # config = json.load(open(FLAGS.config_file, "r")) # config = Bunch(config) # config.use_one_hot_embeddings = True # config.scope = "bert" # config.dropout_prob = 0.1 # config.label_type = "single_label" # config.model = FLAGS.model_type config = model_config_parser(FLAGS) if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": print("==number of gpus==", kargs.get('num_gpus', 1)) train_size = int(FLAGS.train_size/worker_count/kargs.get('num_gpus', 1)) # train_size = int(FLAGS.train_size) epoch = FLAGS.epoch else: train_size = int(FLAGS.train_size/worker_count) epoch = FLAGS.epoch init_lr = config.init_lr label_dict = json.load(tf.gfile.Open(FLAGS.label_id)) warmup_ratio = config.get('warmup', 0.1) num_train_steps = int( train_size / FLAGS.batch_size * epoch) if config.get('ln_type', 'postln') == 'postln': num_warmup_steps = int(num_train_steps * warmup_ratio) elif config.get('ln_type', 'preln') == 'postln': num_warmup_steps = 0 else: num_warmup_steps = int(num_train_steps * warmup_ratio) print('==num warmup steps==', num_warmup_steps) num_storage_steps = min([int(train_size / FLAGS.batch_size), 10000 ]) if num_storage_steps <= 100: num_storage_steps = 500 num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 2 num_eval_steps = 10 num_train_steps = 10 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============", kargs.get('num_gpus', 1), '==number of gpus==') # if worker_count*kargs.get("num_gpus", 1) >= 2: # clip_norm_scale = 1.0 # lr_scale = 0.8 # else: # clip_norm_scale = 1.0 # lr_scale = 1.0 # lr = init_lr*worker_count*kargs.get("num_gpus", 1)*lr_scale # if lr >= 1e-3: # lr = 1e-3 lr = init_lr opt_config = Bunch({"init_lr":lr, "num_train_steps":num_train_steps, "num_warmup_steps":num_warmup_steps, "worker_count":worker_count, "gpu_count":worker_count*kargs.get("num_gpus", 1), "opt_type":FLAGS.opt_type, "is_chief":is_chief, "train_op":kargs.get("train_op", "adam"), "decay":kargs.get("decay", "no"), "warmup":kargs.get("warmup", "no"), "clip_norm":config.get("clip_norm", 1.0), "grad_clip":config.get("grad_clip", "global_norm"), "epoch":FLAGS.epoch, "strategy":FLAGS.distribution_strategy}) anneal_config = Bunch({ "initial_value":1.0, "num_train_steps":num_train_steps }) model_io_config = Bunch({"fix_lm":False}) model_io_fn = model_io.ModelIO(model_io_config) num_classes = FLAGS.num_classes if FLAGS.opt_type == "hvd" and hvd: checkpoint_dir = checkpoint_dir if task_index == 0 else None elif FLAGS.opt_type == "all_reduce": checkpoint_dir = checkpoint_dir elif FLAGS.opt_type == "collective_reduce": checkpoint_dir = checkpoint_dir if task_index == 0 else None elif FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": checkpoint_dir = checkpoint_dir if task_index == 0 else None print("==checkpoint_dir==", checkpoint_dir, is_chief) # if kargs.get("rule_model", "normal") == "rule": # model_fn_interface = rule_model_fn_builder # print("==apply rule model==") # else: # model_fn_interface = model_fn_builder # print("==apply normal model==") model_fn_builder = model_fn_interface(FLAGS) model_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=None, load_pretrained=FLAGS.load_pretrained, model_io_config=model_io_config, opt_config=opt_config, model_io_fn=model_io_fn, exclude_scope=FLAGS.exclude_scope, not_storage_params=[], target=kargs.get("input_target", ""), output_type="estimator", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, task_index=task_index, anneal_config=anneal_config, **kargs) name_to_features = data_interface(FLAGS) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def _decode_batch_record(record, name_to_features): example = tf.parse_example(record, name_to_features) # for name in list(example.keys()): # t = example[name] # if t.dtype == tf.int64: # t = tf.to_int32(t) # example[name] = t return example params = Bunch({}) params.epoch = FLAGS.epoch params.batch_size = FLAGS.batch_size if kargs.get("run_config", None): if kargs.get("parse_type", "parse_single") == "parse_single": train_features = lambda: tf_data_utils.all_reduce_train_input_fn(train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.all_reduce_eval_input_fn(dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch": print("==apply parse example==") train_features = lambda: tf_data_utils.all_reduce_train_batch_input_fn(train_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.all_reduce_eval_batch_input_fn(dev_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch_multi_task": data_prior = [float(item) for item in FLAGS.data_prior.split(',')] train_features = lambda: tf_data_utils.all_reduce_multitask_train_batch_input_fn_sample( train_file, _decode_record, name_to_features, params, data_prior=data_prior, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.all_reduce_eval_batch_input_fn(dev_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) else: train_features = lambda: tf_data_utils.train_input_fn(train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.eval_input_fn(dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) train_hooks = [] eval_hooks = [] sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": print("==no need for hook==") elif FLAGS.opt_type == "pai_soar" and pai: print("no need for hook") elif FLAGS.opt_type == "hvd" and hvd: sess_config.gpu_options.allow_growth = True sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) print("==no need fo hook==") else: print("==no need for hooks==") if kargs.get("run_config", None): run_config = kargs.get("run_config", None) run_config = run_config.replace(save_checkpoints_steps=num_storage_steps) print("==run config==", run_config.save_checkpoints_steps) else: run_config = tf.estimator.RunConfig(model_dir=checkpoint_dir, save_checkpoints_steps=num_storage_steps, session_config=sess_config) if kargs.get("profiler", "profiler") == "profiler": if checkpoint_dir: hooks = tf.train.ProfilerHook( save_steps=100, save_secs=None, output_dir=os.path.join(checkpoint_dir, "profiler"), ) train_hooks.append(hooks) print("==add profiler hooks==") model_estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=checkpoint_dir, config=run_config) train_being_time = time.time() tf.logging.info("==training distribution_strategy=={}".format(kargs.get("distribution_strategy", "MirroredStrategy"))) if kargs.get("distribution_strategy", "MirroredStrategy") == "MirroredStrategy": print("==apply single machine multi-card training==") train_spec = tf.estimator.TrainSpec(input_fn=train_features, max_steps=num_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=eval_features, steps=num_eval_steps) model_estimator.train(input_fn=train_features, max_steps=num_train_steps, hooks=train_hooks) # tf.estimator.train(model_estimator, train_spec) train_end_time = time.time() print("==training time==", train_end_time - train_being_time) tf.logging.info("==training time=={}".format(train_end_time - train_being_time)) eval_results = model_estimator.evaluate(input_fn=eval_features, steps=num_eval_steps) print(eval_results) elif kargs.get("distribution_strategy", "MirroredStrategy") in ["ParameterServerStrategy", "CollectiveAllReduceStrategy"]: print("==apply multi-machine machine multi-card training==") try: print(os.environ['TF_CONFIG'], "==tf_run_config==") except: print("==not tf config==") train_spec = tf.estimator.TrainSpec(input_fn=train_features, max_steps=num_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=eval_features, steps=num_eval_steps) # tf.estimator.train(model_estimator, train_spec) # tf 1.12 doesn't need evaluate tf.estimator.train_and_evaluate(model_estimator, train_spec, eval_spec)
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug, **kargs): graph = tf.Graph() with graph.as_default(): import json # config = model_config_parser(FLAGS) print(FLAGS.train_size) if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch else: train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch multi_task_config = Bunch( json.load(tf.gfile.Open(FLAGS.multi_task_config))) num_train_steps = int(train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 190 num_eval_steps = 100 num_train_steps = 200 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}". format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({ "init_lr": kargs.get("init_lr", 5e-5) / worker_count, "num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "worker_count": worker_count, "opt_type": FLAGS.opt_type, "is_chief": is_chief, "train_op": kargs.get("train_op", "adam"), "decay": kargs.get("decay", "no"), "warmup": kargs.get("warmup", "no"), "grad_clip": kargs.get("grad_clip", "global_norm"), "clip_norm": kargs.get("clip_norm", 1.0) }) anneal_config = Bunch({ "initial_value": 1.0, "num_train_steps": num_train_steps }) model_io_config = Bunch({"fix_lm": False}) if FLAGS.opt_type == "hvd" and hvd: checkpoint_dir = checkpoint_dir if task_index == 0 else None else: checkpoint_dir = checkpoint_dir print("==checkpoint_dir==", checkpoint_dir, is_chief) model_config_dict = {} num_labels_dict = {} init_checkpoint_dict = {} load_pretrained_dict = {} exclude_scope_dict = {} not_storage_params_dict = {} target_dict = {} task_type_dict = {} model_type_lst = [] label_dict = {} for task_type in FLAGS.multi_task_type.split(","): print("==task type==", task_type) model_config_dict[task_type] = model_config_parser( Bunch(multi_task_config[task_type])) num_labels_dict[task_type] = multi_task_config[task_type][ "num_labels"] init_checkpoint_dict[task_type] = os.path.join( FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"]) load_pretrained_dict[task_type] = multi_task_config[task_type][ "load_pretrained"] exclude_scope_dict[task_type] = multi_task_config[task_type][ "exclude_scope"] not_storage_params_dict[task_type] = multi_task_config[task_type][ "not_storage_params"] target_dict[task_type] = multi_task_config[task_type]["target"] task_type_dict[task_type] = multi_task_config[task_type][ "task_type"] label_dict[task_type] = json.load( tf.gfile.Open( os.path.join(FLAGS.buckets, multi_task_config[task_type]["label_id"]))) model_fn = multitask_model_fn( model_config_dict, num_labels_dict, task_type_dict, init_checkpoint_dict, load_pretrained_dict=load_pretrained_dict, opt_config=opt_config, model_io_config=model_io_config, exclude_scope_dict=exclude_scope_dict, not_storage_params_dict=not_storage_params_dict, target_dict=target_dict, output_type="estimator", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, anneal_config=anneal_config, task_layer_reuse=None, model_type_lst=model_type_lst, **kargs) print("==succeeded in building model==") name_to_features = data_interface(FLAGS, multi_task_config, FLAGS.multi_task_type.split(",")) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def _decode_batch_record(record, name_to_features): example = tf.parse_example(record, name_to_features) return example params = Bunch({}) params.epoch = epoch params.batch_size = FLAGS.batch_size if kargs.get("parse_type", "parse_single") == "parse_single": train_file_lst = [ multi_task_config[task_type]["train_result_file"] for task_type in FLAGS.multi_task_type.split(",") ] print(train_file_lst) train_features = lambda: tf_data_utils.all_reduce_multitask_train_input_fn( train_file_lst, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch": train_file_lst = [ multi_task_config[task_type]["train_result_file"] for task_type in FLAGS.multi_task_type.split(",") ] train_file_path_lst = [ os.path.join(FLAGS.buckets, train_file) for train_file in train_file_lst ] print(train_file_path_lst) print("==apply train batch==") train_features = lambda: tf_data_utils.all_reduce_train_batch_input_fn( train_file_path_lst, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) print("==succeeded in building data and model==") print("start training") train_hooks = [] sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": print("==no need for hook==") elif FLAGS.opt_type == "pai_soar" and pai: print("no need for hook") elif FLAGS.opt_type == "hvd" and hvd: sess_config.gpu_options.allow_growth = True sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) print("==no need fo hook==") else: print("==no need for hooks==") if kargs.get("run_config", None): run_config = kargs.get("run_config", None) run_config = run_config.replace( save_checkpoints_steps=num_storage_steps) print("==run config==", run_config.save_checkpoints_steps) else: run_config = tf.estimator.RunConfig( model_dir=checkpoint_dir, save_checkpoints_steps=num_storage_steps, session_config=sess_config) if kargs.get("profiler", "profiler") == "profiler": if checkpoint_dir: hooks = tf.train.ProfilerHook( save_steps=100, save_secs=None, output_dir=os.path.join(checkpoint_dir, "profiler"), ) train_hooks.append(hooks) print("==add profiler hooks==") model_estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) train_being_time = time.time() tf.logging.info("==training distribution_strategy=={}".format( kargs.get("distribution_strategy", "MirroredStrategy"))) if kargs.get("distribution_strategy", "MirroredStrategy") == "MirroredStrategy": print("==apply single machine multi-card training==") model_estimator.train(input_fn=train_features, max_steps=num_train_steps, hooks=train_hooks) train_end_time = time.time() print("==training time==", train_end_time - train_being_time) tf.logging.info("==training time=={}".format(train_end_time - train_being_time)) elif kargs.get("distribution_strategy", "MirroredStrategy") in [ "ParameterServerStrategy", "CollectiveAllReduceStrategy" ]: print("==apply multi-machine machine multi-card training==") try: print(os.environ['TF_CONFIG'], "==tf_run_config==") except: print("==not tf config==") train_spec = tf.estimator.TrainSpec(input_fn=train_features, max_steps=num_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=eval_features, steps=num_eval_steps) tf.estimator.train_and_evaluate(model_estimator, train_spec, eval_spec) train_end_time = time.time() print("==training time==", train_end_time - train_being_time)
def eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug, **kargs): graph = tf.Graph() with graph.as_default(): import json # config = model_config_parser(FLAGS) if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch else: train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch multi_task_config = Bunch( json.load(tf.gfile.Open(FLAGS.multi_task_config))) num_train_steps = int(train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 190 num_eval_steps = 100 num_train_steps = 200 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}". format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({ "init_lr": kargs.get("init_lr", 1e-5) / worker_count, "num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "worker_count": worker_count, "opt_type": FLAGS.opt_type, "is_chief": is_chief, "train_op": kargs.get("train_op", "adam"), "decay": kargs.get("decay", "no"), "warmup": kargs.get("warmup", "no"), "grad_clip": kargs.get("grad_clip", "global_norm"), "clip_norm": kargs.get("clip_norm", 1.0) }) anneal_config = Bunch({ "initial_value": 1.0, "num_train_steps": num_train_steps }) model_io_config = Bunch({"fix_lm": False}) if FLAGS.opt_type == "hvd" and hvd: checkpoint_dir = checkpoint_dir if task_index == 0 else None else: checkpoint_dir = checkpoint_dir print("==checkpoint_dir==", checkpoint_dir, is_chief) model_config_dict = {} num_labels_dict = {} init_checkpoint_dict = {} load_pretrained_dict = {} exclude_scope_dict = {} not_storage_params_dict = {} target_dict = {} task_type_dict = {} model_type_lst = [] label_dict = {} eval_model_fn = {} for task_type in FLAGS.multi_task_type.split(","): eval_task_type_dict = {} model_config_dict[task_type] = model_config_parser( Bunch(multi_task_config[task_type])) num_labels_dict[task_type] = multi_task_config[task_type][ "num_labels"] init_checkpoint_dict[task_type] = os.path.join( FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"]) print( init_checkpoint_dict[task_type], task_type, "===", os.path.join(FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"])) load_pretrained_dict[task_type] = multi_task_config[task_type][ "load_pretrained"] exclude_scope_dict[task_type] = multi_task_config[task_type][ "exclude_scope"] not_storage_params_dict[task_type] = multi_task_config[task_type][ "not_storage_params"] target_dict[task_type] = multi_task_config[task_type]["target"] eval_task_type_dict[task_type] = multi_task_config[task_type][ "task_type"] label_dict[task_type] = json.load( tf.gfile.Open( os.path.join(FLAGS.buckets, multi_task_config[task_type]["label_id"]))) eval_model_fn[task_type] = multitask_model_fn( model_config_dict, num_labels_dict, eval_task_type_dict, init_checkpoint_dict, load_pretrained_dict=load_pretrained_dict, opt_config=opt_config, model_io_config=model_io_config, exclude_scope_dict=exclude_scope_dict, not_storage_params_dict=not_storage_params_dict, target_dict=target_dict, output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, anneal_config=anneal_config, task_layer_reuse=False, model_type_lst=model_type_lst, **kargs) print(init_checkpoint_dict, "==init_checkpoint==") print("==succeeded in building model==") def eval_metric_fn(features, eval_op_dict, task_type): logits = eval_op_dict["logits"][task_type] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["{}_label_ids".format(task_type)], tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return { "accuracy": accuracy, "loss": eval_op_dict["loss"][task_type], "pred_label": pred_label, "label_ids": features["{}_label_ids".format(task_type)] } name_to_features = data_interface(FLAGS, multi_task_config, FLAGS.multi_task_type.split(",")) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def _decode_batch_record(record, name_to_features): example = tf.parse_example(record, name_to_features) return example params = Bunch({}) params.epoch = 0 params.batch_size = FLAGS.batch_size if kargs.get("parse_type", "parse_single") == "parse_single": eval_features_dict = {} for task_type in FLAGS.multi_task_type.split(","): name_to_features = data_interface( FLAGS, {task_type: multi_task_config[task_type]}, [task_type]) eval_features_dict[task_type] = tf_data_utils.eval_input_fn( multi_task_config[task_type]["dev_result_file"], _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch": eval_features_dict = {} for task_type in FLAGS.multi_task_type.split(","): name_to_features = data_interface( FLAGS, {task_type: multi_task_config[task_type]}, [task_type]) dev_file_path = os.path.join( FLAGS.buckets, multi_task_config[task_type]["test_result_file"]) eval_features_dict[ task_type] = tf_data_utils.eval_batch_input_fn( dev_file_path, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_dict = {} for task_type in eval_features_dict: eval_features = eval_features_dict[task_type] eval_op_dict = eval_model_fn[task_type](eval_features, [], tf.estimator.ModeKeys.EVAL) eval_dict_tmp = eval_metric_fn(eval_features, eval_op_dict["eval"], task_type) eval_dict[task_type] = eval_dict_tmp print(eval_dict) print("==succeeded in building data and model==") def task_eval(eval_dict, sess, eval_total_dict): eval_result = sess.run(eval_dict) for key in eval_result: if key not in eval_total_dict: if key in ["pred_label", "label_ids"]: eval_total_dict[key] = [] eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] = 0.0 eval_total_dict[key] += eval_result[key] else: if key in ["pred_label", "label_ids"]: eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] += eval_result[key] def task_metric(eval_dict, label_dict, eval_total_dict): label_id = eval_dict["label_ids"] pred_label = eval_dict["pred_label"] label_dict_id = sorted(list(label_dict["id2label"].keys())) print(len(label_id), len(pred_label), len(set(label_id))) accuracy = accuracy_score(label_id, pred_label) print("==accuracy==", accuracy) if len(label_dict["id2label"]) < 10: result = classification_report(label_id, pred_label, target_names=[ label_dict["id2label"][key] for key in label_dict_id ], digits=4) print(result, task_index) eval_total_dict["classification_report"] = result print("==classification report==") def eval_fn(eval_dict, sess): i = 0 total_accuracy = 0 eval_total_dict = {} for task_type in eval_dict: eval_total_dict[task_type] = {} while True: try: for task_type in eval_dict: task_eval(eval_dict[task_type], sess, eval_total_dict[task_type]) i += 1 except tf.errors.OutOfRangeError: print("End of dataset") break for task_type in eval_total_dict: task_metric(eval_total_dict[task_type], label_dict[task_type], eval_total_dict[task_type]) return eval_total_dict print("start evaluating") sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess = tf.Session(config=sess_config) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) print("==begin to train and eval==") start_time = time.time() eval_finial_dict = eval_fn(eval_dict, sess) end_time = time.time() print("==forward time==", end_time - start_time) return eval_finial_dict
def eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug, **kargs): graph = tf.Graph() with graph.as_default(): import json # config = json.load(open(FLAGS.config_file, "r")) # config = Bunch(config) # config.use_one_hot_embeddings = True # config.scope = "bert" # config.dropout_prob = 0.1 # config.label_type = "single_label" # config.model_type = FLAGS.model_type config = model_config_parser(FLAGS) if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size/worker_count) epoch = FLAGS.epoch else: train_size = int(FLAGS.train_size/worker_count) epoch = FLAGS.epoch init_lr = 2e-5 label_dict = json.load(open(FLAGS.label_id)) num_train_steps = int( train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 2 num_eval_steps = 10 num_train_steps = 10 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({"init_lr":init_lr/worker_count, "num_train_steps":num_train_steps, "num_warmup_steps":num_warmup_steps, "worker_count":worker_count, "opt_type":FLAGS.opt_type, "is_chief":is_chief, "train_op":kargs.get("train_op", "adam")}) anneal_config = Bunch({ "initial_value":1.0, "num_train_steps":num_train_steps }) model_io_config = Bunch({"fix_lm":False}) model_io_fn = model_io.ModelIO(model_io_config) num_classes = FLAGS.num_classes if FLAGS.opt_type == "hvd" and hvd: checkpoint_dir = checkpoint_dir if task_index == 0 else None else: checkpoint_dir = checkpoint_dir print("==checkpoint_dir==", checkpoint_dir, is_chief) # if kargs.get("rule_model", "rule"): # model_fn_interface = rule_model_fn_builder # print("==apply rule model==") # else: # model_fn_interface = model_fn_builder # print("==apply normal model==") model_fn_builder = model_fn_interface(FLAGS) model_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=None, load_pretrained=FLAGS.load_pretrained, model_io_config=model_io_config, opt_config=opt_config, model_io_fn=model_io_fn, exclude_scope="", not_storage_params=[], target=kargs.get("input_target", ""), output_type="estimator", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, task_index=task_index, anneal_config=anneal_config, **kargs) # name_to_features = { # "input_ids": # tf.FixedLenFeature([FLAGS.max_length], tf.int64), # "input_mask": # tf.FixedLenFeature([FLAGS.max_length], tf.int64), # "segment_ids": # tf.FixedLenFeature([FLAGS.max_length], tf.int64), # "label_ids": # tf.FixedLenFeature([], tf.int64), # } name_to_features = data_interface(FLAGS) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def _decode_batch_record(record, name_to_features): example = tf.parse_example(record, name_to_features) # for name in list(example.keys()): # t = example[name] # if t.dtype == tf.int64: # t = tf.to_int32(t) # example[name] = t return example params = Bunch({}) params.epoch = 0 params.batch_size = FLAGS.batch_size if kargs.get("run_config", None): if kargs.get("parse_type", "parse_single") == "parse_single": train_features = lambda: tf_data_utils.all_reduce_train_input_fn(train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.all_reduce_eval_input_fn(dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch": print("==apply parse example==") train_features = lambda: tf_data_utils.all_reduce_train_batch_input_fn(train_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.all_reduce_eval_batch_input_fn(dev_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) else: train_features = lambda: tf_data_utils.train_input_fn(train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = lambda: tf_data_utils.eval_input_fn(dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) train_hooks = [] eval_hooks = [] sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": print("==no need for hook==") elif FLAGS.opt_type == "pai_soar" and pai: print("no need for hook") elif FLAGS.opt_type == "hvd" and hvd: sess_config.gpu_options.allow_growth = True sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) print("==no need fo hook==") else: print("==no need for hooks==") if kargs.get("run_config", None): run_config = kargs.get("run_config", None) run_config = run_config.replace(save_checkpoints_steps=num_storage_steps) print("==run config==", run_config.save_checkpoints_steps) else: run_config = tf.estimator.RunConfig(model_dir=checkpoint_dir, save_checkpoints_steps=num_storage_steps, session_config=sess_config) if kargs.get("profiler", "profiler") == "profiler": hooks = tf.train.ProfilerHook( save_steps=100, save_secs=None, output_dir=os.path.join(checkpoint_dir, "profiler"), ) train_hooks.append(hooks) print("==add profiler hooks==") model_estimator = tf.estimator.Estimator( model_fn=model_fn, config=run_config) eval_results = model_estimator.evaluate(input_fn=eval_features, steps=num_eval_steps) print(eval_results)
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug, **kargs): graph = tf.Graph() with graph.as_default(): import json # config = json.load(open(FLAGS.config_file, "r")) # config = Bunch(config) # config.use_one_hot_embeddings = True # config.scope = "bert" # config.dropout_prob = 0.1 # config.label_type = "single_label" # config.model = FLAGS.model_type config = model_config_parser(FLAGS) # print(config, "==model config==") if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch else: train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch init_lr = config.init_lr label_dict = json.load(tf.gfile.Open(FLAGS.label_id)) num_train_steps = int(train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 190 num_eval_steps = 100 num_train_steps = 200 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}". format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({ "init_lr": init_lr / worker_count, "num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "worker_count": worker_count, "opt_type": FLAGS.opt_type, "is_chief": is_chief, "train_op": kargs.get("train_op", "adam"), "decay": kargs.get("decay", "no"), "warmup": kargs.get("warmup", "no"), "grad_clip": config.get("grad_clip", "global_norm"), "clip_norm": config.get("clip_norm", 1.0) }) anneal_config = Bunch({ "initial_value": 1.0, "num_train_steps": num_train_steps }) model_io_config = Bunch({"fix_lm": False}) num_classes = FLAGS.num_classes if FLAGS.opt_type == "hvd" and hvd: checkpoint_dir = checkpoint_dir if task_index == 0 else None else: checkpoint_dir = checkpoint_dir print("==checkpoint_dir==", checkpoint_dir, is_chief) # if kargs.get("rule_model", "rule"): # model_fn_interface = rule_model_fn_builder # print("==apply rule model==") # else: # model_fn_interface = model_fn_builder # print("==apply normal model==") model_fn_builder = model_fn_interface(FLAGS) model_train_fn = model_fn_builder( config, num_classes, init_checkpoint, model_reuse=None, load_pretrained=FLAGS.load_pretrained, opt_config=opt_config, model_io_config=model_io_config, exclude_scope="", not_storage_params=[], target=kargs.get("input_target", ""), output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, task_index=task_index, anneal_config=anneal_config, **kargs) model_eval_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=True, load_pretrained=FLAGS.load_pretrained, opt_config=opt_config, model_io_config=model_io_config, exclude_scope="", not_storage_params=[], target=kargs.get("input_target", ""), output_type="sess", checkpoint_dir=checkpoint_dir, num_storage_steps=num_storage_steps, task_index=task_index, anneal_config=anneal_config, **kargs) print("==succeeded in building model==") def eval_metric_fn(features, eval_op_dict): logits = eval_op_dict["logits"] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["label_ids"], tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return { "accuracy": accuracy, "loss": eval_op_dict["loss"], "pred_label": pred_label, "label_ids": features["label_ids"] } def train_metric_fn(features, train_op_dict): logits = train_op_dict["logits"] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["label_ids"], tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) train_op_dict["accuracy"] = accuracy # train_op_dict.pop("logits") # return {"accuracy":accuracy, "loss":train_op_dict["loss"], # "train_op":train_op_dict["train_op"]} return train_op_dict # name_to_features = { # "input_ids": # tf.FixedLenFeature([FLAGS.max_length], tf.int64), # "input_mask": # tf.FixedLenFeature([FLAGS.max_length], tf.int64), # "segment_ids": # tf.FixedLenFeature([FLAGS.max_length], tf.int64), # "label_ids": # tf.FixedLenFeature([], tf.int64), # } name_to_features = data_interface(FLAGS) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def _decode_batch_record(record, name_to_features): example = tf.parse_example(record, name_to_features) # for name in list(example.keys()): # t = example[name] # if t.dtype == tf.int64: # t = tf.to_int32(t) # example[name] = t return example params = Bunch({}) params.epoch = epoch params.batch_size = FLAGS.batch_size print("==train_file==", train_file, params) if kargs.get("parse_type", "parse_single") == "parse_single": train_features = tf_data_utils.train_input_fn( train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = tf_data_utils.eval_input_fn( dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) elif kargs.get("parse_type", "parse_single") == "parse_batch": train_features = tf_data_utils.train_batch_input_fn( train_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = tf_data_utils.eval_batch_input_fn( dev_file, _decode_batch_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) train_op_dict = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN) eval_op_dict = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL) eval_dict = eval_metric_fn(eval_features, eval_op_dict["eval"]) train_dict = train_metric_fn(train_features, train_op_dict["train"]) print("==succeeded in building data and model==") print(train_op_dict) def eval_fn(eval_dict, sess): i = 0 total_accuracy = 0 eval_total_dict = {} while True: try: eval_result = sess.run(eval_dict) for key in eval_result: if key not in eval_total_dict: if key in ["pred_label", "label_ids"]: eval_total_dict[key] = [] eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] = 0.0 eval_total_dict[key] += eval_result[key] else: if key in ["pred_label", "label_ids"]: eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] += eval_result[key] i += 1 if np.mod(i, num_eval_steps) == 0: break except tf.errors.OutOfRangeError: print("End of dataset") break label_id = eval_total_dict["label_ids"] pred_label = eval_total_dict["pred_label"] label_dict_id = sorted(list(label_dict["id2label"].keys())) print(len(label_id), len(pred_label), len(set(label_id))) accuracy = accuracy_score(label_id, pred_label) print("==accuracy==", accuracy) if len(label_dict["id2label"]) < 10: result = classification_report(label_id, pred_label, target_names=[ label_dict["id2label"][key] for key in label_dict_id ], digits=4) print(result, task_index) eval_total_dict["classification_report"] = result print("==classification report==") return eval_total_dict def train_fn(train_op_dict, sess): i = 0 cnt = 0 loss_dict = {} monitoring_train = [] monitoring_eval = [] while True: try: [train_result] = sess.run([train_op_dict]) for key in train_result: if key == "train_op": continue else: try: if np.isnan(train_result[key]): print(train_loss, "get nan loss") break else: if key in loss_dict: loss_dict[key] += train_result[key] else: loss_dict[key] = train_result[key] except: # if key == "student_logit": # print(train_result[key]) continue # print(pkl, "==pkl==") # if pkl: # pkl.dump(train_result, open("/data/xuht/distillation.pkl", "wb")) i += 1 cnt += 1 if np.mod(i, num_storage_steps) == 0: string = "" for key in loss_dict: tmp = key + " " + str(loss_dict[key] / cnt) + "\t" string += tmp print(string) monitoring_train.append(loss_dict) eval_finial_dict = eval_fn(eval_dict, sess) monitoring_eval.append(eval_finial_dict) for key in loss_dict: loss_dict[key] = 0.0 cnt = 0 if is_debug == "0": if i == num_train_steps: break except tf.errors.OutOfRangeError: print("==Succeeded in training model==") break return {"eval": monitoring_eval, "train": monitoring_train} print("===========begin to train============") # sess_config = tf.ConfigProto(allow_soft_placement=False, # log_device_placement=False) # # sess_config.gpu_options.visible_device_list = str(task_index) # print(sess_config.gpu_options.visible_device_list, task_index, "==============") print("start training") hooks = [] hooks.extend(train_op_dict["hooks"]) if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync": sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) print("==create monitored training session==", FLAGS.opt_type, is_chief) sess = tf.train.MonitoredTrainingSession( master=target, is_chief=is_chief, config=kargs.get("sess_config", sess_config), hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) elif FLAGS.opt_type == "pai_soar" and pai: sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess = tf.train.MonitoredTrainingSession( master=target, is_chief=is_chief, config=kargs.get("sess_config", sess_config), hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) elif FLAGS.opt_type == "hvd" and hvd: sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess_config.gpu_options.allow_growth = False sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, hooks=hooks, config=sess_config, save_checkpoint_steps=num_storage_steps) else: print("==single sess==") sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) sess = tf.train.MonitoredTrainingSession( config=sess_config, hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) print("==begin to train and eval==") # step = sess.run(tf.train.get_global_step()) # print(step, task_index, "==task_index, global_step==") monitoring_info = train_fn(train_dict, sess) # for i in range(10): # l = sess.run(train_features) # print(l, task_index) if task_index == 0: start_time = time.time() print("===========begin to eval============") eval_finial_dict = eval_fn(eval_dict, sess) end_time = time.time() print("==total forward time==", end_time - start_time)