def optimize_graph(params: Params): config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) init_checkpoint = params.ckpt_dir tf.logging.info('build graph...') # input placeholders, not sure if they are friendly to XLA input_ids = tf.placeholder(tf.int32, (None, params.max_seq_len), 'input_ids') input_mask = tf.placeholder(tf.int32, (None, params.max_seq_len), 'input_mask') input_type_ids = tf.placeholder(tf.int32, (None, params.max_seq_len), 'segment_ids') jit_scope = tf.contrib.compiler.jit.experimental_jit_scope with jit_scope(): features = {} features['input_ids'] = input_ids features['input_mask'] = input_mask features['segment_ids'] = input_type_ids model = BertMultiTask(params) hidden_feature = model.body(features, tf.estimator.ModeKeys.PREDICT) pred = model.top(features, hidden_feature, tf.estimator.ModeKeys.PREDICT) output_tensors = [pred[k] for k in pred] tvars = tf.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tmp_g = tf.get_default_graph().as_graph_def() with tf.Session(config=config) as sess: tf.logging.info('load parameters from checkpoint...') sess.run(tf.global_variables_initializer()) tf.logging.info('freeze...') tmp_g = tf.graph_util.convert_variables_to_constants( sess, tmp_g, [n.name[:-2] for n in output_tensors]) tmp_file = os.path.join(params.ckpt_dir, 'export_model') tf.logging.info('write graph to a tmp file: %s' % tmp_file) with tf.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString()) return tmp_file
def train_problem(params, problem, gpu=4, base='baseline'): tf.keras.backend.clear_session() if not os.path.exists('tmp'): os.mkdir('tmp') base = os.path.join('tmp', base) params.assign_problem(problem, gpu=int(gpu), base_dir=base) create_path(params.ckpt_dir) tf.logging.info('Checkpoint dir: %s' % params.ckpt_dir) time.sleep(3) model = BertMultiTask(params=params) model_fn = model.get_model_fn(warm_start=False) dist_trategy = tf.contrib.distribute.MirroredStrategy( num_gpus=int(gpu), cross_tower_ops=tf.contrib.distribute.AllReduceCrossTowerOps( 'nccl', num_packs=int(gpu))) run_config = tf.estimator.RunConfig( train_distribute=dist_trategy, eval_distribute=dist_trategy, log_step_count_steps=params.log_every_n_steps) # ws = make_warm_start_setting(params) estimator = Estimator(model_fn, model_dir=params.ckpt_dir, params=params, config=run_config) train_hook = RestoreCheckpointHook(params) def train_input_fn(): return train_eval_input_fn(params) estimator.train(train_input_fn, max_steps=params.train_steps, hooks=[train_hook]) return estimator
def main(_): if not os.path.exists('tmp'): os.mkdir('tmp') if FLAGS.model_dir: base_dir, dir_name = os.path.split(FLAGS.model_dir) else: base_dir, dir_name = None, None params = Params() params.assign_problem(FLAGS.problem, gpu=int(FLAGS.gpu), base_dir=base_dir, dir_name=dir_name) tf.logging.info('Checkpoint dir: %s' % params.ckpt_dir) time.sleep(3) model = BertMultiTask(params=params) model_fn = model.get_model_fn(warm_start=False) dist_trategy = tf.contrib.distribute.MirroredStrategy( num_gpus=int(FLAGS.gpu), cross_tower_ops=tf.contrib.distribute.AllReduceCrossTowerOps( 'nccl', num_packs=int(FLAGS.gpu))) run_config = tf.estimator.RunConfig( train_distribute=dist_trategy, eval_distribute=dist_trategy, log_step_count_steps=params.log_every_n_steps) # ws = make_warm_start_setting(params) estimator = Estimator(model_fn, model_dir=params.ckpt_dir, params=params, config=run_config) if FLAGS.schedule == 'train': train_hook = RestoreCheckpointHook(params) def train_input_fn(): return train_eval_input_fn(params) estimator.train(train_input_fn, max_steps=params.train_steps, hooks=[train_hook]) def input_fn(): return train_eval_input_fn(params, mode='eval') estimator.evaluate(input_fn=input_fn) elif FLAGS.schedule == 'eval': evaluate_func = getattr(metrics, FLAGS.eval_scheme + '_evaluate') print(evaluate_func(FLAGS.problem, estimator, params)) elif FLAGS.schedule == 'predict': def input_fn(): return predict_input_fn([ '''兰心餐厅\n作为一个无辣不欢的妹子,对上海菜的偏清淡偏甜真的是各种吃不惯。 每次出门和闺蜜越饭局都是避开本帮菜。后来听很多朋友说上海有几家特别正宗味道做 的很好的餐厅于是这周末和闺蜜们准备一起去尝一尝正宗的本帮菜。\n进贤路是我在上 海比较喜欢的一条街啦,这家餐厅就开在这条路上。已经开了三十多年的老餐厅了,地 方很小,就五六张桌子。但是翻桌率比较快。二楼之前的居民间也改成了餐厅,但是在 上海的名气却非常大。烧的就是家常菜,普通到和家里烧的一样,生意非常好,外面排 队的比里面吃的人还要多。''' ], params, mode='predict') pred = estimator.predict(input_fn=input_fn) for p in pred: print(p)
def optimize_graph(params): config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) init_checkpoint = params.ckpt_dir tf.logging.info('build graph...') # input placeholders, not sure if they are friendly to XLA input_ids = tf.placeholder(tf.int32, (None, params.max_seq_len), 'input_ids') input_mask = tf.placeholder(tf.int32, (None, params.max_seq_len), 'input_mask') input_type_ids = tf.placeholder(tf.int32, (None, params.max_seq_len), 'segment_ids') jit_scope = tf.contrib.compiler.jit.experimental_jit_scope with jit_scope(): features = {} features['input_ids'] = input_ids features['input_mask'] = input_mask features['segment_ids'] = input_type_ids model = BertMultiTask(params) hidden_feature = model.body(features, tf.estimator.ModeKeys.PREDICT) pred = model.top(features, hidden_feature, tf.estimator.ModeKeys.PREDICT) output_tensors = [pred[k] for k in pred] tvars = tf.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tmp_g = tf.get_default_graph().as_graph_def() input_node_names = ['input_ids', 'input_mask', 'segment_ids'] output_node_names = [ '%s_top/%s_predict' % (params.share_top[problem], params.share_top[problem]) for problem in params.problem_list ] transforms = [ 'remove_nodes(op=Identity)', 'fold_constants(ignore_errors=true)', 'fold_batch_norms', # 'quantize_weights', # 'quantize_nodes', 'merge_duplicate_nodes', 'strip_unused_nodes', 'sort_by_execution_order' ] with tf.Session(config=config) as sess: tf.logging.info('load parameters from checkpoint...') sess.run(tf.global_variables_initializer()) tf.logging.info('freeze...') tmp_g = tf.graph_util.convert_variables_to_constants( sess, tmp_g, [n.name[:-2] for n in output_tensors]) tmp_g = TransformGraph(tmp_g, input_node_names, output_node_names, transforms) tmp_file = os.path.join(params.ckpt_dir, 'export_model') tf.logging.info('write graph to a tmp file: %s' % tmp_file) with tf.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString()) return tmp_file