Example #1
0
def main():
    args = get_arguments()

    logdir = args.log_dir
    model_dir = args.output_model_path
    restore_dir = args.output_model_path

    train_dir = os.path.join(logdir, STARTED_DATESTRING, 'train')
    dev_dir = os.path.join(logdir, STARTED_DATESTRING, 'dev')
    # directories = validate_directories(args.restore_from, args.overwrite)
    # restore_dir = directories['restore_from']
    # logdir = directories['logdir']
    # dev_dir = directories['dev_dir']

    # dataset
    train_set = tf.data.Dataset.from_generator(
        train_generator,
        output_types=(tf.float32, tf.float32, tf.int32),
        output_shapes=([None, MFCC_DIM], [None, PPG_DIM], []))
    train_set = train_set.padded_batch(args.batch_size,
                                       padded_shapes=([None, MFCC_DIM],
                                                      [None,
                                                       PPG_DIM], [])).repeat()
    train_iterator = train_set.make_initializable_iterator()
    test_set = tf.data.Dataset.from_generator(
        test_generator,
        output_types=(tf.float32, tf.float32, tf.int32),
        output_shapes=([None, MFCC_DIM], [None, PPG_DIM], []))
    test_set = test_set.padded_batch(args.batch_size,
                                     padded_shapes=([None, MFCC_DIM],
                                                    [None,
                                                     PPG_DIM], [])).repeat()
    test_iterator = test_set.make_initializable_iterator()
    dataset_handle = tf.placeholder(tf.string, shape=[])
    dataset_iter = tf.data.Iterator.from_string_handle(dataset_handle,
                                                       train_set.output_types,
                                                       train_set.output_shapes)
    batch_data = dataset_iter.get_next()

    # classifier = DNNClassifier(out_dims=PPG_DIM, hiddens=[256, 256, 256],
    #                            drop_rate=0.2, name='dnn_classifier')
    # classifier = CnnDnnClassifier(out_dims=PPG_DIM, n_cnn=5,
    #                               cnn_hidden=64, dense_hiddens=[256, 256, 256])
    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM,
                                    n_cnn=3,
                                    cnn_hidden=256,
                                    cnn_kernel=3,
                                    n_blstm=2,
                                    lstm_hidden=128)
    results_dict = classifier(batch_data[0], batch_data[1], batch_data[2])
    predicted = tf.nn.softmax(results_dict['logits'])
    mask = tf.sequence_mask(batch_data[2], dtype=tf.float32)
    accuracy = tf.reduce_sum(
        tf.cast(
            tf.equal(tf.argmax(predicted, axis=-1),
                     tf.argmax(batch_data[1], axis=-1)), tf.float32) *
        mask) / tf.reduce_sum(tf.cast(batch_data[2], dtype=tf.float32))
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.image('predicted',
                     tf.expand_dims(tf.transpose(predicted, [0, 2, 1]),
                                    axis=-1),
                     max_outputs=1)
    tf.summary.image('groundtruth',
                     tf.expand_dims(tf.cast(
                         tf.transpose(batch_data[1], [0, 2, 1]), tf.float32),
                                    axis=-1),
                     max_outputs=1)
    loss = results_dict['cross_entropy']
    learning_rate_pl = tf.placeholder(tf.float32, None, 'learning_rate')
    tf.summary.scalar('cross_entropy', loss)
    tf.summary.scalar('learning_rate', learning_rate_pl)
    optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate_pl)
    optim = optimizer.minimize(loss)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    optim = tf.group([optim, update_ops])

    # Set up logging for TensorBoard.
    train_writer = tf.summary.FileWriter(train_dir)
    train_writer.add_graph(tf.get_default_graph())
    dev_writer = tf.summary.FileWriter(dev_dir)
    summaries = tf.summary.merge_all()
    saver = tf.train.Saver(max_to_keep=args.max_ckpts)

    # set up session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    sess.run([train_iterator.initializer, test_iterator.initializer])
    train_handle, test_handle = sess.run(
        [train_iterator.string_handle(),
         test_iterator.string_handle()])
    sess.run(init)
    # try to load saved model
    try:
        saved_global_step = load_model(saver, sess, restore_dir)
        if saved_global_step is None:
            saved_global_step = -1
    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    last_saved_step = saved_global_step
    step = None
    try:
        for step in range(saved_global_step + 1, args.steps):
            if step <= int(4e5):
                lr = args.lr
            elif step <= int(6e5):
                lr = 0.5 * args.lr
            elif step <= int(8e5):
                lr = 0.25 * args.lr
            else:
                lr = 0.125 * args.lr
            start_time = time.time()
            if step % args.ckpt_every == 0:
                summary, loss_value = sess.run([summaries, loss],
                                               feed_dict={
                                                   dataset_handle: test_handle,
                                                   learning_rate_pl: lr
                                               })
                dev_writer.add_summary(summary, step)
                duration = time.time() - start_time
                print(
                    'step {:d} - eval loss = {:.3f}, ({:.3f} sec/step)'.format(
                        step, loss_value, duration))
                save_model(saver, sess, model_dir, step)
                last_saved_step = step
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  feed_dict={
                                                      dataset_handle:
                                                      train_handle,
                                                      learning_rate_pl: lr
                                                  })
                train_writer.add_summary(summary, step)
                if step % 10 == 0:
                    duration = time.time() - start_time
                    print(
                        'step {:d} - training loss = {:.3f}, ({:.3f} sec/step)'
                        .format(step, loss_value, duration))
    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save_model(saver, sess, model_dir, step)
    sess.close()
def main():
    print('start...')
    a = open(meta_path, 'r').readlines()

    # NN->PPG
    # Set up network
    mfcc_pl = tf.placeholder(dtype=tf.float32,
                             shape=[None, None, MFCC_DIM],
                             name='mfcc_pl')
    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM,
                                    n_cnn=3,
                                    cnn_hidden=256,
                                    cnn_kernel=3,
                                    n_blstm=2,
                                    lstm_hidden=128)
    predicted_ppgs = tf.nn.softmax(classifier(inputs=mfcc_pl)['logits'])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print('Restoring model from {}'.format(ckpt_path))
    saver.restore(sess, ckpt_path)

    cnt = 0
    bad_list = []
    for wav_f in tqdm(a):
        try:
            wav_f = wav_f.strip()
            print('process:', wav_f)
            # 提取声学参数
            wav_arr = load_wav(wav_f)
            mfcc_feats = wav2unnormalized_mfcc(wav_arr)
            ppgs = sess.run(
                predicted_ppgs,
                feed_dict={mfcc_pl: np.expand_dims(mfcc_feats, axis=0)})
            ppgs = np.squeeze(ppgs)
            mel_feats = wav2normalized_db_mel(wav_arr)
            spec_feats = wav2normalized_db_spec(wav_arr)

            # /datapool/home/hujk17/ppg_decode_spec_10ms_sch_Multi/inference_findA_Multi_log_dir/2020-11-09T11-09-00/0_sample_spec.wav
            fname = wav_f.split('/')[-1].split('.')[0]
            save_mel_rec_name = fname + '_mel_rec.wav'
            save_spec_rec_name = fname + '_spec_rec.wav'
            assert ppgs.shape[0] == mfcc_feats.shape[0]
            assert mfcc_feats.shape[0] == mel_feats.shape[
                0] and mel_feats.shape[0] == spec_feats.shape[0]
            write_wav(os.path.join(rec_wav_dir, save_mel_rec_name),
                      normalized_db_mel2wav(mel_feats))
            write_wav(os.path.join(rec_wav_dir, save_spec_rec_name),
                      normalized_db_spec2wav(spec_feats))
            check_ppg(ppgs)

            # 存储ppg参数
            ppg_save_name = os.path.join(ppg_dir, fname + '.npy')
            np.save(ppg_save_name, ppgs)

            cnt += 1
        except Exception as e:
            bad_list.append(wav_f)
            print(str(e))

        # break

    print('good:', cnt)
    print('bad:', len(bad_list))
    print(bad_list)

    return
Example #3
0
def main():
    #这一部分用于处理LJSpeech格式的数据集
    a = open(meta_path, 'r').readlines()
    a = [i.strip().split('|')[0] for i in a]

    a = PPG_get_restore(a, ppg_dir, ppg_dir, mel_dir, spec_dir)

    # NN->PPG
    # Set up network
    mfcc_pl = tf.placeholder(dtype=tf.float32, shape=[None, None, MFCC_DIM], name='mfcc_pl')
    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM, n_cnn=3, cnn_hidden=256, cnn_kernel=3, n_blstm=2, lstm_hidden=128)
    predicted_ppgs = tf.nn.softmax(classifier(inputs=mfcc_pl)['logits'])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print('Restoring model from {}'.format(ckpt_path))
    saver.restore(sess, ckpt_path)

    
    cnt = 0
    bad_list = []
    for fname in tqdm(a):
        try:
            # 提取声学参数
            # print('aaaaaaaaaaa111111111111111111111111111')
            wav_f = os.path.join(wav_dir, fname + '.wav')
            wav_arr = load_wav(wav_f)
            # print('0000000000000000')
            mfcc_feats = wav2unnormalized_mfcc(wav_arr)
            # print('111111111111111111111111111')
            ppgs = sess.run(predicted_ppgs, feed_dict={mfcc_pl: np.expand_dims(mfcc_feats, axis=0)})
            # print('5555555111111111111111111111111111')
            ppgs = np.squeeze(ppgs)
            # print('66666666666S111111111111111111111111111')
            mel_feats = wav2normalized_db_mel(wav_arr)
            spec_feats = wav2normalized_db_spec(wav_arr)
            # print('222222222111111111111111111111111111')
            # 验证声学参数提取的对
            save_name = fname + '.npy'
            save_mel_rec_name = fname + '_mel_rec.wav'
            save_spec_rec_name = fname + '_spec_rec.wav'
            assert ppgs.shape[0] == mfcc_feats.shape[0]
            assert mfcc_feats.shape[0] == mel_feats.shape[0] and mel_feats.shape[0] == spec_feats.shape[0]
            write_wav(os.path.join(rec_wav_dir, save_mel_rec_name), normalized_db_mel2wav(mel_feats))
            write_wav(os.path.join(rec_wav_dir, save_spec_rec_name), normalized_db_spec2wav(spec_feats))
            # print('11111111111111333333333331111111111111')
            check_ppg(ppgs)
            
            # 存储声学参数
            mfcc_save_name = os.path.join(mfcc_dir, save_name)
            ppg_save_name = os.path.join(ppg_dir, save_name)
            mel_save_name = os.path.join(mel_dir, save_name)
            spec_save_name = os.path.join(spec_dir, save_name)
            np.save(mfcc_save_name, mfcc_feats)
            np.save(ppg_save_name, ppgs)
            np.save(mel_save_name, mel_feats)
            np.save(spec_save_name, spec_feats)

            f_good_meta.write(fname + '\n')
            cnt += 1
        except Exception as e:
            bad_list.append(fname)
            print(str(e))
        
        # break

    print('good:', cnt)
    print('bad:', len(bad_list))
    print(bad_list)

    return
Example #4
0
def main():
    data_dir = 'LJSpeech-1.1'
    wav_dir = os.path.join(data_dir, 'wavs_16000')
    ppg_dir = os.path.join(data_dir, 'ppg_from_generate_batch')
    mfcc_dir = os.path.join(data_dir, 'mfcc_from_generate_batch')
    linear_dir = os.path.join(data_dir, 'linear_from_generate_batch')

    # model_checkpoint_path: "tacotron_model.ckpt-103000"
    ckpt_path = 'LibriSpeech_ckpt_model_zhaoxt_dir/vqvae.ckpt-233000'

    if not os.path.isdir(wav_dir):
        raise ValueError('wav directory not exists!')
    if not os.path.isdir(mfcc_dir):
        print(
            'MFCC save directory not exists! Create it as {}'.format(mfcc_dir))
        os.makedirs(mfcc_dir)
    if not os.path.isdir(linear_dir):
        print('Linear save directory not exists! Create it as {}'.format(
            linear_dir))
        os.makedirs(linear_dir)
    if not os.path.isdir(ppg_dir):
        print('PPG save directory not exists! Create it as {}'.format(ppg_dir))
        os.makedirs(ppg_dir)

    # get wav file path list
    wav_list = [
        os.path.join(wav_dir, f) for f in os.listdir(wav_dir)
        if f.endswith('.wav') or f.endswith('.WAV')
    ]
    mfcc_list = [
        os.path.join(mfcc_dir,
                     f.split('.')[0] + '.npy') for f in os.listdir(wav_dir)
        if f.endswith('.wav') or f.endswith('.WAV')
    ]
    linear_list = [
        os.path.join(linear_dir,
                     f.split('.')[0] + '.npy') for f in os.listdir(wav_dir)
        if f.endswith('.wav') or f.endswith('.WAV')
    ]
    ppg_list = [
        os.path.join(ppg_dir,
                     f.split('.')[0] + '.npy') for f in os.listdir(wav_dir)
        if f.endswith('.wav') or f.endswith('.WAV')
    ]

    # Set up network
    mfcc_pl = tf.placeholder(dtype=tf.float32,
                             shape=[None, None, MFCC_DIM],
                             name='mfcc_pl')
    # classifier = DNNClassifier(out_dims=PPG_DIM, hiddens=[256, 256, 256],
    #                            drop_rate=0.2, name='dnn_classifier')
    # classifier = CnnDnnClassifier(out_dims=PPG_DIM, n_cnn=5,
    #                               cnn_hidden=64, dense_hiddens=[256, 256, 256])
    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM,
                                    n_cnn=3,
                                    cnn_hidden=256,
                                    cnn_kernel=3,
                                    n_blstm=2,
                                    lstm_hidden=128)
    predicted_ppgs = tf.nn.softmax(classifier(inputs=mfcc_pl)['logits'])

    # set up a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    start_time = time.time()
    sess.run(tf.global_variables_initializer())
    # load saved model
    saver = tf.train.Saver()
    # sess.run(tf.global_variables_initializer())
    print('Restoring model from {}'.format(ckpt_path))
    saver.restore(sess, ckpt_path)
    for wav_f, mfcc_f, linear_f, ppg_f in tqdm(
            zip(wav_list, mfcc_list, linear_list, ppg_list)):
        wav_arr = load_wav(wav_f)
        mfcc = wav2mfcc(wav_arr)
        linear = wav2linear_for_ppg_cbhg(wav_arr)
        ppgs = sess.run(predicted_ppgs,
                        feed_dict={mfcc_pl: np.expand_dims(mfcc, axis=0)})

        assert mfcc.shape[0] == (
            np.squeeze(ppgs)).shape[0] and linear.shape[0] == mfcc.shape[0]
        np.save(mfcc_f, mfcc)
        np.save(linear_f, linear)
        np.save(ppg_f, np.squeeze(ppgs))
        # break
    duration = time.time() - start_time
    print("PPGs file generated in {:.3f} seconds".format(duration))
    sess.close()
Example #5
0
def main():
    train_dir = os.path.join(logdir, STARTED_DATESTRING, 'train')
    dev_dir = os.path.join(logdir, STARTED_DATESTRING, 'dev')

    # dataset
    train_set = tf.data.Dataset.from_generator(
        train_generator,
        output_types=(tf.float32, tf.float32, tf.int32),
        output_shapes=([None, MFCC_DIM], [None, AiShell1_PPG_DIM], []))
    #padding train data
    train_set = train_set.padded_batch(BATCH_SIZE,
                                       padded_shapes=([None, MFCC_DIM],
                                                      [None, AiShell1_PPG_DIM
                                                       ], [])).repeat()

    train_iterator = train_set.make_initializable_iterator()

    test_set = tf.data.Dataset.from_generator(
        test_generator,
        output_types=(tf.float32, tf.float32, tf.int32),
        output_shapes=([None, MFCC_DIM], [None, AiShell1_PPG_DIM], []))

    #设置repeat(),在get_next循环中,如果越界了就自动循环。不计上限
    test_set = test_set.padded_batch(BATCH_SIZE,
                                     padded_shapes=([None, MFCC_DIM],
                                                    [None, AiShell1_PPG_DIM
                                                     ], [])).repeat()
    test_iterator = test_set.make_initializable_iterator()

    #创建一个handle占位符,在sess.run该handle迭代器的next()时,可以送入一个feed_dict 代表handle占位符,从而调用对应的迭代器
    dataset_handle = tf.placeholder(tf.string, shape=[])
    dataset_iter = tf.data.Iterator.from_string_handle(dataset_handle,
                                                       train_set.output_types,
                                                       train_set.output_shapes)
    batch_data = dataset_iter.get_next()

    # 调用模型
    classifier = CNNBLSTMCalssifier(out_dims=AiShell1_PPG_DIM,
                                    n_cnn=3,
                                    cnn_hidden=256,
                                    cnn_kernel=3,
                                    n_blstm=2,
                                    lstm_hidden=128)
    # 模型output
    results_dict = classifier(batch_data[0], batch_data[1], batch_data[2])
    #inputs labels lengths
    #results_dict['logits']= np.zeros([10])
    predicted = tf.nn.softmax(results_dict['logits'])
    mask = tf.sequence_mask(
        batch_data[2], dtype=tf.float32
    )  #batch[2]是(None,),是每条数据的MFCC数目,需要这个的原因是会填充成最长的MFCC长度。mask的维度是(None,max(batch[2]))

    #batch_data[2]是MFCC数组的长度,MFCC有多少是不一定的。
    accuracy = tf.reduce_sum(
        tf.cast(  #bool转float
            tf.equal(
                tf.argmax(predicted, axis=-1),  #比较每一行的最大元素
                tf.argmax(batch_data[1], axis=-1)),
            tf.float32) *
        mask  #乘上mask,是因为所有数据都被填充为最多mfcc的维度了。所以填充部分一定都是相等的,于是需要将其mask掉。
    ) / tf.reduce_sum(tf.cast(batch_data[2], dtype=tf.float32))

    tf.summary.scalar('accuracy', accuracy)
    tf.summary.image('predicted',
                     tf.expand_dims(tf.transpose(predicted, [0, 2, 1]),
                                    axis=-1),
                     max_outputs=1)
    tf.summary.image('groundtruth',
                     tf.expand_dims(tf.cast(
                         tf.transpose(batch_data[1], [0, 2, 1]), tf.float32),
                                    axis=-1),
                     max_outputs=1)

    loss = results_dict['cross_entropy']
    learning_rate_pl = tf.placeholder(tf.float32, None, 'learning_rate')
    tf.summary.scalar('cross_entropy', loss)
    tf.summary.scalar('learning_rate', learning_rate_pl)
    optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate_pl)
    optim = optimizer.minimize(loss)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    optim = tf.group([optim, update_ops])

    # Set up logging for TensorBoard.
    train_writer = tf.summary.FileWriter(train_dir)
    train_writer.add_graph(tf.get_default_graph())
    dev_writer = tf.summary.FileWriter(dev_dir)
    summaries = tf.summary.merge_all()
    #设置将所有的summary保存 run这个会将所有的summary更新
    saver = tf.train.Saver(max_to_keep=MAX_TO_SAVE)

    # set up session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    sess.run([train_iterator.initializer, test_iterator.initializer])
    train_handle, test_handle = sess.run(
        [train_iterator.string_handle(),
         test_iterator.string_handle()])
    sess.run(init)
    # try to load saved model
    try:
        saved_global_step = load_model(saver, sess, restore_dir)
        if saved_global_step is None:
            saved_global_step = 0
    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise
    last_saved_step = saved_global_step
    step = None
    try:
        for step in range(saved_global_step + 1, STEPS):
            if step <= int(4e5):
                lr = LEARNING_RATE
            elif step <= int(6e5):
                lr = 0.5 * LEARNING_RATE
            elif step <= int(8e5):
                lr = 0.25 * LEARNING_RATE
            else:
                lr = 0.125 * LEARNING_RATE
            start_time = time.time()
            if step % CKPT_EVERY == 0:
                summary, loss_value = sess.run([summaries, loss],
                                               feed_dict={
                                                   dataset_handle: test_handle,
                                                   learning_rate_pl: lr
                                               })
                dev_writer.add_summary(summary, step)
                duration = time.time() - start_time
                print(
                    'step {:d} - eval loss = {:.3f}, ({:.3f} sec/step)'.format(
                        step, loss_value, duration))
                save_model(saver, sess, model_dir, step)
                last_saved_step = step
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  feed_dict={
                                                      dataset_handle:
                                                      train_handle,
                                                      learning_rate_pl: lr
                                                  })
                train_writer.add_summary(summary, step)
                if step % 10 == 0:
                    duration = time.time() - start_time
                    print(
                        'step {:d} - training loss = {:.3f}, ({:.3f} sec/step)'
                        .format(step, loss_value, duration))
    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save_model(saver, sess, model_dir, step)
    sess.close()
Example #6
0
def main():

    test_set = tf.data.Dataset.from_generator(
        test_generator,
        output_types=(tf.float32, tf.float32, tf.int32),
        output_shapes=([None, MFCC_DIM], [None, PPG_DIM], []))

    # 设置repeat(),在get_next循环中,如果越界了就自动循环。不计上限
    test_set = test_set.padded_batch(
        1,  #args.batch_size,
        padded_shapes=([None, MFCC_DIM], [None, PPG_DIM], []))  #.repeat()
    test_iterator = test_set.make_initializable_iterator()

    batch_data = test_iterator.get_next()

    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM,
                                    n_cnn=3,
                                    cnn_hidden=256,
                                    cnn_kernel=3,
                                    n_blstm=2,
                                    lstm_hidden=128)

    results_dict = classifier(batch_data[0], batch_data[1],
                              batch_data[2])  #如果是生成,则batch1和batch2就不需要了
    predicted = tf.nn.softmax(results_dict['logits'])
    mask = tf.sequence_mask(batch_data[2], dtype=tf.float32)
    accuracy = tf.reduce_sum(  #如果是生成,则不需要accuracy了
        tf.cast(  # bool转float
            tf.equal(
                tf.argmax(predicted, axis=-1),  # 比较每一行的最大元素
                tf.argmax(batch_data[1], axis=-1)),
            tf.float32) *
        mask  # 乘上mask,是因为所有数据都被填充为最多mfcc的维度了。所以填充部分一定都是相等的,于是需要将其mask掉。
    ) / tf.reduce_sum(tf.cast(batch_data[2], dtype=tf.float32))

    tf.summary.scalar('accuracy', accuracy)
    tf.summary.image('predicted',
                     tf.expand_dims(tf.transpose(predicted, [0, 2, 1]),
                                    axis=-1),
                     max_outputs=1)
    tf.summary.image('groundtruth',
                     tf.expand_dims(tf.cast(
                         tf.transpose(batch_data[1], [0, 2, 1]), tf.float32),
                                    axis=-1),
                     max_outputs=1)

    loss = results_dict['cross_entropy']

    tf.summary.scalar('cross_entropy', loss)

    #saver = tf.train.import_meta_graph(Model_Path + 'vqvae.ckpt-62000.meta')  # 读取图结构
    saver = tf.train.Saver()

    init = tf.global_variables_initializer()
    with tf.Session(config=config) as sess:

        sess.run(test_iterator.initializer)
        sess.run(init)
        saver.restore(sess, Model_Path + 'vqvae.ckpt-233000')
        print("start")
        a = 0
        count = 0
        max = 0
        min = 120
        while a < 800:  #True:
            a = a + 1
            acc = sess.run(accuracy)
            if max < acc:
                max = acc
            if min > acc:
                min = acc
            count = count + sess.run(accuracy)

            #print(sess.run(accuracy))
            #a = sess.run(predicted)
            #np.savetxt('./answer.txt',sess.run(predicted),)
            #print(saver)

        print("max: " + str(max))
        print("min: " + str(min))
        print("average:" + str(count / a))