def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        batchedData, maxTimeSteps, totalN = self.load_data(args,
                                                           mode=mode,
                                                           type=level)
        model = model_fn(args, maxTimeSteps)

        # count the num of params
        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        print(model.config)

        with tf.Session(graph=model.graph) as sess:
            # restore from stored model
            if keep == True:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            for epoch in range(num_epochs):
                ## training
                start = time.time()
                if mode == 'train':
                    print('Epoch', epoch + 1, '...')
                batchErrors = np.zeros(len(batchedData))
                batchRandIxs = np.random.permutation(len(batchedData))

                for batch, batchOrigI in enumerate(batchRandIxs):
                    batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                        batchOrigI]
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {
                        model.inputX: batchInputs,
                        model.targetIxs: batchTargetIxs,
                        model.targetVals: batchTargetVals,
                        model.targetShape: batchTargetShape,
                        model.seqLengths: batchSeqLengths
                    }

                    if level == 'cha':
                        if mode == 'train':
                            _, l, pre, y, er = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY, model.errorRate
                            ],
                                                        feed_dict=feedDict)

                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er / batch_size))

                        elif mode == 'test':
                            l, pre, y, er = sess.run([
                                model.loss, model.predictions, model.targetY,
                                model.errorRate
                            ],
                                                     feed_dict=feedDict)
                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er / batch_size))

                    elif level == 'phn':
                        if mode == 'train':
                            _, l, pre, y = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY
                            ],
                                                    feed_dict=feedDict)

                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)
                        elif mode == 'test':
                            l, pre, y = sess.run(
                                [model.loss, model.predictions, model.targetY],
                                feed_dict=feedDict)
                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)

                    # NOTE:
                    if er / batch_size == 1.0:
                        break

                    if batch % 30 == 0:
                        print('Truth:\n' + output_to_sequence(y, type=level))
                        print('Output:\n' +
                              output_to_sequence(pre, type=level))

                    if mode == 'train' and (
                        (epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or
                        (epoch == num_epochs - 1
                         and batch == len(batchRandIxs) - 1)):
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' +
                      str(delta_time) + ' s')

                if mode == 'train':
                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epochER = batchErrors.sum() / totalN
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epochER)
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode='config')
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode=mode)

                if mode == 'test':
                    with open(os.path.join(resultdir, level + '_result.txt'),
                              'a') as result:
                        result.write(output_to_sequence(y, type=level) + '\n')
                        result.write(
                            output_to_sequence(pre, type=level) + '\n')
                        result.write('\n')
                    epochER = batchErrors.sum() / totalN
                    print(' test error rate:', epochER)
                    logging(model, logfile, epochER, mode=mode)
Ejemplo n.º 2
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        feature_dirs, label_dirs = get_data(datadir, level, train_dataset,
                                            dev_dataset, test_dataset, mode)
        batchedData, maxTimeSteps, totalN = self.load_data(
            feature_dirs[0], label_dirs[0], mode, level)
        model = model_fn(args, maxTimeSteps)

        ## shuffle feature_dir and label_dir by same order
        FL_pair = list(zip(feature_dirs, label_dirs))
        random.shuffle(FL_pair)
        feature_dirs, label_dirs = zip(*FL_pair)

        for feature_dir, label_dir in zip(feature_dirs, label_dirs):
            id_dir = feature_dirs.index(feature_dir)
            print('dir id:{}'.format(id_dir))
            batchedData, maxTimeSteps, totalN = self.load_data(
                feature_dir, label_dir, mode, level)
            model = model_fn(args, maxTimeSteps)
            num_params = count_params(model, mode='trainable')
            all_num_params = count_params(model, mode='all')
            model.config['trainable params'] = num_params
            model.config['all params'] = all_num_params
            print(model.config)
            with tf.Session(graph=model.graph) as sess:
                sess.run(model.initial_op)
                # restore from stored model
                if keep == True:
                    ckpt = tf.train.get_checkpoint_state(savedir)
                    if ckpt and ckpt.model_checkpoint_path:
                        model.saver.restore(sess, ckpt.model_checkpoint_path)
                        print('Model restored from:' + savedir)
                else:
                    print('Initializing')
                    sess.run(model.initial_op)
                for epoch in range(num_epochs):
                    ## training
                    start = time.time()
                    if mode == 'train':
                        print('Epoch {} ...'.format(epoch + 1))

                    batchErrors = np.zeros(len(batchedData))
                    batchRandIxs = np.random.permutation(len(batchedData))

                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                            batchOrigI]
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }

                        if level == 'cha':
                            if mode == 'train':
                                _, l, pre, y, er = sess.run([
                                    model.optimizer, model.loss,
                                    model.predictions, model.targetY,
                                    model.errorRate
                                ],
                                                            feed_dict=feedDict)

                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                    .format(level, totalN, id_dir + 1,
                                            len(feature_dirs), batch + 1,
                                            len(batchRandIxs), epoch + 1,
                                            num_epochs, l, er / batch_size))

                            elif mode == 'dev':
                                l, pre, y, er = sess.run([
                                    model.loss, model.predictions,
                                    model.targetY, model.errorRate
                                ],
                                                         feed_dict=feedDict)
                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},dev loss={:.3f},mean dev CER={:.3f}\n'
                                    .format(level, totalN, id_dir + 1,
                                            len(feature_dirs), batch + 1,
                                            len(batchRandIxs), l,
                                            er / batch_size))

                            elif mode == 'test':
                                l, pre, y, er = sess.run([
                                    model.loss, model.predictions,
                                    model.targetY, model.errorRate
                                ],
                                                         feed_dict=feedDict)
                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                    .format(level, totalN, id_dir + 1,
                                            len(feature_dirs), batch + 1,
                                            len(batchRandIxs), l,
                                            er / batch_size))
                        elif level == 'seq2seq':
                            raise ValueError('level %s is not supported now' %
                                             str(level))

                        # NOTE:
                        if er / batch_size == 1.0:
                            break

                        if batch % 20 == 0:
                            print('Truth:\n' +
                                  output_to_sequence(y, type=level))
                            print('Output:\n' +
                                  output_to_sequence(pre, type=level))

                        if mode == 'train' and (
                            (epoch * len(batchRandIxs) + batch + 1) % 20 == 0
                                or (epoch == num_epochs - 1
                                    and batch == len(batchRandIxs) - 1)):
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' +
                          str(delta_time) + ' s')

                    if mode == 'train':
                        if (epoch + 1) % 1 == 0:
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))
                        epochER = batchErrors.sum() / totalN
                        print('Epoch', epoch + 1, 'mean train error rate:',
                              epochER)
                        logging(model,
                                logfile,
                                epochER,
                                epoch,
                                delta_time,
                                mode='config')
                        logging(model,
                                logfile,
                                epochER,
                                epoch,
                                delta_time,
                                mode=mode)

                    if mode == 'test' or mode == 'dev':
                        with open(
                                os.path.join(resultdir, level + '_result.txt'),
                                'a') as result:
                            result.write(
                                output_to_sequence(y, type=level) + '\n')
                            result.write(
                                output_to_sequence(pre, type=level) + '\n')
                            result.write('\n')
                        epochER = batchErrors.sum() / totalN
                        print(' test error rate:', epochER)
                        logging(model, logfile, epochER, mode=mode)