def load_data(self, args, mode, type): if mode == 'train': return load_batched_data(args.train_mfcc_dir, args.train_label_dir, args.batch_size, mode, type) elif mode == 'test': args.batch_size = args.test_batch_size return load_batched_data(args.test_mfcc_dir, args.test_label_dir, args.test_batch_size, mode, type) else: raise TypeError('mode should be train or test.')
def load_data(self, feature_dir, label_dir, mode, level): return load_batched_data(feature_dir, label_dir, batch_size, mode, level)
def load_data(self, X, labels, batchSize, mode, level): return load_batched_data(X, labels, batchSize, mode, level)
def run(self): # load data args_dict = self._default_configs() args = dotdict(args_dict) X, labels = get_data(level, train_dataset, test_dataset, mode) totalN = len(X) print("X :", len(X)) num_batches = len(X) / batch_size maxLength = 0 for x in X: maxLength = max(maxLength, x.shape[1]) if (maxLength % 5000 != 0): maxLength = maxLength + 5000 - maxLength % 5000 #batchedData, maxTimeSteps, totalN = self.load_data(X,labels,batch_size,mode,level) maxTimeSteps = maxLength model = model_fn(args, maxTimeSteps) model.build_graph(args, maxTimeSteps) #print("hello") #num_params = count_params(model, mode='trainable') #all_num_params = count_params(model, mode='all') #model.config['trainable params'] = num_params #model.config['all params'] = all_num_params print(model.config) with tf.Session(graph=model.graph) as sess: # restore from stored model if keep == True: ckpt = tf.train.get_checkpoint_state(savedir) if ckpt and ckpt.model_checkpoint_path: model.saver.restore(sess, ckpt.model_checkpoint_path) print('Model restored from:' + savedir) else: sess.run(model.initial_op) else: print('Initializing') sess.run(model.initial_op) if (mode == 'train'): writer = tf.summary.FileWriter("loggingdir", graph=model.graph) for epoch in range(num_epochs): # training start = time.time() print('Epoch {} ...'.format(epoch + 1)) batchErrors = np.zeros(num_batches) batchRandIxs = np.random.permutation(num_batches) for batch, batchOrigI in enumerate(batchRandIxs): batchInputs, batchTargetSparse, batchSeqLengths = next( load_batched_data(X, labels, batch_size, mode, level)) batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse feedDict = { model.inputX: batchInputs, model.targetIxs: batchTargetIxs, model.targetVals: batchTargetVals, model.targetShape: batchTargetShape, model.seqLengths: batchSeqLengths } if level == 'cha': _, l, pre, y, er, summary = sess.run( [ model.optimizer, model.loss, model.predictions, model.targetY, model.errorRate, model.summary_op ], feed_dict=feedDict) writer.add_summary(summary, epoch * num_batches + batch) batchErrors[batch] = er print( '\n{} mode, batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n' .format(level, batch + 1, len(batchRandIxs), epoch + 1, num_epochs, l, er / batch_size)) if (batch + 1) % 10 == 0: print('Truth:\n' + output_to_sequence(y, type='phn')) print('Output:\n' + output_to_sequence(pre, type='phn')) if ((epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or (epoch == num_epochs - 1 and batch == len(batchRandIxs) - 1)): checkpoint_path = os.path.join( savedir, 'model.ckpt') model.saver.save(sess, checkpoint_path, global_step=epoch) print('Model has been saved in {}'.format(savedir)) end = time.time() delta_time = end - start print('Epoch ' + str(epoch + 1) + ' needs time:' + str(delta_time) + ' s') if (epoch + 1) % 1 == 0: checkpoint_path = os.path.join(savedir, 'model.ckpt') model.saver.save(sess, checkpoint_path, global_step=epoch) print('Model has been saved in {}'.format(savedir)) epochER = batchErrors.sum() / totalN print('Epoch', epoch + 1, 'mean train error rate:', epochER) logging(model, logfile, epochER, epoch, delta_time, mode='config') logging(model, logfile, epochER, epoch, delta_time, mode=mode) elif (mode == 'test'): for data in load_batched_data(X, labels, batch_size, mode, level): batchInputs, batchTargetSparse, batchSeqLengths = data batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse feedDict = { model.inputX: batchInputs, model.targetIxs: batchTargetIxs, model.targetVals: batchTargetVals, model.targetShape: batchTargetShape, model.seqLengths: batchSeqLengths } _, l, pre, y, er = sess.run([ model.optimizer, model.loss, model.predictions, model.targetY, model.errorRate ], feed_dict=feedDict) with open(os.path.join(resultdir, level + '_result.txt'), 'a') as result: result.write(output_to_sequence(y, type='phn') + '\n') result.write( output_to_sequence(pre, type='phn') + '\n') result.write('\n') #epochER = batchErrors.sum() / totalN print(' test error rate:', epochER) logging(model, logfile, epochER, mode=mode)