Ejemplo n.º 1
0
def test_model():
    vocab, embeddings = data_helper.load_embeddings(config.get('data', 'embedding_file'))
    model = RNNModel(embeddings, num_classes=5)
    model.load(config.get('data', 'model_dir'))
    test_data = data_helper.load_data(os.path.join(config.get('data', 'treebank_dir'), 'test.txt'))
    numeric_test_samples = data_helper.convert_to_numeric_samples(test_data, vocab, num_classes=5)
    model.eval(numeric_test_samples)
Ejemplo n.º 2
0
class EnsembleModel(_ModelBase):
    def __init__(self):
        super().__init__()
        self.mfcc_rnn = RNNModel()
        self.pitch_clf01 = PitchModel([0, 1])
        self.pitch_clf67 = PitchModel([6, 7])

        self.mfcc_rnn.load()
        self.pitch_clf01.load()
        self.pitch_clf67.load()

        self.mfcc_rnn.clf.eval()

    def test(self):
        dev_data = self.reader.mini_batch_iterator(self.reader.val_person)
        y, pred = [], []
        for itr, total_iter, feat, label, files in dev_data:
            pred_, prob_ = self.mfcc_rnn.test_iter(itr, total_iter, feat,
                                                   label, files)
            for i, p in enumerate(pred_):
                if p in [0, 1] and prob_[i][p] < 0.8:
                    pred_[i] = self.pitch_clf01.test_iter(*feat[i])
                if p in [6, 7] and prob_[i][p] < 0.7:
                    pred_[i] = self.pitch_clf67.test_iter(*feat[i])
            y.extend(label)
            pred.extend(pred_)

            printer.info('%d/%d' % (itr, total_iter))
            # for i,_ in enumerate(pred_):
            #    if pred_[i] != label[i]:
            #       logger.info(files[i])
            # if itr > 1000: break
        acc = accuracy_score(y, pred)
        printer.info(acc)
        cm = confusion_matrix(y, pred)
        pickle.dump(cm, open('models/cm.pkl', 'wb'))
        print(cm)
        return acc

    def interactive(self):
        test_data = self.reader.new_file_detect_iterator()
        y, pred = [], []
        for itr, total_iter, feat, label, files in test_data:
            pred_, prob_ = self.mfcc_rnn.test_iter(itr, total_iter, feat,
                                                   label, files)
            for i, p in enumerate(pred_):
                if p in [0, 1] and prob_[i][p] < 0.8:
                    pred_[i] = self.pitch_clf01.test_iter(*feat[i])
                if p in [6, 7] and prob_[i][p] < 0.7:
                    pred_[i] = self.pitch_clf67.test_iter(*feat[i])
            print('***Prediction***\n%s\n' % output_dict[pred_[0]])
            print('***Confidence***\n%s\n' % str(prob_[0][pred_[0]]))
Ejemplo n.º 3
0
        exit(0)

    tf.gfile.MakeDirs(config.summary_dir)
    tf.gfile.MakeDirs(config.checkpoint_dir)

    model = RNNModel(config)

    with tf.Session() as sess:

        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            tf.tables_initializer()
        ])

        model.load(sess)
        trainer = Trainer(sess, model, config)

        if args.train:
            trainer.train()

        if args.test:
            for line in sys.stdin:
                try:
                    data_x = np.array([brackets.parse_string(line)])
                    data_y = np.array([brackets.statistics(b) for b in data_x])
                    _, preds, _, _ = trainer.predict(data_x, data_y)

                    for x, y, p in zip(data_x, data_y, preds):
                        print(x, ',', *y, ',', *p)
                except: