Exemple #1
0
    def __init__(self,
                 tokenizer: Tokenizer = None,
                 source_token_indexers: Dict[str, TokenIndexer] = None,
                 target_token_indexers: Dict[str, TokenIndexer] = None,
                 source_max_tokens: int = 400,
                 target_max_tokens: int = 100,
                 separate_namespaces: bool = False,
                 target_namespace: str = "target_tokens",
                 save_copy_fields: bool = False,
                 save_pgn_fields: bool = False) -> None:
        super().__init__(lazy=True)

        assert save_pgn_fields or save_copy_fields or (not save_pgn_fields and
                                                       not save_copy_fields)

        self._source_max_tokens = source_max_tokens
        self._target_max_tokens = target_max_tokens

        self._tokenizer = tokenizer or WordTokenizer(
            word_splitter=SimpleWordSplitter())

        tokens_indexer = {"tokens": SingleIdTokenIndexer()}
        self._source_token_indexers = source_token_indexers or tokens_indexer
        self._target_token_indexers = target_token_indexers or tokens_indexer

        self._save_copy_fields = save_copy_fields
        self._save_pgn_fields = save_pgn_fields
        self._target_namespace = "tokens"
        if separate_namespaces:
            self._target_namespace = target_namespace
            second_tokens_indexer = {
                "tokens": SingleIdTokenIndexer(namespace=target_namespace)
            }
            self._target_token_indexers = target_token_indexers or second_tokens_indexer
Exemple #2
0
class TestSimpleWordSplitter:
    word_splitter = SimpleWordSplitter()
    def test_tokenize_handles_complex_punctuation(self):
        sentence = "this (sentence) has 'crazy' \"punctuation\"."
        expected_tokens = ["this", "(", "sentence", ")", "has", "'", "crazy", "'", '"',
                           "punctuation", '"', "."]
        tokens = self.word_splitter.split_words(sentence)
        assert tokens == expected_tokens

    def test_tokenize_handles_contraction(self):
        sentence = "it ain't joe's problem; would've been yesterday"
        expected_tokens = ["it", "ai", "n't", "joe", "'s", "problem", ";", "would", "'ve", "been",
                           "yesterday"]
        tokens = self.word_splitter.split_words(sentence)
        assert tokens == expected_tokens

    def test_tokenize_handles_multiple_contraction(self):
        sentence = "wouldn't've"
        expected_tokens = ["would", "n't", "'ve"]
        tokens = self.word_splitter.split_words(sentence)
        assert tokens == expected_tokens

    def test_tokenize_handles_final_apostrophe(self):
        sentence = "the jones' house"
        expected_tokens = ["the", "jones", "'", "house"]
        tokens = self.word_splitter.split_words(sentence)
        assert tokens == expected_tokens

    def test_tokenize_handles_special_cases(self):
        sentence = "mr. and mrs. jones, etc., went to, e.g., the store"
        expected_tokens = ["mr.", "and", "mrs.", "jones", ",", "etc.", ",", "went", "to", ",",
                           "e.g.", ",", "the", "store"]
        tokens = self.word_splitter.split_words(sentence)
        assert tokens == expected_tokens
Exemple #3
0
 def __init__(
     self,
     tokenizer: Tokenizer = None,
     source_token_indexers: Dict[str, TokenIndexer] = None,
     target_token_indexers: Dict[str, TokenIndexer] = None,
     source_max_tokens: int = 400,
     target_max_tokens: int = 100,
     separate_namespaces: bool = False,
     target_namespace: str = "target_tokens",
     save_copy_fields: bool = False,
     save_pgn_fields: bool = False,
 ) -> None:
     if not tokenizer:
         tokenizer = WordTokenizer(word_splitter=SimpleWordSplitter())
     super().__init__(
         tokenizer=tokenizer,
         source_token_indexers=source_token_indexers,
         target_token_indexers=target_token_indexers,
         source_max_tokens=source_max_tokens,
         target_max_tokens=target_max_tokens,
         separate_namespaces=separate_namespaces,
         target_namespace=target_namespace,
         save_copy_fields=save_copy_fields,
         save_pgn_fields=save_pgn_fields,
     )
    def __init__(self,
                 langs_list: List[str],
                 ae_steps: List[str] = None,
                 bt_steps: List[str] = None,
                 para_steps: List[str] = None,
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = False) -> None:
        super().__init__(lazy)
        self._undefined_lang_id = "xx"
        self._tokenizer = tokenizer or WordTokenizer(
            word_splitter=SimpleWordSplitter())
        self._token_indexers = token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._denoising_dataset_reader = ParallelDatasetReader(
            lang1_tokenizer=tokenizer,
            lang1_token_indexers=token_indexers,
            lazy=lazy,
            denoising=True)
        self._backtranslation_dataset_reader = BacktranslationDatasetReader(
            tokenizer=tokenizer, token_indexers=token_indexers, lazy=lazy)
        self._parallel_dataset_reader = ParallelDatasetReader(
            lang1_tokenizer=tokenizer,
            lang1_token_indexers=token_indexers,
            lazy=lazy)

        self._mingler = RoundRobinMingler(dataset_name_field="lang_pair",
                                          take_at_a_time=1)

        self._langs_list = langs_list
        self._ae_steps = ae_steps
        self._bt_steps = bt_steps
        self._para_steps = para_steps
Exemple #5
0
 def __init__(self, bert_model):
     lower_case = True if "uncased" in bert_model else False
     self.bert_indexer, self.tokenizer = self.get_bert_indexer(
         bert_model, lower_case=lower_case)
     self.tokenizer_bert = MyBertWordSplitter(do_lower_case=lower_case)
     self.spacy_splitter = SpacyWordSplitter(keep_spacy_tokens=True)
     self.just_space_tokenization = JustSpacesWordSplitter()
     self.simple_tokenization = SimpleWordSplitter()
Exemple #6
0
 def __init__(
     self,
     word_splitter: WordSplitter = SimpleWordSplitter(),
     word_filter: WordFilter = PassThroughWordFilter(),
     word_stemmer: WordStemmer = PassThroughWordStemmer()
 ) -> None:
     self.word_splitter = word_splitter
     self.word_filter = word_filter
     self.word_stemmer = word_stemmer
 def __init__(self,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False) -> None:
     super().__init__(lazy)
     self._tokenizer = tokenizer or WordTokenizer(
         word_splitter=SimpleWordSplitter())
     self._token_indexers = token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
 def __init__(self,
              tokens_per_instance: int = None,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False) -> None:
     super().__init__(lazy=False)
     self._token_indexers = token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     splitter = SimpleWordSplitter()
     self._tokenizer = tokenizer or WordTokenizer(word_splitter=splitter)
     self._tokens_per_instance = tokens_per_instance
     self._lower = False
    def test_ria_reader(self):
        tokenizer = WordTokenizer(word_splitter=SimpleWordSplitter())
        reader = RIAReader(tokenizer)
        dataset = reader.read(RIA_EXAMPLE_FILE)
        for sample in dataset:
            self.assertEqual(sample.fields["source_tokens"][0].text,
                             START_SYMBOL)
            self.assertEqual(sample.fields["source_tokens"][-1].text,
                             END_SYMBOL)
            self.assertGreater(len(sample.fields["source_tokens"]), 2)

            self.assertEqual(sample.fields["target_tokens"][0].text,
                             START_SYMBOL)
            self.assertEqual(sample.fields["target_tokens"][-1].text,
                             END_SYMBOL)
            self.assertGreater(len(sample.fields["target_tokens"]), 2)
Exemple #10
0
class TestSimpleWordSplitter(AllenNlpTestCase):
    def setUp(self):
        super(TestSimpleWordSplitter, self).setUp()
        self.word_splitter = SimpleWordSplitter()

    def test_tokenize_handles_complex_punctuation(self):
        sentence = "this (sentence) has 'crazy' \"punctuation\"."
        expected_tokens = ["this", "(", "sentence", ")", "has", "'", "crazy", "'", '"',
                           "punctuation", '"', "."]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_tokenize_handles_contraction(self):
        sentence = "it ain't joe's problem; would've been yesterday"
        expected_tokens = ["it", "ai", "n't", "joe", "'s", "problem", ";", "would", "'ve", "been",
                           "yesterday"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_batch_tokenization(self):
        sentences = ["This is a sentence",
                     "This isn't a sentence.",
                     "This is the 3rd sentence."
                     "Here's the 'fourth' sentence."]
        batch_split = self.word_splitter.batch_split_words(sentences)
        separately_split = [self.word_splitter.split_words(sentence) for sentence in sentences]
        assert len(batch_split) == len(separately_split)
        for batch_sentence, separate_sentence in zip(batch_split, separately_split):
            assert len(batch_sentence) == len(separate_sentence)
            for batch_word, separate_word in zip(batch_sentence, separate_sentence):
                assert batch_word.text == separate_word.text

    def test_tokenize_handles_multiple_contraction(self):
        sentence = "wouldn't've"
        expected_tokens = ["would", "n't", "'ve"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_tokenize_handles_final_apostrophe(self):
        sentence = "the jones' house"
        expected_tokens = ["the", "jones", "'", "house"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_tokenize_handles_special_cases(self):
        sentence = "mr. and mrs. jones, etc., went to, e.g., the store"
        expected_tokens = ["mr.", "and", "mrs.", "jones", ",", "etc.", ",", "went", "to", ",",
                           "e.g.", ",", "the", "store"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens
 def __init__(self,
              lang1_tokenizer: Tokenizer = None,
              lang2_tokenizer: Tokenizer = None,
              lang1_token_indexers: Dict[str, TokenIndexer] = None,
              lang2_token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False,
              denoising=False) -> None:
     super().__init__(lazy)
     self._lang1_tokenizer = lang1_tokenizer or WordTokenizer(
         word_splitter=SimpleWordSplitter())
     self._lang2_tokenizer = lang2_tokenizer or self._lang1_tokenizer
     self._lang1_token_indexers = lang1_token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     self._lang2_token_indexers = lang2_token_indexers or self._lang1_token_indexers
     self._denoising = denoising
class TestSimpleWordSplitter(AllenNlpTestCase):
    def setUp(self):
        super(TestSimpleWordSplitter, self).setUp()
        self.word_splitter = SimpleWordSplitter()

    def test_tokenize_handles_complex_punctuation(self):
        sentence = "this (sentence) has 'crazy' \"punctuation\"."
        expected_tokens = ["this", "(", "sentence", ")", "has", "'", "crazy", "'", '"',
                           "punctuation", '"', "."]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_tokenize_handles_contraction(self):
        sentence = "it ain't joe's problem; would've been yesterday"
        expected_tokens = ["it", "ai", "n't", "joe", "'s", "problem", ";", "would", "'ve", "been",
                           "yesterday"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_batch_tokenization(self):
        sentences = ["This is a sentence",
                     "This isn't a sentence.",
                     "This is the 3rd sentence."
                     "Here's the 'fourth' sentence."]
        batch_split = self.word_splitter.batch_split_words(sentences)
        separately_split = [self.word_splitter.split_words(sentence) for sentence in sentences]
        assert len(batch_split) == len(separately_split)
        for batch_sentence, separate_sentence in zip(batch_split, separately_split):
            assert len(batch_sentence) == len(separate_sentence)
            for batch_word, separate_word in zip(batch_sentence, separate_sentence):
                assert batch_word.text == separate_word.text

    def test_tokenize_handles_multiple_contraction(self):
        sentence = "wouldn't've"
        expected_tokens = ["would", "n't", "'ve"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_tokenize_handles_final_apostrophe(self):
        sentence = "the jones' house"
        expected_tokens = ["the", "jones", "'", "house"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens

    def test_tokenize_handles_special_cases(self):
        sentence = "mr. and mrs. jones, etc., went to, e.g., the store"
        expected_tokens = ["mr.", "and", "mrs.", "jones", ",", "etc.", ",", "went", "to", ",",
                           "e.g.", ",", "the", "store"]
        tokens = [t.text for t in self.word_splitter.split_words(sentence)]
        assert tokens == expected_tokens
    def test_cnn_dailymail_reader(self):
        tokenizer = WordTokenizer(word_splitter=SimpleWordSplitter())
        reader = CNNDailyMailReader(tokenizer,
                                    cnn_tokenized_dir=TEST_STORIES_DIR,
                                    separate_namespaces=False)
        dataset = reader.read(TEST_URLS_FILE)
        for sample in dataset:
            self.assertEqual(sample.fields["source_tokens"][0].text,
                             START_SYMBOL)
            self.assertEqual(sample.fields["source_tokens"][-1].text,
                             END_SYMBOL)
            self.assertGreater(len(sample.fields["source_tokens"]), 2)

            self.assertEqual(sample.fields["target_tokens"][0].text,
                             START_SYMBOL)
            self.assertEqual(sample.fields["target_tokens"][-1].text,
                             END_SYMBOL)
            self.assertGreater(len(sample.fields["target_tokens"]), 2)
    def test_ria_copy_reader(self):
        tokenizer = WordTokenizer(word_splitter=SimpleWordSplitter())
        reader = RIAReader(tokenizer,
                           separate_namespaces=True,
                           save_copy_fields=True)
        dataset = reader.read(RIA_EXAMPLE_FILE)
        vocabulary = Vocabulary.from_instances(dataset)

        for sample in dataset:
            sample.index_fields(vocabulary)
            self.assertIsNotNone(sample.fields["source_tokens"])
            self.assertIsNotNone(sample.fields["target_tokens"])
            self.assertIsNotNone(sample.fields["metadata"].metadata)
            self.assertIsNotNone(sample.fields["source_token_ids"].array)
            self.assertIsNotNone(sample.fields["target_token_ids"].array)
            self.assertIsNotNone(
                sample.fields["source_to_target"]._mapping_array)
            self.assertIsNotNone(
                sample.fields["source_to_target"]._target_namespace)
Exemple #15
0
    def __init__(self,
                 lazy: bool = False,
                 paper_features_path: str = None,
                 word_splitter: WordSplitter = None,
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 data_file: Optional[str] = None,
                 samples_per_query: int = 5,
                 margin_fraction: float = 0.5,
                 ratio_hard_negatives: float = 0.5,
                 predict_mode: bool = False,
                 max_num_authors: Optional[int] = 5,
                 ratio_training_samples: Optional[float] = None,
                 max_sequence_length: Optional[int] = -1,
                 cache_path: Optional[str] = None,
                 overwrite_cache: Optional[bool] = False,
                 use_cls_token: Optional[bool] = None,
                 concat_title_abstract: Optional[bool] = None,
                 coviews_file: Optional[str] = None,
                 included_text_fields: Optional[str] = None,
                 use_paper_feature_cache: bool = True) -> None:
        """
        Args:
            lazy: if false returns a list
            paper_features_path: path to the paper features json file (result of scripts.generate_paper_features.py
            candidates_path: path to the candidate papers
            tokenizer: tokenizer to be used for tokenizing strings
            token_indexers: token indexer for indexing vocab
            data_file: path to the data file (e.g, citations)
            samples_per_query: number of triplets to generate for each query
            margin_fraction: minimum margin of co-views between positive and negative samples
            ratio_hard_negatives: ratio of training data that is selected from hard negatives
                remaining is allocated to easy negatives. should be set to 1.0 in case of similar click data
            predict_mode: if `True` the model only considers the current paper and returns an embedding
                otherwise the model uses the triplet format to train the embedder
            author_id_embedder: Embedder for author ids
            s2_id_embedder: Embedder for respresenting s2 ids
            other_id_embedder: Embedder for representing other ids (e.g., id assigned by metadata)
            max_num_authors: maximum number of authors,
            ratio_training_samples: Limits training to proportion of all training instances
            max_sequence_length: Longer sequences would be truncated (if -1 then there would be no truncation)
            cache_path: Path to file to cache instances, if None, instances won't be cached.
                If specified, instances are cached after being created so next time they are not created
                again from scratch
            overwrite_cache: If true, it overwrites the cached files. Each file corresponds to
                all instances created from the train, dev or test set.
            use_cls_token: Like bert, use an additional CLS token in the begginning (for transoformer)
            concat_title_abstract: Whether to consider title and abstract as a single field.
            coviews_file: Only for backward compatibility to work with older models (renamed to 
                `data_file` in newer models), leave this empty as it won't have any effect
            included_text_fields: space delimited fields to concat to the title: e.g., `title abstract authors`
            use_paper_feature_cache: set to False to disable the in-memory cache of paper features
        """
        super().__init__(lazy)
        self._word_splitter = word_splitter or SimpleWordSplitter()
        self._tokenizer = tokenizer or WordTokenizer(self._word_splitter)
        self._token_indexers = token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._token_indexer_author_id = {
            "tokens": SingleIdTokenIndexer(namespace='author')
        }
        self._token_indexer_author_position = \
            {"tokens": SingleIdTokenIndexer(namespace='author_positions')}

        self._token_indexer_venue = {
            "tokens": SingleIdTokenIndexer(namespace='venue')
        }
        self._token_indexer_id = {
            "tokens": SingleIdTokenIndexer(namespace='id')
        }

        with open(paper_features_path) as f_in:
            self.papers = json.load(f_in)
        self.samples_per_query = samples_per_query
        self.margin_fraction = margin_fraction
        self.ratio_hard_negatives = ratio_hard_negatives

        self.predict_mode = predict_mode
        self.max_sequence_length = max_sequence_length
        self.use_cls_token = use_cls_token

        if data_file and not predict_mode:
            # logger.info(f'reading contents of the file at: {coviews_file}')
            with open(data_file) as f_in:
                self.dataset = json.load(f_in)
            # logger.info(f'reading complete. Total {len(self.dataset)} records found.')
            root_path, _ = os.path.splitext(data_file)
            # for multitask interleaving reader, track which dataset the instance is coming from
            self.data_source = root_path.split('/')[-1]
        else:
            self.dataset = None
            self.data_source = None

        self.max_num_authors = max_num_authors

        self.triplet_generator = TripletGenerator(
            paper_ids=list(self.papers.keys()),
            coviews=self.dataset,
            margin_fraction=margin_fraction,
            samples_per_query=samples_per_query,
            ratio_hard_negatives=ratio_hard_negatives)
        self.paper_feature_cache = {
        }  # paper_id -> paper features. Serves as a cache for the _get_paper_features function

        self.ratio_training_samples = float(
            ratio_training_samples) if ratio_training_samples else None

        self.cache_path = cache_path
        self.overwrite_cache = overwrite_cache
        self.data_file = data_file
        self.paper_features_path = paper_features_path
        self.ratio_training_samples = ratio_training_samples

        self.concat_title_abstract = concat_title_abstract
        self.included_text_fields = set(included_text_fields.split())
        self.use_paper_feature_cache = use_paper_feature_cache

        self.abstract_delimiter = [Token('[SEP]')]
        self.author_delimiter = [Token('[unused0]')]
Exemple #16
0
                    help="which part of dataset",
                    type=str,
                    required=True)
parser.add_argument("--max_sent_len",
                    help="maximum sentence length",
                    type=int,
                    default=200)
args = parser.parse_args()
in_file = args.in_file
out_dir = args.out_dir
split = args.split
max_sent_len = args.max_sent_len

jsondecoder = json.JSONDecoder()

tokenizer = SimpleWordSplitter()

premise_fp = open(out_dir + "/" + split + ".premise", "w")
hypothesis_fp = open(out_dir + "/" + split + ".hypothesis", "w")
label_fp = open(out_dir + "/" + split + ".label", "w")
index_fp = open(out_dir + "/" + split + ".index", "w")

with open(in_file, "r") as in_fp:
    for line in tqdm(in_fp.readlines()):
        struct = jsondecoder.decode(line)

        hypothesis = struct["claim"]

        premise_idx = 0
        for sentence in struct["predicted_sentences"]:
            underlined_title = sentence[0]
 def setUp(self):
     super(TestSimpleWordSplitter, self).setUp()
     self.word_splitter = SimpleWordSplitter()
# from retrieval.fever_doc_db import FeverDocDB

parser = argparse.ArgumentParser()
parser.add_argument("--in_file", type=str, required=True)
parser.add_argument("--out_file", type=str, required=True)
args = parser.parse_args()
in_file = args.in_file
out_file = args.out_file

if os.path.exists(out_file):
    raise ValueError("Output already exists")

jsondecoder = json.JSONDecoder()
jsonencoder = json.JSONEncoder()

tokenizer = SimpleWordSplitter()
print("Tokenizing")

with open(in_file, "r") as in_fp:
    with open(out_file, "w") as out_fp:
        for line in tqdm(in_fp.readlines()):
            struct = jsondecoder.decode(line)

            tok = tokenizer.split_words(struct["claim"])
            tokenized = " ".join(map(lambda x: x.text, tok))
            struct["claim"] = tokenized

            result = jsonencoder.encode(struct)
            out_fp.write(result + "\n")
Exemple #19
0
def fever_app(caller):


    global db, tokenizer, text_encoder, encoder, X_train, M_train, X, M, Y_train, Y,params,sess, n_batch_train, db_file, \
        drqa_index, max_page, max_sent, encoder_path, bpe_path, n_ctx, n_batch, model_file
    global n_vocab,n_special,n_y,max_len,clf_token,eval_lm_losses,eval_clf_losses,eval_mgpu_clf_losses,eval_logits, \
        eval_mgpu_logits,eval_logits

    LogHelper.setup()
    logger = LogHelper.get_logger("papelo")

    logger.info("Load config")
    config = json.load(open(os.getenv("CONFIG_FILE","configs/config-docker.json")))
    globals().update(config)
    print(globals())

    logger.info("Set Seeds")
    random.seed(42)
    np.random.seed(42)
    tf.set_random_seed(42)

    logger.info("Load FEVER DB")
    db = FeverDocDB(db_file)
    retrieval = TopNDocsTopNSents(db, max_page, max_sent, True, False, drqa_index)

    logger.info("Init word tokenizer")
    tokenizer = SimpleWordSplitter()

    # Prepare text encoder
    logger.info("Load BPE Text Encoder")
    text_encoder = TextEncoder(encoder_path, bpe_path)
    encoder = text_encoder.encoder
    n_vocab = len(text_encoder.encoder)

    n_y = 3
    encoder['_start_'] = len(encoder)
    encoder['_delimiter_'] = len(encoder)
    encoder['_classify_'] = len(encoder)
    clf_token = encoder['_classify_']
    n_special = 3
    max_len = n_ctx // 2 - 2

    n_batch_train = n_batch

    logger.info("Create TF Placeholders")
    X_train = tf.placeholder(tf.int32, [n_batch, 1, n_ctx, 2])
    M_train = tf.placeholder(tf.float32, [n_batch, 1, n_ctx])
    X = tf.placeholder(tf.int32, [None, 1, n_ctx, 2])
    M = tf.placeholder(tf.float32, [None, 1, n_ctx])

    Y_train = tf.placeholder(tf.int32, [n_batch])
    Y = tf.placeholder(tf.int32, [None])

    logger.info("Model Setup")
    eval_logits, eval_clf_losses, eval_lm_losses = model(X, M, Y, train=False, reuse=None)
    eval_mgpu_logits, eval_mgpu_clf_losses, eval_mgpu_lm_losses = mgpu_predict(X_train, M_train, Y_train)

    logger.info("Create TF Session")
    params = find_trainable_variables('model')

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=float(os.getenv("TF_GPU_MEMORY_FRACTION","0.5")))
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options))
    sess.run(tf.global_variables_initializer())
    sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(model_file))])

    logger.info("Ready")

    def predict(instances):
        predictions = []

        for instance in tqdm(instances):
            sents = retrieval.get_sentences_for_claim(instance["claim"])
            found_evidence = resolve_evidence(sents)
            instance["tokenized_claim"] = " ".join(map(lambda x: x.text, tokenizer.split_words(instance["claim"])))

            sub_instances = make_instances(instance, found_evidence)
            sub_predictions = predict_sub_instances(text_encoder, sub_instances)

            refute_evidence =  [i for i, x in enumerate(sub_predictions) if x == 2]
            support_evidence = [i for i, x in enumerate(sub_predictions) if x == 0]

            if len(support_evidence):
                predicted_label = "SUPPORTS"
                predicted_evidence = [[found_evidence[i]["title"], found_evidence[i]["line_number"]] for i in support_evidence]
            elif len(refute_evidence):
                predicted_label = "REFUTES"
                predicted_evidence = [[found_evidence[i]["title"], found_evidence[i]["line_number"]] for i in refute_evidence]
            else:
                predicted_label = "NOT ENOUGH INFO"
                predicted_evidence = []

            predictions.append({"predicted_label":predicted_label,
                                "predicted_evidence": predicted_evidence})




        return predictions

    return caller(predict)
 def setUp(self):
     super(TestSimpleWordSplitter, self).setUp()
     self.word_splitter = SimpleWordSplitter()
from allennlp.data import Instance
from allennlp.data.dataset import Batch
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.fields import MetadataField, SequenceLabelField, TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.data.tokenizers.word_splitter import SimpleWordSplitter
from overrides import overrides

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

torch.manual_seed(1)

# TODO: find a better place for this
golden_tokenizer = WordTokenizer(word_splitter=SimpleWordSplitter())
golden_token_indexers = {
    "golden_tokens": SingleIdTokenIndexer(namespace="tokens")
}


def string_to_fields(string: str, tokenizer: Tokenizer,
                     token_indexers: Dict[str, TokenIndexer]):
    tokenized_string = tokenizer.tokenize(string)
    tokenized_string.insert(0, Token(END_SYMBOL))
    field = TextField(tokenized_string, token_indexers)

    # TODO: always use single id token indexer and tokenizer default/bpe cause we will have bert/elmo passed to main str
    tokenized_golden_string = golden_tokenizer.tokenize(string)
    tokenized_golden_string.append(
        Token(END_SYMBOL))  # with eos at the end for loss compute