def main(_): utils.load(FLAGS.config_path) global_step = tf.train.get_or_create_global_step() inputs = tf.placeholder(tf.float32, [1, None, utils.Data.num_channel]) wavenet.bulid_wavenet(inputs, len(utils.Data.vocabulary), is_training=False) restore = utils.restore_from_pretrain(FLAGS.input_dir) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(restore) saver.save(sess, FLAGS.output_path) shutil.copy(FLAGS.config_path, FLAGS.output_path + '.json')
def main(_): inputs = tf.placeholder(shape=[None, None, 20], dtype=tf.float32) labels = tf.placeholder(shape=[None, None], dtype=tf.int64) is_training = tf.placeholder(shape=[], dtype=tf.bool) seq_len = tf.reduce_sum(tf.cast( tf.not_equal(tf.reduce_sum(inputs, axis=2), 0.), tf.int32), axis=1) global_step = tf.train.get_or_create_global_step() logits = wavenet.bulid_wavenet(inputs, len(utils.class_names), is_training) loss = tf.nn.ctc_loss(labels=labels, inputs=logits, sequence_length=seq_len) outputs, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_len, merge_repeated=False) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimize = tf.train.AdamOptimizer(learning_rate=0.01).minimize( loss=loss, global_step=global_step) restore_op = utils.restore_from_pretrain(FLAGS.pretrain_dir) save = tf.train.Saver() train_dattaset = dataset.create(FLAGS.train_dir) test_dataset = dataset.create(FLAGS.test_dir) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(restore_op) if len(os.listdir(FLAGS.checkpoint_dir)) > 0: save.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
def main(_): if not os.path.exists(FLAGS.ckpt_path + '.index'): glog.error('%s was not found.' % FLAGS.ckpt_path) return -1 utils.load(FLAGS.ckpt_path + '.json') vocabulary = tf.constant(utils.Data.vocabulary) inputs = tf.placeholder(tf.float32, [1, None, utils.Data.num_channel]) sequence_length = tf.placeholder(tf.int32, [None]) logits = wavenet.bulid_wavenet(inputs, len(utils.Data.vocabulary), is_training=False) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), sequence_length, merge_repeated=False) outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0])) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, FLAGS.ckpt_path) wave = utils.read_wave(FLAGS.input_path) output = utils.cvt_np2string( sess.run(outputs, feed_dict={ inputs: [wave], sequence_length: [wave.shape[0]] }))[0] glog.info('%s: %s.', FLAGS.input_path, output) return 0
def main(_): class_names = tf.constant(utils.Data.class_names) inputs = tf.placeholder(tf.float32, [1, None, utils.Data.channels]) seq_len = tf.reduce_sum(tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=2), 0.), tf.int32), axis=1) logits = wavenet.bulid_wavenet(inputs, len(utils.Data.class_names), is_training=False) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_len, merge_repeated=False) outputs = tf.sparse.to_dense(decodes[0]) + 1 outputs = tf.gather(class_names, outputs) restore = utils.restore_from_pretrain(FLAGS.pretrain_dir) save = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(restore) if os.path.exists(FLAGS.checkpoint_dir) and len(os.listdir(FLAGS.checkpoint_dir)) > 0: save.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) output = utils.cvt_np2string(sess.run(outputs, feed_dict={inputs: [utils.read_wave(FLAGS.input_path)]}))[0] glog.info('%s: %s.', FLAGS.input_path, output)
def main(_): os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.device) utils.load(FLAGS.config_path) global_step = tf.train.get_or_create_global_step() train_dataset = dataset.create(FLAGS.dataset_path, FLAGS.batch_size, repeat=True) # bug tensorflow!!! the train_dataset[0].shape[0] != FLAGS.batch_size once in a while # waves = tf.reshape(tf.sparse.to_dense(train_dataset[0]), shape=[FLAGS.batch_size, -1, utils.Data.num_channel]) waves = tf.sparse.to_dense(train_dataset[0]) waves = tf.reshape(waves, [tf.shape(waves)[0], -1, utils.Data.num_channel]) labels = tf.cast(train_dataset[1], tf.int32) sequence_length = tf.cast(train_dataset[2], tf.int32) logits = wavenet.bulid_wavenet(waves, len(utils.Data.vocabulary), is_training=True) loss = tf.reduce_mean( tf.nn.ctc_loss(labels, logits, sequence_length, time_major=False)) vocabulary = tf.constant(utils.Data.vocabulary) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, [1, 0, 2]), sequence_length, merge_repeated=False) outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0])) labels = tf.gather(vocabulary, tf.sparse.to_dense(labels)) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimize = tf.train.AdamOptimizer( learning_rate=FLAGS.learning_rate).minimize( loss=loss, global_step=global_step) save = tf.train.Saver(max_to_keep=1000) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(train_dataset[-1]) # if os.path.exists(FLAGS.pretrain_dir) and len(os.listdir(FLAGS.pretrain_dir)) > 0: # save.restore(sess, tf.train.latest_checkpoint(FLAGS.pretrain_dir)) ckpt_dir = os.path.split(FLAGS.ckpt_path)[0] if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) if len(os.listdir(ckpt_dir)) > 0: save.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) losses, tps, preds, poses = 0, 0, 0, 0 while True: gp, ll, uid, ot, ls, _ = sess.run( (global_step, labels, train_dataset[3], outputs, loss, optimize)) tp, pred, pos = utils.evalutes(utils.cvt_np2string(ot), utils.cvt_np2string(ll)) tps += tp losses += ls preds += pred poses += pos if gp % FLAGS.display == 0: glog.info( "Step %d: loss=%f, tp=%d, pos=%d, pred=%d, f1=%f." % (gp, losses if gp == 0 else (losses / FLAGS.display), tps, preds, poses, 2 * tps / (preds + poses + 1e-10))) losses, tps, preds, poses = 0, 0, 0, 0 if (gp + 1) % FLAGS.snapshot == 0 and gp != 0: save.save(sess, FLAGS.ckpt_path, global_step=global_step)
def main(_): utils.load(FLAGS.config_path) os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.device) # with tf.device(FLAGS.device): test_dataset = dataset.create(FLAGS.dataset_path, repeat=False, batch_size=1) waves = tf.reshape(tf.sparse.to_dense(test_dataset[0]), shape=[1, -1, utils.Data.num_channel]) labels = tf.sparse.to_dense(test_dataset[1]) sequence_length = tf.cast(test_dataset[2], tf.int32) vocabulary = tf.constant(utils.Data.vocabulary) labels = tf.gather(vocabulary, labels) logits = wavenet.bulid_wavenet(waves, len(utils.Data.vocabulary)) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), sequence_length, merge_repeated=False) outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0])) save = tf.train.Saver() evalutes = {} if os.path.exists(FLAGS.ckpt_dir + '/evalute.json'): evalutes = json.load( open(FLAGS.ckpt_dir + '/evalute.json', encoding='utf-8')) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: status = 0 while True: filepaths = glob.glob(FLAGS.ckpt_dir + '/*.index') filepaths.sort() filepaths.reverse() filepath = filepaths[0] max_uid = 0 for filepath in filepaths: model_path = os.path.splitext(filepath)[0] uid = os.path.split(model_path)[-1] if max_uid <= int(uid.split("-")[1]): max_uid = int(uid.split("-")[1]) max_uid_full = uid max_model_path = model_path # print(max_uid) status = 2 sess.run(tf.global_variables_initializer()) sess.run(test_dataset[-1]) save.restore(sess, max_model_path) # sa print(tf.train.latest_checkpoint(FLAGS.ckpt_dir)) # ve.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir)) evalutes[max_uid_full] = {} tps, preds, poses, count = 0, 0, 0, 0 while True: try: count += 1 y, y_ = sess.run((labels, outputs)) y = utils.cvt_np2string(y) y_ = utils.cvt_np2string(y_) tp, pred, pos = utils.evalutes(y_, y) tps += tp preds += pred poses += pos # if count % 1000 == 0: # glog.info('processed %d: tp=%d, pred=%d, pos=%d.' % (count, tps, preds, poses)) except: # if count % 1000 != 0: # glog.info('processed %d: tp=%d, pred=%d, pos=%d.' % (count, tps, preds, poses)) break evalutes[max_uid_full]['tp'] = tps evalutes[max_uid_full]['pred'] = preds evalutes[max_uid_full]['pos'] = poses evalutes[max_uid_full]['f1'] = 2 * tps / (preds + poses + 1e-20) json.dump( evalutes, open(FLAGS.ckpt_dir + '/evalute.json', mode='w', encoding='utf-8')) evalute = evalutes[max_uid_full] glog.info('Evalute %s: tp=%d, pred=%d, pos=%d, f1=%f.' % (max_uid_full, evalute['tp'], evalute['pred'], evalute['pos'], evalute['f1'])) if status == 1: time.sleep(60) status = 1