Example #1
0
    def encode_no_lookup(self, embedded_inputs, inputs_mask):
        """Encoder step for transformer given already-embedded inputs

      Args:
        embedded_inputs: int tensor with shape [batch_size, input_length, emb_size].
        inputs_mask: tensor with shape [batch_size, input_length]

      Returns:
        float tensor with shape [batch_size, input_length, hidden_size]
      """
        (encoder_input, self_attention_bias,
         _) = (t2t_transformer.transformer_prepare_encoder(
             embedded_inputs, self.target_space, self.hparams))

        encoder_input = tf.nn.dropout(
            encoder_input, 1.0 - self.hparams.layer_prepostprocess_dropout)

        (encoder_output, encoder_extra_output) = (
            universal_transformer_util.universal_transformer_encoder(
                encoder_input,
                self_attention_bias,
                self.hparams,
                nonpadding=inputs_mask,
                save_weights_to=self.model.attention_weights))

        return encoder_output, encoder_extra_output
Example #2
0
def universal_transformer_encoder(inputs,
                                  target_space,
                                  hparams,
                                  features=None,
                                  make_image_summary=False):

    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer.transformer_prepare_encoder(inputs,
                                                target_space,
                                                hparams,
                                                features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    [encoder_output, encoder_extra_output
     ] = universal_transformer_util.universal_transformer_encoder(
         encoder_input,
         self_attention_bias,
         hparams,
         nonpadding=transformer.features_to_nonpadding(features, "inputs"),
         save_weights_to=None,
         make_image_summary=make_image_summary)

    if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0:
        ponder_times, remainders = encoder_extra_output
        act_loss = hparams.act_loss_weight * tf.reduce_mean(ponder_times +
                                                            remainders)

        return encoder_output, act_loss
    else:
        return encoder_output, tf.constant(0.0, tf.float32)
    def encode(self,
               inputs,
               target_space,
               hparams,
               features=None,
               losses=None,
               **kwargs):
        """Encode Universal Transformer inputs.

    It is similar to "transformer.encode", but it uses
    "universal_transformer_util.universal_transformer_encoder" instead of
    "transformer.transformer_encoder".

    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.
      losses: Unused.
      **kwargs: additional arguments to pass to encoder_function

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encoder-decoder attention. [batch_size, input_length]
          encoder_extra_output: which is extra encoder output used in some
            variants of the model (e.g. in ACT, to pass the ponder-time to body)
    """
        del losses

        inputs = common_layers.flatten4d3d(inputs)

        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer.transformer_prepare_encoder(inputs,
                                                    target_space,
                                                    hparams,
                                                    features=features))

        encoder_input = tf.nn.dropout(
            encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

        (encoder_output, encoder_extra_output) = (
            universal_transformer_util.universal_transformer_encoder(
                encoder_input,
                self_attention_bias,
                hparams,
                nonpadding=transformer.features_to_nonpadding(
                    features, "inputs"),
                save_weights_to=self.attention_weights))

        return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
  def encode(self, inputs, target_space, hparams, features=None, losses=None):
    """Encode Universal Transformer inputs.

    It is similar to "transformer.encode", but it uses
    "universal_transformer_util.universal_transformer_encoder" instead of
    "transformer.transformer_encoder".

    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.
      losses: Unused.

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encoder-decoder attention. [batch_size, input_length]
          encoder_extra_output: which is extra encoder output used in some
            variants of the model (e.g. in ACT, to pass the ponder-time to body)
    """
    del losses

    inputs = common_layers.flatten4d3d(inputs)

    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer.transformer_prepare_encoder(
            inputs, target_space, hparams, features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    (encoder_output, encoder_extra_output) = (
        universal_transformer_util.universal_transformer_encoder(
            encoder_input,
            self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "inputs"),
            save_weights_to=self.attention_weights))

    return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
def universal_transformer_encoder(inputs, target_space, 
				hparams, features=None, make_image_summary=False):
    
    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer.transformer_prepare_encoder(
            inputs, target_space, hparams, features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    [encoder_output, 
    encoder_extra_output] = universal_transformer_util.universal_transformer_encoder(
        encoder_input,
        self_attention_bias,
        hparams,
        nonpadding=transformer.features_to_nonpadding(features, "inputs"),
        save_weights_to=None,
        make_image_summary=make_image_summary)

    # encoder_output = tf.expand_dims(encoder_output, 2)

    return encoder_output
Example #6
0
  def encode(self, stories, questions, target_space, hparams,
             features=None):
    """Encode transformer inputs.

    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      unused_features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encodre-decoder attention. [batch_size, input_length]
    """

    inputs = tf.concat([stories, questions], axis=1)
    # inputs = common_layers.flatten4d3d(inputs)

    encoder_input, self_attention_bias, _ = (
      transformer.transformer_prepare_encoder(inputs, target_space, hparams,
        features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    (encoder_output,
        extra_output) = universal_transformer_util.universal_transformer_encoder(
            encoder_input, self_attention_bias, hparams,
            nonpadding=transformer.features_to_nonpadding(features, "inputs"),
            save_weights_to=self.attention_weights)

    return encoder_output, _, extra_output
    def transformer_fn(self,
                       sentence_complex_input_placeholder, emb_complex,
                       sentence_simple_input_placeholder, emb_simple,
                       w, b,
                       rule_id_input_placeholder, rule_target_input_placeholder,
                       mem_contexts, mem_outputs,
                       global_step, score, comp_features, obj):
        encoder_mask = tf.to_float(
            tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                     self.data.vocab_complex.encode(constant.SYMBOL_PAD)))
        encoder_attn_bias = common_attention.attention_bias_ignore_padding(encoder_mask)

        obj_tensors = {}

        train_mode = self.model_config.train_mode
        if self.model_config.bert_mode:
            # Leave space for decoder when static seq
            gpu_id = 0 if train_mode == 'static_seq' or train_mode == 'static_self-critical' or 'direct' in self.model_config.memory else 1
            with tf.device('/device:GPU:%s' % gpu_id):
                sentence_complex_input = tf.stack(sentence_complex_input_placeholder, axis=1)
                bert_model = BertModel(
                    BertConfig.from_json_file(self.model_config.bert_config),
                    self.is_train, sentence_complex_input,
                    input_mask=1.0-encoder_mask, token_type_ids=None, use_one_hot_embeddings=False)
                encoder_embed_inputs = bert_model.embedding_output
                encoder_outputs = bert_model.sequence_output
                emb_complex = bert_model.embedding_table # update emb complex
                if (self.model_config.tie_embedding == 'all' or
                        self.model_config.tie_embedding == 'enc_dec'):
                    emb_simple = bert_model.embedding_table
                if (self.model_config.tie_embedding == 'all' or
                        self.model_config.tie_embedding == 'dec_out'):
                    emb_w_proj = tf.get_variable(
                        'emb_w_proj', shape=[self.model_config.dimension, self.model_config.dimension],
                        initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32)
                    w = tf.matmul(bert_model.embedding_table, emb_w_proj)

                if 'direct' in self.model_config.memory:
                    with tf.device('/device:GPU:1'):
                        direct_mask = tf.to_float(
                            tf.equal(tf.stack(rule_target_input_placeholder, axis=1),
                                     self.data.vocab_complex.encode(constant.SYMBOL_PAD)))
                        direct_bert_model = BertModel(
                            BertConfig.from_json_file(self.model_config.bert_config),
                            self.is_train, tf.stack(rule_target_input_placeholder, axis=1),
                            input_mask=1.0 - direct_mask, token_type_ids=None, use_one_hot_embeddings=False,
                            embedding_table=emb_simple,
                            scope='direct')
                        direct_bert_output = direct_bert_model.sequence_output
                        obj_tensors['direct_bert_bias'] = common_attention.attention_bias_ignore_padding(direct_mask)
                        obj_tensors['direct_bert_output'] = direct_bert_output
        else:
            encoder_embed_inputs = tf.stack(
                self.embedding_fn(sentence_complex_input_placeholder, emb_complex), axis=1)
            if self.hparams.pos == 'timing':
                encoder_embed_inputs = common_attention.add_timing_signal_1d(encoder_embed_inputs)
                print('Use positional encoding in encoder text.')

            if self.model_config.subword_vocab_size and self.model_config.seg_mode:
                encoder_embed_inputs = common_attention.add_positional_embedding(
                    encoder_embed_inputs, 100, 'seg_embedding',
                    positions=obj['line_comp_segids'])
                print('Add segment embedding.')

            with tf.variable_scope('transformer_encoder'):
                encoder_embed_inputs = tf.nn.dropout(encoder_embed_inputs,
                                                     1.0 - self.hparams.layer_prepostprocess_dropout)

                if self.model_config.architecture == 'ut2t':
                    encoder_outputs, encoder_extra_output = universal_transformer_util.universal_transformer_encoder(
                        encoder_embed_inputs, encoder_attn_bias, self.hparams)
                    enc_ponder_times, enc_remainders = encoder_extra_output
                    extra_encoder_loss = (
                            self.hparams.act_loss_weight *
                            tf.reduce_mean(enc_ponder_times + enc_remainders))
                    if self.is_train:
                        obj_tensors['extra_encoder_loss'] = extra_encoder_loss
                else:
                    encoder_outputs = transformer.transformer_encoder(
                        encoder_embed_inputs, encoder_attn_bias, self.hparams)

                # Update score based on multiplier
                score, pred_score_tuple = self.update_score(
                    score, encoder_outputs=encoder_outputs, encoder_mask=tf.to_float(
                        tf.not_equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                                     self.data.vocab_complex.encode(constant.SYMBOL_PAD))),
                    comp_features=comp_features)

                encoder_outputs = self.update_encoder_embedding(encoder_outputs, score)

        encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1)

        with tf.variable_scope('transformer_decoder', reuse=tf.AUTO_REUSE):
            if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode:
                go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0]
            else:
                go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)
            batch_go = tf.tile(
                tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0),
                [self.model_config.batch_size, 1])

            # For static_seq train_mode
            if self.model_config.npad_mode == 'static_seq':
                with tf.variable_scope('npad'):
                    npad_w = tf.get_variable(
                        'npad_w', shape=[1, self.model_config.dimension, self.model_config.dimension],
                        initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32)
                    obj_tensors['npad_w'] = npad_w

            if self.is_train and (train_mode == 'teacher' or
                                  train_mode == 'teachercritical'or train_mode ==  'teachercriticalv2'):
                # General train
                print('Use Generally Process.')
                decoder_embed_inputs_list = self.embedding_fn(
                    sentence_simple_input_placeholder[:-1], emb_simple)
                final_output, decoder_output, cur_context = self.decode_step(
                    decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias,
                    rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go,
                    obj_tensors)

                decoder_logit = (
                        tf.nn.conv1d(final_output, tf.expand_dims(tf.transpose(w), axis=0), 1, 'SAME') +
                        tf.expand_dims(tf.expand_dims(b, axis=0), axis=0))
                decoder_target_list = []
                decoder_logit_list = tf.unstack(decoder_logit, axis=1)
                for logit in decoder_logit_list:
                    decoder_target_list.append(tf.argmax(logit, output_type=tf.int32, axis=-1))

                decoder_output_list = [
                    tf.squeeze(d, 1)
                    for d in tf.split(decoder_output, self.model_config.max_simple_sentence, axis=1)]
                final_output_list = [
                    tf.squeeze(d, 1)
                    for d in tf.split(final_output, self.model_config.max_simple_sentence, axis=1)]

                if self.model_config.pointer_mode:
                    segment_mask = None
                    if 'line_comp_segids' in obj:
                        segment_mask = obj['line_comp_segids']
                    decoder_logit_list = word_distribution(
                        decoder_logit_list, decoder_output_list, encoder_outputs, encoder_embed_inputs,
                        sentence_complex_input_placeholder, obj_tensors, self.model_config, self.data, segment_mask)
            elif self.is_train and (train_mode == 'static_seq' or train_mode == 'static_self-critical'):
                decoder_target_list = []
                decoder_logit_list = []
                decoder_embed_inputs_list = []
                # Will Override for following 3 lists
                final_output_list = []
                decoder_output_list = []
                contexts = []
                sample_target_list = []
                sample_logit_list = []

                gpu_assign_interval = int(self.model_config.max_simple_sentence / 3)
                for step in range(self.model_config.max_simple_sentence):
                    gpu_id = int(step / gpu_assign_interval)
                    if gpu_id > 3:
                        gpu_id = 3
                    gpu_id += 1
                    with tf.device('/device:GPU:%s' % gpu_id):
                        print('Step%s with GPU%s' % (step, gpu_id))
                        final_outputs, _, cur_context = self.decode_step(
                            decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias,
                            rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                            score, batch_go, obj_tensors)

                        final_output_list = [
                            tf.squeeze(d, 1)
                            for d in tf.split(final_outputs, step+1, axis=1)]
                        final_output = final_output_list[-1]

                        # if self.model_config.npad_mode == 'static_seq':
                        #     final_output = tf.matmul(final_output, npad_w)

                        last_logit_list = self.output_to_logit(final_output, w, b)
                        last_target_list = tf.argmax(last_logit_list, output_type=tf.int32, axis=-1)
                        decoder_logit_list.append(last_logit_list)
                        decoder_target_list.append(last_target_list)
                        decoder_embed_inputs_list.append(
                            tf.stop_gradient(self.embedding_fn(last_target_list, emb_simple)))
                        if train_mode == 'static_self-critical':
                            last_sample_list = tf.multinomial(last_logit_list, 1)
                            sample_target_list.append(last_sample_list)
                            indices = tf.stack(
                                [tf.range(0, self.model_config.batch_size, dtype=tf.int64),
                                 tf.squeeze(last_sample_list)],
                                axis=-1)
                            sample_logit_list.append(tf.gather_nd(tf.nn.softmax(last_logit_list), indices))
            else:
                # Beam Search
                print('Use Beam Search with Beam Search Size %d.' % self.model_config.beam_search_size)
                return self.transformer_beam_search(encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list,
                                                    sentence_complex_input_placeholder, emb_simple, w, b,
                                                    rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                                                    score, obj, obj_tensors)

        gt_target_list = sentence_simple_input_placeholder
        output = ModelOutput(
            contexts=cur_context if 'rule' in self.model_config.memory else None,
            encoder_outputs=encoder_outputs,
            decoder_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None,
            final_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None,
            decoder_logit_list=decoder_logit_list if train_mode != 'dynamic_self-critical' else None,
            gt_target_list=gt_target_list,
            encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1),
            decoder_target_list=decoder_target_list,
            sample_logit_list=sampled_logit_list if train_mode == 'dynamic_self-critical' else None,
            sample_target_list=sampled_target_list if train_mode == 'dynamic_self-critical' else None,
            pred_score_tuple=pred_score_tuple if 'pred' in self.model_config.tune_mode else None,
            obj_tensors=obj_tensors,
        )
        return output
Example #8
0
    def create_model_cui(self):
        assert self.model_config.extra_loss
        self.global_step_cui = tf.get_variable('global_step_cui',
                                               initializer=tf.constant(
                                                   0, dtype=tf.int64),
                                               trainable=False)

        with tf.variable_scope('cui'):
            # Semantic type embedding
            if 'stype' in self.model_config.extra_loss:
                self.stype_embs = tf.get_variable(
                    'stype_embs',
                    [len(self.data.id2stype), self.model_config.dimension],
                    tf.float32,
                    initializer=tf.contrib.layers.xavier_initializer())

            abbr_inp = tf.zeros(self.model_config.batch_size,
                                tf.int32,
                                name='abbr_input')
            sense_inp = tf.zeros(self.model_config.batch_size,
                                 tf.int32,
                                 name='sense_input')

            inputs = []
            if 'def' in self.model_config.extra_loss:
                defs = []
                for _ in range(self.model_config.max_def_len):
                    defs.append(
                        tf.zeros(self.model_config.batch_size,
                                 tf.int32,
                                 name='def_input'))

                defs_stack = tf.stack(defs, axis=1)
                defs_embed = embedding_fn(defs_stack, self.embs)
                defs_bias = common_attention.attention_bias_ignore_padding(
                    tf.to_float(
                        tf.equal(defs_stack,
                                 self.data.voc.encode(constant.PAD))))
                defs_embed = tf.nn.dropout(
                    defs_embed,
                    1.0 - self.hparams.layer_prepostprocess_dropout)
                # defs_output = transformer.transformer_encoder(
                #     defs_embed, defs_bias, self.hparams)
                defs_output, def_extra_output = universal_transformer_util.universal_transformer_encoder(
                    defs_embed, defs_bias, self.hparams)
                def_enc_ponder_times, def_enc_remainders = def_extra_output
                extra_loss = (
                    self.hparams.act_loss_weight *
                    tf.reduce_mean(def_enc_ponder_times + def_enc_remainders))
                defs_output = tf.reduce_mean(defs_output, axis=1)
                inputs.append(defs_output)

            if 'stype' in self.model_config.extra_loss:
                stype_inp = tf.zeros(self.model_config.batch_size,
                                     tf.int32,
                                     name='stype_input')
                style_emb = embedding_fn(stype_inp, self.stype_embs)
                inputs.append(style_emb)

            if len(inputs) > 1:
                inputs = tf.concat(inputs, axis=1)
            elif len(inputs) == 1:
                inputs = inputs[0]
            aggregate_state = tf.contrib.layers.fully_connected(
                inputs, self.model_config.dimension, activation_fn=None)
            mask = tf.nn.embedding_lookup(self.mask_embs, abbr_inp)
            logits = self.get_logits(aggregate_state, mask)

            self.loss_cui = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=sense_inp) + extra_loss

            with tf.variable_scope('cui_optimization'):
                optim = get_optim(self.model_config)
                self.perplexity_cui = tf.exp(tf.reduce_mean(self.loss_cui))
                self.train_op_cui = optim.minimize(self.loss_cui)
                self.increment_global_step_cui = tf.assign_add(
                    self.global_step_cui, 1)

            self.obj_cui = {'abbr_inp': abbr_inp, 'sense_inp': sense_inp}

            if 'def' in self.model_config.extra_loss:
                self.obj_cui['def'] = defs

            if 'stype' in self.model_config.extra_loss:
                self.obj_cui['stype'] = stype_inp
Example #9
0
    def context_encoder(self, contexts_emb, contexts, abbr_inp_emb=None):
        """

        :param contexts_emb: a tensor of [batch_size, max_context_len, emb_dim]
        :param contexts: a list of [max_context_len, batch_size]
        :param abbr_inp_emb: a tensor of [batch_size, context_len, emb_dim], in transformer_abbr_encoder
        :return:
            encoder_output: [batch_size, context_len, channel_dim]
            weights: a list of multihead weights, num_layer elements,
                     each of which is [batch_size, num_head, context_len, context_len]
            extra_loss: None
        """
        weights = {}
        # Create an bias tensor as mask (big neg values for padded part), input=[batch_size, context_len], output=[batch_size, 1, 1, context_len]
        contexts_bias = common_attention.attention_bias_ignore_padding(
            tf.to_float(
                tf.equal(tf.stack(contexts, axis=1),
                         self.voc.encode(constant.PAD))))
        # add dropout to context input [batch_size, max_context_len, emb_dim]
        contexts_emb = tf.nn.dropout(
            contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)
        # get the output vector of transformer, [batch_size, context_len, channel_dim]
        # encoder_ouput = transformer.transformer_encoder_abbr(
        #     contexts_emb, contexts_bias, abbr_inp_emb,
        #     tf.zeros([self.model_config.batch_size,1,1,1]), self.hparams,
        #     save_weights_to=weights)
        if self.model_config.encoder_mode == 't2t':
            encoder_output = transformer.transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            extra_loss = None
        elif self.model_config.encoder_mode == 'ut2t':
            encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            enc_ponder_times, enc_remainders = extra_output
            extra_loss = (self.hparams.act_loss_weight *
                          tf.reduce_mean(enc_ponder_times + enc_remainders))
        elif self.model_config.encoder_mode == 'abbr_ut2t':
            encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            enc_ponder_times, enc_remainders = extra_output
            extra_loss = (self.hparams.act_loss_weight *
                          tf.reduce_mean(enc_ponder_times + enc_remainders))

            encoder_ouput2, extra_output2 = universal_transformer_util.universal_transformer_decoder(
                abbr_inp_emb, encoder_output,
                tf.zeros([self.model_config.batch_size, 1, 1, 1]),
                contexts_bias, self.hparams)
            enc_ponder_times2, enc_remainders2 = extra_output2
            extra_loss2 = (self.hparams.act_loss_weight *
                           tf.reduce_mean(enc_ponder_times2 + enc_remainders2))
            extra_loss += extra_loss2

        else:
            raise ValueError('Unknow encoder_mode.')

        return encoder_output, weights, extra_loss