def main(_): if not os.path.exists(FLAGS.ckpt_path + '.index'): glog.error('%s was not found.' % FLAGS.ckpt_path) return -1 utils.load(FLAGS.ckpt_path + '.json') vocabulary = tf.constant(utils.Data.vocabulary) inputs = tf.placeholder(tf.float32, [1, None, utils.Data.num_channel]) sequence_length = tf.placeholder(tf.int32, [None]) logits = wavenet.bulid_wavenet(inputs, len(utils.Data.vocabulary), is_training=False) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), sequence_length, merge_repeated=False) outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0])) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, FLAGS.ckpt_path) wave = utils.read_wave(FLAGS.input_path) output = utils.cvt_np2string( sess.run(outputs, feed_dict={ inputs: [wave], sequence_length: [wave.shape[0]] }))[0] glog.info('%s: %s.', FLAGS.input_path, output) return 0
def main(_): class_names = tf.constant(utils.Data.class_names) inputs = tf.placeholder(tf.float32, [1, None, utils.Data.channels]) seq_len = tf.reduce_sum(tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=2), 0.), tf.int32), axis=1) logits = wavenet.bulid_wavenet(inputs, len(utils.Data.class_names), is_training=False) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_len, merge_repeated=False) outputs = tf.sparse.to_dense(decodes[0]) + 1 outputs = tf.gather(class_names, outputs) restore = utils.restore_from_pretrain(FLAGS.pretrain_dir) save = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(restore) if os.path.exists(FLAGS.checkpoint_dir) and len(os.listdir(FLAGS.checkpoint_dir)) > 0: save.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) output = utils.cvt_np2string(sess.run(outputs, feed_dict={inputs: [utils.read_wave(FLAGS.input_path)]}))[0] glog.info('%s: %s.', FLAGS.input_path, output)
def main(_): os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.device) utils.load(FLAGS.config_path) global_step = tf.train.get_or_create_global_step() train_dataset = dataset.create(FLAGS.dataset_path, FLAGS.batch_size, repeat=True) # bug tensorflow!!! the train_dataset[0].shape[0] != FLAGS.batch_size once in a while # waves = tf.reshape(tf.sparse.to_dense(train_dataset[0]), shape=[FLAGS.batch_size, -1, utils.Data.num_channel]) waves = tf.sparse.to_dense(train_dataset[0]) waves = tf.reshape(waves, [tf.shape(waves)[0], -1, utils.Data.num_channel]) labels = tf.cast(train_dataset[1], tf.int32) sequence_length = tf.cast(train_dataset[2], tf.int32) logits = wavenet.bulid_wavenet(waves, len(utils.Data.vocabulary), is_training=True) loss = tf.reduce_mean( tf.nn.ctc_loss(labels, logits, sequence_length, time_major=False)) vocabulary = tf.constant(utils.Data.vocabulary) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, [1, 0, 2]), sequence_length, merge_repeated=False) outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0])) labels = tf.gather(vocabulary, tf.sparse.to_dense(labels)) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimize = tf.train.AdamOptimizer( learning_rate=FLAGS.learning_rate).minimize( loss=loss, global_step=global_step) save = tf.train.Saver(max_to_keep=1000) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(train_dataset[-1]) # if os.path.exists(FLAGS.pretrain_dir) and len(os.listdir(FLAGS.pretrain_dir)) > 0: # save.restore(sess, tf.train.latest_checkpoint(FLAGS.pretrain_dir)) ckpt_dir = os.path.split(FLAGS.ckpt_path)[0] if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) if len(os.listdir(ckpt_dir)) > 0: save.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) losses, tps, preds, poses = 0, 0, 0, 0 while True: gp, ll, uid, ot, ls, _ = sess.run( (global_step, labels, train_dataset[3], outputs, loss, optimize)) tp, pred, pos = utils.evalutes(utils.cvt_np2string(ot), utils.cvt_np2string(ll)) tps += tp losses += ls preds += pred poses += pos if gp % FLAGS.display == 0: glog.info( "Step %d: loss=%f, tp=%d, pos=%d, pred=%d, f1=%f." % (gp, losses if gp == 0 else (losses / FLAGS.display), tps, preds, poses, 2 * tps / (preds + poses + 1e-10))) losses, tps, preds, poses = 0, 0, 0, 0 if (gp + 1) % FLAGS.snapshot == 0 and gp != 0: save.save(sess, FLAGS.ckpt_path, global_step=global_step)
def train(train_loader, scheduler, model, loss_fn, val_loader, writer=None): decoder_vocabulary = utils.Data.decoder_vocabulary vocabulary = utils.Data.vocabulary decoder = CTCBeamDecoder( decoder_vocabulary, #"_abcdefghijklmopqrstuvwxyz_", model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, num_processes=4, blank_id=27, log_probs_input=True ) train_loss_list = list() val_loss_list = list() # prefetcher = data_prefetcher(train_loader) # data = prefetcher.next() best_loss = float('inf') for epoch in range(cfg.epochs): print(f'Training epoch {epoch}') # start_time = time.time() _loss = 0.0 step_cnt = 0 if epoch == 0: if cfg.sparse_mode == 'find_retrain': print("find_pattern_start") cfg.fd_rtn_pattern_set = dict() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) cnt = 0 patterns_dir = os.path.join(cfg.workdir, 'patterns') if not os.path.exists(patterns_dir): os.mkdir(patterns_dir) weights_dir = os.path.join(cfg.workdir, 'weights_save') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if cfg.layer_or_model_wise == "l": for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] print(name, raw_w.size()) if raw_w.size(0) == 128 and raw_w.size(1) == 128: cfg.fd_rtn_pattern_set[name], _ = find_top_k_by_kmeans( raw_w, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape) elif raw_w.size(0) == 128 and raw_w.size(1) == 40: # raw_w_pad = torch.cat([raw_w, torch.zeros(raw_w.size(0), 4, raw_w.size(2)).cuda()], 1) cfg.fd_rtn_pattern_set[name], _ = find_top_k_by_kmeans( raw_w, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape) elif raw_w.size(0) == 28 and raw_w.size(1) == 128: raw_w_pad = torch.cat([raw_w, torch.zeros(4, raw_w.size(1), raw_w.size(2)).cuda()], 0) cfg.fd_rtn_pattern_set[name], _ = find_top_k_by_kmeans( raw_w_pad, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape) print(name, cfg.fd_rtn_pattern_set[name].size()) pattern_save = np.array(cfg.fd_rtn_pattern_set[name].cpu()).transpose((0, 1, 3, 2)) # raw_w_save = np.array(raw_w.cpu().detach()) np.savetxt(os.path.join(patterns_dir, name + '.txt'), pattern_save.flatten()) # np.savetxt(os.path.join(weights_dir, name + '.txt'), raw_w_save.transpose((2, 1, 0)).flatten()) # exit() elif cfg.layer_or_model_wise == "m": for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] if raw_w.size(0) == 128 and raw_w.size(1) == 128: if cnt == 0: raw_w_all = raw_w else: raw_w_all = torch.cat([raw_w_all, raw_w], 2) cnt += 1 cfg.fd_rtn_pattern_set['all'], _ = find_top_k_by_kmeans( raw_w_all, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape) print("find_pattern_end") _tp, _pred, _pos = 0, 0, 0 for data in train_loader: # data = prefetcher.next() wave = data['wave'].cuda() # [1, 128, 109] if step_cnt % 10 == 0: # print("test1") model = pruning(model, cfg.sparse_mode) # print("test2") model.train() if epoch == 0 and step_cnt == 0: # print("test3") loss_val = validate(val_loader, model, loss_fn) writer.add_scalar('val/loss', loss_val, epoch) best_loss = loss_val not_better_cnt = 0 torch.save(model.state_dict(), cfg.workdir+'/weights/best.pth') print("saved", cfg.workdir+'/weights/best.pth', not_better_cnt) val_loss_list.append(float(loss_val)) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) # print(f1, val_loss, tps, preds, poses) model.train() # print("test4") logits = model(wave) mask = torch.zeros_like(logits) for n in range(len(data['length_wave'])): mask[:, :, :data['length_wave'][n]] = 1 logits *= mask logits = logits.permute(2, 0, 1) logits = F.log_softmax(logits, dim=2) if data['text'].size(0) == cfg.batch_size: for i in range(cfg.batch_size): if i == 0: text = data['text'][i][0:data['length_text'][i]].cuda() else: text = torch.cat([text, data['text'][i][0: data['length_text'][i]].cuda()]) else: continue # try: loss = 0.0 for i in range(logits.size(1)): loss += loss_fn(logits[:data['length_wave'][i], i:i+1, :], data['text'][i][0:data['length_text'][i]].cuda(), data['length_wave'][i], data['length_text'][i]) loss /= logits.size(1) scheduler.zero_grad() loss.backward() scheduler.step() _loss += loss.data# * float(data['length_text'].float().mean()) if epoch == 0 and step_cnt == 10: writer.add_scalar('train/loss', _loss, epoch) train_loss_list.append(float(_loss)) if step_cnt % int(1200/cfg.batch_size) == 10: print("Epoch", epoch, ", train step", step_cnt, "/", len(train_loader), ", loss: ", round(float(_loss.data/step_cnt), 5)) torch.save(model.state_dict(), cfg.workdir+'/weights/last.pth') if float(_loss.data/step_cnt) < 0.7: # TODO get the correct evaluate results for i in range(logits.size(1)): logit = logits[:data['length_wave'][i], i:i+1, :] beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2)) voc = np.tile(vocabulary, (cfg.batch_size, 1)) pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy()) text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int)) pred = [pred] text_np = [text_np] # print(utils.cvt_np2string(pred)) # print(utils.cvt_np2string(text_np)) tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np)) _tp += tp _pred += pred _pos += pos f1 = 2 * _tp / (_pred + _pos + 1e-10) print(" Train tp:", _tp, ",pred:", _pred, ",pos:", _pos, ",f1:", f1) step_cnt += 1 # except: # continue # print(time.time()-start_time) # exit() _loss /= len(train_loader) writer.add_scalar('train/loss', _loss, epoch) train_loss_list.append(float(_loss)) torch.cuda.empty_cache() model = pruning(model, cfg.sparse_mode) sparsity = cal_sparsity(model) print(sparsity) loss_val = validate(val_loader, model, loss_fn) writer.add_scalar('val/loss', loss_val, epoch) val_loss_list.append(float(loss_val)) model.train() if loss_val < best_loss: not_better_cnt = 0 torch.save(model.state_dict(), cfg.workdir+f'/weights/best.pth') print("saved", cfg.workdir+f'/weights/best.pth', not_better_cnt) best_loss = loss_val else: not_better_cnt += 1 if not_better_cnt > 3: write_excel(os.path.join(cfg.work_root, cfg.save_excel), cfg.exp_name, train_loss_list, val_loss_list)
def test_acc_cmodel(val_loader, model, loss_fn): decoder_vocabulary = utils.Data.decoder_vocabulary vocabulary = utils.Data.vocabulary decoder = CTCBeamDecoder( decoder_vocabulary, #"_abcdefghijklmopqrstuvwxyz_", model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, num_processes=4, blank_id=27, log_probs_input=True ) model.eval() _loss = 0.0 _loss2 = 0.0 step_cnt = 0 tps, preds, poses = 0, 0, 0 tps2, preds2, poses2 = 0, 0, 0 f_cnt = 0 with torch.no_grad(): for data in val_loader: wave = data['wave'].cuda() # [1, 128, 109] # if 1: # print(data['wave'].size()) # np.savetxt("/zhzhao/dataset/VCTK/c_model_input_txt/"+str(f_cnt)+".txt", data['wave'].flatten()) # print(f_cnt) logits = model(wave) logits_cmodel = torch.from_numpy(np.loadtxt("/zhzhao/dataset/VCTK/c_model_output_txt/"+str(f_cnt)+".txt").reshape((1, 28, 720))) # print(logits_cmodel.reshape((28, 1, 720))) f_cnt += 1 logits = logits.permute(2, 0, 1) logits_cmodel = logits_cmodel.permute(2, 0, 1) print(logits) print(logits_cmodel) # logits_cmodel = logits_cmodel.permute(2, 0, 1) logits = F.log_softmax(logits, dim=2) logits_cmodel = F.log_softmax(logits_cmodel, dim=2) if data['text'].size(0) == 1: for i in range(1): if i == 0: text = data['text'][i][0:data['length_text'][i]].cuda() # print(data['text'].size()) # print(data['length_text'][i]) else: text = torch.cat([text, data['text'][i][0: data['length_text'][i]].cuda()]) else: continue loss = loss_fn(logits, text, data['length_wave'], data['length_text']) loss2 = loss_fn(logits_cmodel, text, data['length_wave'], data['length_text']) _loss += loss.data _loss2 += loss2.data # print(_loss) # print(_loss2) for i in range(logits.size(1)): logit = logits[:data['length_wave'][i], i:i+1, :] beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2)) voc = np.tile(vocabulary, (1, 1)) pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy()) text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int)) pred = [pred] text_np = [text_np] # print(utils.cvt_np2string(pred)) # print(utils.cvt_np2string(text_np)) tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np)) tps += tp preds += pred poses += pos f1 = 2 * tps / (preds + poses + 1e-10) for i in range(logits_cmodel.size(1)): logit = logits_cmodel[:data['length_wave'][i], i:i+1, :] beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2)) voc = np.tile(vocabulary, (1, 1)) pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy()) text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int)) pred = [pred] text_np = [text_np] # print(utils.cvt_np2string(pred)) # print(utils.cvt_np2string(text_np)) tp2, pred2, pos2 = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np)) tps2 += tp2 preds2 += pred2 poses2 += pos2 f12 = 2 * tps2 / (preds2 + poses2 + 1e-10) step_cnt += 1 print("Val step", step_cnt, "/", len(val_loader), ", loss: ", round(float(_loss.data/step_cnt), 5)) print("Val tps:", tps, ",preds:", preds, ",poses:", poses, ",f1:", f1) print("C Model Val step", step_cnt, "/", len(val_loader), ", loss: ", round(float(_loss2.data/step_cnt), 5)) print("C Model Val tps:", tps2, ",preds:", preds2, ",poses:", poses2, ",f1:", f12) print(" ") # if f_cnt > 6 : # break return f1, _loss/len(val_loader), tps, preds, poses
def train(train_loader, scheduler, model, loss_fn, val_loader, writer=None): decoder_vocabulary = utils.Data.decoder_vocabulary vocabulary = utils.Data.vocabulary decoder = CTCBeamDecoder( decoder_vocabulary, #"_abcdefghijklmopqrstuvwxyz_", model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, num_processes=4, blank_id=27, log_probs_input=True ) train_loss_list = list() val_loss_list = list() # prefetcher = data_prefetcher(train_loader) # data = prefetcher.next() best_loss = float('inf') for epoch in range(cfg.epochs): print(f'Training epoch {epoch}') _loss = 0.0 step_cnt = 0 # sparsity = cal_sparsity(model) # print("sparsity:", sparsity) _tp, _pred, _pos = 0, 0, 0 for data in train_loader: # data = prefetcher.next() wave = data['wave'].cuda() # [1, 128, 109] if step_cnt % 10 == 0: model = pruning(model, cfg.sparse_mode) model.train() if epoch == 0 and step_cnt == 0: loss_val = validate(val_loader, model, loss_fn) writer.add_scalar('val/loss', loss_val, epoch) best_loss = loss_val not_better_cnt = 0 torch.save(model.state_dict(), cfg.workdir+'/weights/best.pth') print("saved", cfg.workdir+'/weights/best.pth', not_better_cnt) val_loss_list.append(float(loss_val)) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) # print(f1, val_loss, tps, preds, poses) model.train() # wave = rnn_utils.pack_padded_sequence(wave, data['length_wave'], batch_first=True, enforce_sorted=False) logits = model(wave) # logits, out_len = rnn_utils.pad_packed_sequence(logits, batch_first=True) # print(logits.size()) # print(data['length_wave']) # exit() mask = torch.zeros_like(logits) for n in range(len(data['length_wave'])): mask[:, :, :data['length_wave'][n]] = 1 logits *= mask logits = logits.permute(2, 0, 1) logits = F.log_softmax(logits, dim=2) # print(logits[:, 0, :].max(1)) # for l in logits[:, 0, :].max(1)[1]: # print(vocabulary[l], end='') # print(data['text'][0]) # logits = F.softmax(logits, dim=2) if data['text'].size(0) == cfg.batch_size: for i in range(cfg.batch_size): if i == 0: text = data['text'][i][0:data['length_text'][i]].cuda() # print(data['text'].size()) # print(data['length_text'][i]) else: text = torch.cat([text, data['text'][i][0: data['length_text'][i]].cuda()]) else: continue # try: loss = 0.0 for i in range(logits.size(1)): loss += loss_fn(logits[:data['length_wave'][i], i:i+1, :], data['text'][i][0:data['length_text'][i]].cuda(), data['length_wave'][i], data['length_text'][i]) loss /= logits.size(1) scheduler.zero_grad() loss.backward() scheduler.step() # print(data['length_text']) # print(data['length_text'].size().data) _loss += loss.data# * float(data['length_text'].float().mean()) if epoch == 0 and step_cnt == 10: writer.add_scalar('train/loss', _loss, epoch) train_loss_list.append(float(_loss)) if step_cnt % int(12000/cfg.batch_size) == 10: print("Epoch", epoch, ", train step", step_cnt, "/", len(train_loader), ", loss: ", round(float(_loss.data/step_cnt), 5)) torch.save(model.state_dict(), cfg.workdir+'/weights/last.pth') if float(_loss.data/step_cnt) < 0.7: # TODO get the correct evaluate results for i in range(logits.size(1)): logit = logits[:data['length_wave'][i], i:i+1, :] beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2)) voc = np.tile(vocabulary, (cfg.batch_size, 1)) pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy()) text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int)) pred = [pred] text_np = [text_np] # print(utils.cvt_np2string(pred)) # print(utils.cvt_np2string(text_np)) tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np)) _tp += tp _pred += pred _pos += pos f1 = 2 * _tp / (_pred + _pos + 1e-10) print(" Train tp:", _tp, ",pred:", _pred, ",pos:", _pos, ",f1:", f1) step_cnt += 1 # except: # continue _loss /= len(train_loader) writer.add_scalar('train/loss', _loss, epoch) train_loss_list.append(float(_loss)) torch.cuda.empty_cache() model = pruning(model, cfg.sparse_mode) sparsity = cal_sparsity(model) print(sparsity) loss_val = validate(val_loader, model, loss_fn) writer.add_scalar('val/loss', loss_val, epoch) val_loss_list.append(float(loss_val)) model.train() if loss_val < best_loss: not_better_cnt = 0 torch.save(model.state_dict(), cfg.workdir+'/weights/best.pth') print("saved", cfg.workdir+'/weights/best.pth', not_better_cnt) best_loss = loss_val else: not_better_cnt += 1 if not_better_cnt > 1: write_excel(os.path.join(cfg.work_root, cfg.save_excel), cfg.exp_name, train_loss_list, val_loss_list) exit()
def test_acc(val_loader, model, loss_fn): decoder_vocabulary = utils.Data.decoder_vocabulary vocabulary = utils.Data.vocabulary decoder = CTCBeamDecoder( decoder_vocabulary, #"_abcdefghijklmopqrstuvwxyz_", model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, num_processes=4, blank_id=27, log_probs_input=True ) model.eval() _loss = 0.0 step_cnt = 0 tps, preds, poses = 0, 0, 0 with torch.no_grad(): for data in val_loader: wave = data['wave'].cuda() # [1, 128, 109] logits = model(wave) logits = logits.permute(2, 0, 1) logits = F.log_softmax(logits, dim=2) if data['text'].size(0) == cfg.batch_size: for i in range(cfg.batch_size): if i == 0: text = data['text'][i][0:data['length_text'][i]].cuda() # print(data['text'].size()) # print(data['length_text'][i]) else: text = torch.cat([text, data['text'][i][0: data['length_text'][i]].cuda()]) else: continue loss = loss_fn(logits, text, data['length_wave'], data['length_text']) _loss += loss.data for i in range(logits.size(1)): logit = logits[:data['length_wave'][i], i:i+1, :] beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2)) voc = np.tile(vocabulary, (cfg.batch_size, 1)) pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy()) text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int)) pred = [pred] text_np = [text_np] # print(utils.cvt_np2string(pred)) # print(utils.cvt_np2string(text_np)) tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np)) tps += tp preds += pred poses += pos f1 = 2 * tps / (preds + poses + 1e-10) step_cnt += 1 # if cnt % 10 == 0: print("Val step", step_cnt, "/", len(val_loader), ", loss: ", round(float(_loss.data/step_cnt), 5)) print("Val tps:", tps, ",preds:", preds, ",poses:", poses, ",f1:", f1) return f1, _loss/len(val_loader), tps, preds, poses
def main(_): utils.load(FLAGS.config_path) os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.device) # with tf.device(FLAGS.device): test_dataset = dataset.create(FLAGS.dataset_path, repeat=False, batch_size=1) waves = tf.reshape(tf.sparse.to_dense(test_dataset[0]), shape=[1, -1, utils.Data.num_channel]) labels = tf.sparse.to_dense(test_dataset[1]) sequence_length = tf.cast(test_dataset[2], tf.int32) vocabulary = tf.constant(utils.Data.vocabulary) labels = tf.gather(vocabulary, labels) logits = wavenet.bulid_wavenet(waves, len(utils.Data.vocabulary)) decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), sequence_length, merge_repeated=False) outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0])) save = tf.train.Saver() evalutes = {} if os.path.exists(FLAGS.ckpt_dir + '/evalute.json'): evalutes = json.load( open(FLAGS.ckpt_dir + '/evalute.json', encoding='utf-8')) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: status = 0 while True: filepaths = glob.glob(FLAGS.ckpt_dir + '/*.index') filepaths.sort() filepaths.reverse() filepath = filepaths[0] max_uid = 0 for filepath in filepaths: model_path = os.path.splitext(filepath)[0] uid = os.path.split(model_path)[-1] if max_uid <= int(uid.split("-")[1]): max_uid = int(uid.split("-")[1]) max_uid_full = uid max_model_path = model_path # print(max_uid) status = 2 sess.run(tf.global_variables_initializer()) sess.run(test_dataset[-1]) save.restore(sess, max_model_path) # sa print(tf.train.latest_checkpoint(FLAGS.ckpt_dir)) # ve.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir)) evalutes[max_uid_full] = {} tps, preds, poses, count = 0, 0, 0, 0 while True: try: count += 1 y, y_ = sess.run((labels, outputs)) y = utils.cvt_np2string(y) y_ = utils.cvt_np2string(y_) tp, pred, pos = utils.evalutes(y_, y) tps += tp preds += pred poses += pos # if count % 1000 == 0: # glog.info('processed %d: tp=%d, pred=%d, pos=%d.' % (count, tps, preds, poses)) except: # if count % 1000 != 0: # glog.info('processed %d: tp=%d, pred=%d, pos=%d.' % (count, tps, preds, poses)) break evalutes[max_uid_full]['tp'] = tps evalutes[max_uid_full]['pred'] = preds evalutes[max_uid_full]['pos'] = poses evalutes[max_uid_full]['f1'] = 2 * tps / (preds + poses + 1e-20) json.dump( evalutes, open(FLAGS.ckpt_dir + '/evalute.json', mode='w', encoding='utf-8')) evalute = evalutes[max_uid_full] glog.info('Evalute %s: tp=%d, pred=%d, pos=%d, f1=%f.' % (max_uid_full, evalute['tp'], evalute['pred'], evalute['pos'], evalute['f1'])) if status == 1: time.sleep(60) status = 1