(assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tmp_g = tf.get_default_graph().as_graph_def() with tf.Session(config=config) as sess: tf.logging.info('load parameters from checkpoint...') sess.run(tf.global_variables_initializer()) tf.logging.info('freeze...') tmp_g = tf.graph_util.convert_variables_to_constants( sess, tmp_g, [n.name[:-2] for n in output_tensors]) tmp_file = os.path.join(params.ckpt_dir, 'export_model') tf.logging.info('write graph to a tmp file: %s' % tmp_file) with tf.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString()) return tmp_file if __name__ == "__main__": if FLAGS.model_dir: base_dir, dir_name = os.path.split(FLAGS.model_dir) else: base_dir, dir_name = None, None params = Params() params.assign_problem(FLAGS.problem, base_dir=base_dir, dir_name=dir_name) optimize_graph(params) params.to_json()
def main(_): if not os.path.exists('tmp'): os.mkdir('tmp') if FLAGS.model_dir: base_dir, dir_name = os.path.split(FLAGS.model_dir) else: base_dir, dir_name = None, None params = Params() params.assign_problem(FLAGS.problem, gpu=int(FLAGS.gpu), base_dir=base_dir, dir_name=dir_name) tf.logging.info('Checkpoint dir: %s' % params.ckpt_dir) time.sleep(3) model = BertMultiTask(params=params) model_fn = model.get_model_fn(warm_start=False) dist_trategy = tf.contrib.distribute.MirroredStrategy( num_gpus=int(FLAGS.gpu), cross_tower_ops=tf.contrib.distribute.AllReduceCrossTowerOps( 'nccl', num_packs=int(FLAGS.gpu))) run_config = tf.estimator.RunConfig( train_distribute=dist_trategy, eval_distribute=dist_trategy, log_step_count_steps=params.log_every_n_steps) # ws = make_warm_start_setting(params) estimator = Estimator(model_fn, model_dir=params.ckpt_dir, params=params, config=run_config) if FLAGS.schedule == 'train': train_hook = RestoreCheckpointHook(params) def train_input_fn(): return train_eval_input_fn(params) estimator.train(train_input_fn, max_steps=params.train_steps, hooks=[train_hook]) def input_fn(): return train_eval_input_fn(params, mode='eval') estimator.evaluate(input_fn=input_fn) elif FLAGS.schedule == 'eval': evaluate_func = getattr(metrics, FLAGS.eval_scheme + '_evaluate') print(evaluate_func(FLAGS.problem, estimator, params)) elif FLAGS.schedule == 'predict': def input_fn(): return predict_input_fn([ '''兰心餐厅\n作为一个无辣不欢的妹子,对上海菜的偏清淡偏甜真的是各种吃不惯。 每次出门和闺蜜越饭局都是避开本帮菜。后来听很多朋友说上海有几家特别正宗味道做 的很好的餐厅于是这周末和闺蜜们准备一起去尝一尝正宗的本帮菜。\n进贤路是我在上 海比较喜欢的一条街啦,这家餐厅就开在这条路上。已经开了三十多年的老餐厅了,地 方很小,就五六张桌子。但是翻桌率比较快。二楼之前的居民间也改成了餐厅,但是在 上海的名气却非常大。烧的就是家常菜,普通到和家里烧的一样,生意非常好,外面排 队的比里面吃的人还要多。''' ], params, mode='predict') pred = estimator.predict(input_fn=input_fn) for p in pred: print(p)