class TransformerCache(Transformer):
    """Attention net.  See file docstring."""
    def __init__(self, *args, **kwargs):
        super(TransformerCache, self).__init__(*args, **kwargs)
        self.attention_weights = dict()  # For vizualizing attention heads.
        self.sentence_cache = LRUCache_new(self.hparams.hidden_size,
                                           max_size=20,
                                           batch_size=self.hparams.batch_size)

        with tf.variable_scope("sentence_level_cache"):
            self.m_weight = tf.get_variable(
                'm_weight',
                shape=[self.hparams.hidden_size, self.hparams.hidden_size])
            self.s_weight = tf.get_variable(
                's_weight',
                shape=[self.hparams.hidden_size, self.hparams.hidden_size])
            self.cache_flag = tf.Variable(0,
                                          trainable=False,
                                          name='flag',
                                          dtype=tf.int64)

    def calculate_mixing_weight(self, s, m):
        return tf.sigmoid(
            tf.einsum('jl,ikl->ikj', self.m_weight, m) +
            tf.einsum('jl,ikl->ikj', self.s_weight, s))

    def decode(self,
               decoder_input,
               encoder_output,
               encoder_decoder_attention_bias,
               decoder_self_attention_bias,
               hparams,
               cache=None,
               nonpadding=None):
        """Decode Transformer outputs from encoder representation.

    Args:
      decoder_input: inputs to bottom of the model.
          [batch_size, decoder_length, hidden_dim]
      encoder_output: Encoder representation.
          [batch_size, input_length, hidden_dim]
      encoder_decoder_attention_bias: Bias and mask weights for
          encoder-decoder attention. [batch_size, input_length]
      decoder_self_attention_bias: Bias and mask weights for decoder
          self-attention. [batch_size, decoder_length]
      hparams: hyperparmeters for model.
      cache: dict, containing tensors which are the results of previous
          attentions, used for fast decoding.
      nonpadding: optional Tensor with shape [batch_size, decoder_length]

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        decoder_input = tf.nn.dropout(
            decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

        decoder_output = transformer_decoder(
            decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_decoder_attention_bias,
            hparams,
            cache=cache,
            nonpadding=nonpadding,
            save_weights_to=self.attention_weights)

        if (common_layers.is_on_tpu()
                and hparams.mode == tf.estimator.ModeKeys.TRAIN):
            # TPU does not react kindly to extra dimensions.
            # TODO(noam): remove this once TPU is more forgiving of extra dims.
            return decoder_output
        else:
            # Expand since t2t expects 4d tensors.

            m = self.sentence_cache.Query(
                tf.reshape(decoder_output,
                           [hparams.batch_size, -1, hparams.hidden_size]))
            #m = tf.py_func(self.sentence_cache.QueryMultipleEntries, [decoder_output], tf.float32)

            lambd = self.calculate_mixing_weight(
                tf.reshape(decoder_output,
                           [hparams.batch_size, -1, hparams.hidden_size]), m)

            m = tf.reshape(m, tf.shape(decoder_output))

            lambd = tf.reshape(
                lambd, (tf.shape(decoder_output)[0], -1, hparams.hidden_size))

            if self.hparams.use_cache:
                return tf.expand_dims(lambd * decoder_output +
                                      (1.0 - lambd) * m,
                                      axis=2)
            else:
                return tf.expand_dims(decoder_output, axis=2)

    def body(self, features):
        """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "tragets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        hparams = self._hparams

        if self.has_input:
            inputs = features["inputs"]
            target_space = features["target_space_id"]
            encoder_output, encoder_decoder_attention_bias = self.encode(
                inputs, target_space, hparams, features=features)
        else:
            encoder_output, encoder_decoder_attention_bias = (None, None)

        targets = features["targets"]
        targets = common_layers.flatten4d3d(targets)

        decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
            targets, hparams, features=features)

        decoder_output = self.decode(decoder_input,
                                     encoder_output,
                                     encoder_decoder_attention_bias,
                                     decoder_self_attention_bias,
                                     hparams,
                                     nonpadding=features_to_nonpadding(
                                         features, "targets"))

        self.cache_flag = self.sentence_cache.Add(
            tf.squeeze(features["targets_raw"], [2, 3]),
            tf.squeeze(decoder_output, 2), tf.squeeze(decoder_output, 2))

        tf.cast(self.cache_flag, tf.float32)

        expected_attentions = features.get("expected_attentions")
        if expected_attentions is not None:
            attention_loss = common_attention.encoder_decoder_attention_loss(
                expected_attentions, self.attention_weights,
                hparams.expected_attention_loss_type,
                hparams.expected_attention_loss_multiplier)
            return decoder_output, {"attention_loss": attention_loss}

        return decoder_output + self.cache_flag

    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0,
                     sentence_cache=None):
        """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.target_modality

        if self.has_input:
            inputs = features["inputs"]
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = common_layers.shape_list(
                    inputs)[1] + decode_length

            # TODO(llion): Clean up this reshaping logic.
            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.input_modality["inputs"]
            with tf.variable_scope(input_modality.name):
                inputs = input_modality.bottom_sharded(inputs, dp)
            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            # In this case, features["inputs"] contains partial targets.
            # We force the outputs to begin with these sequences.
            encoder_output = None
            encoder_decoder_attention_bias = None
            partial_targets = tf.squeeze(tf.to_int64(features["inputs"]),
                                         [2, 3])
            partial_targets_length = common_layers.shape_list(
                partial_targets)[1]
            decode_length += partial_targets_length
            batch_size = tf.shape(partial_targets)[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            print(i)
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache.get("encoder_output"),
                                  cache.get("encoder_decoder_attention_bias"),
                                  bias,
                                  hparams,
                                  cache,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size, 0.0, -1e9)

                ret = tf.cond(tf.less(i, partial_targets_length),
                              forced_logits, lambda: ret)
            return ret, cache, body_outputs

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            sentence_cache=self.sentence_cache,
            cache_flag=self.cache_flag)
        if partial_targets is not None:
            ret["outputs"] = ret["outputs"][:, partial_targets_length:]
        return ret