Example #1
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        batchedData, maxTimeSteps, totalN = self.load_data(args,
                                                           mode=mode,
                                                           type=level)
        model = model_fn(args, maxTimeSteps)

        # count the num of params
        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        print(model.config)

        with tf.Session(graph=model.graph) as sess:
            # restore from stored model
            if keep == True:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            for epoch in range(num_epochs):
                ## training
                start = time.time()
                if mode == 'train':
                    print('Epoch', epoch + 1, '...')
                batchErrors = np.zeros(len(batchedData))
                batchRandIxs = np.random.permutation(len(batchedData))

                for batch, batchOrigI in enumerate(batchRandIxs):
                    batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                        batchOrigI]
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {
                        model.inputX: batchInputs,
                        model.targetIxs: batchTargetIxs,
                        model.targetVals: batchTargetVals,
                        model.targetShape: batchTargetShape,
                        model.seqLengths: batchSeqLengths
                    }

                    if level == 'cha':
                        if mode == 'train':
                            _, l, pre, y, er = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY, model.errorRate
                            ],
                                                        feed_dict=feedDict)

                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er / batch_size))

                        elif mode == 'test':
                            l, pre, y, er = sess.run([
                                model.loss, model.predictions, model.targetY,
                                model.errorRate
                            ],
                                                     feed_dict=feedDict)
                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er / batch_size))

                    elif level == 'phn':
                        if mode == 'train':
                            _, l, pre, y = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY
                            ],
                                                    feed_dict=feedDict)

                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)
                        elif mode == 'test':
                            l, pre, y = sess.run(
                                [model.loss, model.predictions, model.targetY],
                                feed_dict=feedDict)
                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)

                    # NOTE:
                    if er / batch_size == 1.0:
                        break

                    if batch % 30 == 0:
                        print('Truth:\n' + output_to_sequence(y, type=level))
                        print('Output:\n' +
                              output_to_sequence(pre, type=level))

                    if mode == 'train' and (
                        (epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or
                        (epoch == num_epochs - 1
                         and batch == len(batchRandIxs) - 1)):
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' +
                      str(delta_time) + ' s')

                if mode == 'train':
                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epochER = batchErrors.sum() / totalN
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epochER)
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode='config')
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode=mode)

                if mode == 'test':
                    with open(os.path.join(resultdir, level + '_result.txt'),
                              'a') as result:
                        result.write(output_to_sequence(y, type=level) + '\n')
                        result.write(
                            output_to_sequence(pre, type=level) + '\n')
                        result.write('\n')
                    epochER = batchErrors.sum() / totalN
                    print(' test error rate:', epochER)
                    logging(model, logfile, epochER, mode=mode)
Example #2
0
    def test(self):
        # load data
        args = self.args
        batchedData, maxTimeSteps, totalN = self.load_data(args,
                                                           mode='test',
                                                           type=args.level)
        if args.model == 'ResNet':
            model = ResNet(args, maxTimeSteps)
        elif args.model == 'BiRNN':
            model = BiRNN(args, maxTimeSteps)
        elif args.model == 'DBiRNN':
            model = DBiRNN(args, maxTimeSteps)

        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        with tf.Session(graph=model.graph) as sess:
            ckpt = tf.train.get_checkpoint_state(args.save_dir)
            if ckpt and ckpt.model_checkpoint_path:
                model.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Model restored from:' + args.save_dir)

            batchErrors = np.zeros(len(batchedData))
            batchRandIxs = np.random.permutation(len(batchedData))
            for batch, batchOrigI in enumerate(batchRandIxs):
                batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                    batchOrigI]
                batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                feedDict = {
                    model.inputX: batchInputs,
                    model.targetIxs: batchTargetIxs,
                    model.targetVals: batchTargetVals,
                    model.targetShape: batchTargetShape,
                    model.seqLengths: batchSeqLengths
                }

                if args.level == 'cha':
                    l, pre, y, er = sess.run([
                        model.loss, model.predictions, model.targetY,
                        model.errorRate
                    ],
                                             feed_dict=feedDict)
                    batchErrors[batch] = er
                    print(
                        '\ntotal:{},batch:{}/{},loss={:.3f},mean CER={:.3f}\n'.
                        format(totalN, batch + 1, len(batchRandIxs), l,
                               er / args.batch_size))

                elif args.level == 'phn':
                    l, pre, y = sess.run(
                        [model.loss, model.predictions, model.targetY],
                        feed_dict=feedDict)
                    er = get_edit_distance([pre.values], [y.values], True,
                                           'test', args.level)
                    print(
                        '\ntotal:{},batch:{}/{},loss={:.3f},mean PER={:.3f}\n'.
                        format(totalN, batch + 1, len(batchRandIxs), l,
                               er / args.batch_size))
                    batchErrors[batch] = er * len(batchSeqLengths)

                print('Truth:\n' + output_to_sequence(y, type=args.level))
                print('Output:\n' + output_to_sequence(pre, type=args.level))
                '''
                l, pre, y = sess.run([ model.loss,
					    model.predictions,
					    model.targetY],
				            feed_dict=feedDict)


		er = get_edit_distance([pre.values], [y.values], True, 'test', args.level)
	    	print(output_to_sequence(y,type=args.level))
	    	print(output_to_sequence(pre,type=args.level))
		'''
                with open(args.task + '_result.txt', 'a') as result:
                    result.write(output_to_sequence(y, type=args.level) + '\n')
                    result.write(
                        output_to_sequence(pre, type=args.level) + '\n')
                    result.write('\n')
            epochER = batchErrors.sum() / totalN
            print(args.task + ' test error rate:', epochER)
            logging(model, self.logfile, epochER, mode='test')