示例#1
0
 def load_data(self, args, mode, type):
     if mode == 'train':
         return load_batched_data(args.train_mfcc_dir, args.train_label_dir, args.batch_size, mode, type)
     elif mode == 'test':
         args.batch_size = args.test_batch_size
         return load_batched_data(args.test_mfcc_dir, args.test_label_dir, args.test_batch_size, mode, type)
     else:
         raise TypeError('mode should be train or test.')
示例#2
0
 def load_data(self, feature_dir, label_dir, mode, level):
     return load_batched_data(feature_dir, label_dir, batch_size, mode,
                              level)
示例#3
0
 def load_data(self, X, labels, batchSize, mode, level):
     return load_batched_data(X, labels, batchSize, mode, level)
示例#4
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)

        X, labels = get_data(level, train_dataset, test_dataset, mode)
        totalN = len(X)
        print("X :", len(X))
        num_batches = len(X) / batch_size
        maxLength = 0
        for x in X:
            maxLength = max(maxLength, x.shape[1])
        if (maxLength % 5000 != 0):
            maxLength = maxLength + 5000 - maxLength % 5000

        #batchedData, maxTimeSteps, totalN = self.load_data(X,labels,batch_size,mode,level)
        maxTimeSteps = maxLength
        model = model_fn(args, maxTimeSteps)
        model.build_graph(args, maxTimeSteps)
        #print("hello")
        #num_params = count_params(model, mode='trainable')
        #all_num_params = count_params(model, mode='all')
        #model.config['trainable params'] = num_params
        #model.config['all params'] = all_num_params
        print(model.config)

        with tf.Session(graph=model.graph) as sess:
            # restore from stored model
            if keep == True:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
                else:
                    sess.run(model.initial_op)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            if (mode == 'train'):
                writer = tf.summary.FileWriter("loggingdir", graph=model.graph)
                for epoch in range(num_epochs):
                    # training
                    start = time.time()
                    print('Epoch {} ...'.format(epoch + 1))

                    batchErrors = np.zeros(num_batches)
                    batchRandIxs = np.random.permutation(num_batches)
                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = next(
                            load_batched_data(X, labels, batch_size, mode,
                                              level))
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }
                        if level == 'cha':
                            _, l, pre, y, er, summary = sess.run(
                                [
                                    model.optimizer, model.loss,
                                    model.predictions, model.targetY,
                                    model.errorRate, model.summary_op
                                ],
                                feed_dict=feedDict)
                            writer.add_summary(summary,
                                               epoch * num_batches + batch)

                            batchErrors[batch] = er
                            print(
                                '\n{} mode, batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(level, batch + 1, len(batchRandIxs),
                                        epoch + 1, num_epochs, l,
                                        er / batch_size))

                        if (batch + 1) % 10 == 0:
                            print('Truth:\n' +
                                  output_to_sequence(y, type='phn'))
                            print('Output:\n' +
                                  output_to_sequence(pre, type='phn'))

                        if ((epoch * len(batchRandIxs) + batch + 1) % 20 == 0
                                or (epoch == num_epochs - 1
                                    and batch == len(batchRandIxs) - 1)):
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' +
                          str(delta_time) + ' s')

                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epochER = batchErrors.sum() / totalN
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epochER)
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode='config')
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode=mode)

            elif (mode == 'test'):
                for data in load_batched_data(X, labels, batch_size, mode,
                                              level):
                    batchInputs, batchTargetSparse, batchSeqLengths = data
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {
                        model.inputX: batchInputs,
                        model.targetIxs: batchTargetIxs,
                        model.targetVals: batchTargetVals,
                        model.targetShape: batchTargetShape,
                        model.seqLengths: batchSeqLengths
                    }
                    _, l, pre, y, er = sess.run([
                        model.optimizer, model.loss, model.predictions,
                        model.targetY, model.errorRate
                    ],
                                                feed_dict=feedDict)
                    with open(os.path.join(resultdir, level + '_result.txt'),
                              'a') as result:
                        result.write(output_to_sequence(y, type='phn') + '\n')
                        result.write(
                            output_to_sequence(pre, type='phn') + '\n')
                        result.write('\n')
                        #epochER = batchErrors.sum() / totalN
                        print(' test error rate:', epochER)
                        logging(model, logfile, epochER, mode=mode)