示例#1
0
def predict(token_vocab, target_vocab, sent):
    # load trained model
    model_fn = os.path.join(os.path.dirname(__file__), 'model.pt')
    model = torch.load(model_fn)

    batch_size = 1
    num_steps = model.num_steps

    in_sent = '{}\t{}'.format('___DUMMY_CLASS___', sent)

    # to make raw sentence to id data easily
    pred_data = N21TextData(in_sent, mode='sentence')
    pred_id_data = N21Converter.convert(pred_data, target_vocab, token_vocab)
    pred_data_set = SentimentDataset(pred_id_data, batch_size, num_steps)

    #
    a_batch_data = next(pred_data_set.predict_iterator)  # a result
    b_sentiment_id, b_token_ids, b_weight = a_batch_data

    # convert numpy to pytorch
    b_sentiment_id = Variable(torch.from_numpy(b_sentiment_id))
    b_token_ids = Variable(torch.from_numpy(b_token_ids))

    # feed-forward (inference)
    model.batch_size = batch_size
    model.initial_state = model.reset_initial_state()
    logit = model(b_token_ids)
    np_probs = F.softmax(logit).data[0].numpy()  # to get probability

    # torch to numpy
    best_index = np_probs.argmax()
    best_prob = np_probs[best_index]
    best_target_class = target_vocab.get_symbol(best_index)

    print("[Prediction] {} ===> {}".format(sent, best_target_class))
示例#2
0
def predict(token_vocab, target_vocab, sent):
    os.environ[
        'CUDA_VISIBLE_DEVICES'] = '-1'  # force to use cpu only (prediction)
    model_dir = "./trained_models"

    # prepare sentence converting
    # to make raw sentence to id data easily
    in_sent = '{}\t{}'.format('___DUMMY_CLASS___', sent)
    pred_data = N21TextData(in_sent, mode='sentence')
    pred_id_data = N21Converter.convert(pred_data, target_vocab, token_vocab)
    pred_data_set = SentimentDataset(pred_id_data, 1, 128)

    #
    a_batch_data = next(pred_data_set.predict_iterator)  # a result
    b_sentiment_id, b_token_ids, b_weight = a_batch_data

    # Restore graph
    # note that frozen_graph.tf.pb contains graph definition with parameter values in binary format
    _graph_fn = os.path.join(model_dir, 'frozen_graph.tf.pb')
    with tf.gfile.GFile(_graph_fn, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    with tf.Session(graph=graph) as sess:
        # to check load graph
        #for n in tf.get_default_graph().as_graph_def().node: print(n.name)

        # make interface for input
        pl_token = graph.get_tensor_by_name('import/model/pl_tokens:0')
        pl_keep_prob = graph.get_tensor_by_name('import/model/pl_keep_prob:0')

        # make interface for output
        out_pred = graph.get_tensor_by_name('import/model/out_pred:0')
        out_probs = graph.get_tensor_by_name('import/model/out_probs:0')

        # predict sentence
        b_best_pred_index, b_pred_probs = sess.run([out_pred, out_probs],
                                                   feed_dict={
                                                       pl_token: b_token_ids,
                                                       pl_keep_prob: 1.0,
                                                   })

        best_pred_index = b_best_pred_index[0]
        pred_probs = b_pred_probs[0]

        best_target_class = target_vocab.get_symbol(best_pred_index)
        print(pred_probs[best_pred_index])
        best_prob = int(pred_probs[best_pred_index])
        print(best_target_class, best_prob)
示例#3
0
def load_data():

    # vocab loader
    token_vocab_fn = os.path.join(os.path.dirname(__file__), 'data',
                                  'token.vocab.txt')
    token_vocab = Vocab(token_vocab_fn, mode='token')
    target_vocab_fn = os.path.join(os.path.dirname(__file__), 'data',
                                   'target.vocab.txt')
    target_vocab = Vocab(target_vocab_fn, mode='target')

    # load train data
    #train_data_fn  = os.path.join( os.path.dirname(__file__), 'data', 'train.sent_data.txt')
    train_data_fn = os.path.join(os.path.dirname(__file__), 'data',
                                 'small.train.sent_data.txt')
    train_txt_data = N21TextData(train_data_fn)

    # convert text data to id data
    train_id_data = N21Converter.convert(train_txt_data, target_vocab,
                                         token_vocab)

    return train_id_data, token_vocab, target_vocab