def run(self): # load data args_dict = self._default_configs() args = dotdict(args_dict) batchedData, maxTimeSteps, totalN = self.load_data(args, mode=mode, type=level) model = model_fn(args, maxTimeSteps) # count the num of params num_params = count_params(model, mode='trainable') all_num_params = count_params(model, mode='all') model.config['trainable params'] = num_params model.config['all params'] = all_num_params print(model.config) with tf.Session(graph=model.graph) as sess: # restore from stored model if keep == True: ckpt = tf.train.get_checkpoint_state(savedir) if ckpt and ckpt.model_checkpoint_path: model.saver.restore(sess, ckpt.model_checkpoint_path) print('Model restored from:' + savedir) else: print('Initializing') sess.run(model.initial_op) for epoch in range(num_epochs): ## training start = time.time() if mode == 'train': print('Epoch', epoch + 1, '...') batchErrors = np.zeros(len(batchedData)) batchRandIxs = np.random.permutation(len(batchedData)) for batch, batchOrigI in enumerate(batchRandIxs): batchInputs, batchTargetSparse, batchSeqLengths = batchedData[ batchOrigI] batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse feedDict = { model.inputX: batchInputs, model.targetIxs: batchTargetIxs, model.targetVals: batchTargetVals, model.targetShape: batchTargetShape, model.seqLengths: batchSeqLengths } if level == 'cha': if mode == 'train': _, l, pre, y, er = sess.run([ model.optimizer, model.loss, model.predictions, model.targetY, model.errorRate ], feed_dict=feedDict) batchErrors[batch] = er print( '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n' .format(level, totalN, batch + 1, len(batchRandIxs), epoch + 1, num_epochs, l, er / batch_size)) elif mode == 'test': l, pre, y, er = sess.run([ model.loss, model.predictions, model.targetY, model.errorRate ], feed_dict=feedDict) batchErrors[batch] = er print( '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n' .format(level, totalN, batch + 1, len(batchRandIxs), l, er / batch_size)) elif level == 'phn': if mode == 'train': _, l, pre, y = sess.run([ model.optimizer, model.loss, model.predictions, model.targetY ], feed_dict=feedDict) er = get_edit_distance([pre.values], [y.values], True, level) print( '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train PER={:.3f}\n' .format(level, totalN, batch + 1, len(batchRandIxs), epoch + 1, num_epochs, l, er)) batchErrors[batch] = er * len(batchSeqLengths) elif mode == 'test': l, pre, y = sess.run( [model.loss, model.predictions, model.targetY], feed_dict=feedDict) er = get_edit_distance([pre.values], [y.values], True, level) print( '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test PER={:.3f}\n' .format(level, totalN, batch + 1, len(batchRandIxs), l, er)) batchErrors[batch] = er * len(batchSeqLengths) # NOTE: if er / batch_size == 1.0: break if batch % 30 == 0: print('Truth:\n' + output_to_sequence(y, type=level)) print('Output:\n' + output_to_sequence(pre, type=level)) if mode == 'train' and ( (epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or (epoch == num_epochs - 1 and batch == len(batchRandIxs) - 1)): checkpoint_path = os.path.join(savedir, 'model.ckpt') model.saver.save(sess, checkpoint_path, global_step=epoch) print('Model has been saved in {}'.format(savedir)) end = time.time() delta_time = end - start print('Epoch ' + str(epoch + 1) + ' needs time:' + str(delta_time) + ' s') if mode == 'train': if (epoch + 1) % 1 == 0: checkpoint_path = os.path.join(savedir, 'model.ckpt') model.saver.save(sess, checkpoint_path, global_step=epoch) print('Model has been saved in {}'.format(savedir)) epochER = batchErrors.sum() / totalN print('Epoch', epoch + 1, 'mean train error rate:', epochER) logging(model, logfile, epochER, epoch, delta_time, mode='config') logging(model, logfile, epochER, epoch, delta_time, mode=mode) if mode == 'test': with open(os.path.join(resultdir, level + '_result.txt'), 'a') as result: result.write(output_to_sequence(y, type=level) + '\n') result.write( output_to_sequence(pre, type=level) + '\n') result.write('\n') epochER = batchErrors.sum() / totalN print(' test error rate:', epochER) logging(model, logfile, epochER, mode=mode)
def test(self): # load data args = self.args batchedData, maxTimeSteps, totalN = self.load_data(args, mode='test', type=args.level) if args.model == 'ResNet': model = ResNet(args, maxTimeSteps) elif args.model == 'BiRNN': model = BiRNN(args, maxTimeSteps) elif args.model == 'DBiRNN': model = DBiRNN(args, maxTimeSteps) num_params = count_params(model, mode='trainable') all_num_params = count_params(model, mode='all') model.config['trainable params'] = num_params model.config['all params'] = all_num_params with tf.Session(graph=model.graph) as sess: ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: model.saver.restore(sess, ckpt.model_checkpoint_path) print('Model restored from:' + args.save_dir) batchErrors = np.zeros(len(batchedData)) batchRandIxs = np.random.permutation(len(batchedData)) for batch, batchOrigI in enumerate(batchRandIxs): batchInputs, batchTargetSparse, batchSeqLengths = batchedData[ batchOrigI] batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse feedDict = { model.inputX: batchInputs, model.targetIxs: batchTargetIxs, model.targetVals: batchTargetVals, model.targetShape: batchTargetShape, model.seqLengths: batchSeqLengths } if args.level == 'cha': l, pre, y, er = sess.run([ model.loss, model.predictions, model.targetY, model.errorRate ], feed_dict=feedDict) batchErrors[batch] = er print( '\ntotal:{},batch:{}/{},loss={:.3f},mean CER={:.3f}\n'. format(totalN, batch + 1, len(batchRandIxs), l, er / args.batch_size)) elif args.level == 'phn': l, pre, y = sess.run( [model.loss, model.predictions, model.targetY], feed_dict=feedDict) er = get_edit_distance([pre.values], [y.values], True, 'test', args.level) print( '\ntotal:{},batch:{}/{},loss={:.3f},mean PER={:.3f}\n'. format(totalN, batch + 1, len(batchRandIxs), l, er / args.batch_size)) batchErrors[batch] = er * len(batchSeqLengths) print('Truth:\n' + output_to_sequence(y, type=args.level)) print('Output:\n' + output_to_sequence(pre, type=args.level)) ''' l, pre, y = sess.run([ model.loss, model.predictions, model.targetY], feed_dict=feedDict) er = get_edit_distance([pre.values], [y.values], True, 'test', args.level) print(output_to_sequence(y,type=args.level)) print(output_to_sequence(pre,type=args.level)) ''' with open(args.task + '_result.txt', 'a') as result: result.write(output_to_sequence(y, type=args.level) + '\n') result.write( output_to_sequence(pre, type=args.level) + '\n') result.write('\n') epochER = batchErrors.sum() / totalN print(args.task + ' test error rate:', epochER) logging(model, self.logfile, epochER, mode='test')