def main(_):
    if not os.path.exists(FLAGS.ckpt_path + '.index'):
        glog.error('%s was not found.' % FLAGS.ckpt_path)
        return -1

    utils.load(FLAGS.ckpt_path + '.json')
    vocabulary = tf.constant(utils.Data.vocabulary)
    inputs = tf.placeholder(tf.float32, [1, None, utils.Data.num_channel])
    sequence_length = tf.placeholder(tf.int32, [None])

    logits = wavenet.bulid_wavenet(inputs,
                                   len(utils.Data.vocabulary),
                                   is_training=False)
    decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,
                                                            perm=[1, 0, 2]),
                                               sequence_length,
                                               merge_repeated=False)
    outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0]))
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, FLAGS.ckpt_path)
        wave = utils.read_wave(FLAGS.input_path)
        output = utils.cvt_np2string(
            sess.run(outputs,
                     feed_dict={
                         inputs: [wave],
                         sequence_length: [wave.shape[0]]
                     }))[0]
        glog.info('%s: %s.', FLAGS.input_path, output)
    return 0
def main(_):
  class_names = tf.constant(utils.Data.class_names)
  inputs = tf.placeholder(tf.float32, [1, None, utils.Data.channels])
  seq_len = tf.reduce_sum(tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=2), 0.), tf.int32), axis=1)

  logits = wavenet.bulid_wavenet(inputs, len(utils.Data.class_names), is_training=False)
  decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_len, merge_repeated=False)
  outputs = tf.sparse.to_dense(decodes[0]) + 1
  outputs = tf.gather(class_names, outputs)
  restore = utils.restore_from_pretrain(FLAGS.pretrain_dir)
  save = tf.train.Saver()
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(restore)
    if os.path.exists(FLAGS.checkpoint_dir) and len(os.listdir(FLAGS.checkpoint_dir)) > 0:
      save.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
    output = utils.cvt_np2string(sess.run(outputs, feed_dict={inputs: [utils.read_wave(FLAGS.input_path)]}))[0]
    glog.info('%s: %s.', FLAGS.input_path, output)
Exemple #3
0
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.device)
    utils.load(FLAGS.config_path)
    global_step = tf.train.get_or_create_global_step()
    train_dataset = dataset.create(FLAGS.dataset_path,
                                   FLAGS.batch_size,
                                   repeat=True)

    # bug tensorflow!!!  the  train_dataset[0].shape[0] != FLAGS.batch_size once in a while
    # waves = tf.reshape(tf.sparse.to_dense(train_dataset[0]), shape=[FLAGS.batch_size, -1, utils.Data.num_channel])
    waves = tf.sparse.to_dense(train_dataset[0])
    waves = tf.reshape(waves, [tf.shape(waves)[0], -1, utils.Data.num_channel])

    labels = tf.cast(train_dataset[1], tf.int32)
    sequence_length = tf.cast(train_dataset[2], tf.int32)
    logits = wavenet.bulid_wavenet(waves,
                                   len(utils.Data.vocabulary),
                                   is_training=True)
    loss = tf.reduce_mean(
        tf.nn.ctc_loss(labels, logits, sequence_length, time_major=False))

    vocabulary = tf.constant(utils.Data.vocabulary)
    decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, [1, 0, 2]),
                                               sequence_length,
                                               merge_repeated=False)
    outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0]))
    labels = tf.gather(vocabulary, tf.sparse.to_dense(labels))

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimize = tf.train.AdamOptimizer(
            learning_rate=FLAGS.learning_rate).minimize(
                loss=loss, global_step=global_step)

    save = tf.train.Saver(max_to_keep=1000)
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(train_dataset[-1])
        # if os.path.exists(FLAGS.pretrain_dir) and len(os.listdir(FLAGS.pretrain_dir)) > 0:
        #   save.restore(sess, tf.train.latest_checkpoint(FLAGS.pretrain_dir))
        ckpt_dir = os.path.split(FLAGS.ckpt_path)[0]
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        if len(os.listdir(ckpt_dir)) > 0:
            save.restore(sess, tf.train.latest_checkpoint(ckpt_dir))

        losses, tps, preds, poses = 0, 0, 0, 0
        while True:
            gp, ll, uid, ot, ls, _ = sess.run(
                (global_step, labels, train_dataset[3], outputs, loss,
                 optimize))
            tp, pred, pos = utils.evalutes(utils.cvt_np2string(ot),
                                           utils.cvt_np2string(ll))
            tps += tp
            losses += ls
            preds += pred
            poses += pos
            if gp % FLAGS.display == 0:
                glog.info(
                    "Step %d: loss=%f, tp=%d, pos=%d, pred=%d, f1=%f." %
                    (gp, losses if gp == 0 else
                     (losses / FLAGS.display), tps, preds, poses, 2 * tps /
                     (preds + poses + 1e-10)))
                losses, tps, preds, poses = 0, 0, 0, 0
            if (gp + 1) % FLAGS.snapshot == 0 and gp != 0:
                save.save(sess, FLAGS.ckpt_path, global_step=global_step)
Exemple #4
0
def train(train_loader, scheduler, model, loss_fn, val_loader, writer=None):
    decoder_vocabulary = utils.Data.decoder_vocabulary
    vocabulary = utils.Data.vocabulary
    decoder = CTCBeamDecoder(
        decoder_vocabulary,
        #"_abcdefghijklmopqrstuvwxyz_",
        model_path=None,
        alpha=0,
        beta=0,
        cutoff_top_n=40,
        cutoff_prob=1.0,
        beam_width=100,
        num_processes=4,
        blank_id=27,
        log_probs_input=True
    )
    train_loss_list = list()
    val_loss_list = list()

    # prefetcher = data_prefetcher(train_loader)
    # data = prefetcher.next()
    
    best_loss = float('inf')
    for epoch in range(cfg.epochs):
        print(f'Training epoch {epoch}')
        # start_time = time.time()
        _loss = 0.0
        step_cnt = 0
        
        if epoch == 0:
            if cfg.sparse_mode == 'find_retrain':
                print("find_pattern_start")
                cfg.fd_rtn_pattern_set = dict()
                name_list = list()
                para_list = list()
                for name, para in model.named_parameters():
                    name_list.append(name)
                    para_list.append(para)
                    
                cnt = 0


                patterns_dir = os.path.join(cfg.workdir, 'patterns')
                if not os.path.exists(patterns_dir):
                    os.mkdir(patterns_dir)

                weights_dir = os.path.join(cfg.workdir, 'weights_save')
                if not os.path.exists(weights_dir):
                    os.mkdir(weights_dir)

                if cfg.layer_or_model_wise == "l":
                    for i, name in enumerate(name_list):
                        if name.split(".")[-2] != "bn" \
                            and name.split(".")[-2] != "bn2" \
                            and name.split(".")[-2] != "bn3" \
                            and name.split(".")[-1] != "bias":
                            raw_w = para_list[i]
                            print(name, raw_w.size())
                            if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                                cfg.fd_rtn_pattern_set[name], _ = find_top_k_by_kmeans(
                                    raw_w, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape)
                            elif raw_w.size(0) == 128 and raw_w.size(1) == 40:
                                # raw_w_pad = torch.cat([raw_w, torch.zeros(raw_w.size(0), 4, raw_w.size(2)).cuda()], 1)
                                cfg.fd_rtn_pattern_set[name], _ = find_top_k_by_kmeans(
                                    raw_w, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape)
                            elif raw_w.size(0) == 28 and raw_w.size(1) == 128:
                                raw_w_pad = torch.cat([raw_w, torch.zeros(4, raw_w.size(1), raw_w.size(2)).cuda()], 0)
                                cfg.fd_rtn_pattern_set[name], _ = find_top_k_by_kmeans(
                                    raw_w_pad, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape)
                            print(name, cfg.fd_rtn_pattern_set[name].size())
                            pattern_save = np.array(cfg.fd_rtn_pattern_set[name].cpu()).transpose((0, 1, 3, 2))
                            # raw_w_save = np.array(raw_w.cpu().detach())
                            np.savetxt(os.path.join(patterns_dir, name + '.txt'), pattern_save.flatten())
                    #         np.savetxt(os.path.join(weights_dir, name + '.txt'), raw_w_save.transpose((2, 1, 0)).flatten())
                    # exit()
                elif cfg.layer_or_model_wise == "m":
                    for i, name in enumerate(name_list):
                        if name.split(".")[-2] != "bn" \
                            and name.split(".")[-2] != "bn2" \
                            and name.split(".")[-2] != "bn3" \
                            and name.split(".")[-1] != "bias":
                            raw_w = para_list[i]
                            if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                                if cnt == 0:
                                    raw_w_all = raw_w
                                else:
                                    raw_w_all = torch.cat([raw_w_all, raw_w], 2)
                                cnt += 1
                    cfg.fd_rtn_pattern_set['all'], _ = find_top_k_by_kmeans(
                        raw_w_all, cfg.pattern_num, cfg.pattern_shape, cfg.pattern_nnz, stride=cfg.pattern_shape)
                print("find_pattern_end")

        _tp, _pred, _pos = 0, 0, 0
        for data in train_loader:
            # data = prefetcher.next()
            wave = data['wave'].cuda()  # [1, 128, 109]
            if step_cnt % 10 == 0:
                # print("test1")
                model = pruning(model, cfg.sparse_mode)
                # print("test2")
                model.train() 
            if epoch == 0 and step_cnt == 0:
                # print("test3")
                loss_val = validate(val_loader, model, loss_fn)
                writer.add_scalar('val/loss', loss_val, epoch)
                best_loss = loss_val
                not_better_cnt = 0
                torch.save(model.state_dict(), cfg.workdir+'/weights/best.pth')
                print("saved", cfg.workdir+'/weights/best.pth', not_better_cnt)
                val_loss_list.append(float(loss_val))
                # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
                # print(f1, val_loss, tps, preds, poses)
                model.train()    
                # print("test4")

            logits = model(wave)
            mask = torch.zeros_like(logits)
            for n in range(len(data['length_wave'])):
                mask[:, :, :data['length_wave'][n]] = 1
            logits *= mask

            logits = logits.permute(2, 0, 1)
            logits = F.log_softmax(logits, dim=2)
            if data['text'].size(0) == cfg.batch_size:
                for i in range(cfg.batch_size):
                    if i == 0:
                        text = data['text'][i][0:data['length_text'][i]].cuda()
                    else:
                        text = torch.cat([text, 
                                    data['text'][i][0: data['length_text'][i]].cuda()])
            else:
                continue


            # try:
            loss = 0.0
            for i in range(logits.size(1)):
                loss += loss_fn(logits[:data['length_wave'][i], i:i+1, :], data['text'][i][0:data['length_text'][i]].cuda(), data['length_wave'][i], data['length_text'][i])
            loss /= logits.size(1)
            scheduler.zero_grad()
            loss.backward()
            scheduler.step()
            _loss += loss.data# * float(data['length_text'].float().mean())


            if epoch == 0 and step_cnt == 10:
                writer.add_scalar('train/loss', _loss, epoch)
                train_loss_list.append(float(_loss))

            if step_cnt % int(1200/cfg.batch_size) == 10:
                print("Epoch", epoch,
                        ", train step", step_cnt, "/", len(train_loader),
                        ", loss: ", round(float(_loss.data/step_cnt), 5))
                torch.save(model.state_dict(), cfg.workdir+'/weights/last.pth')

                if float(_loss.data/step_cnt) < 0.7:
                    # TODO get the correct evaluate results
                    for i in range(logits.size(1)):
                        logit = logits[:data['length_wave'][i], i:i+1, :]
                        beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2))

                        voc = np.tile(vocabulary, (cfg.batch_size, 1))
                        pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy())
                        text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int))

                        pred = [pred]
                        text_np = [text_np]
                        # print(utils.cvt_np2string(pred))
                        # print(utils.cvt_np2string(text_np))

                        tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np))
                        _tp += tp
                        _pred += pred
                        _pos += pos
                    f1 = 2 * _tp / (_pred + _pos + 1e-10)
                
                    print("             Train tp:", _tp, ",pred:", _pred, ",pos:", _pos, ",f1:", f1)
            step_cnt += 1
            # except:
            #     continue
        # print(time.time()-start_time)
        # exit()
        _loss /= len(train_loader)
        writer.add_scalar('train/loss', _loss, epoch)
        train_loss_list.append(float(_loss))
        torch.cuda.empty_cache()

        model = pruning(model, cfg.sparse_mode)
        sparsity = cal_sparsity(model)
        print(sparsity)
        loss_val = validate(val_loader, model, loss_fn)
        writer.add_scalar('val/loss', loss_val, epoch)
        val_loss_list.append(float(loss_val))

        model.train()

        if loss_val < best_loss:
            not_better_cnt = 0
            torch.save(model.state_dict(), cfg.workdir+f'/weights/best.pth')
            print("saved", cfg.workdir+f'/weights/best.pth', not_better_cnt)
            best_loss = loss_val
        else:
            not_better_cnt += 1

        if not_better_cnt > 3:
            write_excel(os.path.join(cfg.work_root, cfg.save_excel), 
                            cfg.exp_name, train_loss_list, val_loss_list)
Exemple #5
0
def test_acc_cmodel(val_loader, model, loss_fn):    
    decoder_vocabulary = utils.Data.decoder_vocabulary
    vocabulary = utils.Data.vocabulary
    decoder = CTCBeamDecoder(
        decoder_vocabulary,
        #"_abcdefghijklmopqrstuvwxyz_",
        model_path=None,
        alpha=0,
        beta=0,
        cutoff_top_n=40,
        cutoff_prob=1.0,
        beam_width=100,
        num_processes=4,
        blank_id=27,
        log_probs_input=True
    )
    model.eval()
    _loss = 0.0
    _loss2 = 0.0
    step_cnt = 0
    tps, preds, poses = 0, 0, 0
    tps2, preds2, poses2 = 0, 0, 0
    f_cnt = 0
    with torch.no_grad():
        for data in val_loader:
            wave = data['wave'].cuda()  # [1, 128, 109]
            # if 1:
            #     print(data['wave'].size())
            #     np.savetxt("/zhzhao/dataset/VCTK/c_model_input_txt/"+str(f_cnt)+".txt", data['wave'].flatten())
            #     print(f_cnt)
            logits = model(wave)
            logits_cmodel = torch.from_numpy(np.loadtxt("/zhzhao/dataset/VCTK/c_model_output_txt/"+str(f_cnt)+".txt").reshape((1, 28, 720)))
            # print(logits_cmodel.reshape((28, 1, 720)))
            f_cnt += 1
            logits = logits.permute(2, 0, 1)
            logits_cmodel = logits_cmodel.permute(2, 0, 1)
            print(logits)
            print(logits_cmodel)
            # logits_cmodel = logits_cmodel.permute(2, 0, 1)
            logits = F.log_softmax(logits, dim=2)
            logits_cmodel = F.log_softmax(logits_cmodel, dim=2)
            if data['text'].size(0) == 1:
                for i in range(1):
                    if i == 0:
                        text = data['text'][i][0:data['length_text'][i]].cuda()
                        # print(data['text'].size())
                        # print(data['length_text'][i])
                    else:
                        text = torch.cat([text, 
                                    data['text'][i][0: data['length_text'][i]].cuda()])
            else:
                continue
            loss = loss_fn(logits, text, data['length_wave'], data['length_text'])
            loss2 = loss_fn(logits_cmodel, text, data['length_wave'], data['length_text'])
            _loss += loss.data
            _loss2 += loss2.data
            # print(_loss)
            # print(_loss2)
            for i in range(logits.size(1)):
                logit = logits[:data['length_wave'][i], i:i+1, :]
                beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2))

                voc = np.tile(vocabulary, (1, 1))
                pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy())
                text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int))

                pred = [pred]
                text_np = [text_np]
                # print(utils.cvt_np2string(pred))
                # print(utils.cvt_np2string(text_np))

                tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np))
                tps += tp
                preds += pred
                poses += pos
            f1 = 2 * tps / (preds + poses + 1e-10)

            for i in range(logits_cmodel.size(1)):
                logit = logits_cmodel[:data['length_wave'][i], i:i+1, :]
                beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2))

                voc = np.tile(vocabulary, (1, 1))
                pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy())
                text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int))

                pred = [pred]
                text_np = [text_np]
                # print(utils.cvt_np2string(pred))
                # print(utils.cvt_np2string(text_np))

                tp2, pred2, pos2 = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np))
                tps2 += tp2
                preds2 += pred2
                poses2 += pos2
            f12 = 2 * tps2 / (preds2 + poses2 + 1e-10)
            
            step_cnt += 1
            print("Val step", step_cnt, "/", len(val_loader),
                    ", loss: ", round(float(_loss.data/step_cnt), 5))
            print("Val tps:", tps, ",preds:", preds, ",poses:", poses, ",f1:", f1)
            print("C Model Val step", step_cnt, "/", len(val_loader),
                    ", loss: ", round(float(_loss2.data/step_cnt), 5))
            print("C Model Val tps:", tps2, ",preds:", preds2, ",poses:", poses2, ",f1:", f12)
            print(" ")
            # if f_cnt > 6 :
            #     break

    return f1, _loss/len(val_loader), tps, preds, poses
Exemple #6
0
def train(train_loader, scheduler, model, loss_fn, val_loader, writer=None):
    decoder_vocabulary = utils.Data.decoder_vocabulary
    vocabulary = utils.Data.vocabulary
    decoder = CTCBeamDecoder(
        decoder_vocabulary,
        #"_abcdefghijklmopqrstuvwxyz_",
        model_path=None,
        alpha=0,
        beta=0,
        cutoff_top_n=40,
        cutoff_prob=1.0,
        beam_width=100,
        num_processes=4,
        blank_id=27,
        log_probs_input=True
    )
    train_loss_list = list()
    val_loss_list = list()

    # prefetcher = data_prefetcher(train_loader)
    # data = prefetcher.next()
    
    best_loss = float('inf')
    for epoch in range(cfg.epochs):
        print(f'Training epoch {epoch}')
        _loss = 0.0
        step_cnt = 0
        
        # sparsity = cal_sparsity(model)
        # print("sparsity:", sparsity)
        _tp, _pred, _pos = 0, 0, 0
        for data in train_loader:
            # data = prefetcher.next()
            wave = data['wave'].cuda()  # [1, 128, 109]
            if step_cnt % 10 == 0:
                model = pruning(model, cfg.sparse_mode)
                model.train() 
            if epoch == 0 and step_cnt == 0:
                loss_val = validate(val_loader, model, loss_fn)
                writer.add_scalar('val/loss', loss_val, epoch)
                best_loss = loss_val
                not_better_cnt = 0
                torch.save(model.state_dict(), cfg.workdir+'/weights/best.pth')
                print("saved", cfg.workdir+'/weights/best.pth', not_better_cnt)
                val_loss_list.append(float(loss_val))
                # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
                # print(f1, val_loss, tps, preds, poses)
                model.train()    

            # wave = rnn_utils.pack_padded_sequence(wave, data['length_wave'], batch_first=True, enforce_sorted=False)
            logits = model(wave)
            # logits, out_len = rnn_utils.pad_packed_sequence(logits, batch_first=True)
            
            # print(logits.size())
            # print(data['length_wave'])
            # exit()
            mask = torch.zeros_like(logits)
            for n in range(len(data['length_wave'])):
                mask[:, :, :data['length_wave'][n]] = 1
            logits *= mask

            logits = logits.permute(2, 0, 1)
            logits = F.log_softmax(logits, dim=2)
            # print(logits[:, 0, :].max(1))
            # for l in logits[:, 0, :].max(1)[1]:
            #     print(vocabulary[l], end='')
            # print(data['text'][0])
            # logits = F.softmax(logits, dim=2)
            if data['text'].size(0) == cfg.batch_size:
                for i in range(cfg.batch_size):
                    if i == 0:
                        text = data['text'][i][0:data['length_text'][i]].cuda()
                        # print(data['text'].size())
                        # print(data['length_text'][i])
                    else:
                        text = torch.cat([text, 
                                    data['text'][i][0: data['length_text'][i]].cuda()])
            else:
                continue


            # try:
            loss = 0.0
            for i in range(logits.size(1)):
                loss += loss_fn(logits[:data['length_wave'][i], i:i+1, :], data['text'][i][0:data['length_text'][i]].cuda(), data['length_wave'][i], data['length_text'][i])
            loss /= logits.size(1)
            scheduler.zero_grad()
            loss.backward()
            scheduler.step()
            # print(data['length_text'])
            # print(data['length_text'].size().data)
            _loss += loss.data# * float(data['length_text'].float().mean())


            if epoch == 0 and step_cnt == 10:
                writer.add_scalar('train/loss', _loss, epoch)
                train_loss_list.append(float(_loss))

            if step_cnt % int(12000/cfg.batch_size) == 10:
                print("Epoch", epoch,
                        ", train step", step_cnt, "/", len(train_loader),
                        ", loss: ", round(float(_loss.data/step_cnt), 5))
                torch.save(model.state_dict(), cfg.workdir+'/weights/last.pth')

                if float(_loss.data/step_cnt) < 0.7:
                    # TODO get the correct evaluate results
                    for i in range(logits.size(1)):
                        logit = logits[:data['length_wave'][i], i:i+1, :]
                        beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2))

                        voc = np.tile(vocabulary, (cfg.batch_size, 1))
                        pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy())
                        text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int))

                        pred = [pred]
                        text_np = [text_np]
                        # print(utils.cvt_np2string(pred))
                        # print(utils.cvt_np2string(text_np))

                        tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np))
                        _tp += tp
                        _pred += pred
                        _pos += pos
                    f1 = 2 * _tp / (_pred + _pos + 1e-10)
                
                    print("             Train tp:", _tp, ",pred:", _pred, ",pos:", _pos, ",f1:", f1)
            step_cnt += 1
            # except:
            #     continue

        _loss /= len(train_loader)
        writer.add_scalar('train/loss', _loss, epoch)
        train_loss_list.append(float(_loss))
        torch.cuda.empty_cache()

        model = pruning(model, cfg.sparse_mode)
        sparsity = cal_sparsity(model)
        print(sparsity)
        loss_val = validate(val_loader, model, loss_fn)
        writer.add_scalar('val/loss', loss_val, epoch)
        val_loss_list.append(float(loss_val))

        model.train()

        if loss_val < best_loss:
            not_better_cnt = 0
            torch.save(model.state_dict(), cfg.workdir+'/weights/best.pth')
            print("saved", cfg.workdir+'/weights/best.pth', not_better_cnt)
            best_loss = loss_val
        else:
            not_better_cnt += 1

        if not_better_cnt > 1:
            write_excel(os.path.join(cfg.work_root, cfg.save_excel), 
                            cfg.exp_name, train_loss_list, val_loss_list)
            exit()
Exemple #7
0
def test_acc(val_loader, model, loss_fn):    
    decoder_vocabulary = utils.Data.decoder_vocabulary
    vocabulary = utils.Data.vocabulary
    decoder = CTCBeamDecoder(
        decoder_vocabulary,
        #"_abcdefghijklmopqrstuvwxyz_",
        model_path=None,
        alpha=0,
        beta=0,
        cutoff_top_n=40,
        cutoff_prob=1.0,
        beam_width=100,
        num_processes=4,
        blank_id=27,
        log_probs_input=True
    )
    model.eval()
    _loss = 0.0
    step_cnt = 0
    tps, preds, poses = 0, 0, 0
    with torch.no_grad():
        for data in val_loader:
            wave = data['wave'].cuda()  # [1, 128, 109]
            logits = model(wave)
            logits = logits.permute(2, 0, 1)
            logits = F.log_softmax(logits, dim=2)
            if data['text'].size(0) == cfg.batch_size:
                for i in range(cfg.batch_size):
                    if i == 0:
                        text = data['text'][i][0:data['length_text'][i]].cuda()
                        # print(data['text'].size())
                        # print(data['length_text'][i])
                    else:
                        text = torch.cat([text, 
                                    data['text'][i][0: data['length_text'][i]].cuda()])
            else:
                continue
            loss = loss_fn(logits, text, data['length_wave'], data['length_text'])
            _loss += loss.data
            for i in range(logits.size(1)):
                logit = logits[:data['length_wave'][i], i:i+1, :]
                beam_results, beam_scores, timesteps, out_lens = decoder.decode(logit.permute(1, 0, 2))

                voc = np.tile(vocabulary, (cfg.batch_size, 1))
                pred = np.take(voc, beam_results[0][0][:out_lens[0][0]].data.numpy())
                text_np = np.take(voc, data['text'][i][0:data['length_text'][i]].cpu().numpy().astype(int))

                pred = [pred]
                text_np = [text_np]
                # print(utils.cvt_np2string(pred))
                # print(utils.cvt_np2string(text_np))

                tp, pred, pos = utils.evalutes(utils.cvt_np2string(pred), utils.cvt_np2string(text_np))
                tps += tp
                preds += pred
                poses += pos
            f1 = 2 * tps / (preds + poses + 1e-10)
            
            step_cnt += 1
            # if cnt % 10 == 0:
    print("Val step", step_cnt, "/", len(val_loader),
            ", loss: ", round(float(_loss.data/step_cnt), 5))
    print("Val tps:", tps, ",preds:", preds, ",poses:", poses, ",f1:", f1)

    return f1, _loss/len(val_loader), tps, preds, poses
def main(_):
    utils.load(FLAGS.config_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.device)
    # with tf.device(FLAGS.device):
    test_dataset = dataset.create(FLAGS.dataset_path,
                                  repeat=False,
                                  batch_size=1)
    waves = tf.reshape(tf.sparse.to_dense(test_dataset[0]),
                       shape=[1, -1, utils.Data.num_channel])
    labels = tf.sparse.to_dense(test_dataset[1])
    sequence_length = tf.cast(test_dataset[2], tf.int32)
    vocabulary = tf.constant(utils.Data.vocabulary)
    labels = tf.gather(vocabulary, labels)
    logits = wavenet.bulid_wavenet(waves, len(utils.Data.vocabulary))
    decodes, _ = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,
                                                            perm=[1, 0, 2]),
                                               sequence_length,
                                               merge_repeated=False)
    outputs = tf.gather(vocabulary, tf.sparse.to_dense(decodes[0]))
    save = tf.train.Saver()

    evalutes = {}
    if os.path.exists(FLAGS.ckpt_dir + '/evalute.json'):
        evalutes = json.load(
            open(FLAGS.ckpt_dir + '/evalute.json', encoding='utf-8'))

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        status = 0
        while True:
            filepaths = glob.glob(FLAGS.ckpt_dir + '/*.index')
            filepaths.sort()
            filepaths.reverse()
            filepath = filepaths[0]
            max_uid = 0
            for filepath in filepaths:
                model_path = os.path.splitext(filepath)[0]
                uid = os.path.split(model_path)[-1]
                if max_uid <= int(uid.split("-")[1]):
                    max_uid = int(uid.split("-")[1])
                    max_uid_full = uid
                    max_model_path = model_path
                    # print(max_uid)
            status = 2
            sess.run(tf.global_variables_initializer())
            sess.run(test_dataset[-1])
            save.restore(sess, max_model_path)
            #   sa print(tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            #  ve.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            evalutes[max_uid_full] = {}
            tps, preds, poses, count = 0, 0, 0, 0
            while True:
                try:
                    count += 1
                    y, y_ = sess.run((labels, outputs))
                    y = utils.cvt_np2string(y)
                    y_ = utils.cvt_np2string(y_)
                    tp, pred, pos = utils.evalutes(y_, y)
                    tps += tp
                    preds += pred
                    poses += pos
                #  if count % 1000 == 0:
                #    glog.info('processed %d: tp=%d, pred=%d, pos=%d.' % (count, tps, preds, poses))
                except:
                    #  if count % 1000 != 0:
                    #    glog.info('processed %d: tp=%d, pred=%d, pos=%d.' % (count, tps, preds, poses))
                    break

            evalutes[max_uid_full]['tp'] = tps
            evalutes[max_uid_full]['pred'] = preds
            evalutes[max_uid_full]['pos'] = poses
            evalutes[max_uid_full]['f1'] = 2 * tps / (preds + poses + 1e-20)
            json.dump(
                evalutes,
                open(FLAGS.ckpt_dir + '/evalute.json',
                     mode='w',
                     encoding='utf-8'))
            evalute = evalutes[max_uid_full]
            glog.info('Evalute %s: tp=%d, pred=%d, pos=%d, f1=%f.' %
                      (max_uid_full, evalute['tp'], evalute['pred'],
                       evalute['pos'], evalute['f1']))
            if status == 1:
                time.sleep(60)
            status = 1