def freeze_graph_tf1(json_file, frozen_graph_dir): """ freeze TensorFlow model trained on Python to pb format, which gather graph defination and weights together into a single file . """ # restore the best model _, model, _, checkpointer = build_model_from_jsonfile(json_file, pre_run=False) model.deploy() checkpointer.restore_from_best() model.save_weights(os.path.join(frozen_graph_dir, "model.h5")) tf.keras.backend.clear_session() tf.compat.v1.disable_eager_execution() _, model, _, _ = build_model_from_jsonfile(json_file, pre_run=False) model.deploy() model.load_weights(os.path.join(frozen_graph_dir, "model.h5")) session = tf.compat.v1.keras.backend.get_session() frozen_graph_enc = freeze_session( session, output_names=[out.op.name for out in model.deploy_encoder.outputs]) tf.compat.v1.train.write_graph(frozen_graph_enc, frozen_graph_dir, "encoder.pb", as_text=False) frozen_graph_dec = freeze_session( session, output_names=[out.op.name for out in model.deploy_decoder.outputs]) tf.compat.v1.train.write_graph(frozen_graph_dec, frozen_graph_dir, "decoder.pb", as_text=False)
def decode(jsonfile): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn": _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(p.decode_config['lm_path']) lm_checkpointer.restore_from_best() checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config) dataset_builder = dataset_builder.compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1))
def decode(jsonfile, n=1, log_file=None): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) lm_model = None if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn": _, lm_model, _, lm_checkpointer = build_model_from_jsonfile( p.decode_config['lm_path']) lm_checkpointer.restore_from_best() if not os.path.exists(log_file): raise IOError('The file is not exist!') checkpoint_wer_dict = {} for line in open(log_file): if 'epoch:' in line: print(line) splits = line.strip().split('\t') epoch = int(splits[0].split(' ')[-1]) ctc_acc = float(splits[-1].split(' ')[-1]) checkpoint_wer_dict[epoch] = ctc_acc checkpoint_wer_dict = { k: v for k, v in sorted(checkpoint_wer_dict.items(), key=lambda item: item[1], reverse=True) } ckpt_index_list = list(checkpoint_wer_dict.keys())[0:n] print('best_wer_checkpoint: ') print(ckpt_index_list) ckpt_v_list = [] #restore v from ckpts for idx in ckpt_index_list: ckpt_path = p.ckpt + 'ckpt-' + str(idx + 1) checkpointer.restore(ckpt_path) #current variables will be updated var_list = [] for i in model.trainable_variables: v = tf.constant(i.value()) var_list.append(v) ckpt_v_list.append(var_list) #compute average, and assign to current variables for i in range(len(model.trainable_variables)): v = [ tf.expand_dims(ckpt_v_list[j][i], [0]) for j in range(len(ckpt_v_list)) ] v = tf.reduce_mean(tf.concat(v, axis=0), axis=0) model.trainable_variables[i].assign(v) solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder]( p.testset_config) dataset_builder = dataset_builder.compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1), dataset_builder)
def restore_model(json_file, frozen_graph_dir): """ restore the best model """ _, model, _, checkpointer = build_model_from_jsonfile(json_file, pre_run=False) model.deploy() checkpointer.restore_from_best() model.save_weights(os.path.join(frozen_graph_dir, "model.h5"))
def __init__(self, model, data_descriptions=None, config=None): super().__init__(model, None, None) self.model = model self.hparams = register_and_parse_hparams(self.default_config, config, cls=self.__class__) lm_model = None if self.hparams.lm_type == "rnn": from athena.main import build_model_from_jsonfile _, lm_model, _, lm_checkpointer = build_model_from_jsonfile( self.hparams.lm_path) lm_checkpointer.restore_from_best() if self.hparams.decoder_type == "beam_search_decoder": self.decoder = BeamSearchDecoder.build_decoder( self.hparams, self.model.num_class, self.model.sos, self.model.eos, self.model.time_propagate, lm_model=lm_model) elif self.hparams.decoder_type == "wfst_decoder": self.decoder = WFSTDecoder( self.hparams.wfst_graph, acoustic_scale=self.hparams.acoustic_scale, max_active=self.hparams.max_active, min_active=self.hparams.min_active, beam=self.hparams.wfst_beam, max_seq_len=self.hparams.max_seq_len, sos=self.model.sos, eos=self.model.eos) else: raise ValueError("This decoder type is not supported")
def decode(jsonfile): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile( jsonfile, 0) checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config) dataset_builder = dataset_builder.load_csv( p.test_csv).compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1))
def synthesize(jsonfile): """ entry point for speech synthesis, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) avg_num = 1 if 'model_avg_num' not in p.decode_config else p.decode_config['model_avg_num'] if avg_num > 0: checkpointer.compute_nbest_avg(avg_num) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config) solver = SynthesisSolver(model, dataset_builder, config=p.decode_config) solver.synthesize(dataset_builder.as_dataset(batch_size=1))
def decode(jsonfile, rank_size=1, rank=0): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) avg_num = 1 if 'model_avg_num' not in p.decode_config else p.decode_config[ 'model_avg_num'] checkpointer.compute_nbest_avg(avg_num) lm_model = None if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn": _, lm_model, _, lm_checkpointer = build_model_from_jsonfile( p.decode_config['lm_path']) lm_checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder]( p.testset_config) dataset_builder.shard(rank_size, rank) logging.info("shard result: %d" % len(dataset_builder)) solver.decode(dataset_builder.as_dataset(batch_size=1), rank_size=rank_size)
def inference(jsonfile, rank_size=1, rank=0): """ entry point for model inference, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) avg_num = 1 if 'model_avg_num' not in p.inference_config else p.inference_config['model_avg_num'] if avg_num > 0: checkpointer.compute_nbest_avg(avg_num) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config) dataset_builder.shard(rank_size, rank) logging.info("shard result: %d" % len(dataset_builder)) inference_solver = SOLVERS[p.solver_type] solver = inference_solver(model, dataset_builder, config=p.inference_config) solver.inference(dataset_builder.as_dataset(batch_size=1), rank_size=rank_size)
def freeze_saved_model(json_file, frozen_graph_dir): """ freeze TensorFlow model trained on Python to saved model, """ _, model, _, checkpointer = build_model_from_jsonfile(json_file, pre_run=False) def inference(x): samples = {"input": x} outputs = model.synthesize(samples) return outputs[0] model.deploy_function = tf.function( inference, input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.int32)]) tf.saved_model.save(obj=model, export_dir=frozen_graph_dir)
if __name__ == "__main__": logging.set_verbosity(logging.INFO) if len(sys.argv) < 2: logging.warning('Usage: python {} config_json_file'.format( sys.argv[0])) sys.exit() tf.random.set_seed(1) json_file = sys.argv[1] config = None with open(json_file) as f: config = json.load(f) p = parse_config(config) BaseSolver.initialize_devices(p.solver_gpu) sess = K.get_session() tf.compat.v1.disable_v2_behavior() _, model, _, checkpointer = build_model_from_jsonfile(json_file, pre_run=False) model.load_weights('./model.h5') frozen_graph = freeze_session( sess, output_names=[out.op.name for out in model.net.outputs]) # For Transformer #frozen_graph = freeze_session(sess, output_names=[out.op.name for out in model.encoder_pb.outputs]) #frozen_graph = freeze_session(sess, output_names=[out.op.name for out in model.decoder_pb.outputs]) with tf.compat.v1.gfile.GFile("./model.pb", "wb") as in_f: in_f.write(frozen_graph.SerializeToString()) print("Done!")