Ejemplo n.º 1
0
def main(unused_argv):
  if not FLAGS.data_dir:
    raise ValueError("--data_dir is required.")
  if not FLAGS.model_config:
    raise ValueError("--model_config is required.")


  encoder = encoder_manager.EncoderManager()

  with open(FLAGS.model_config) as json_config_file:
    model_config = json.load(json_config_file)

  if type(model_config) is dict:
    model_config = [model_config]

  for mdl_cfg in model_config:
    model_config = configuration.model_config(mdl_cfg, mode="encode")
    encoder.load_model(model_config)

  if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
    results = eval_classification.eval_nested_kfold(
        encoder, FLAGS.eval_task, FLAGS.data_dir, use_nb=False)
    scores = results[0]
    print('Mean score', np.mean(scores))
  elif FLAGS.eval_task == "SICK":
    results = eval_sick.evaluate(encoder, evaltest=True, loc=FLAGS.data_dir)
  elif FLAGS.eval_task == "MSRP":
    results = eval_msrp.evaluate(
        encoder, evalcv=True, evaltest=True, use_feats=False, loc=FLAGS.data_dir)
  elif FLAGS.eval_task == "TREC":
    eval_trec.evaluate(encoder, evalcv=True, evaltest=True, loc=FLAGS.data_dir)
  else:
    raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)

  encoder.close()
Ejemplo n.º 2
0
def load_model(FLAGS):
    if FLAGS.sr_model == 'IS':
        #Load InferSent
        MODEL_PATH = os.path.join(FLAGS.is_dir, 'encoder/infersent1.pkl')

        params_model = {
            'bsize': 64,
            'word_emb_dim': 300,
            'enc_lstm_dim': 2048,
            'pool_type': 'max',
            'dpout_model': 0.0,
            'version': 1
        }
        model = InferSent(params_model)
        model.load_state_dict(torch.load(MODEL_PATH))
        W2V_PATH = os.path.join(FLAGS.is_dir,
                                'dataset/GloVe/glove.840B.300d.txt')
        model.set_w2v_path(W2V_PATH)
    elif FLAGS.sr_model == 'QT':
        # Load Quick-Thought
        model = encoder_manager.EncoderManager()

        with open(FLAGS.model_config) as json_config_file:
            model_config = json.load(json_config_file)
        if type(model_config) is dict:
            model_config = [model_config]

        for mdl_cfg in model_config:
            model_config = configuration.model_config(mdl_cfg, mode='encode')
            model.load_model(model_config)
    elif FLAGS.sr_model == 'USE':
        model = hub.Module(
            'https://tfhub.dev/google/universal-sentence-encoder-large/2')

    return model
Ejemplo n.º 3
0
    def load_model(self, FLAGS):
        tf.logging.set_verbosity(tf.logging.INFO)

        model = encoder_manager.EncoderManager()

        with open(FLAGS.model_config) as json_config_file:
            model_config = json.load(json_config_file)
        if type(model_config) is dict:
            model_config = [model_config]

        for mdl_cfg in model_config:
            model_config = configuration.model_config(mdl_cfg, mode="encode")
            model.load_model(model_config)

        return model
Ejemplo n.º 4
0
def main(unused_argv):
    if not FLAGS.data_dir:
        raise ValueError("--data_dir is required.")

    encoder = encoder_manager.EncoderManager()

    # Maybe load unidirectional encoder.
    if FLAGS.uni_checkpoint_path:
        print("Loading unidirectional model...")
        uni_config = configuration.model_config()
        encoder.load_model(uni_config, FLAGS.uni_vocab_file,
                           FLAGS.uni_embeddings_file,
                           FLAGS.uni_checkpoint_path)

    # Maybe load bidirectional encoder.
    if FLAGS.bi_checkpoint_path:
        print("Loading bidirectional model...")
        bi_config = configuration.model_config(bidirectional_encoder=True)
        encoder.load_model(bi_config, FLAGS.bi_vocab_file,
                           FLAGS.bi_embeddings_file, FLAGS.bi_checkpoint_path)

    if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
        # we changed from skip-thoughts to ec.
        ec.eval_nested_kfold(encoder,
                             FLAGS.eval_task,
                             FLAGS.data_dir,
                             use_nb=False)
    elif FLAGS.eval_task == "SICK":
        eval_sick.evaluate(encoder, evaltest=True, loc=FLAGS.data_dir)
    elif FLAGS.eval_task == "MSRP":
        eval_msrp.evaluate(encoder,
                           evalcv=True,
                           evaltest=True,
                           use_feats=True,
                           loc=FLAGS.data_dir)
    elif FLAGS.eval_task == "TREC":
        eval_trec.evaluate(encoder,
                           evalcv=True,
                           evaltest=True,
                           loc=FLAGS.data_dir)
    else:
        raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)

    encoder.close()
Ejemplo n.º 5
0
from TfidfVocabCreator import Tfidf_Creator as TFIDF
import configuration
import encoder_manager
import numpy as np
import document as doc

# Set paths to the model.
VOCAB_FILE = "./data/vocab.txt"
EMBEDDING_MATRIX_FILE = 'use_trained_model'
CHECKPOINT_PATH = "./model/train/model.ckpt-22638"
TEXT_FILE = "/home/shmuelfeld/Desktop/inputFiles_Heb/*.txt"

data = []
tfidf = TFIDF()
tfidf_dict = tfidf.get_tfidf_dic(TEXT_FILE)
encoder = encoder_manager.EncoderManager()
encoder.load_model(configuration.model_config(),
                   vocabulary_file=VOCAB_FILE,
                   embedding_matrix_file=EMBEDDING_MATRIX_FILE,
                   checkpoint_path=CHECKPOINT_PATH)


def sentence_to_vec(sentence):
    sen = list()
    sen.append(sentence)
    return encoder.encode(sen)


def sens2vec(list_of_sentences, total_tfidf):
    multed = []
    for sent in list_of_sentences:
Ejemplo n.º 6
0
def main(unused_argv):
    if not FLAGS.data_dir:
        raise ValueError("--data_dir is required.")
    if not FLAGS.model_config:
        raise ValueError("--model_config is required.")

    encoder = encoder_manager.EncoderManager()

    with open(FLAGS.model_config) as json_config_file:
        model_config = json.load(json_config_file)

    if type(model_config) is dict:
        model_config = [model_config]

    sp = None
    if FLAGS.sentencepiece_model_path:
        print('Loading sentencepiece model', FLAGS.sentencepiece_model_path)
        sp = spm.SentencePieceProcessor()
        sp.Load(FLAGS.sentencepiece_model_path)

    for mdl_cfg in model_config:
        model_config = configuration.model_config(mdl_cfg, mode="encode")
        encoder.load_model(model_config)

    if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
        results = eval_classification.eval_nested_kfold(encoder,
                                                        FLAGS.eval_task,
                                                        FLAGS.data_dir,
                                                        use_nb=False)
        scores = results[0]
        print('Mean score', np.mean(scores))
    elif FLAGS.eval_task == "SICK":
        results = eval_sick.evaluate(encoder,
                                     evaltest=True,
                                     loc=FLAGS.data_dir,
                                     sp=sp)
    elif FLAGS.eval_task == "MSRP":
        results = eval_msrp.evaluate(encoder,
                                     evalcv=True,
                                     evaltest=True,
                                     use_feats=False,
                                     loc=FLAGS.data_dir)
    elif FLAGS.eval_task == "TREC":
        eval_trec.evaluate(encoder,
                           evalcv=True,
                           evaltest=True,
                           loc=FLAGS.data_dir)
    elif FLAGS.eval_task == 'SNLI-MT-TR':
        file_meta_data = {
            'file_names': {
                'train': 'snli_train_translation.jsonl',
                'dev': 'snli_dev_translation.jsonl',
                'test': 'snli_test_translation.jsonl'
            },
            'sentence_keys': {
                'sentence1': 'translate-sentence1',
                'sentence2': 'translate-sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral']
        }
        eval_nli.evaluate(encoder,
                          evaltest=True,
                          loc=FLAGS.data_dir,
                          file_meta_data=file_meta_data,
                          sp=sp)

    elif FLAGS.eval_task == 'SNLI':
        file_meta_data = {
            'file_names': {
                'train': 'snli_1.0_train.jsonl',
                'dev': 'snli_1.0_dev.jsonl',
                'test': 'snli_1.0_test.jsonl'
            },
            'sentence_keys': {
                'sentence1': 'sentence1',
                'sentence2': 'sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral']
        }
        eval_nli.evaluate(encoder,
                          evaltest=True,
                          loc=FLAGS.data_dir,
                          file_meta_data=file_meta_data,
                          sp=sp)

    elif FLAGS.eval_task == 'MULTINLI-MT-TR-MATCHED':
        file_meta_data = {
            'file_names': {
                'train':
                'multinli_train_translation.jsonl',
                'dev':
                'multinli_dev_matched_translation.jsonl',
                'test':
                'multinli_0.9_test_matched_translation_unlabeled.jsonl',
                'test_output':
                'multinli_0.9_test_matched_translation_unlabeled_output.csv'
            },
            'sentence_keys': {
                'sentence1': 'translate-sentence1',
                'sentence2': 'translate-sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral']
        }
        eval_multinli.evaluate(encoder,
                               evaltest=True,
                               loc=FLAGS.data_dir,
                               file_meta_data=file_meta_data,
                               sp=sp)

    elif FLAGS.eval_task == 'MULTINLI-MATCHED':
        file_meta_data = {
            'file_names': {
                'train': 'multinli_1.0_train.jsonl',
                'dev': 'multinli_1.0_dev_matched.jsonl',
                'test': 'multinli_0.9_test_matched_unlabeled.jsonl',
                'test_output': 'multinli_0.9_test_matched_unlabeled_output.csv'
            },
            'sentence_keys': {
                'sentence1': 'sentence1',
                'sentence2': 'sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral']
        }
        eval_multinli.evaluate(encoder,
                               evaltest=True,
                               loc=FLAGS.data_dir,
                               file_meta_data=file_meta_data,
                               sp=sp)
    elif FLAGS.eval_task == 'MULTINLI-MT-TR-MISMATCHED':
        file_meta_data = {
            'file_names': {
                'train':
                'multinli_train_translation.jsonl',
                'dev':
                'multinli_dev_mismatched_translation.jsonl',
                'test':
                'multinli_0.9_test_mismatched_translation_unlabeled.jsonl',
                'test_output':
                'multinli_0.9_test_mismatched_translation_unlabeled_output.csv',
            },
            'sentence_keys': {
                'sentence1': 'translate-sentence1',
                'sentence2': 'translate-sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral']
        }
        eval_multinli.evaluate(encoder,
                               evaltest=True,
                               loc=FLAGS.data_dir,
                               file_meta_data=file_meta_data,
                               sp=sp)
    elif FLAGS.eval_task == 'MULTINLI-MISMATCHED':
        file_meta_data = {
            'file_names': {
                'train': 'multinli_1.0_train.jsonl',
                'dev': 'multinli_1.0_dev_mismatched.jsonl',
                'test': 'multinli_0.9_test_mismatched_unlabeled.jsonl',
                'test_output':
                'multinli_0.9_test_mismatched_unlabeled_output.csv'
            },
            'sentence_keys': {
                'sentence1': 'sentence1',
                'sentence2': 'sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral']
        }
        eval_multinli.evaluate(encoder,
                               evaltest=True,
                               loc=FLAGS.data_dir,
                               file_meta_data=file_meta_data,
                               sp=sp)
    elif FLAGS.eval_task == 'XNLI-MT-TR':
        file_meta_data = {
            'file_names': {
                'train': 'multinli_train_translation.jsonl',
                'dev': 'xnli_dev_translation.jsonl',
                'test': 'xnli_test_translation.jsonl'
            },
            'sentence_keys': {
                'sentence1': 'translate-sentence1',
                'sentence2': 'translate-sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral'],
            'language': 'any'
        }
        eval_xnli.evaluate(encoder,
                           evaltest=True,
                           loc=FLAGS.data_dir,
                           file_meta_data=file_meta_data,
                           sp=sp)
    elif FLAGS.eval_task == 'XNLI':
        file_meta_data = {
            'file_names': {
                'train': 'multinli_1.0_train.jsonl',
                'dev': 'xnli.dev.jsonl',
                'test': 'xnli.test.jsonl'
            },
            'sentence_keys': {
                'sentence1': 'sentence1',
                'sentence2': 'sentence2'
            },
            'label_classes': ['contradiction', 'entailment', 'neutral'],
            'language': 'en'
        }
        eval_xnli.evaluate(encoder,
                           evaltest=True,
                           loc=FLAGS.data_dir,
                           file_meta_data=file_meta_data,
                           sp=sp)

    else:
        raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)

    encoder.close()