Exemplo n.º 1
0
def train_model(data_path, config):
    x, y = reader.read_dataset(data_path)
    config = [int(i) for i in config[1:-1].split(' ')]
    bg = bgen.BatchGenerator(x, y, 200)
    nn = nnet.NNet(x.shape[1], y.shape[1], config)
    nn.minibatch_train(bg, 5)
    return nn
Exemplo n.º 2
0
def run(filename, seconds=10, mutation_rate=0.2, selector=selection.elite):
    dataset, volume = reader.read_dataset(filename)
    pop = population.gen_population(len(dataset), 4)

    alpha = None
    i = 0
    target_time = time.time() + seconds
    while time.time() < target_time:
        parents, fitness = zip(
            *selector(pop, population.fitness(pop, dataset, volume), items=2))

        max_fit = (max(fitness), parents[fitness.index(max(fitness))])

        print('gen:', i, 'fitness:', max_fit[0], max_fit[1])
        i += 0

        if alpha == None or max_fit[0] > alpha[0]:
            alpha = max_fit

        population.mutate(parents, 0.2)
        pop = population.cross_pop(parents)

    # Printing the best bag
    output_file = open("output.txt", "w")
    for i in population.bag(alpha[1], dataset, volume):
        print(i, file=output_file)

    print('golden individual:', alpha)
Exemplo n.º 3
0
                        '-b',
                        default=13,
                        type=int,
                        help='Number of records in a batch')
    parser.add_argument('--trainsize',
                        '-t',
                        default=390,
                        type=int,
                        help='Number of training records in whole dataset')
    args = parser.parse_args()
    # 確認進捗表示用
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                        level=logging.INFO)
    # 訓練用データの数をバッチサイズの定数倍で決める
    train_size = args.trainsize
    # データ読み取り
    (train_data, test_data,
     info) = reader.read_dataset(train_size,
                                 normalize=args.normalize,
                                 scale=args.scale)
    num_inputs = info["SHAPE_TRAIN_X"][1]  # 入力層の要素数
    # 確率的勾配降下法で学習させる際の1回分のバッチサイズ
    batch_size = args.batchsize
    # 学習の繰り返し回数
    num_epoch = args.epoch
    # オプティマイザー指定
    optimizer = select_optimizer(args.optimizer)
    # モデル実行
    execute(args.network, train_data, test_data, batch_size, num_inputs,
            optimizer, num_epoch, args.gpu, args.chart)
Exemplo n.º 4
0
def main():
    # tensorflow input and output
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 1],
                                name="annotation")

    pred_annotation, logits = inference(image, keep_probability)

    # Summary
    print('====================================================')
    tf.summary.image("input_image", image, max_outputs=4)
    tf.summary.image("ground_truth",
                     tf.cast(annotation * 255, tf.uint8),
                     max_outputs=4)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation * 255, tf.uint8),
                     max_outputs=4)
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("train_entropy", loss)

    trainable_var = tf.trainable_variables()
    if args.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("> [FCN] Setting up summary op...")
    summary_op = tf.summary.merge_all()

    # Validation summary
    val_summary = tf.summary.scalar("validation_entropy", loss)

    # Read data
    print("> [FCN] Setting up image reader...")
    train_records, valid_records = read_dataset(args.data_dir)
    print('> [FCN] Train len:', len(train_records))
    print('> [FCN] Val len:', len(valid_records))

    t = timer.Timer()  # Qhan's timer

    if args.mode != 'test':
        print("> [FCN] Setting up dataset reader")
        image_options = {
            'resize': True,
            'resize_height': IMAGE_HEIGHT,
            'resize_width': IMAGE_WIDTH
        }
        if args.mode == 'train':
            t.tic()
            train_dataset_reader = dataset.BatchDatset(train_records,
                                                       image_options,
                                                       mode='train')
            load_time = t.toc()
            print('> [FCN] Train data set loaded. %.4f ms' % (load_time))
        t.tic()
        validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                        image_options,
                                                        mode='val')
        load_time = t.toc()
        print('> [FCN] Validation data set loaded. %.4f ms' % (load_time))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90,
                                allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    # Initialize model
    print("> [FCN] Setting up Saver...", flush=True)
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(args.logs_dir, sess.graph)

    print("> [FCN] Initialize variables... ", flush=True, end='')
    t.tic()
    sess.run(tf.global_variables_initializer())
    print('%.4f ms' % (t.toc()))

    t.tic()
    ckpt = tf.train.get_checkpoint_state(args.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("> [FCN] Model restored..." + ckpt.model_checkpoint_path +
              ', %.4f ms' % (t.toc()))

    print('==================================================== [%s]' %
          args.mode)

    if args.mode == 'train':
        np.random.seed(1028)
        start = args.start_iter
        end = start + args.iter + 1
        for itr in range(start, end):

            # Read batch data
            train_images, train_annotations = train_dataset_reader.next_batch(
                args.batch_size)
            images = np.zeros_like(train_images)
            annotations = np.zeros_like(train_annotations)

            # Data augmentation
            for i, (im, ann) in enumerate(zip(train_images,
                                              train_annotations)):
                flip_prob = np.random.random()
                aug_type = np.random.randint(0, 3)
                randoms = np.random.random(2)
                images[i] = augment(im, flip_prob, aug_type, randoms)
                annotations[i] = augment(ann, flip_prob, aug_type, randoms)

            t.tic()
            feed_dict = {
                image: images,
                annotation: annotations,
                keep_probability: 0.85
            }
            sess.run(train_op, feed_dict=feed_dict)
            train_time = t.toc()

            if itr % 10 == 0 and itr > 10:
                train_loss, summary_str = sess.run([loss, summary_op],
                                                   feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, itr)
                print("[%6d], Train_loss: %g, %.4f ms" %
                      (itr, train_loss, train_time),
                      flush=True)

            if itr % 100 == 0 and itr != 0:
                valid_images, valid_annotations = validation_dataset_reader.next_batch(
                    args.batch_size * 2)
                val_feed_dict = {
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                }
                t.tic()
                val_loss, val_str = sess.run([loss, val_summary],
                                             feed_dict=val_feed_dict)
                val_time = t.toc()
                summary_writer.add_summary(val_str, itr)
                print("[%6d], Validation_loss: %g, %.4f ms" %
                      (itr, val_loss, val_time))

            if itr % 1000 == 0 and itr != 0:
                saver.save(sess, args.logs_dir + "model.ckpt", itr)

    elif args.mode == 'visualize':
        for itr in range(20):
            valid_images, valid_annotations = validation_dataset_reader.get_random_batch(
                1)
            t.tic()
            pred = sess.run(pred_annotation,
                            feed_dict={
                                image: valid_images,
                                keep_probability: 1.0
                            })
            val_time = t.toc()

            valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)

            utils.save_image(valid_images[0].astype(np.uint8),
                             args.res_dir,
                             name="inp_" + str(itr))
            utils.save_image(valid_annotations[0].astype(np.uint8),
                             args.res_dir,
                             name="gt_" + str(itr))
            utils.save_image(pred[0].astype(np.uint8),
                             args.res_dir,
                             name="pred_" + str(itr))
            print("> [FCN] Saved image: %d, %.4f ms" % (itr, val_time))

    elif args.mode == 'test':
        testlist = args.testlist
        images, names, (H, W) = read_test_data(testlist, IMAGE_HEIGHT,
                                               IMAGE_WIDTH)
        for i, (im, name) in enumerate(zip(images, names)):

            t.tic()
            pred = sess.run(pred_annotation,
                            feed_dict={
                                image: im.reshape((1, ) + im.shape),
                                keep_probability: 1.0
                            })
            test_time = t.toc()

            pred = pred.reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
            if args.video:
                save_video_image(im, pred,
                                 args.res_dir + '/pred_%05d' % (i) + '.png', H,
                                 W)
            else:
                misc.imsave(args.res_dir + '/inp_%d' % (i) + '.png',
                            im.astype(np.uint8))
                misc.imsave(args.res_dir + '/pred_%d' % (i) + '.png',
                            pred.astype(np.uint8))
            print('> [FCN] Img: %d,' % (i) + ' Name: ' + name +
                  ', %.4f ms' % test_time)

    else:
        pass
Exemplo n.º 5
0
    def run(self, TrainingModel):

        graph = tf.Graph()
        with graph.as_default(), tf.Session() as sess:

            self.data = read_dataset(self.config)

            if config.mode == 'train':
                print('building training model....')
                with tf.variable_scope("model"):
                    self.train_model = TrainingModel(self.config,
                                                     self.data.batch_input_queue(),
                                                     is_train=True)
                    self.train_model.config.show()
                print('building valid model....')
                with tf.variable_scope("model", reuse=True):
                    self.valid_model = TrainingModel(self.config,
                                                     self.data.valid_queue(),
                                                     is_train=False)
            else:
                with tf.variable_scope("model", reuse=False):
                    self.valid_model = TrainingModel(self.config,
                                                     self.data.valid_queue(),
                                                     is_train=False)
            saver = tf.train.Saver()

            # restore from stored models
            files = glob(path_join(self.config.model_path, '*.ckpt.*'))

            if len(files) > 0:

                saver.restore(sess, path_join(self.config.model_path,
                                              self.config.model_name))
                print(('Model restored from:' + self.config.model_path))
            else:
                print("Model doesn't exist.\nInitializing........")
                sess.run(tf.global_variables_initializer())

            sess.run(tf.local_variables_initializer())
            tf.Graph.finalize(graph)

            st_time = time.time()
            if os.path.exists(path_join(self.config.model_path, 'best.pkl')):
                with open(path_join(self.config.model_path, 'best.pkl'),
                          'rb') as f:
                    best_miss, best_false = pickle.load(f)
                    print('best miss', best_miss, 'best false', best_false)
            else:
                print('best not exist')

            check_dir(self.config.model_path)

            if self.config.mode == 'train':
                best_miss = 1
                best_false = 1
                accu_loss = 0
                epoch_step = config.tfrecord_size * self.data.train_file_size // config.batch_size

                if self.config.reset_global:
                    sess.run(self.train_model.reset_global_step)

                def handler_stop_signals(signum, frame):
                    global run
                    run = False
                    if not DEBUG:
                        print(
                            'training shut down, total setp %s, the model will be save in %s' % (
                                step, self.config.model_path))
                        saver.save(sess, save_path=(
                            path_join(self.config.model_path, 'latest.ckpt')))
                        print('best miss rate:%f\tbest false rate %f' % (
                            best_miss, best_false))
                    sys.exit(0)

                signal.signal(signal.SIGINT, handler_stop_signals)
                signal.signal(signal.SIGTERM, handler_stop_signals)

                best_list = []
                best_threshold = 0.08
                best_count = 0
                # (miss,false,step,best_count)

                last_time = time.time()

                try:
                    sess.run([self.data.noise_stage_op,
                              self.data.noise_filequeue_enqueue_op,
                              self.train_model.stage_op,
                              self.train_model.input_filequeue_enqueue_op,
                              self.valid_model.stage_op,
                              self.valid_model.input_filequeue_enqueue_op])

                    va = tf.trainable_variables()
                    for i in va:
                        print(i.name)
                    while self.epoch < self.config.max_epoch:

                        _, _, _, _, _, l, lr, step, grads = sess.run(
                            [self.train_model.train_op,
                             self.data.noise_stage_op,
                             self.data.noise_filequeue_enqueue_op,
                             self.train_model.stage_op,
                             self.train_model.input_filequeue_enqueue_op,
                             self.train_model.loss,
                             self.train_model.learning_rate,
                             self.train_model.global_step,
                             self.train_model.grads
                             ])
                        epoch = step // epoch_step
                        accu_loss += l
                        if epoch > self.epoch:
                            self.epoch = epoch
                            print('accumulated loss', accu_loss)
                            saver.save(sess, save_path=(
                                path_join(self.config.model_path,
                                          'latest.ckpt')))
                            print('latest.ckpt save in %s' % (
                                path_join(self.config.model_path,
                                          'latest.ckpt')))
                            accu_loss = 0
                        if step % config.valid_step == 0:
                            print('epoch time ', (time.time() - last_time) / 60)
                            last_time = time.time()

                            miss_count = 0
                            false_count = 0
                            target_count = 0
                            wer = 0
                            valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size
                            text = ""
                            for i in range(valid_batch):
                                softmax, correctness, labels, _, _ = sess.run(
                                    [self.valid_model.softmax,
                                     self.valid_model.correctness,
                                     self.valid_model.labels,
                                     self.valid_model.stage_op,
                                     self.valid_model.input_filequeue_enqueue_op])
                                np.set_printoptions(precision=4,
                                                    threshold=np.inf,
                                                    suppress=True)

                                decode_output = [ctc_decode(s) for s in softmax]
                                for i in decode_output:
                                    text += str(i) + '\n'
                                    text += str(labels) + '\n'
                                    text += '=' * 20 + '\n'
                                result = [ctc_predict(seq, config.label_seqs)
                                          for seq in
                                          decode_output]
                                miss, target, false_accept = evaluate(
                                    result, correctness.tolist())

                                miss_count += miss
                                target_count += target
                                false_count += false_accept

                                wer += self.wer_cal.cal_batch_wer(labels,
                                                                  decode_output).sum()
                                # print(miss_count, false_count)
                            with open('./valid.txt', 'w') as f:
                                f.write(text)

                            miss_rate = miss_count / target_count
                            false_accept_rate = false_count / (
                                self.data.validation_size - target_count)
                            print('--------------------------------')
                            print('epoch %d' % self.epoch)
                            print('training loss:' + str(l))
                            print('learning rate:', lr, 'global step', step)
                            print('miss rate:' + str(miss_rate))
                            print('flase_accept_rate:' + str(false_accept_rate))
                            print(miss_count, '/', target_count)
                            print('wer', wer / self.data.validation_size)

                            if miss_rate + false_accept_rate < best_miss + best_false:
                                best_miss = miss_rate
                                best_false = false_accept_rate
                                saver.save(sess,
                                           save_path=(path_join(
                                               self.config.model_path,
                                               'best.ckpt')))
                                with open(path_join(
                                        self.config.model_path, 'best.pkl'),
                                        'wb') as f:
                                    best_tuple = (best_miss, best_false)
                                    pickle.dump(best_tuple, f)
                            if miss_rate + false_accept_rate < best_threshold:
                                best_count += 1
                                print('best_count', best_count)
                                best_list.append((miss_rate,
                                                  false_accept_rate, step,
                                                  best_count))
                                saver.save(sess,
                                           save_path=(path_join(
                                               self.config.model_path,
                                               'best' + str(
                                                   best_count) + '.ckpt')))

                    print(
                        'training finished, total epoch %d, the model will be save in %s' % (
                            self.epoch, self.config.model_path))
                    saver.save(sess, save_path=(
                        path_join(self.config.model_path, 'latest.ckpt')))
                    print('best miss rate:%f\tbest false rate"%f' % (
                        best_miss, best_false))

                except tf.errors.OutOfRangeError:
                    print('Done training -- epoch limit reached')
                except Exception as e:
                    print(e)
                    traceback.print_exc()
                finally:
                    with open('best_list.pkl', 'wb') as f:
                        pickle.dump(best_list, f)
                    print('total time:%f hours' % (
                        (time.time() - st_time) / 3600))
                    # When done, ask the threads to stop.

            else:
                with open(
                                        config.rawdata_path + 'valid/' + "ctc_valid.pkl.sorted",
                                        'rb') as f:
                    pkl = pickle.load(f)
                miss_count = 0
                false_count = 0
                target_count = 0

                valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size

                for i in range(valid_batch):
                    # if i > 7:
                    #     break
                    ind = 14
                    softmax, ctc_input, correctness, labels, _, _ = sess.run(
                        [self.valid_model.softmax,
                         self.valid_model.nn_outputs,
                         self.valid_model.correctness,
                         self.valid_model.labels,
                         self.valid_model.stage_op,
                         self.valid_model.input_filequeue_enqueue_op])
                    np.set_printoptions(precision=4,
                                        threshold=np.inf,
                                        suppress=True)

                    correctness = correctness.tolist()
                    decode_output = [ctc_decode(s) for s in softmax]
                    result = [ctc_predict(seq, config.label_seqs)
                              for seq in
                              decode_output]
                    for k, r in enumerate(result):
                        if r != correctness[k]:
                            name = pkl[i * config.batch_size + k][0]
                            print("scp [email protected]:/ssd/keyword_raw/valid/%s ./"%name)
                            # print(pkl[i * config.batch_size + k])
                            # print(decode_output[k])
                            # print(labels[k])
                            with open('logits.txt', 'w') as f:
                                f.write(str(ctc_input[k]))

                    miss, target, false_accept = evaluate(
                        result, correctness)

                    miss_count += miss
                    target_count += target
                    false_count += false_accept

                print('--------------------------------')
                print('miss rate: %d/%d' % (miss_count, target_count))
                print('flase_accept_rate: %d/%d' % (
                    false_count, self.data.validation_size - target_count))
Exemplo n.º 6
0
if __name__ == '__main__':
    # 引数
    # 例: python main.py -n dA -sc -c
    parser = argparse.ArgumentParser(description='Chainer example')
    parser.add_argument('--gpu', '-g', default=-1, type=int, help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--network', '-n', default="MLP", type=str, help='Network structure')
    parser.add_argument('--optimizer', '-o', default="AdaDelta", type=str, help='Network optimizer')
    parser.add_argument('--normalize', '-nm', default=False, action='store_true', help='Apply normalizing to [0, 1]')
    parser.add_argument('--scale', '-sc', default=False, action='store_true', help='Apply scaling with mean and standard deviation')
    parser.add_argument('--chart', '-c', default=False, action='store_true', help='Draw and save charts')
    parser.add_argument('--epoch', '-e', default=100, type=int, help='Number of learning epoches')
    parser.add_argument('--batchsize', '-b', default=13, type=int, help='Number of records in a batch')
    parser.add_argument('--trainsize', '-t', default=390, type=int, help='Number of training records in whole dataset')
    args = parser.parse_args()
    # 確認進捗表示用
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
    # 訓練用データの数をバッチサイズの定数倍で決める
    train_size = args.trainsize
    # データ読み取り
    (train_data, test_data, info) = reader.read_dataset(train_size, normalize=args.normalize, scale=args.scale)
    num_inputs = info["SHAPE_TRAIN_X"][1] # 入力層の要素数
    # 確率的勾配降下法で学習させる際の1回分のバッチサイズ
    batch_size = args.batchsize
    # 学習の繰り返し回数
    num_epoch = args.epoch
    # オプティマイザー指定
    optimizer = select_optimizer(args.optimizer)
    # モデル実行
    execute(args.network, train_data, test_data, batch_size, num_inputs, optimizer, num_epoch, args.gpu, args.chart)

Exemplo n.º 7
0

def train_model(data_path, config):
    x, y = reader.read_dataset(data_path)
    config = [int(i) for i in config[1:-1].split(' ')]
    bg = bgen.BatchGenerator(x, y, 200)
    nn = nnet.NNet(x.shape[1], y.shape[1], config)
    nn.minibatch_train(bg, 5)
    return nn


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--test-data', type=str)
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--train-data', type=str)
    parser.add_argument('--configuration', type=str)
    args = parser.parse_args()

    #config = [int(i) for i in args.configuration[1:-1].split(' ')]
    #print(args.test_data, args.dataset, args.train_data, config)

    if args.train_data:
        nn = train_model(args.train_data, args.configuration)
    else:
        nn = load_model(args.dataset)

    test_data = reader.read_dataset(
        args.test_data, resize=(28, 28) if args.dataset == 'Cat-Dog' else None)
    test_model(test_data, nn)
Exemplo n.º 8
0
    def run(self, TrainingModel):

        graph = tf.Graph()
        with graph.as_default(), tf.Session() as sess:

            self.data = read_dataset(self.config)

            if config.mode == 'train':
                print('building training model....')
                with tf.variable_scope("model"):
                    self.train_model = TrainingModel(self.config,
                                                     self.data.batch_input_queue(),
                                                     is_train=True)
                    self.train_model.config.show()
                print('building valid model....')
                with tf.variable_scope("model", reuse=True):
                    self.valid_model = TrainingModel(self.config,
                                                     self.data.valid_queue(),
                                                     is_train=False)
            else:
                with tf.variable_scope("model", reuse=False):
                    self.valid_model = TrainingModel(self.config,
                                                     self.data.valid_queue(),
                                                     is_train=False)
            saver = tf.train.Saver()

            # restore from stored models
            files = glob(path_join(self.config.model_path, '*.ckpt.*'))

            if len(files) > 0:

                saver.restore(sess, path_join(self.config.model_path,
                                              self.config.model_name))
                print(('Model restored from:' + self.config.model_path))
            else:
                print("Model doesn't exist.\nInitializing........")
                sess.run(tf.global_variables_initializer())

            sess.run(tf.local_variables_initializer())
            tf.Graph.finalize(graph)

            best_accuracy = 1
            accu_loss = 0
            st_time = time.time()
            epoch_step = config.tfrecord_size * self.data.train_file_size // config.batch_size
            if os.path.exists(path_join(self.config.save_path, 'best.pkl')):
                with open(path_join(self.config.save_path, 'best.pkl'),
                          'rb') as f:
                    best_accuracy = pickle.load(f)
                    print('best accuracy', best_accuracy)
            else:
                print('best not exist')

            check_dir(self.config.save_path)

            if self.config.mode == 'train':

                if self.config.reset_global:
                    sess.run(self.train_model.reset_global_step)

                def handler_stop_signals(signum, frame):
                    global run
                    run = False
                    if not DEBUG:
                        print(
                            'training shut down, total setp %s, the model will be save in %s' % (
                                step, self.config.save_path))
                        saver.save(sess, save_path=(
                            path_join(self.config.save_path, 'latest.ckpt')))
                        print('best accuracy', best_accuracy)
                    sys.exit(0)

                signal.signal(signal.SIGINT, handler_stop_signals)
                signal.signal(signal.SIGTERM, handler_stop_signals)

                best_list = []
                best_threshold = 0.1
                best_count = 0
                # (miss,false,step,best_count)

                last_time = time.time()

                try:
                    sess.run([self.data.noise_stage_op,
                              self.data.noise_filequeue_enqueue_op,
                              self.train_model.stage_op,
                              self.train_model.input_filequeue_enqueue_op,
                              self.valid_model.stage_op,
                              self.valid_model.input_filequeue_enqueue_op])

                    va = tf.trainable_variables()
                    for i in va:
                        print(i.name)
                    while self.epoch < self.config.max_epoch:
                        _, _, _, _, _, l, lr, step, grads, vs = sess.run(
                            [self.train_model.train_op,
                             self.data.noise_stage_op,
                             self.data.noise_filequeue_enqueue_op,
                             self.train_model.stage_op,
                             self.train_model.input_filequeue_enqueue_op,
                             self.train_model.loss,
                             self.train_model.learning_rate,
                             self.train_model.global_step,
                             self.train_model.grads,
                             self.train_model.vs
                             ])
                        # print('-'*30)
                        # for i in vs:
                        #     print(i.sum())
                        # print('grads',len(grads))
                        # print('='*30)
                        # for g in grads:
                        #     print(g.sum())
                        epoch = step // epoch_step
                        accu_loss += l
                        if epoch > self.epoch:
                            self.epoch = epoch
                            print('accumulated loss', accu_loss)
                            saver.save(sess, save_path=(
                                path_join(self.config.save_path,
                                          'latest.ckpt')))
                            print('latest.ckpt save in %s' % (
                                path_join(self.config.save_path,
                                          'latest.ckpt')))
                            accu_loss = 0
                        if step % config.valid_step == 2:
                            print('epoch time ', (time.time() - last_time) / 60)
                            last_time = time.time()

                            total_speech = 0
                            total_silence = 0
                            total_miss = 0
                            total_false = 0
                            valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size
                            text = ""
                            for i in range(valid_batch):
                                soft, labels, seqlen, _, _ = sess.run(
                                    [
                                        self.valid_model.softmax,
                                        self.valid_model.labels,
                                        self.valid_model.seqLengths,
                                        self.valid_model.stage_op,
                                        self.valid_model.input_filequeue_enqueue_op])
                                np.set_printoptions(precision=4,
                                                    threshold=np.inf,
                                                    suppress=True)
                                # print('-----------------')
                                #
                                # print(labels[0])
                                # print(names[0].decode())
                                # print(self.valid_set[i * config.batch_size])
                                # for i in names:
                                #     print(i.decode())
                                # print(softmax.shape)
                                # for i, j in zip(soft, labels):
                                #     print(np.concatenate((i, j), 1))
                                #
                                # for i in range(len(soft)):
                                #     soft[i][seqlen[i]:] = 0

                                logits = posterior_predict(soft, config.thres)
                                target_speech, target_silence, miss, false_trigger = frame_accurcacy(
                                    logits, labels, seqlen)
                                total_speech += target_speech
                                total_silence += target_silence
                                total_miss += miss
                                total_false += false_trigger

                            miss_rate = round(total_miss / total_speech, 4)
                            false_rate = round(total_false / total_silence, 4)
                            print('--------------------------------')
                            print('epoch %d' % self.epoch)
                            print('training loss:' + str(l))
                            print('learning rate:', lr, 'global step', step)
                            print('miss rate:' + str(miss_rate))
                            print('false rate: ' + str(false_rate))

                            if miss_rate + false_trigger < best_accuracy:
                                best_accuracy = miss_rate + false_trigger
                                saver.save(sess,
                                           save_path=(path_join(
                                               self.config.save_path,
                                               'best.ckpt')))
                                with open(path_join(
                                        self.config.save_path, 'best.pkl'),
                                        'wb') as f:
                                    best_tuple = (miss_rate, false_rate)
                                    pickle.dump(best_tuple, f)
                            if miss_rate + false_trigger < best_threshold:
                                best_count += 1
                                print('best_count', best_count)
                                best_list.append(
                                    (miss_rate, false_trigger, step,
                                     best_count))
                                saver.save(sess,
                                           save_path=(path_join(
                                               self.config.save_path,
                                               'best' + str(
                                                   best_count) + '.ckpt')))

                    print(
                        'training finished, total epoch %d, the model will be save in %s' % (
                            self.epoch, self.config.save_path))
                    saver.save(sess, save_path=(
                        path_join(self.config.save_path, 'latest.ckpt')))
                    print('best accuracy:%f' % (best_accuracy))

                except tf.errors.OutOfRangeError:
                    print('Done training -- epoch limit reached')
                except Exception as e:
                    print(e)
                    traceback.print_exc()
                finally:
                    with open('best_list.pkl', 'wb') as f:
                        pickle.dump(best_list, f)
                    print('total time:%f hours' % (
                        (time.time() - st_time) / 3600))
                    # When done, ask the threads to stop.

            else:
                total = 0
                wrong = 0

                valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size

                for i in range(valid_batch):
                    # if i > 7:
                    #     break
                    ind = 14
                    logits, labels, _, _ = sess.run(
                        [self.valid_model.outputs,
                         self.valid_model.labels,
                         self.valid_model.stage_op,
                         self.valid_model.input_filequeue_enqueue_op])
                    np.set_printoptions(precision=4,
                                        threshold=np.inf,
                                        suppress=True)

                    total_count, wrong_count = frame_accurcacy(
                        logits, labels)
                    total += total_count
                    wrong += wrong_count

                accruracy = 1 - wrong / wrong

                # miss_rate = miss_count / target_count
                # false_accept_rate = false_count / total_count
                print('--------------------------------')
                print('accurcay: %f' % (accruracy))
out_dir = args.out_dir_path
out_dir = out_dir + "-" + model_type

U.mkdir_p(out_dir)
U.mkdir_p(out_dir + '/data')
U.mkdir_p(out_dir + '/preds')
U.mkdir_p(out_dir + '/models')
U.mkdir_p(out_dir + '/models/best_weights')
U.set_logger(out_dir)
U.print_args(args)

if args.is_test and args.test_path == None:
    logger.error("Please enter the path to the file for testing!")
    exit()

train, x, y = R.read_dataset(args.train_path, model=model_type)

# If it is a test
if args.is_test and args.test_path:
    test, test_x, test_y = R.read_dataset(args.test_path, model=model_type)
    test_x = np.array(test_x)
    test_y = np.array(test_y)

    train_x, train_y, dev_x, dev_y = H.getDevFoldDrawCards(x, y)
    logger.info("================ Testing ================")
    accuracy = M.run_model(train_x,
                           train_y,
                           dev_x,
                           dev_y,
                           test_x,
                           test_y,