def dump_token_embeddings(vocab_file, options_file, weight_file, outfile):
    """
    Given an input vocabulary file, dump all the token embeddings to the
    outfile.  The result can be used as the embedding_weight_file when
    constructing a BidirectionalLanguageModel.

    Patched to print progress
    """
    with open(options_file, 'r') as fin:
        options = json.load(fin)
    max_word_length = options['char_cnn']['max_characters_per_token']

    vocab = UnicodeCharsVocabulary(vocab_file, max_word_length)
    batcher = Batcher(vocab_file, max_word_length)
    print('Computing {} LM token vectors'.format(vocab.size))

    ids_placeholder = tf.placeholder(
        'int32',
        shape=(None, None, max_word_length)
    )
    print('Building Language model')
    model = BidirectionalLanguageModel(
        options_file, weight_file, ids_placeholder
    )

    embedding_op = model.get_ops()['token_embeddings']

    n_tokens = vocab.size
    embed_dim = int(embedding_op.shape[2])

    embeddings = np.zeros((n_tokens, embed_dim), dtype=lm_model.DTYPE)

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        for k in tqdm(range(n_tokens), desc='Computing LM token embeddings', ncols=80):
            token = vocab.id_to_word(k)
            char_ids = batcher.batch_sentences([[token]])[0, 1, :].reshape(
                1, 1, -1)
            embeddings[k, :] = sess.run(
                embedding_op, feed_dict={ids_placeholder: char_ids}
            )

    with h5py.File(outfile, 'w') as fout:
        fout.create_dataset(
            'embedding', embeddings.shape, dtype='float32', data=embeddings
        )
class CapeElmoQaModel(Model):
    """ Base classes for Cape ELMo models """
    def __init__(self,
                 encoder: DocumentAndQuestionEncoder,
                 lm_model: LanguageModel,
                 per_sentence: bool,
                 max_batch_size: int,
                 word_embed: Optional[WordEmbedder],
                 char_embed: Optional[CharWordEmbedder] = None,
                 preprocessor: Optional[TextPreprocessor] = None):
        if word_embed is None and char_embed is None:
            raise ValueError()
        self.preprocessor = preprocessor
        self.max_batch_size = max_batch_size
        self.lm_model = lm_model
        self.per_sentence = per_sentence
        self.preprocessor = None
        self.word_embed = word_embed
        self.char_embed = char_embed
        self.encoder = encoder
        if self.per_sentence:
            self._max_num_sentences = self.max_batch_size * 30  # TODO hard coded for SQuAD
        else:
            self._max_num_sentences = self.max_batch_size
        self._batcher = None
        self._max_word_size = None

        # placeholders
        self._is_train_placeholder = None
        self._batch_len_placeholders = None
        self._question_char_ids_placeholder = None
        self._context_char_ids_placeholder = None
        self._context_sentence_ixs = None

        # Patrick's Caching edits
        self._cached_doc_placeholder = None
        self.document_embedding_dim = -1

    @property
    def token_lookup(self):
        """
        Are we using pre-computed word vectors, or running the LM's CNN to dynmacially derive
        word vectors from characters.
        """
        return self.lm_model.embed_weights_file is not None

    def init(self, corpus, loader: ResourceLoader):
        if self.word_embed is not None:
            self.word_embed.set_vocab(
                corpus, loader, None if self.preprocessor is None else
                self.preprocessor.special_tokens())
        if self.char_embed is not None:
            self.char_embed.embeder.set_vocab(corpus)

    def set_inputs(self,
                   datasets: List[ParagraphAndQuestionDataset],
                   word_vec_loader=None):
        voc = set()
        for dataset in datasets:
            voc.update(dataset.get_vocab())

        input_spec = datasets[0].get_spec()
        for dataset in datasets[1:]:
            input_spec += dataset.get_spec()

        return self.set_input_spec(input_spec, voc, word_vec_loader)

    def set_input_spec(self, input_spec, voc, word_vec_loader=None):
        if word_vec_loader is None:
            word_vec_loader = ResourceLoader()
        if self.word_embed is not None:
            self.word_embed.init(word_vec_loader, voc)
        if self.char_embed is not None:
            self.char_embed.embeder.init(word_vec_loader, voc)

        batch_size = input_spec.batch_size
        self.encoder.init(
            input_spec, True, self.word_embed,
            None if self.char_embed is None else self.char_embed.embeder)
        self._is_train_placeholder = tf.placeholder(tf.bool, (),
                                                    name='is_train')

        if self.token_lookup:
            self._batcher = TokenBatcher(self.lm_model.lm_vocab_file)
            self._question_char_ids_placeholder = tf.placeholder(
                tf.int32, (batch_size, None), name='lm_q_ids_pl')
            self._context_char_ids_placeholder = tf.placeholder(
                tf.int32, (batch_size, None), name='lm_c_ids_pl')
            self._max_word_size = input_spec.max_word_size
            self._context_sentence_ixs = None
        else:
            input_spec.max_word_size = 50  # TODO hack, harded coded from the lm model
            self._batcher = Batcher(self.lm_model.lm_vocab_file, 50)
            self._max_word_size = input_spec.max_word_size
            self._question_char_ids_placeholder = tf.placeholder(
                tf.int32, (batch_size, None, self._max_word_size))
            if self.per_sentence:
                self._context_char_ids_placeholder = tf.placeholder(
                    tf.int32, (None, None, self._max_word_size))
                self._context_sentence_ixs = tf.placeholder(
                    tf.int32, (batch_size, 3, None, 3))
            else:
                self._context_char_ids_placeholder = tf.placeholder(
                    tf.int32, (batch_size, None, self._max_word_size))
                self._context_sentence_ixs = None
        return self.get_placeholders()

    def get_placeholders(self):
        pls = self.encoder.get_placeholders() + [
            self._is_train_placeholder, self._question_char_ids_placeholder,
            self._context_char_ids_placeholder
        ] + ([self._context_sentence_ixs] if
             (self._context_sentence_ixs is not None) else [])
        return pls + [self._cached_doc_placeholder] if hasattr(
            self, '_cached_doc_placeholder') else pls

    def get_predictions_for(self, input_tensors: Dict[Tensor, Tensor]):
        return self._build_elmo(input_tensors, is_prod=False)

    def get_production_predictions_for(self, input_tensors: Dict[Tensor,
                                                                 Tensor]):
        return self._build_elmo(input_tensors, is_prod=True)

    def _build_elmo(self, input_tensors: Dict[Tensor, Tensor], is_prod=False):
        is_train = input_tensors[self._is_train_placeholder]
        enc = self.encoder

        q_lm_model = BidirectionalLanguageModel(
            self.lm_model.options_file,
            self.lm_model.weight_file,
            input_tensors[self._question_char_ids_placeholder],
            embedding_weight_file=self.lm_model.embed_weights_file,
            use_character_inputs=not self.token_lookup,
            max_batch_size=self.max_batch_size)
        q_lm_encoding = q_lm_model.get_ops()["lm_embeddings"]

        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            c_lm_model = BidirectionalLanguageModel(
                self.lm_model.options_file,
                self.lm_model.weight_file,
                input_tensors[self._context_char_ids_placeholder],
                embedding_weight_file=self.lm_model.embed_weights_file,
                use_character_inputs=not self.token_lookup,
                max_batch_size=self._max_num_sentences)
            c_lm_encoding = c_lm_model.get_ops()["lm_embeddings"]

        if self.per_sentence:
            c_lm_encoding = tf.gather_nd(
                c_lm_encoding, input_tensors[self._context_sentence_ixs])

        q_mask = input_tensors[enc.question_len]
        c_mask = input_tensors[enc.context_len]

        q_embed = []
        c_embed = []

        if enc.question_chars in input_tensors:
            with tf.variable_scope("char-embed"):
                q, c = self.char_embed.embed(
                    is_train, (input_tensors[enc.question_chars],
                               input_tensors[enc.question_word_len]),
                    (input_tensors[enc.context_chars],
                     input_tensors[enc.context_word_len]))
            q_embed.append(q)
            c_embed.append(c)

        if enc.question_words in input_tensors:
            with tf.variable_scope("word-embed"):
                q, c = self.word_embed.embed(
                    is_train, (input_tensors[enc.question_words], q_mask),
                    (input_tensors[enc.context_words], c_mask))
            q_embed.append(q)
            c_embed.append(c)

        if enc.question_features in input_tensors:
            q_embed.append(input_tensors.get(enc.question_features))
            c_embed.append(input_tensors.get(enc.context_features))

        q_embed = tf.concat(q_embed, axis=2)
        c_embed = tf.concat(c_embed, axis=2)

        answer = [
            input_tensors[x] for x in enc.answer_encoder.get_placeholders()
        ]
        if not is_prod:
            return self._get_predictions_for(is_train, q_embed, q_mask,
                                             q_lm_encoding, c_embed, c_mask,
                                             c_lm_encoding, answer)
        else:
            cached_context = input_tensors[self._cached_doc_placeholder]
            return self._get_production_predictions_for(
                is_train, cached_context, q_embed, q_mask, q_lm_encoding,
                c_embed, c_mask, c_lm_encoding, answer)

    def _get_predictions_for(self, is_train, question_embed, question_mask,
                             question_lm, context_embed, context_mask,
                             context_lm, answer) -> Prediction:
        raise NotImplementedError()

    def _get_production_predictions_for(self, is_train, cached_context,
                                        question_embed, question_mask,
                                        question_lm, context_embed,
                                        context_mask, context_lm,
                                        answer) -> Prediction:
        raise NotImplementedError()

    def encode(self,
               batch: List[ContextAndQuestion],
               is_train: bool,
               cached_doc=None):
        if len(batch) > self.max_batch_size:
            raise ValueError(
                "The model can only use a batch <= %d, but got %d" %
                (self.max_batch_size, len(batch)))
        data = self.encoder.encode(batch, is_train)
        data[self.
             _question_char_ids_placeholder] = self._batcher.batch_sentences(
                 [q.question for q in batch])
        data[self._is_train_placeholder] = is_train
        if cached_doc is not None:
            data[self._cached_doc_placeholder] = cached_doc
            context_word_dim = cached_doc.shape[1]
        else:
            data[self._cached_doc_placeholder] = np.zeros(
                (1, 1, self.document_embedding_dim))
            context_word_dim = data[self.encoder.context_words].shape[1]

        if not self.per_sentence:
            data[self._context_char_ids_placeholder] = \
                self._batcher.batch_sentences([x.get_context() for x in batch])
        else:
            data[self._context_char_ids_placeholder] = \
                self._batcher.batch_sentences(flatten_iterable([x.sentences for x in batch]))

            # Compute indices where context_sentence_ixs[sentence#, k, sentence_word#] = (batch#, k, batch_word#)
            # for each word. We use this to map the tokens built for the sentences back to
            # the format where sentences are flattened for each batch
            context_sentence_ixs = np.zeros(
                (len(batch), 3, context_word_dim, 3), dtype=np.int32)
            total_sent_ix = 0
            for ix, point in enumerate(batch):
                word_ix = 0
                for sent_ix, sent in enumerate(point.sentences):
                    for w_ix in range(len(sent)):
                        for k in range(3):
                            context_sentence_ixs[ix, k, word_ix] = [
                                total_sent_ix, k, w_ix
                            ]
                        word_ix += 1
                    total_sent_ix += 1
            data[self._context_sentence_ixs] = context_sentence_ixs
        return data