def model_fn(features, labels, mode, params): tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] print('shape of input_ids', input_ids.shape) # label_mask = features["label_mask"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) # 使用参数构建模型,input_idx 就是输入的样本idx表示,label_ids 就是标签的idx表示 total_loss, logits, trans, pred_ids = create_model( bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, False, args.dropout_rate, args.lstm_size, args.cell, args.num_layers) tvars = tf.trainable_variables() # 加载BERT模型 if init_checkpoint: (assignment_map, initialized_variable_names) = \ modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 打印变量名 # logger.info("**** Trainable Variables ****") # # # 打印加载模型的参数 # for var in tvars: # init_string = "" # if var.name in initialized_variable_names: # init_string = ", *INIT_FROM_CKPT*" # logger.info(" name = %s, shape = %s%s", var.name, var.shape, # init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: #train_op = optimizer.optimizer(total_loss, learning_rate, num_train_steps) train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, False) hook_dict = {} hook_dict['loss'] = total_loss hook_dict['global_steps'] = tf.train.get_or_create_global_step() logging_hook = tf.train.LoggingTensorHook( hook_dict, every_n_iter=args.save_summary_steps) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[logging_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # 针对NER ,进行了修改 def metric_fn(label_ids, pred_ids): return { "eval_loss": tf.metrics.mean_squared_error(labels=label_ids, predictions=pred_ids), } eval_metrics = metric_fn(label_ids, pred_ids) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=eval_metrics ) else: output_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=pred_ids ) return output_spec
def main(): # if args.mode == 'train' ap = [] with open('../../../medical_char_data_cleaned/vocab.tags.txt', 'r') as fin: for line in fin: ap.append(line.strip()) fin.close() length = len(ap) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.625) # config = tf.ConfigProto() # config.gpu_options.allow_growth = True sess = tf.Session(config=tf.ConfigProto( # device_count={ "CPU": 48 }, # inter_op_parallelism_threads=10, allow_soft_placement=True, # intra_op_parallelism_threads=20, gpu_options=gpu_options)) generator = Generator_BiLSTM_CRF(0.5, 1, batch_size, params, filter_sizes, num_filters, 0.75, length) generator.build_graph() tvars = tf.trainable_variables() (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 最后初始化变量 # sess.run(tf.global_variables_initializer()) sess.run(generator.init_op) sess.run(generator.table_op) sess.run(generator.init_op_1) saver = tf.train.Saver(tf.global_variables()) # tf.logging.info("**** Trainable Variables ****") # for var in tvars: # init_string = "" # if var.name in initialized_variable_names: # init_string = ", *INIT_FROM_CKPT*" # print(" name = %s, shape = %s%s", var.name, var.shape, # init_string) # if args.mode == 'train': train_path = os.path.join('.', args.train_data, 'train_data1') train_unlabel_path = os.path.join('.', args.train_data_unlabel, 'train_unlabel') train_unlabel_path_1 = os.path.join('.', args.train_data_unlabel, 'train_unlabel1') test_path = os.path.join('.', args.test_data, 'test_data1') sub_test_path = os.path.join('.', args.sub_test_data, 'sub_test_data') train_data = read_corpus(train_path) train_data_unlabel = read_corpus_unlabel(train_unlabel_path) train_data_unlabel_1 = read_corpus_unlabel(train_unlabel_path_1) test_data = read_corpus(test_path) test_size = len(test_data) sub_test_data = read_corpus(sub_test_path) batches_labeled = batch_yield(train_data, batch_size, shuffle=True) batches_labeled = list(batches_labeled) # print(len(batches_labeled)) num_batches = (len(train_data) + batch_size - 1) // batch_size batches_unlabeled = batch_yield_for_unla_da(train_data_unlabel, batch_size, shuffle=True) batches_unlabeled = list(batches_unlabeled) # print(len(batches_unlabeled)) batches_labeled_for_dis = batch_yield_for_discri(train_data, batch_size, shuffle=True) batches_labeled_for_dis = list(batches_labeled_for_dis) batches_unlabeled_for_dis = batch_yield_for_discri_unlabeled( train_data_unlabel, batch_size, shuffle=True) batches_unlabeled_for_dis = list(batches_unlabeled_for_dis) dev = batch_yield(test_data, batch_size, shuffle=True) # num_batches = min(len(batches_labeled),len(batches_unlabeled)) num_batches_unlabel = (len(train_data_unlabel) + batch_size - 1) // batch_size num_batches_1 = min(len(batches_labeled_for_dis), len(batches_unlabeled_for_dis)) index = 0 if args.mode == 'train': for epoch_total in range(30): print('epoch_total and index are {} and {}'.format( epoch_total + 1, index)) medi_lis = get_metrics(sess, generator, dev, test_size, batch_size, flag=0) for ele in medi_lis: print('实体识别的', ele) print('the whole epoch training accuracy finished!!!!!!!!!!!!') for i, (words, labels) in enumerate(batches_labeled): run_one_epoch(sess, words, labels, tags=[], dev=test_data, epoch=epoch_total, gen=generator, num_batches=num_batches, batch=i, label=0, it=0, iteration=0, saver=saver) dev1 = batch_yield(test_data, batch_size, shuffle=True) medi_lis_from_cross_entropy_training = get_metrics(sess, generator, dev1, test_size, batch_size, flag=0) for ele in medi_lis_from_cross_entropy_training: print('第一次', ele) print( 'the accuray after cross entropy training finished!!!!!!!!!!!!!!!!!!1' ) # if epoch_total > 3: # # batches_labeled_for_dis = batches_labeled_for_dis[0: len(batches_labeled_for_dis)-5] # batch_dis_for_label = len(batches_labeled_for_dis) # batch_dis_for_unlabel = len(batches_unlabeled_for_dis) # for (ele, ele2) in zip(enumerate(batches_labeled_for_dis), enumerate(batches_unlabeled_for_dis)): # index += 1 # # if index > 70: # # break # run_one_epoch(sess, ele[1][0], ele[1][1], ele[1][2], dev=test_data, epoch=epoch_total, # gen=generator, # num_batches=batch_dis_for_label, batch=index, label=2, it=0, iteration=0, saver=saver) # run_one_epoch(sess, ele2[1][0], ele2[1][1], ele2[1][2], dev=test_data, epoch=epoch_total, # gen=generator, # num_batches=batch_dis_for_unlabel, batch=index, label=3, it=0, iteration=0, # saver=saver) # index = 0 # # print('the whole dis phaseI finished') # # index += 1 # for it in range(5): # for i, (words, labels, tags) in enumerate(batches_unlabeled): # # print(i) # run_one_epoch(sess, words, labels, tags=tags, dev=test_data, epoch=epoch_total, gen=generator, # num_batches=num_batches_unlabel, batch=i, label=1, it=it, iteration=i, # saver=saver) # # dev2 = batch_yield(test_data, batch_size, shuffle=True) # # medi_lis_from_adversarial_training = get_metrics(sess, generator, dev2, test_size, batch_size, flag=0) # # for ele in medi_lis_from_adversarial_training: # print('第二次打印', ele) # # print('the accuracy after adversarial training of generator finised!!!!!!!!!!!!!!') # # print('epoch {} finished!'.format(epoch_total)) if args.mode == 'test': sub_dev = batch_yield_for_discri_unlabeled(sub_test_data, batch_size, shuffle=True) # print(list(sub_dev)) ckpt_file = tf.train.latest_checkpoint(model_path) # print(ckpt_file) generator = Generator_BiLSTM_CRF(0.5, batch_size, params, filter_sizes, num_filters, 0.75, length, is_training=False) generator.build_graph() generator.test(sess, sub_dev, test_size, 20)
def train(self): if ARGS.bert: from bert_data_utils import BertDataUtils tokenizer = tokenization.FullTokenizer(vocab_file=ARGS.vocab_dir, ) self.train_data = BertDataUtils(tokenizer, batch_size=1) self.dev_data = BertDataUtils(tokenizer, batch_size=20) self.dev_batch = self.dev_data.iteration() else: from data_utils import DataBatch self.train_data = DataBatch(data_type='train', batch_size=1) data = { "batch_size": self.train_data.batch_size, "input_size": self.train_data.input_size, "vocab": self.train_data.vocab, "tag_map": self.train_data.tag_map, } f = open("data/data_map.pkl", "wb") cPickle.dump(data, f) f.close() self.vocab = self.train_data.vocab self.input_size = len(self.vocab.values()) + 1 self.dev_data = DataBatch(data_type='dev', batch_size=300) self.dev_batch = self.dev_data.iteration() self.nums_tags = len(self.train_data.tag_map.keys()) self.tag_map = self.train_data.tag_map self.train_length = len(self.train_data.data) # self.test_data = DataBatch(data_type='test', batch_size=100) # self.test_batch = self.test_data.get_batch() # save vocab print("-" * 50) print("train data:\t", self.train_length) print("nums of tags:\t", self.nums_tags) self.__creat_model() with tf.Session() as sess: with tf.device("/gpu:0"): ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) if ckpt and tf.train.checkpoint_exists( ckpt.model_checkpoint_path): print("restore model") self.saver.restore(sess, ckpt.model_checkpoint_path) else: sess.run(tf.global_variables_initializer()) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names) = \ modeling.get_assignment_map_from_checkpoint(tvars, ARGS.init_checkpoint) tf.train.init_from_checkpoint(ARGS.init_checkpoint, assignment_map) for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" print(" name = %s, shape = %s%s", var.name, var.shape, init_string) for i in range(self.max_epoch): print("-" * 50) print("epoch {}".format(i)) steps = 0 for batch in self.train_data.get_batch(): steps += 1 if ARGS.bert: global_steps, loss, logits, acc, length = self.bert_step( sess, batch) else: global_steps, loss, logits, acc, length = self.step( sess, batch) if steps % 1 == 0: print("[->] step {}/{}\tloss {:.2f}\tacc {:.2f}". format(steps, len(self.train_data.batch_data), loss, acc)) if ARGS.bert: self.bert_evaluate(sess, "ORG") self.bert_evaluate(sess, "PER") else: self.evaluate(sess, "ORG") self.evaluate(sess, "PER") self.saver.save(sess, self.checkpoint_path)
def optimize_bert_graph(args, logger=None): if not logger: logger = set_logger(colored('GRAPHOPT', 'cyan'), args.verbose) try: if not os.path.exists(args.model_pb_dir): os.mkdir(args.model_pb_dir) pb_file = os.path.join(args.model_pb_dir, 'bert_model.pb') if os.path.exists(pb_file): return pb_file # we don't need GPU for optimizing the graph tf = import_tf(verbose=args.verbose) from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) config_fp = os.path.join(args.model_dir, args.config_name) init_checkpoint = os.path.join(args.tuned_model_dir or args.bert_model_dir, args.ckpt_name) if args.fp16: logger.warning('fp16 is turned on! ' 'Note that not all CPU GPU support fast fp16 instructions, ' 'worst case you will have degraded performance!') logger.info('model config: %s' % config_fp) logger.info( 'checkpoint%s: %s' % ( ' (override by the fine-tuned model)' if args.tuned_model_dir else '', init_checkpoint)) with tf.gfile.GFile(config_fp, 'r') as f: bert_config = modeling.BertConfig.from_dict(json.load(f)) logger.info('build graph...') # input placeholders, not sure if they are friendly to XLA input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids') input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask') input_type_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_type_ids') jit_scope = tf.contrib.compiler.jit.experimental_jit_scope if args.xla else contextlib.suppress with jit_scope(): input_tensors = [input_ids, input_mask, input_type_ids] model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=False) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30 mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1) masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) with tf.variable_scope("pooling"): if len(args.pooling_layer) == 1: encoder_layer = model.all_encoder_layers[args.pooling_layer[0]] else: all_layers = [model.all_encoder_layers[l] for l in args.pooling_layer] encoder_layer = tf.concat(all_layers, -1) input_mask = tf.cast(input_mask, tf.float32) if args.pooling_strategy == PoolingStrategy.REDUCE_MEAN: pooled = masked_reduce_mean(encoder_layer, input_mask) elif args.pooling_strategy == PoolingStrategy.REDUCE_MAX: pooled = masked_reduce_max(encoder_layer, input_mask) elif args.pooling_strategy == PoolingStrategy.REDUCE_MEAN_MAX: pooled = tf.concat([masked_reduce_mean(encoder_layer, input_mask), masked_reduce_max(encoder_layer, input_mask)], axis=1) elif args.pooling_strategy == PoolingStrategy.FIRST_TOKEN or \ args.pooling_strategy == PoolingStrategy.CLS_TOKEN: pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1) elif args.pooling_strategy == PoolingStrategy.LAST_TOKEN or \ args.pooling_strategy == PoolingStrategy.SEP_TOKEN: seq_len = tf.cast(tf.reduce_sum(input_mask, axis=1), tf.int32) rng = tf.range(0, tf.shape(seq_len)[0]) indexes = tf.stack([rng, seq_len - 1], 1) pooled = tf.gather_nd(encoder_layer, indexes) elif args.pooling_strategy == PoolingStrategy.NONE: pooled = mul_mask(encoder_layer, input_mask) else: raise NotImplementedError() if args.fp16: pooled = tf.cast(pooled, tf.float16) pooled = tf.identity(pooled, 'final_encodes') output_tensors = [pooled] tmp_g = tf.get_default_graph().as_graph_def() with tf.Session(config=config) as sess: logger.info('load parameters from checkpoint...') sess.run(tf.global_variables_initializer()) dtypes = [n.dtype for n in input_tensors] logger.info('optimize...') tmp_g = optimize_for_inference( tmp_g, [n.name[:-2] for n in input_tensors], [n.name[:-2] for n in output_tensors], [dtype.as_datatype_enum for dtype in dtypes], False) logger.info('freeze...') tmp_g = convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors], use_fp16=args.fp16) logger.info('write graph to a tmp file: %s' % args.model_pb_dir) with tf.gfile.GFile(pb_file, 'wb') as f: f.write(tmp_g.SerializeToString()) except Exception: logger.error('fail to optimize the graph!', exc_info=True)
def model_fn(features, labels, mode, params): logger.info("*** Features ***") for name in sorted(features.keys()): logger.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] else: input_mask = None segment_ids = None label_ids = None (total_loss, per_example_loss, logits, probabilities) = create_model(bert_config, is_training, input_ids, mode, input_mask, segment_ids, label_ids, num_labels) # resort variable from checkpoint file to init current graph tvars = tf.trainable_variables() initialized_variable_names = {} init_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = \ modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) #variables_to_restore = tf.contrib.framework.get_model_variables() #init_fn = tf.contrib.framework.\ # assign_from_checkpoint_fn(init_checkpoint, # variables_to_restore, # ignore_missing_vars=True) # 打印变量名称 logger.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logger.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False) output_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op) # training_hooks=[RestoreHook(init_fn)]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(label_ids, predictions) loss = tf.metrics.mean(per_example_loss) return { "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = metric_fn(per_example_loss, label_ids, logits) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=eval_metrics #evaluation_hooks=[RestoreHook(init_fn)] ) else: output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities) return output_spec
def model_fn(features, labels, mode, params): """ :param features: :param labels: :param mode: :param params: :return: """ logger.info('*** Features ***') for name in sorted(features.keys()): logger.info(' name = %s, shape = %s' % (name, features[name].shape)) input_ids = features['input_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] label_ids = features['label_ids'] print('shape of input ids', input_ids.shape) is_training = (mode == tf.estimator.ModeKeys.TRAIN) # 使用参数构造模型,input_idx就是输入的样本idx表示,label_ids就是标签的idx表示 total_loss, logits, trans, pred_ids = create_model( bert_config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=label_ids, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=FLAGS.dropout_rate, lstm_size=FLAGS.lstm_size, cell=FLAGS.cell, num_layers=FLAGS.num_layers) """ tf.trainable_variables(): 返回需要训练的变量列表 tf.all_variables(): 返回的是所有变量的列表 """ tvars = tf.trainable_variables() # 加载Bert模型 if init_checkpoint: (assigment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars=tvars, init_checkpoint=init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assigment_map) output_spec = None # 分三种情况,mode分别为训练、验证、测试 if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( loss=total_loss, init_lr=learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=False) hook_dict = dict() hook_dict['loss'] = total_loss hook_dict['global_steps'] = tf.train.get_or_create_global_step() logging_hook = tf.train.LoggingTensorHook( hook_dict, every_n_iter=FLAGS.save_summary_steps) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[logging_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(label_ids, pred_ids): return { 'eval_loss': tf.metrics.mean_squared_error(labels=label_ids, predictions=pred_ids) } eval_metrics = metric_fn(label_ids, pred_ids) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=eval_metrics) else: output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=pred_ids) return output_spec
def model_fn(features, labels, mode, params): tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] print('shape of input_ids', input_ids.shape) # label_mask = features["label_mask"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) # 使用参数构建模型,input_idx 就是输入的样本idx表示,label_ids 就是标签的idx表示 (total_loss, logits, trans, pred_ids) = create_model(bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, False, args.dropout_rate, args.lstm_size, args.cell, args.num_layers) tvars = tf.trainable_variables() # 加载BERT模型 if init_checkpoint: (assignment_map, initialized_variable_names) = \ modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 打印变量名 # logger.info("**** Trainable Variables ****") # # # 打印加载模型的参数 # for var in tvars: # init_string = "" # if var.name in initialized_variable_names: # init_string = ", *INIT_FROM_CKPT*" # logger.info(" name = %s, shape = %s%s", var.name, var.shape, # init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: #train_op = optimizer.optimizer(total_loss, learning_rate, num_train_steps) train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, False) output_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op) elif mode == tf.estimator.ModeKeys.EVAL: # 针对NER ,进行了修改 def metric_fn(label_ids, logits, trans): # 首先对结果进行维特比解码 # crf 解码 weight = tf.sequence_mask(args.max_seq_length) precision = tf_metrics.precision(label_ids, pred_ids, num_labels, None, weight) recall = tf_metrics.recall(label_ids, pred_ids, num_labels, None, weight) f = tf_metrics.f1(label_ids, pred_ids, num_labels, None, weight) return { "eval_precision": precision, "eval_recall": recall, "eval_f": f, "eval_loss": tf.metrics.mean_squared_error(labels=label_ids, predictions=pred_ids), } eval_metrics = metric_fn(label_ids, logits, trans) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=eval_metrics) else: output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=pred_ids) return output_spec
token_type_ids= tf.placeholder(tf.int32, shape=[20, 128]) model = modeling.BertModel( config=bert_config, is_training=True, input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids, use_one_hot_embeddings=False ) # 调用init_from_checkpoint方法 # 最后初始化变量 graph = tf.get_default_graph() tvars = tf.trainable_variables() (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 初始化所有的变量 tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" print(" name = %s, shape = %s%s", var.name, var.shape, init_string) sess.run(tf.global_variables_initializer()) embeddings = model.get_sequence_output() print(embeddings.shape)