Esempio n. 1
0
def get_predictions(model, data_reader, data_path):
    predictor = TextClassifierPredictor(model=model,
                                        dataset_reader=data_reader)
    data = list(data_reader.read(data_path))

    size = len(data)
    bound = 4000
    preds = []

    if size > bound:
        times = int(size / bound)
        print(f"Set is too big; total size: {size}. "
              f"Batching {times} times.")

        for i in range(times):
            print(f"Lower: {bound*i}, Upper: {bound*(i + 1)}")
            preds += predictor.predict_batch_instance(data[bound * i:bound *
                                                           (i + 1)])

        if (size - (bound * times)) > 0:
            print(f"Lower: {bound*times}, Upper: {size}")
            preds += predictor.predict_batch_instance(data[bound * times:])
    else:
        preds = predictor.predict_batch_instance(data)

    labelmap = predictor._model.vocab.get_index_to_token_vocabulary('labels')

    predictions = [labelmap[np.argmax(lst['probs'])] for lst in preds]
    actuals = [str(i['label'].label) for i in data]
    labels = list(labelmap.values())

    return actuals, predictions, labels
    if simple_lstm:
        EMBEDDING_DIM = 128
        HIDDEN_DIM = 128
        reader = StanfordSentimentTreeBankDatasetReader()
        train_dataset = reader.read('data/stanfordSentimentTreebank/trees/train.txt')
        dev_dataset = reader.read('data/stanfordSentimentTreebank/trees/dev.txt')
        test_dataset = reader.read('data/stanfordSentimentTreebank/trees/test.txt')
        vocab = Vocabulary.from_instances(train_dataset + dev_dataset, min_count={'tokens': 3})
        token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=EMBEDDING_DIM)
        word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
        lstm = PytorchSeq2VecWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
        model = LstmClassifier(word_embeddings, lstm, vocab)
        with open("models/simple_LSTM_sentiment_classifier.th", 'rb') as f:
            model.load_state_dict(torch.load(f))
        predictor = TextClassifierPredictor(model, dataset_reader=reader)
        test_results = predictor.predict_batch_instance(test_dataset)

    # ELMo LSTM
    if elmo_lstm:
        elmo_embedding_dim = 256
        HIDDEN_DIM = 128
        elmo_token_indexer = ELMoTokenCharactersIndexer()
        reader = StanfordSentimentTreeBankDatasetReader(token_indexers={'tokens': elmo_token_indexer})
        train_dataset = reader.read('data/stanfordSentimentTreebank/trees/train.txt')
        dev_dataset = reader.read('data/stanfordSentimentTreebank/trees/dev.txt')
        test_dataset = reader.read('data/stanfordSentimentTreebank/trees/test.txt')
        vocab = Vocabulary.from_instances(train_dataset + dev_dataset, min_count={'tokens': 3})
        options_file = 'data/elmo/elmo_2x1024_128_2048cnn_1xhighway_options.json'
        weight_file = 'data/elmo/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
        elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)
        word_embeddings = BasicTextFieldEmbedder({"tokens": elmo_embedder})