Exemplo n.º 1
0
def freeze_graph_tf1(json_file, frozen_graph_dir):
    """ freeze TensorFlow model trained on Python to pb format,
        which gather graph defination and weights together into
        a single file .
    """
    # restore the best model
    _, model, _, checkpointer = build_model_from_jsonfile(json_file,
                                                          pre_run=False)
    model.deploy()
    checkpointer.restore_from_best()
    model.save_weights(os.path.join(frozen_graph_dir, "model.h5"))

    tf.keras.backend.clear_session()
    tf.compat.v1.disable_eager_execution()

    _, model, _, _ = build_model_from_jsonfile(json_file, pre_run=False)
    model.deploy()
    model.load_weights(os.path.join(frozen_graph_dir, "model.h5"))

    session = tf.compat.v1.keras.backend.get_session()
    frozen_graph_enc = freeze_session(
        session,
        output_names=[out.op.name for out in model.deploy_encoder.outputs])
    tf.compat.v1.train.write_graph(frozen_graph_enc,
                                   frozen_graph_dir,
                                   "encoder.pb",
                                   as_text=False)

    frozen_graph_dec = freeze_session(
        session,
        output_names=[out.op.name for out in model.deploy_decoder.outputs])
    tf.compat.v1.train.write_graph(frozen_graph_dec,
                                   frozen_graph_dir,
                                   "decoder.pb",
                                   as_text=False)
Exemplo n.º 2
0
def decode(jsonfile):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn":
        _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(p.decode_config['lm_path'])
        lm_checkpointer.restore_from_best()
    checkpointer.restore_from_best()
    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config)
    dataset_builder = dataset_builder.compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1))
Exemplo n.º 3
0
def decode(jsonfile, n=1, log_file=None):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    lm_model = None
    if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn":
        _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(
            p.decode_config['lm_path'])
        lm_checkpointer.restore_from_best()
    if not os.path.exists(log_file):
        raise IOError('The file is not exist!')
    checkpoint_wer_dict = {}
    for line in open(log_file):
        if 'epoch:' in line:
            print(line)
            splits = line.strip().split('\t')
            epoch = int(splits[0].split(' ')[-1])
            ctc_acc = float(splits[-1].split(' ')[-1])
            checkpoint_wer_dict[epoch] = ctc_acc
    checkpoint_wer_dict = {
        k: v
        for k, v in sorted(checkpoint_wer_dict.items(),
                           key=lambda item: item[1],
                           reverse=True)
    }
    ckpt_index_list = list(checkpoint_wer_dict.keys())[0:n]
    print('best_wer_checkpoint: ')
    print(ckpt_index_list)
    ckpt_v_list = []
    #restore v from ckpts
    for idx in ckpt_index_list:
        ckpt_path = p.ckpt + 'ckpt-' + str(idx + 1)
        checkpointer.restore(ckpt_path)  #current variables will be updated
        var_list = []
        for i in model.trainable_variables:
            v = tf.constant(i.value())
            var_list.append(v)
        ckpt_v_list.append(var_list)
    #compute average, and assign to current variables
    for i in range(len(model.trainable_variables)):
        v = [
            tf.expand_dims(ckpt_v_list[j][i], [0])
            for j in range(len(ckpt_v_list))
        ]
        v = tf.reduce_mean(tf.concat(v, axis=0), axis=0)
        model.trainable_variables[i].assign(v)
    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](
        p.testset_config)
    dataset_builder = dataset_builder.compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1), dataset_builder)
Exemplo n.º 4
0
def restore_model(json_file, frozen_graph_dir):
    """ restore the best model """
    _, model, _, checkpointer = build_model_from_jsonfile(json_file,
                                                          pre_run=False)
    model.deploy()
    checkpointer.restore_from_best()
    model.save_weights(os.path.join(frozen_graph_dir, "model.h5"))
Exemplo n.º 5
0
 def __init__(self, model, data_descriptions=None, config=None):
     super().__init__(model, None, None)
     self.model = model
     self.hparams = register_and_parse_hparams(self.default_config,
                                               config,
                                               cls=self.__class__)
     lm_model = None
     if self.hparams.lm_type == "rnn":
         from athena.main import build_model_from_jsonfile
         _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(
             self.hparams.lm_path)
         lm_checkpointer.restore_from_best()
     if self.hparams.decoder_type == "beam_search_decoder":
         self.decoder = BeamSearchDecoder.build_decoder(
             self.hparams,
             self.model.num_class,
             self.model.sos,
             self.model.eos,
             self.model.time_propagate,
             lm_model=lm_model)
     elif self.hparams.decoder_type == "wfst_decoder":
         self.decoder = WFSTDecoder(
             self.hparams.wfst_graph,
             acoustic_scale=self.hparams.acoustic_scale,
             max_active=self.hparams.max_active,
             min_active=self.hparams.min_active,
             beam=self.hparams.wfst_beam,
             max_seq_len=self.hparams.max_seq_len,
             sos=self.model.sos,
             eos=self.model.eos)
     else:
         raise ValueError("This decoder type is not supported")
Exemplo n.º 6
0
def decode(jsonfile):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile(
        jsonfile, 0)
    checkpointer.restore_from_best()
    solver = DecoderSolver(model, config=p.decode_config)
    dataset_builder = dataset_builder.load_csv(
        p.test_csv).compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1))
Exemplo n.º 7
0
def synthesize(jsonfile):
    """ entry point for speech synthesis, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    avg_num = 1 if 'model_avg_num' not in p.decode_config else p.decode_config['model_avg_num']
    if avg_num > 0:
        checkpointer.compute_nbest_avg(avg_num)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config)
    solver = SynthesisSolver(model, dataset_builder, config=p.decode_config)
    solver.synthesize(dataset_builder.as_dataset(batch_size=1))
Exemplo n.º 8
0
def decode(jsonfile, rank_size=1, rank=0):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    avg_num = 1 if 'model_avg_num' not in p.decode_config else p.decode_config[
        'model_avg_num']
    checkpointer.compute_nbest_avg(avg_num)
    lm_model = None
    if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn":
        _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(
            p.decode_config['lm_path'])
        lm_checkpointer.restore_from_best()

    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](
        p.testset_config)
    dataset_builder.shard(rank_size, rank)
    logging.info("shard result: %d" % len(dataset_builder))
    solver.decode(dataset_builder.as_dataset(batch_size=1),
                  rank_size=rank_size)
Exemplo n.º 9
0
def inference(jsonfile, rank_size=1, rank=0):
    """ entry point for model inference, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    avg_num = 1 if 'model_avg_num' not in p.inference_config else p.inference_config['model_avg_num']
    if avg_num > 0:
        checkpointer.compute_nbest_avg(avg_num)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config)
    dataset_builder.shard(rank_size, rank)
    logging.info("shard result: %d" % len(dataset_builder))

    inference_solver = SOLVERS[p.solver_type]
    solver = inference_solver(model, dataset_builder, config=p.inference_config)
    solver.inference(dataset_builder.as_dataset(batch_size=1), rank_size=rank_size)
Exemplo n.º 10
0
def freeze_saved_model(json_file, frozen_graph_dir):
    """ freeze TensorFlow model trained on Python to saved model,
    """
    _, model, _, checkpointer = build_model_from_jsonfile(json_file,
                                                          pre_run=False)

    def inference(x):
        samples = {"input": x}
        outputs = model.synthesize(samples)
        return outputs[0]

    model.deploy_function = tf.function(
        inference,
        input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.int32)])
    tf.saved_model.save(obj=model, export_dir=frozen_graph_dir)
Exemplo n.º 11
0

if __name__ == "__main__":
    logging.set_verbosity(logging.INFO)
    if len(sys.argv) < 2:
        logging.warning('Usage: python {} config_json_file'.format(
            sys.argv[0]))
        sys.exit()
    tf.random.set_seed(1)

    json_file = sys.argv[1]
    config = None
    with open(json_file) as f:
        config = json.load(f)
    p = parse_config(config)
    BaseSolver.initialize_devices(p.solver_gpu)
    sess = K.get_session()
    tf.compat.v1.disable_v2_behavior()
    _, model, _, checkpointer = build_model_from_jsonfile(json_file,
                                                          pre_run=False)

    model.load_weights('./model.h5')
    frozen_graph = freeze_session(
        sess, output_names=[out.op.name for out in model.net.outputs])
    # For Transformer
    #frozen_graph = freeze_session(sess, output_names=[out.op.name for out in model.encoder_pb.outputs])
    #frozen_graph = freeze_session(sess, output_names=[out.op.name for out in model.decoder_pb.outputs])
    with tf.compat.v1.gfile.GFile("./model.pb", "wb") as in_f:
        in_f.write(frozen_graph.SerializeToString())
    print("Done!")