class RandomMachineReaderModel(CapeMachineReaderModelInterface):
    def __init__(self, _):
        self.tokenizer = NltkAndPunctTokenizer()

    def tokenize(self, text):
        tokens = self.tokenizer.tokenize_paragraph_flat(text)
        spans = self.tokenizer.convert_to_spans(text, [tokens])[0]
        return tokens, spans

    def get_document_embedding(self, text):
        np.random.seed(
            int(hashlib.sha1(text.encode()).hexdigest(), 16) % 10**8)
        document_tokens, _ = self.tokenize(text)
        return np.random.random((len(document_tokens), 240))

    def get_logits(self, question, document_embedding):
        question_tokens, _ = self.tokenize(question)
        n_words = document_embedding.shape[0]
        qseed = int(hashlib.sha1(question.encode()).hexdigest(), 16) % 10**8
        dseed = int(np.sum(document_embedding) * 10**6) % 10**8
        np.random.seed(dseed + qseed)
        start_logits = np.random.random(n_words)
        off = np.random.randint(1, 5)
        end_logits = np.concatenate(
            [np.zeros(off) + np.min(start_logits), start_logits[off:]])
        return start_logits[:n_words], end_logits[:n_words]
class CapeDocQAMachineReaderModel(CapeMachineReaderModelInterface):
    def __init__(self, machine_reader_config):
        self.tokenizer = NltkAndPunctTokenizer()
        self.config = machine_reader_config
        self.model = self._load_model()
        self.sess = tf.Session()
        self.start_logits, self.end_logits, self.context_rep = self._build_model(
        )
        self._initialize()

    def _load_model(self):
        with open(self.config.model_pickle_file, 'rb') as f:
            model = pickle.load(f)

        model.lm_model.weight_file = self.config.lm_weights_file
        model.lm_model.lm_vocab_file = self.config.vocab_file
        model.lm_model.embed_weights_file = self.config.lm_token_weights_file
        model.lm_model.options_file = self.config.lm_options_file
        return model

    def _build_model(self):
        vocab_to_init_with = {
            line.strip()
            for line in open(self.config.vocab_file, encoding="utf-8")
            if line.strip() not in vocab_to_ignore
        }
        self.model.word_embed.vec_name = self.config.word_vector_file
        with self.sess.as_default():
            self.model.set_input_spec(
                ParagraphAndQuestionSpec(None, None, None, 14),
                vocab_to_init_with,
                word_vec_loader=ResourceLoader(
                    load_vec_fn=lambda x, y: load_word_vectors(
                        x, y, is_path=True)))
            pred = self.model.get_production_predictions_for(
                {x: x
                 for x in self.model.get_placeholders()})
        return pred.start_logits, pred.end_logits, self.model.context_rep

    def _initialize(self):
        all_vars = tf.global_variables() + tf.get_collection(
            tf.GraphKeys.SAVEABLE_OBJECTS)
        lm_var_names = {x.name for x in all_vars if x.name.startswith("bilm")}
        vars_to_restore = [x for x in all_vars if x.name not in lm_var_names]
        saver = tf.train.Saver(vars_to_restore)
        saver.restore(self.sess, self.config.checkpoint_file)
        self.sess.run(
            tf.variables_initializer(
                [x for x in all_vars if x.name in lm_var_names]))

    def tokenize(self, text):
        tokens = self.tokenizer.tokenize_paragraph_flat(text)
        spans = self.tokenizer.convert_to_spans(text, [tokens])[0]
        return tokens, spans

    def get_document_embedding(self, text):
        document_tokens, _ = self.tokenize(text)
        test_question = ParagraphAndQuestion(document_tokens,
                                             ['dummy', 'question'], None,
                                             "cape_question", 'cape_document')
        feed = self.model.encode([test_question], False, cached_doc=None)
        return self.sess.run(self.model.context_rep, feed_dict=feed)[0]

    def get_logits(self, question, document_embedding):
        question_tokens, _ = self.tokenize(question)
        n_words = document_embedding.shape[0]
        dummy_document = ['dummy'] * n_words
        test_question = ParagraphAndQuestion(dummy_document, question_tokens,
                                             None, "cape_question",
                                             'cape_document')
        feed = self.model.encode(
            [test_question],
            False,
            cached_doc=document_embedding[np.newaxis, :, :])
        start_logits, end_logits = self.sess.run(
            [self.start_logits, self.end_logits], feed_dict=feed)
        return start_logits[0][:n_words], end_logits[0][:n_words]