Beispiel #1
0
def transformer_text_encoder(inputs, target_space, hparams, name=None):
    """Transformer text encoder over inputs with unmasked full attention.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size].
    target_space: int. Used for encoding inputs under a target space id.
    hparams: tf.contrib.training.HParams.
    name: string, variable scope.

  Returns:
    encoder_output: Tensor of shape [batch, length, hparams.hidden_size].
    ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias
      for any padded tokens.
  """
    with tf.variable_scope(name, default_name="transformer_text_encoder"):
        inputs = common_layers.flatten4d3d(inputs)
        [
            encoder_input,
            encoder_self_attention_bias,
            ed,
        ] = transformer.transformer_prepare_encoder(inputs,
                                                    target_space=target_space,
                                                    hparams=hparams)
        encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
        encoder_output = transformer.transformer_encoder(
            encoder_input, encoder_self_attention_bias, hparams)
        return encoder_output, ed
def transformer_encoder_ht(inputs,
                           target_space,
                           hparams,
                           features=None,
                           losses=None):
    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer.transformer_prepare_encoder(inputs,
                                                target_space,
                                                hparams,
                                                features=features))

    # encoder_input = tf.nn.dropout(encoder_input,
    #                               1.0 - hparams.layer_prepostprocess_dropout)

    encoder_output = transformer.transformer_encoder(
        encoder_input,
        self_attention_bias,
        hparams,
        # nonpadding=transformer.features_to_nonpadding(features, "inputs"),
        nonpadding=None,
        save_weights_to=None,
        losses=losses)

    # encoder_output = tf.expand_dims(encoder_output, 2)

    return encoder_output
Beispiel #3
0
def transformer_text_encoder(x,
                             space_id,
                             hparams,
                             name="transformer_text_encoder"):
  """Transformer text encoder over inputs with unmasked full attention.

  Args:
    x: Tensor of shape [batch, length, 1, hparams.hidden_size].
    space_id: int, id.
    hparams: tf.contrib.training.HParams.
    name: string, variable scope.

  Returns:
    encoder_output: Tensor of shape [batch, length, hparams.hidden_size].
    ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias
      for any padded tokens.
  """
  with tf.variable_scope(name):
    x = common_layers.flatten4d3d(x)
    (encoder_input, encoder_self_attention_bias,
     ed) = transformer.transformer_prepare_encoder(x, space_id, hparams)
    encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
    encoder_output = transformer.transformer_encoder(
        encoder_input, encoder_self_attention_bias, hparams)
    return encoder_output, ed
Beispiel #4
0
    def encode(self, encoder_input, target_space, hparams):
        dir_path = os.path.dirname(os.path.realpath(__file__))
        config_file = os.path.join(dir_path, "config.yml")
        config = yaml.load(open(config_file))
        enc_name = config["model_params"].split('_')[0][3:]

        if enc_name == "simple":
            encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias = transformer.transformer_prepare_encoder(
                encoder_input, target_space, hparams)
            encoder_input = tf.nn.dropout(
                encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
            encoder_output = transformer.transformer_encoder(
                encoder_input, encoder_self_attention_bias, hparams)
        else:
            encoder_input, encoder_self_attention_bias_slices, encoder_decoder_attention_bias_slices = parallel_transformer_prepare_encoder(
                encoder_input, target_space, hparams)
            encoder_input = tf.nn.dropout(
                encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
            encoder_output = getattr(encode_fn, enc_name)(
                encoder_input, encoder_self_attention_bias_slices, hparams,
                "encoder")
            encoder_decoder_attention_bias = tf.stack(
                encoder_decoder_attention_bias_slices)
            encoder_decoder_attention_bias = tf.reduce_mean(
                encoder_decoder_attention_bias, 0)
        return encoder_output, encoder_decoder_attention_bias
Beispiel #5
0
  def encode(self, stories, questions, target_space, hparams,
             unused_features=None):
    """Encode transformer inputs.

    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      unused_features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encodre-decoder attention. [batch_size, input_length]
    """

    inputs = tf.concat([stories, questions], axis=1)
    # inputs = common_layers.flatten4d3d(inputs)

    (encoder_input, encoder_self_attention_bias, _) = (
      transformer.transformer_prepare_encoder(inputs, target_space, hparams))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    encoder_output = transformer.transformer_encoder(encoder_input,
      encoder_self_attention_bias, hparams,
      # nonpadding=features_to_nonpadding(features, "inputs"),
      save_weights_to=self.attention_weights)

    return encoder_output
Beispiel #6
0
def transformer_text_encoder(x,
                             space_id,
                             hparams,
                             name="transformer_text_encoder"):
    """Transformer text encoder over inputs with unmasked full attention.

  Args:
    x: Tensor of shape [batch, length, 1, hparams.hidden_size].
    space_id: int, id.
    hparams: tf.contrib.training.HParams.
    name: string, variable scope.

  Returns:
    encoder_output: Tensor of shape [batch, length, hparams.hidden_size].
    ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias
      for any padded tokens.
  """
    with tf.variable_scope(name):
        x = common_layers.flatten4d3d(x)
        (encoder_input, encoder_self_attention_bias,
         ed) = transformer.transformer_prepare_encoder(x, space_id, hparams)
        encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
        return transformer.transformer_encoder(encoder_input,
                                               encoder_self_attention_bias,
                                               hparams), ed
def encode(x, x_space, hparams, name):
  """Transformer preparations and encoder."""
  with tf.variable_scope(name):
    (encoder_input, encoder_self_attention_bias,
     ed) = transformer.transformer_prepare_encoder(x, x_space, hparams)
    encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
    return transformer.transformer_encoder(
        encoder_input, encoder_self_attention_bias, hparams), ed
def encode(x, x_space, hparams, name):
  """Transformer preparations and encoder."""
  with tf.variable_scope(name):
    (encoder_input, encoder_self_attention_bias,
     ed) = transformer.transformer_prepare_encoder(x, x_space, hparams)
    encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
    return transformer.transformer_encoder(
        encoder_input, encoder_self_attention_bias, hparams), ed
 def encode_syntax_template(self, template_embs, template_bias):
     with tf.variable_scope('syntax_encoder', reuse=tf.AUTO_REUSE):
         # template_mask = tf.cast(
         #     tf.equal(template_ids[:, 0, :], self.data.vocab.pad_id), tf.float32)
         # template_bias = common_attention.attention_bias_ignore_padding(template_mask)
         # template_embs = self._embedding_fn(
         #     template_ids, self.shared_tensors['syntax_embedding_table'])
         template_outputs = transformer.transformer_encoder(
             template_embs, template_bias, self.hparams)
     return template_outputs, template_bias
Beispiel #10
0
def create_t2t_transformer_encoder(
    x_in: "tf.Tensor",
    mask: "tf.Tensor",
    attention_weights: Dict[Text, "tf.Tensor"],
    hparams: "HParams",
    C2: float,
    is_training: "tf.Tensor",
) -> "tf.Tensor":
    """Create t2t transformer encoder."""

    with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE):
        x = create_tf_fnn(
            x_in,
            [hparams.hidden_size],
            hparams.layer_prepostprocess_dropout,
            C2,
            is_training,
            layer_name_suffix="pre_embed",
            activation=None,
            use_bias=False,
            kernel_initializer=tf.random_normal_initializer(
                0.0, hparams.hidden_size**-0.5),
        )
        if hparams.multiply_embedding_mode == "sqrt_depth":
            x *= hparams.hidden_size**0.5

        x *= tf.expand_dims(mask, -1)
        (
            x,
            self_attention_bias,
            encoder_decoder_attention_bias,
        ) = transformer_prepare_encoder(x, None, hparams)

        x *= tf.expand_dims(mask, -1)

        x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout)

        attn_bias_for_padding = None
        # Otherwise the encoder will just use encoder_self_attention_bias.
        if hparams.unidirectional_encoder:
            attn_bias_for_padding = encoder_decoder_attention_bias

        x = transformer_encoder(
            x,
            self_attention_bias,
            hparams,
            nonpadding=mask,
            save_weights_to=attention_weights,
            attn_bias_for_padding=attn_bias_for_padding,
        )

        x *= tf.expand_dims(mask, -1)

        return tf.nn.dropout(tf.nn.relu(x),
                             1.0 - hparams.layer_prepostprocess_dropout)
Beispiel #11
0
def encoder(name, hparams, inputs, target_space):
    """Compute encoder outputs and attention bias."""
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = (transformer_prepare_encoder(
             inputs, target_space, hparams))
        encoder_input = tf.nn.dropout(
            encoder_input, rate=hparams.layer_prepostprocess_dropout)
        encoder_output = transformer_encoder(encoder_input,
                                             encoder_self_attention_bias,
                                             hparams)
        return encoder_output, encoder_decoder_attention_bias
Beispiel #12
0
def transformer_text_encoder(inputs,
                             space_id,
                             hparams,
                             name="transformer_text_enc"):
    """Transformer text encoder."""
    with tf.variable_scope(name):
        x = common_layers.flatten4d3d(inputs)
        (encoder_input, encoder_self_attention_bias,
         ed) = transformer.transformer_prepare_encoder(x, space_id, hparams)
        encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
        return transformer.transformer_encoder(encoder_input,
                                               encoder_self_attention_bias,
                                               hparams), ed
def transformer_encoder(features,
                        hparams,
                        embed_scope=None,
                        embed_token_fn=common_embed.embed_tokens,
                        attention_weights=None):
    """Encodes a screen using Transformer.

  Args:
    features: the feature dict.
    hparams: the hyperparameter.
    embed_scope: the scope for token embedding.
    embed_token_fn: the embed function.
    attention_weights: the attention_weights dict.
  Returns:
    encoder_outputs: a Tensor of shape
        [batch_size, num_steps, max_object_count, hidden_size]
    encoder_attn_bias: A tensor of shape
        [batch_size, num_steps, max_object_count]
  """
    tf.logging.info("Using Transformer screen encoder")
    # Remove the default positional encoding in Transformer
    object_embed, object_mask, encoder_attn_bias = prepare_encoder_input(
        features=features,
        hparams=hparams,
        embed_scope=embed_scope,
        embed_token_fn=embed_token_fn)
    with tf.variable_scope("encode_screen", reuse=tf.AUTO_REUSE):
        shape = tf.shape(object_embed)
        with tf.control_dependencies(
            [tf.assert_equal(shape[3], hparams.hidden_size)]):
            object_embed = tf.reshape(
                object_embed,
                [shape[0] * shape[1], shape[2], hparams.hidden_size])
        encoder_input = tf.nn.dropout(object_embed,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        self_attention_bias = tf.expand_dims(tf.expand_dims(tf.reshape(
            encoder_attn_bias, [shape[0] * shape[1], shape[2]]),
                                                            axis=1),
                                             axis=1)
        encoder_output = transformer.transformer_encoder(
            encoder_input=encoder_input,
            encoder_self_attention_bias=self_attention_bias,
            hparams=hparams,
            save_weights_to=attention_weights,
            make_image_summary=not common_layers.is_xla_compiled())
        encoder_output = tf.reshape(encoder_output,
                                    [shape[0], shape[1], shape[2], shape[3]])
        return encoder_output, object_mask, encoder_attn_bias
Beispiel #14
0
def te_encode(input_seq, hparams, target_space, features, name):
    input_seq = common_layers.flatten4d3d(input_seq)

    (encoder_input, encoder_self_attention_bias, _) = (
        transformer_prepare_encoder(input_seq, target_space, hparams))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer_encoder(
        encoder_input,
        encoder_self_attention_bias,
        hparams,
        nonpadding=features_to_nonpadding(features, "input_seq"))
    encoder_output = tf.expand_dims(encoder_output, 2)
    return encoder_output
  def encode(self, features, input_key):
    hparams = self._hparams
    inputs = common_layers.flatten4d3d(features[input_key])

    (encoder_input, encoder_self_attention_bias, _) = (
        transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
                                                hparams))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer.transformer_encoder(
        encoder_input,
        encoder_self_attention_bias,
        hparams,
        nonpadding=transformer.features_to_nonpadding(features, input_key))

    encoder_output = tf.reduce_mean(encoder_output, axis=1)

    return encoder_output
  def encode(self, features, input_key):
    hparams = self._hparams
    inputs = common_layers.flatten4d3d(features[input_key])

    (encoder_input, encoder_self_attention_bias, _) = (
        transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
                                                hparams))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer.transformer_encoder(
        encoder_input,
        encoder_self_attention_bias,
        hparams,
        nonpadding=transformer.features_to_nonpadding(features, input_key))

    encoder_output = tf.reduce_mean(encoder_output, axis=1)

    return encoder_output
Beispiel #17
0
def sim_encode(inputs, target_space, hparams, features):
    # inputs = tf.Print(inputs, [tf.shape(inputs)], "input", summarize=10)
    inputs = common_layers.flatten4d3d(inputs)

    (encoder_input, encoder_self_attention_bias,
     _) = (transformer.transformer_prepare_encoder(inputs, target_space,
                                                   hparams))
    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer.transformer_encoder(
        encoder_input,
        encoder_self_attention_bias,
        hparams,
        nonpadding=transformer.features_to_nonpadding(features, "inputs"))

    positional_mean = tf.nn.l2_normalize(tf.reduce_mean(encoder_output, 1), 1)
    # out_norm = tf.norm(positional_mean)
    # positional_mean = tf.Print(positional_mean , [out_norm], "enc_out: (should be b_size**0.5) ", summarize=10)
    # positional_mean = tf.Print(positional_mean , [tf.shape(positional_mean)], "enc_out: (should be (b_size, h_size)) ", summarize=10)
    return positional_mean
Beispiel #18
0
    def forward(self, contexts_emb, contexts, abbr_inp_emb, longform_emb=None):
        """
        :param contexts_emb: [batch_size, context_len, emb_dim]
        :param contexts: a list of tensors of words, [batch_size] * context_len
        :param abbr_inp_emb: [batch_size, 1, emb_dim]
        :param longform_emb: [batch_size, longform_len, emb_dim]
        :return:
               decoder_output: predicted abbr embedding, [batch_size, 1, emb_dim]
        """
        saved_weights = {}
        extra_loss = None

        contexts_bias = common_attention.attention_bias_ignore_padding(
            tf.to_float(
                tf.equal(tf.stack(contexts, axis=1),
                         self.voc.encode(constant.PAD))))

        contexts_emb = tf.nn.dropout(
            contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)
        abbr_inp_emb = tf.nn.dropout(
            abbr_inp_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)

        # [batch_size, context_len, emb_dim]
        encoder_output = transformer.transformer_encoder(
            contexts_emb,
            contexts_bias,
            hparams=self.hparams,
            save_weights_to=saved_weights)

        # [batch_size, 1, emb_dim]
        decoder_output = transformer.transformer_decoder(
            abbr_inp_emb,
            encoder_output,
            decoder_self_attention_bias=tf.zeros(
                [self.model_config.batch_size, 1, 1, 1]),
            encoder_decoder_attention_bias=contexts_bias,
            hparams=self.hparams,
            save_weights_to=saved_weights)

        return decoder_output, saved_weights, extra_loss
Beispiel #19
0
    def body(self, features):
        hparams = self._hparams
        inputs = features["inputs"]
        target_space = features["target_space_id"]

        inputs = common_layers.flatten4d3d(inputs)

        (encoder_input, encoder_self_attention_bias,
         _) = (transformer.transformer_prepare_encoder(inputs, target_space,
                                                       hparams))

        encoder_input = tf.nn.dropout(
            encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = transformer.transformer_encoder(
            encoder_input,
            encoder_self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "inputs"))

        encoder_output = encoder_output[:, :1, :]
        encoder_output = tf.expand_dims(encoder_output, 2)

        return encoder_output
Beispiel #20
0
def transformer_text_encoder(x,
                             space_id,
                             hparams,
                             name="transformer_text_encoder"):
    """Transformer text encoder over inputs with unmasked full attention.

  Args:
    x: Tensor of shape [batch, length, hidden_dim].
    space_id: int, id.
    hparams: Dict, hyperparameters.
    name: string, variable scope.

  Returns:
    x: Tensor of shape [batch, length, hidden_dim].
    ed: Tensor, bias for padded tokens in the input, shape [batch, length]
  """
    with tf.variable_scope(name):
        x = common_layers.flatten4d3d(x)
        (encoder_input, encoder_self_attention_bias,
         ed) = transformer.transformer_prepare_encoder(x, space_id, hparams)
        encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
        return transformer.transformer_encoder(encoder_input,
                                               encoder_self_attention_bias,
                                               hparams), ed
Beispiel #21
0
def vae_transformer_internal(inputs, targets, target_space, hparams):
    """VAE Transformer, main step used for training."""
    with tf.variable_scope("vae_transformer"):
        is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
        # Prepare inputs, targets, and k.
        inputs = common_layers.flatten4d3d(inputs)
        targets = common_layers.flatten4d3d(targets)
        k = 2**hparams.num_compress_steps
        _, targets = common_layers.pad_to_same_length(
            inputs, targets, final_length_divisible_by=k)

        # Transformer preparations and encoder.
        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias
         ) = transformer.transformer_prepare_encoder(inputs, target_space,
                                                     hparams)
        residual_fn = transformer.get_residual_fn(hparams)
        encoder_input = tf.nn.dropout(encoder_input,
                                      1.0 - hparams.residual_dropout)
        encoder_output = transformer.transformer_encoder(
            encoder_input, residual_fn, encoder_self_attention_bias, hparams)

        def get_decoder_autoregressive():
            """Decoder input for autoregressive computation."""
            (a, b) = transformer.transformer_prepare_decoder(targets, hparams)
            return (a, b, tf.constant(0.0))

        # 10% of the time we compress all-zeros, as will be at decoding start.
        prob_targets = 0.9 if is_training else 1.0
        to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
                              lambda: targets, lambda: tf.zeros_like(targets))
        z, kl_loss = compress_vae(to_compress, hparams, "vae")
        # Decompress.
        for i in xrange(hparams.num_compress_steps):
            j = hparams.num_hidden_layers - i - 1
            z = decompress(z, hparams, "decompress_%d" % j)

        def get_decoder_from_vae():
            """Decoder input computed by VAE."""
            # Return decoder stuff.
            (a, b) = transformer.transformer_prepare_decoder(
                tf.squeeze(z, axis=2), hparams)
            return (a, b, kl_loss)

        # Randomize decoder inputs..
        prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7
        step = tf.to_float(tf.contrib.framework.get_global_step())
        if not is_training:
            prob_do_vae = tf.cond(tf.less(step,
                                          40000.0), lambda: tf.constant(0.0),
                                  lambda: tf.constant(1.0))
        (decoder_input, decoder_self_attention_bias,
         kl_loss2) = tf.cond(tf.less(tf.random_uniform([]), prob_do_vae),
                             get_decoder_from_vae, get_decoder_autoregressive)

        # Transformer decoder.
        decoder_output = transformer.transformer_decoder(
            decoder_input, encoder_output, residual_fn,
            decoder_self_attention_bias, encoder_decoder_attention_bias,
            hparams)
        decoder_output = tf.expand_dims(decoder_output, 2)

        cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0),
                            lambda: tf.constant(0.0))
        prob_self = 0.4 if is_training else cond_self
        (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]),
                                         prob_self), lambda: (z, kl_loss),
                                 lambda: (decoder_output, kl_loss2))

        kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0
        return ret, kl_loss
    def build(self, features):
        src_ids = features['src_ids']
        trg_ids = None
        self.batch_size = tf.shape(src_ids)[0]
        if self.is_training:
            trg_ids = features['trg_ids']

        with tf.variable_scope('src_encoder'):
            self.shared_tensors['src_ids'] = src_ids
            src_mask = tf.cast(tf.equal(src_ids, self.data.vocab.pad_id),
                               tf.float32)
            src_bias = common_attention.attention_bias_ignore_padding(src_mask)
            self.shared_tensors['src_bias'] = src_bias
            self.shared_tensors['src_mask'] = src_mask

            src_embs = self._embedding_fn(src_ids)
            src_embs = common_attention.add_timing_signal_1d(src_embs)

            if 'syntax_gen' in self.flags.control_mode:
                template_comp_ids = features['template_comp_ids']

                # print_op = tf.print("template_comp_ids output:", template_comp_ids)
                # with tf.control_dependencies([print_op]):
                #     template_comp_ids = tf.identity(template_comp_ids)

                template_embs = self._embedding_fn(
                    template_comp_ids,
                    self.shared_tensors['syntax_embedding_table'])
                template_scale = tf.get_variable(
                    'template_scale',
                    shape=[1, self.flags.syntax_level, 1, 1],
                    trainable=True,
                    dtype=tf.float32)
                template_embs *= template_scale
                template_embs = tf.reduce_mean(template_embs, axis=1)
                src_embs += template_embs

            if 'gpt2' in self.flags.model_mode:
                src_outputs = model.gpt2_encoder(self.hparams,
                                                 src_embs,
                                                 encoder_bias=src_bias)
            elif 't2t' in self.flags.model_mode:
                src_outputs = transformer.transformer_encoder(
                    src_embs, src_bias, self.hparams)
            elif 'bert' in self.flags.model_mode:
                bert_model = BertModel(
                    config=BertConfig.from_json_file(
                        self.flags.bert_config_file),
                    is_training=self.is_training,
                    input_ids=src_ids,
                    input_mask=1.0 - src_mask,
                    embeddings=self.shared_tensors['word_embedding_table'])
                src_outputs = bert_model.get_sequence_output()
            else:
                raise ValueError('model_mode not known.')

            self.shared_tensors['src_outputs'] = src_outputs

            if self.flags.control_mode:
                control_ids = features['control_ids']
                control_mask = tf.cast(
                    tf.equal(control_ids, self.data.vocab.pad_id), tf.float32)
                control_bias = common_attention.attention_bias_ignore_padding(
                    control_mask)
                control_embs = self._embedding_fn(control_ids)

                if 'gpt2' in self.flags.model_mode:
                    control_outputs = model.gpt2_encoder(
                        self.hparams, control_embs, encoder_bias=control_bias)
                elif 't2t' in self.flags.model_mode or 'bert' in self.flags.model_mode:
                    control_outputs = transformer.transformer_encoder(
                        control_embs,
                        control_bias,
                        self.hparams,
                        name='control_encoder')
                else:
                    raise ValueError('model_mode not known.')
                self.shared_tensors['control_vec'] = features['control_vec']
                self.shared_tensors['control_outputs'] = control_outputs
                self.shared_tensors['control_bias'] = control_bias
                self.shared_tensors['extra_vec'] = features['extra_vec']

            # if 'syntax_gen' in self.flags.control_mode:
            #     template_comp_ids = features['template_comp_ids']
            #     template_comp_outputs, template_comp_bias = self.encode_syntax_template(template_comp_ids)
            #     self.shared_tensors['template_comp_outputs'] = template_comp_outputs
            #     self.shared_tensors['template_comp_bias'] = template_comp_bias

        batch_go = tf.tile(
            tf.expand_dims(self._embedding_fn(self.data.vocab.go_id), axis=0),
            [self.batch_size, 1])
        batch_go_id = tf.tile(
            tf.constant(self.data.vocab.go_id, tf.int32, shape=[
                1,
            ]), [self.batch_size])
        self.shared_tensors['batch_go'] = batch_go
        self.shared_tensors['batch_go_id'] = batch_go_id

        batch_syntax_go = tf.tile(
            tf.expand_dims(self._embedding_fn(self.data.syntax_vocab.go_id),
                           axis=0), [self.batch_size, 1])
        batch_syntax_go_id = tf.tile(
            tf.constant(self.data.syntax_vocab.go_id, tf.int32, shape=[
                1,
            ]), [self.batch_size])
        self.shared_tensors['batch_syntax_go'] = batch_syntax_go
        self.shared_tensors['batch_syntax_go_id'] = batch_syntax_go_id

        outputs = {}
        outputs['src_ids'] = src_ids

        if self.flags.control_mode:
            outputs["control_vec"] = self.shared_tensors['control_vec']
        # if 'predict' in self.flags.control_mode:
        #     control_vec, outputs = self.classify(
        #         outputs,
        #         self.shared_tensors['control_vec'],
        #         "fix_predict" in self.flags.control_mode)
        #     self.shared_tensors['control_vec'] = control_vec
        if self.flags.control_mode:
            if "flatten" not in self.flags.control_mode:
                # print_op = tf.print("Debug output:", self.shared_tensors['control_vec'])
                # with tf.control_dependencies([print_op]):
                #     self.shared_tensors['control_vec'] = tf.identity(self.shared_tensors['control_vec'])

                dupicate_copies = self.flags.dimension // self.data.control_vec_len
                batch_size = self.flags.train_batch_size if self.is_training else self.flags.eval_batch_size
                control_vec = tf.concat([
                    tf.reshape(
                        tf.transpose(
                            tf.tile(
                                tf.expand_dims(
                                    self.shared_tensors['control_vec'][o, :],
                                    axis=0), [dupicate_copies, 1])),
                        [1, self.flags.dimension]) for o in range(batch_size)
                ],
                                        axis=0)
                more_control_vec = tf.zeros([
                    batch_size,
                    self.flags.dimension % self.data.control_vec_len
                ])
                if not self.is_training and self.flags.beam_search_size > 1:
                    more_control_vec = tf.zeros([
                        batch_size * self.flags.beam_search_size,
                        self.flags.dimension % self.data.control_vec_len
                    ])
                self.shared_tensors['control_vec'] = tf.concat(
                    [control_vec, more_control_vec], axis=1)
            else:
                score = tf.expand_dims(self.shared_tensors['control_vec'],
                                       axis=-1)
                score = tf.tile(score, [1, 1, self.flags.dimension])
                self.shared_tensors['control_vec'] = score
        if "encoder" in self.flags.control_mode:
            src_outputs = self.update_embedding(src_outputs, False)
            self.shared_tensors['src_outputs'] = src_outputs

        with tf.variable_scope("trg_decoder"):
            if self.is_training:
                # Generate syntax
                if 'syntax_gen' in self.flags.control_mode:
                    syntax_losses = []
                    template_simp_ids = features['template_simp_ids']

                    # print_op = tf.print("template_simp_ids output:", template_simp_ids)
                    # with tf.control_dependencies([print_op]):
                    #     template_simp_ids = tf.identity(template_simp_ids)

                    template_simp_ids_layers = tf.unstack(template_simp_ids,
                                                          axis=1)
                    for l_id in range(self.flags.syntax_level):
                        template_simp_ids_layer = template_simp_ids_layers[
                            l_id]

                        # print_op = tf.print("template_simp_ids_layer %s output:" % l_id, template_simp_ids_layer)
                        # with tf.control_dependencies([print_op]):
                        #     template_simp_ids_layer = tf.identity(template_simp_ids_layer)

                        template_simp_ids_layer_list = tf.unstack(
                            template_simp_ids_layer, axis=1)
                        template_simp_ids_layer_inp_list = [
                            batch_syntax_go_id
                        ] + template_simp_ids_layer_list[:-1]
                        template_simp_emb_list = self._embedding_fn(
                            template_simp_ids_layer_inp_list,
                            self.shared_tensors['syntax_embedding_table'])
                        template_simp_emb = tf.stack(template_simp_emb_list,
                                                     axis=1)

                        template_mask = tf.cast(
                            tf.equal(template_simp_ids_layers[0],
                                     self.data.vocab.pad_id), tf.float32)
                        template_bias = common_attention.attention_bias_ignore_padding(
                            template_mask)

                        if l_id == 0:
                            self.shared_tensors[
                                'template_prev_simp_outputs'] = None
                            self.shared_tensors['template_simp_bias'] = None
                        else:
                            template_simp_prev_ids_layers = template_simp_ids_layers[:
                                                                                     l_id]
                            template_simp_prev_ids = tf.stack(
                                template_simp_prev_ids_layers, axis=1)
                            template_simp_prev_embs = self._embedding_fn(
                                template_simp_prev_ids,
                                self.shared_tensors['syntax_embedding_table'])
                            cur_template_scale = template_scale[:, :l_id, :, :]
                            template_simp_prev_embs *= cur_template_scale
                            template_simp_prev_embs = tf.reduce_mean(
                                template_simp_prev_embs, axis=1)
                            template_simp_outputs, template_simp_bias = self.encode_syntax_template(
                                template_simp_prev_embs, template_bias)
                            self.shared_tensors[
                                'template_prev_simp_outputs'] = template_simp_outputs
                            self.shared_tensors[
                                'template_simp_bias'] = template_simp_bias

                        syntax_outputs = self.decode_syntax_template(
                            template_simp_emb)

                        syntax_logits = tf.nn.conv1d(
                            syntax_outputs,
                            tf.expand_dims(
                                self.shared_tensors['proj_syntax_w'], axis=0),
                            1, 'SAME') + tf.expand_dims(tf.expand_dims(
                                self.shared_tensors['proj_syntax_b'], axis=0),
                                                        axis=0)
                        # syntax_gen = tf.argmax(syntax_logits, axis=-1)
                        syntax_weight = tf.cast(
                            tf.not_equal(template_simp_ids_layer,
                                         self.data.syntax_vocab.pad_id),
                            tf.float32)
                        syntax_loss = sequence_loss(
                            logits=syntax_logits,
                            targets=template_simp_ids_layer,
                            weights=syntax_weight)
                        syntax_losses.append(syntax_loss)

                    outputs['loss_syntax'] = tf.add_n(syntax_losses)
                    outputs['perplexity_syntax'] = tf.exp(
                        outputs['loss_syntax'])
                    tf.summary.scalar("loss_syntax", outputs['loss_syntax'])
                    tf.summary.scalar("perplexity_syntax",
                                      outputs['perplexity_syntax'])

                    template_simp_prev_ids_layers = template_simp_ids_layers
                    template_simp_prev_ids = tf.stack(
                        template_simp_prev_ids_layers, axis=1)
                    template_simp_prev_embs = self._embedding_fn(
                        template_simp_prev_ids,
                        self.shared_tensors['syntax_embedding_table'])
                    cur_template_scale = template_scale
                    template_simp_prev_embs *= cur_template_scale
                    template_simp_prev_embs = tf.reduce_mean(
                        template_simp_prev_embs, axis=1)
                    template_simp_outputs, template_simp_bias = self.encode_syntax_template(
                        template_simp_prev_embs, template_bias)
                    self.shared_tensors[
                        'template_simp_outputs'] = template_simp_outputs
                    self.shared_tensors[
                        'template_simp_bias'] = template_simp_bias

                # Generate sentence
                trg_ids_list = tf.unstack(trg_ids, axis=1)
                trg_input_ids_list = [batch_go_id] + trg_ids_list[:-1]
                trg_emb_list = self._embedding_fn(trg_input_ids_list)
                trg_input_ids = tf.stack(trg_input_ids_list, axis=1)
                trg_output_ids = tf.stack(trg_ids_list, axis=1)
                trg_emb = tf.stack(trg_emb_list, axis=1)

                decoder_outputs = self.decode_srcs_to_trgs(
                    trg_emb=trg_emb,
                    trg_input_ids=trg_input_ids,
                    outputs=outputs)
                word_logits = tf.nn.conv1d(
                    decoder_outputs,
                    tf.expand_dims(self.shared_tensors['proj_word_w'], axis=0),
                    1, 'SAME') + tf.expand_dims(tf.expand_dims(
                        self.shared_tensors['proj_word_b'], axis=0),
                                                axis=0)
                word_gen = tf.argmax(word_logits, axis=-1)
                outputs['gen'] = word_gen
                outputs['logits'] = word_logits

                weight = tf.cast(
                    tf.not_equal(trg_output_ids, self.data.vocab.pad_id),
                    tf.float32)
                loss = sequence_loss(logits=word_logits,
                                     targets=trg_output_ids,
                                     weights=weight)
                outputs['loss_decoder'] = loss
                outputs['perplexity_decoder'] = tf.exp(loss)
                tf.summary.scalar("loss_decoder", outputs['loss_decoder'])
                tf.summary.scalar("perplexity_decoder",
                                  outputs['perplexity_decoder'])
                # if 'predict' in self.flags.control_mode:
                #     # outputs['loss_length'] = outputs['loss_length']
                #     # outputs['loss_syntax'] = outputs['loss_syntax']
                #     # outputs['loss'] += outputs['loss_split']
                #     outputs["loss_pred"] = outputs['loss_length'] + outputs['loss_syntax'] + outputs['loss_split']
                #     tf.summary.scalar("loss_length", outputs['loss_length'])
                #     tf.summary.scalar("loss_syntax", outputs['loss_syntax'])
                #     tf.summary.scalar("loss_split", outputs['loss_split'])

            else:
                outputs['gen_src_syntax_ids'] = features['template_comp_ids']
                confident_scores = []
                self._tile_variables()

                if 'syntax_gen' in self.flags.control_mode:

                    def symbol_to_syntax_logits_fn(gen_ids):
                        cur_ids = tf.concat([
                            tf.expand_dims(batch_syntax_go_id, axis=-1),
                            gen_ids[:, 1:]
                        ],
                                            axis=1)
                        cur_embs = tf.nn.embedding_lookup(
                            self.shared_tensors['syntax_embedding_table'],
                            cur_ids)
                        cur_outputs = self.decode_syntax_template(cur_embs)
                        cur_logit = tf.matmul(
                            cur_outputs[:, -1, :],
                            self.shared_tensors['proj_syntax_w']
                        ) + self.shared_tensors['proj_syntax_b']
                        return cur_logit

                    template_simp_prev_ids_layers = []
                    for l_id in range(self.flags.syntax_level):
                        if l_id == 0:
                            self.shared_tensors[
                                'template_prev_simp_outputs'] = None
                            self.shared_tensors['template_simp_bias'] = None
                        else:
                            template_simp_prev_ids = tf.stack(
                                template_simp_prev_ids_layers, axis=1)
                            template_simp_prev_embs = self._embedding_fn(
                                template_simp_prev_ids,
                                self.shared_tensors['syntax_embedding_table'])
                            cur_template_scale = template_scale[:, :l_id, :, :]
                            template_simp_prev_embs *= cur_template_scale
                            template_simp_prev_embs = tf.reduce_mean(
                                template_simp_prev_embs, axis=1)

                            template_mask = tf.cast(
                                tf.equal(template_simp_prev_ids_layers[-1],
                                         self.data.vocab.pad_id), tf.float32)
                            template_bias = common_attention.attention_bias_ignore_padding(
                                template_mask)

                            template_simp_outputs, template_simp_bias = self.encode_syntax_template(
                                template_simp_prev_embs, template_bias)
                            self.shared_tensors[
                                'template_prev_simp_outputs'] = template_simp_outputs
                            self.shared_tensors[
                                'template_simp_bias'] = template_simp_bias

                        beam_ids, beam_score = beam_search.beam_search(
                            symbols_to_logits_fn=symbol_to_syntax_logits_fn,
                            initial_ids=tf.ones([self.flags.eval_batch_size],
                                                tf.int32) *
                            self.data.syntax_vocab.go_id,
                            beam_size=self.flags.beam_search_size,
                            decode_length=self.flags.max_syntax_trg_len,
                            vocab_size=self.data.syntax_vocab.size(),
                            alpha=0.6,
                            eos_id=self.data.syntax_vocab.eos_id)
                        top_beam_ids = beam_ids[:, 0, 1:]
                        top_beam_ids = tf.pad(
                            top_beam_ids,
                            [[0, 0],
                             [
                                 0, self.flags.max_syntax_trg_len -
                                 tf.shape(top_beam_ids)[1]
                             ]])
                        confident_score = -beam_score[:, 0] / tf.to_float(
                            tf.shape(top_beam_ids)[1])

                        confident_scores.append(confident_score)
                        # outputs['gen_src_syntax_ids'] = features['template_comp_ids']
                        # outputs['gen_trg_syntax_ids'] = top_beam_ids
                        # outputs['gen_trg_syntax_scores'] = confident_score
                        template_simp_prev_ids_layers.append(top_beam_ids)

                    template_simp_prev_ids = tf.stack(
                        template_simp_prev_ids_layers, axis=1)
                    outputs['gen_trg_syntax_ids'] = template_simp_prev_ids
                    outputs['gen_trg_syntax_scores'] = tf.add_n(
                        confident_scores)
                    template_simp_prev_embs = self._embedding_fn(
                        template_simp_prev_ids,
                        self.shared_tensors['syntax_embedding_table'])
                    template_simp_prev_embs *= template_scale
                    template_simp_prev_embs = tf.reduce_mean(
                        template_simp_prev_embs, axis=1)

                    template_mask = tf.cast(
                        tf.equal(template_simp_prev_ids_layers[-1],
                                 self.data.vocab.pad_id), tf.float32)
                    template_bias = common_attention.attention_bias_ignore_padding(
                        template_mask)
                    template_simp_outputs, template_simp_bias = self.encode_syntax_template(
                        template_simp_prev_embs, template_bias)
                    self.shared_tensors[
                        'template_simp_outputs'] = template_simp_outputs
                    self.shared_tensors[
                        'template_simp_bias'] = template_simp_bias

                def symbol_to_logits_fn(gen_ids):
                    cur_ids = tf.concat(
                        [tf.expand_dims(batch_go_id, axis=-1), gen_ids[:, 1:]],
                        axis=1)
                    cur_embs = tf.nn.embedding_lookup(
                        self.shared_tensors['word_embedding_table'], cur_ids)
                    cur_outputs = self.decode_srcs_to_trgs(
                        trg_emb=cur_embs, trg_input_ids=cur_ids)
                    cur_logit = tf.matmul(
                        cur_outputs[:,
                                    -1, :], self.shared_tensors['proj_word_w']
                    ) + self.shared_tensors['proj_word_b']
                    return cur_logit

                beam_ids, beam_score = beam_search.beam_search(
                    symbols_to_logits_fn=symbol_to_logits_fn,
                    initial_ids=tf.ones([self.flags.eval_batch_size],
                                        tf.int32) * self.data.vocab.go_id,
                    beam_size=self.flags.beam_search_size,
                    decode_length=self.flags.max_trg_len,
                    vocab_size=self.data.vocab.size() +
                    len(self.data.vocab.more_tokens),
                    alpha=0.6,
                    eos_id=self.data.vocab.eos_id)
                top_beam_ids = beam_ids[:, 0, 1:]
                top_beam_ids = tf.pad(
                    top_beam_ids,
                    [[0, 0],
                     [0, self.flags.max_trg_len - tf.shape(top_beam_ids)[1]]])
                confident_score = -beam_score[:, 0] / tf.to_float(
                    tf.shape(top_beam_ids)[1])
                outputs['gen_trg_ids'] = top_beam_ids
                outputs['gen_trg_scores'] = confident_score
                if self.flags.control_mode:
                    outputs['control_ids'] = features['control_ids']

        return outputs
Beispiel #23
0
def encode_decode_task(features, hparams, train, attention_weights=None):
    """Model core graph for the one-shot action.

  Args:
    features: a dictionary contains "inputs" that is a tensor in shape of
        [batch_size, num_tokens], "verb_id_seq" that is in shape of
        [batch_size, num_actions], "object_spans" and "param_span" tensor
        in shape of [batch_size, num_actions, 2]. 0 is used as padding or
        non-existent values.
    hparams: the general hyperparameters for the model.
    train: the train mode.
    attention_weights: the dict to keep attention weights for analysis.
  Returns:
    loss_dict: the losses for training.
    prediction_dict: the predictions for action tuples.
    areas: the area encodings of the task.
    scope: the embedding scope.
  """
    del train
    input_embeddings, scope = common_embed.embed_tokens(
        features["task"], hparams.task_vocab_size, hparams.hidden_size,
        hparams)
    with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE):
        encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0)
        input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2),
                                       input_embeddings)
        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer.transformer_prepare_encoder(input_embeddings,
                                                    None,
                                                    hparams,
                                                    features=None))
        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_encoder == "transformer":
            encoder_output = transformer.transformer_encoder(
                encoder_input,
                self_attention_bias,
                hparams,
                save_weights_to=attention_weights,
                make_image_summary=not common_layers.is_xla_compiled())
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        span_rep = hparams.get("span_rep", "area")
        area_encodings, area_starts, area_ends = area_utils.compute_sum_image(
            encoder_output, max_area_width=hparams.max_span)
        current_shape = tf.shape(area_encodings)
        if span_rep == "area":
            area_encodings, _, _ = area_utils.compute_sum_image(
                encoder_output, max_area_width=hparams.max_span)
        elif span_rep == "basic":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=False)
        elif span_rep == "coref":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=True)
        else:
            raise ValueError("xyz")
        areas = {}
        areas["encodings"] = area_encodings
        areas["starts"] = area_starts
        areas["ends"] = area_ends
        with tf.control_dependencies([
                tf.print("encoder_output", tf.shape(encoder_output)),
                tf.assert_equal(current_shape,
                                tf.shape(area_encodings),
                                summarize=100)
        ]):
            paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32)
        padding_sum, _, _ = area_utils.compute_sum_image(
            tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2),
            max_area_width=hparams.max_span)
        num_areas = common_layers.shape_list(area_encodings)[1]
        area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0),
                                   [-1, num_areas])
        areas["bias"] = area_paddings
        decoder_nonpadding = tf.to_float(
            tf.greater(features["verb_refs"][:, :, 1],
                       features["verb_refs"][:, :, 0]))
        if hparams.instruction_encoder == "lstm":
            hparams_decoder = copy.copy(hparams)
            hparams_decoder.set_hparam("pos", "none")
        else:
            hparams_decoder = hparams
        decoder_input, decoder_self_attention_bias = _prepare_decoder_input(
            area_encodings,
            decoder_nonpadding,
            features,
            hparams_decoder,
            embed_scope=scope)
        decoder_input = tf.nn.dropout(decoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_decoder == "transformer":
            decoder_output = transformer.transformer_decoder(
                decoder_input=decoder_input,
                encoder_output=encoder_output,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                hparams=hparams_decoder)
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        return decoder_output, decoder_nonpadding, areas, scope
Beispiel #24
0
    def create_model(self):
        with tf.variable_scope('variables'):
            contexts = []
            for _ in range(self.model_config.max_context_len):
                contexts.append(
                    tf.zeros(self.model_config.batch_size,
                             tf.int32,
                             name='context_input'))

            sense_inps, abbr_sinps, abbr_einps = [], [], []
            for _ in range(self.model_config.max_abbrs):
                sense_inps.append(
                    tf.zeros(self.model_config.batch_size,
                             tf.int32,
                             name='sense_input'))
                abbr_sinps.append(
                    tf.zeros([self.model_config.batch_size],
                             tf.int32,
                             name='sense__sinput'))
                abbr_einps.append(
                    tf.zeros([self.model_config.batch_size],
                             tf.int32,
                             name='sense_einput'))

            num_abbr = tf.zeros([self.model_config.batch_size],
                                tf.float32,
                                name='num_abbr')

        with tf.variable_scope('model'):
            contexts_emb = tf.stack(self.embedding_fn(contexts, self.embs),
                                    axis=1)
            contexts_emb_bias = common_attention.attention_bias_ignore_padding(
                tf.to_float(
                    tf.equal(tf.stack(contexts, axis=1),
                             self.data.voc.encode(constant.PAD))))
            contexts_emb = tf.nn.dropout(
                contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)
            encoder_outputs = transformer.transformer_encoder(
                contexts_emb, contexts_emb_bias, self.hparams)

            if self.model_config.aggregate_mode == 'selfattn':
                selfattn_w = tf.get_variable(
                    'selfattn_w', [1, self.model_config.dimension, 1],
                    tf.float32,
                    initializer=tf.contrib.layers.xavier_initializer())
                selfattn_b = tf.get_variable(
                    'selfattn_b', [1, 1, 1],
                    tf.float32,
                    initializer=tf.contrib.layers.xavier_initializer())
                weight = tf.nn.tanh(
                    tf.nn.conv1d(encoder_outputs, selfattn_w, 1, 'SAME') +
                    selfattn_b)
                encoder_outputs *= weight
                aggregate_state = tf.reduce_mean(encoder_outputs, axis=1)
            else:
                aggregate_state = tf.reduce_mean(encoder_outputs, axis=1)

        with tf.variable_scope('pred'):
            proj_w = tf.get_variable(
                'proj_w', [self.model_config.dimension, self.data.sen_cnt],
                tf.float32,
                initializer=tf.contrib.layers.xavier_initializer())
            proj_b = tf.get_variable(
                'proj_b', [self.data.sen_cnt],
                tf.float32,
                initializer=tf.contrib.layers.xavier_initializer())

            losses = []
            preds = []
            for abbr_id in range(self.model_config.max_abbrs):
                abbr_sinp = abbr_sinps[abbr_id]
                abbr_einp = abbr_einps[abbr_id]
                sense_inp = sense_inps[abbr_id]

                mask = tf.to_float(
                    tf.sequence_mask(
                        abbr_einp, self.data.sen_cnt)) - tf.to_float(
                            tf.sequence_mask(abbr_sinp, self.data.sen_cnt))
                logits = tf.matmul(aggregate_state, proj_w) + proj_b
                logits *= mask

                loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=logits, labels=sense_inp)
                loss_mask = tf.to_float(tf.not_equal(sense_inp, 0))
                loss *= loss_mask

                losses.append(loss)
                preds.append(tf.nn.top_k(logits, k=5, sorted=True)[1])

        preds = tf.stack(preds, axis=1)
        obj = {
            'contexts': contexts,
            'sense_inp': sense_inps,
            'abbr_sinp': abbr_sinps,
            'abbr_einp': abbr_einps,
            'num_abbr': num_abbr,
            'preds': preds,
        }
        return tf.add_n(losses) / num_abbr, obj
Beispiel #25
0
    def context_encoder(self, contexts_emb, contexts, abbr_inp_emb=None):
        """

        :param contexts_emb: a tensor of [batch_size, max_context_len, emb_dim]
        :param contexts: a list of [max_context_len, batch_size]
        :param abbr_inp_emb: a tensor of [batch_size, context_len, emb_dim], in transformer_abbr_encoder
        :return:
            encoder_output: [batch_size, context_len, channel_dim]
            weights: a list of multihead weights, num_layer elements,
                     each of which is [batch_size, num_head, context_len, context_len]
            extra_loss: None
        """
        weights = {}
        # Create an bias tensor as mask (big neg values for padded part), input=[batch_size, context_len], output=[batch_size, 1, 1, context_len]
        contexts_bias = common_attention.attention_bias_ignore_padding(
            tf.to_float(
                tf.equal(tf.stack(contexts, axis=1),
                         self.voc.encode(constant.PAD))))
        # add dropout to context input [batch_size, max_context_len, emb_dim]
        contexts_emb = tf.nn.dropout(
            contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)
        # get the output vector of transformer, [batch_size, context_len, channel_dim]
        # encoder_ouput = transformer.transformer_encoder_abbr(
        #     contexts_emb, contexts_bias, abbr_inp_emb,
        #     tf.zeros([self.model_config.batch_size,1,1,1]), self.hparams,
        #     save_weights_to=weights)
        if self.model_config.encoder_mode == 't2t':
            encoder_output = transformer.transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            extra_loss = None
        elif self.model_config.encoder_mode == 'ut2t':
            encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            enc_ponder_times, enc_remainders = extra_output
            extra_loss = (self.hparams.act_loss_weight *
                          tf.reduce_mean(enc_ponder_times + enc_remainders))
        elif self.model_config.encoder_mode == 'abbr_ut2t':
            encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            enc_ponder_times, enc_remainders = extra_output
            extra_loss = (self.hparams.act_loss_weight *
                          tf.reduce_mean(enc_ponder_times + enc_remainders))

            encoder_ouput2, extra_output2 = universal_transformer_util.universal_transformer_decoder(
                abbr_inp_emb, encoder_output,
                tf.zeros([self.model_config.batch_size, 1, 1, 1]),
                contexts_bias, self.hparams)
            enc_ponder_times2, enc_remainders2 = extra_output2
            extra_loss2 = (self.hparams.act_loss_weight *
                           tf.reduce_mean(enc_ponder_times2 + enc_remainders2))
            extra_loss += extra_loss2

        else:
            raise ValueError('Unknow encoder_mode.')

        return encoder_output, weights, extra_loss
def hierarchical_context_encoder(encoder_input,
                                 encoder_self_attention_bias,
                                 contexts,
                                 context_self_attention_biases,
                                 features,
                                 hparams,
                                 name="discourse_aware_encoder",
                                 save_weights_to=None,
                                 make_image_summary=True,
                                 losses=None):
    input_x = encoder_input
    context_xs = {}
    for context_name in contexts:
        context_xs[context_name] = contexts[context_name]
    context_paddings = {}
    context_nonpaddings = {}
    context_pad_removers = {}

    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        input_padding = common_attention.attention_bias_to_padding(
            encoder_self_attention_bias)
        input_nonpadding = 1.0 - input_padding
        for context_name in context_self_attention_biases:
            context_paddings[
                context_name] = common_attention.attention_bias_to_padding(
                    context_self_attention_biases[context_name])
            context_nonpaddings[
                context_name] = 1.0 - context_paddings[context_name]

        input_pad_remover = None
        for context_name in context_paddings:
            context_pad_removers[context_name] = None
        if hparams.use_pad_remover and not common_layers.is_xla_compiled():
            input_pad_remover = expert_utils.PadRemover(input_padding)
            for context_name in context_paddings:
                context_pad_removers[context_name] = expert_utils.PadRemover(
                    context_paddings[context_name])

        temp_hparam = tf.contrib.training.HParams(
        )  # copy hparams except num_hidden_layers -> num_hidden_layers - 1
        for key, val in hparams.values().items():
            temp_hparam.add_hparam(key, val)
        temp_hparam.set_hparam("num_hidden_layers",
                               hparams.num_hidden_layers - 1)
        encoder_output = transformer_with_contexts_layers.transformer_encoder(
            input_x,
            encoder_self_attention_bias,
            temp_hparam,
            nonpadding=features_to_nonpadding(features, "inputs"),
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary)

        context_encoded_outputs = {}
        for context_name in context_xs:
            context_encoded_outputs[
                context_name] = transformer_with_contexts_layers.transformer_encoder(
                    context_xs[context_name],
                    context_self_attention_biases[context_name],
                    temp_hparam,
                    nonpadding=features_to_nonpadding(features, context_name),
                    save_weights_to=save_weights_to,
                    make_image_summary=make_image_summary)

        with tf.variable_scope("hierarchical_context_encoder",
                               reuse=tf.AUTO_REUSE):
            for context_name in context_encoded_outputs:
                # self attention feed-forward
                _y = ffn_self_attention_layer(
                    context_encoded_outputs[context_name],
                    hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    save_weights_to=save_weights_to,
                    name="attentive_sum")
                # mean over sequence length
                context_encoded_outputs[context_name] = tf.reduce_mean(
                    _y, axis=1, keep_dims=True)

            encoded_contexts = [
                context_encoded_outputs[context_name]
                for context_name in context_encoded_outputs
            ]
            encoded_contexts = tf.concat(encoded_contexts, axis=1)

            temp_hparam = tf.contrib.training.HParams(
            )  # copy hparams except num_hidden_layers -> 1
            for key, val in hparams.values().items():
                temp_hparam.add_hparam(key, val)
            temp_hparam.set_hparam("num_hidden_layers", 1)
            context_padding = common_attention.embedding_to_padding(
                encoded_contexts)
            ignore_padding = common_attention.attention_bias_ignore_padding(
                context_padding)

            encoded_contexts = transformer_encoder(encoded_contexts,
                                                   ignore_padding, temp_hparam)

        with tf.variable_scope("encoder/layer_%d" % hparams.num_hidden_layers,
                               reuse=tf.AUTO_REUSE):
            with tf.variable_scope("context_input_attention"):
                context_padding = common_attention.embedding_to_padding(
                    encoded_contexts)
                ignore_padding = common_attention.attention_bias_ignore_padding(
                    context_padding)
                _y = common_attention.multihead_attention(
                    common_layers.layer_preprocess(encoder_output, hparams),
                    encoded_contexts,
                    ignore_padding,
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    attention_type=hparams.self_attention_type,
                    save_weights_to=save_weights_to,
                    make_image_summary=make_image_summary,
                    max_relative_position=hparams.max_relative_position,
                    dropout_broadcast_dims=attention_dropout_broadcast_dims,
                    max_length=hparams.get("max_length"),
                    vars_3d=hparams.get("attention_variables_3d"))
                encoded_contexts = common_layers.layer_postprocess(
                    encoder_output, _y, hparams)

            with tf.variable_scope("input_self_attention"):
                _y = common_attention.multihead_attention(
                    common_layers.layer_preprocess(encoder_output, hparams),
                    None,
                    encoder_self_attention_bias,
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    attention_type=hparams.self_attention_type,
                    save_weights_to=save_weights_to,
                    max_relative_position=hparams.max_relative_position,
                    make_image_summary=make_image_summary,
                    dropout_broadcast_dims=attention_dropout_broadcast_dims,
                    max_length=hparams.get("max_length"),
                    vars_3d=hparams.get("attention_variables_3d"))
                encoder_output = common_layers.layer_postprocess(
                    encoder_output, _y, hparams)

            with tf.variable_scope("gated_sum"):
                _depth = common_layers.shape_list(encoder_output)[-1]
                gate = tf.layers.dense(tf.concat(
                    [encoded_contexts, encoder_output], axis=-1),
                                       _depth,
                                       activation=tf.nn.sigmoid)
                if save_weights_to:
                    save_weights_to["gated_sum"] = gate
                encoder_output = gate * encoder_output + (
                    1. - gate) * encoded_contexts

            with tf.variable_scope("ffn"):
                _y = transformer_ffn_layer(common_layers.layer_preprocess(
                    encoder_output, hparams),
                                           hparams,
                                           input_pad_remover,
                                           conv_padding="SAME",
                                           nonpadding_mask=input_nonpadding,
                                           losses=losses)
                encoder_output = common_layers.layer_postprocess(
                    encoder_output, _y, hparams)

    return common_layers.layer_preprocess(encoder_output, hparams)
    def create_model(self):
        with tf.variable_scope('variables'):
            abstr_ph = []
            for _ in range(self.model_config.max_abstr_len):
                abstr_ph.append(tf.zeros(self.model_config.batch_size, tf.int32, name='abstract_input'))

            kwords_ph = []
            for _ in range(self.model_config.max_cnt_kword):
                kword = []
                for _ in range(self.model_config.max_kword_len):
                    kword.append(tf.zeros(self.model_config.batch_size, tf.int32, name='kword_input'))
                kwords_ph.append(kword)

            # Train for length control
            if self.is_train:
                kword_occupies_ph = []
                for _ in range(self.model_config.max_cnt_kword):
                    kword_occupies_ph.append(
                        tf.zeros(self.model_config.batch_size, tf.float32, name='kword_occupy_input'))

            emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding()
            abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1)
            kwords = []
            for kword_idx in range(self.model_config.max_cnt_kword):
                kwords.append(self.embedding_fn(kwords_ph[kword_idx], emb_kword))

        with tf.variable_scope('model_encoder'):
            if self.hparams.pos == 'timing':
                abstr = common_attention.add_timing_signal_1d(abstr)
            encoder_embed_inputs = tf.nn.dropout(abstr,
                                                 1.0 - self.hparams.layer_prepostprocess_dropout)
            abstr_bias = common_attention.attention_bias_ignore_padding(
                tf.to_float(tf.equal(tf.stack(abstr_ph, axis=1),
                                     self.voc_kword.encode(constant.SYMBOL_PAD))))
            abstr_outputs = transformer.transformer_encoder(
                encoder_embed_inputs, abstr_bias, self.hparams)

        losses = []
        targets = []
        pred_occupies = []
        obj = {}

        hist_vector = None
        if 'kp_attn' in self.model_config.cov_mode:
            hist_vector = tf.zeros(
                [self.model_config.batch_size, 1, self.model_config.dimension,])

        with tf.variable_scope('model_decoder'):
            if self.model_config.subword_vocab_size:
                go_id = self.voc_kword.encode(constant.SYMBOL_GO)[0]
            else:
                go_id = self.voc_kword.encode(constant.SYMBOL_GO)
            batch_go = tf.tile(
                tf.expand_dims(self.embedding_fn(go_id, emb_kword), axis=0),
                [self.model_config.batch_size, 1])

            for kword_idx in range(self.model_config.max_cnt_kword):
                if self.is_train:
                    kword = kwords[kword_idx][:-1]
                    kword_ph = kwords_ph[kword_idx]
                    kword_output, kword_output_list = self.decode_step(
                        kword, abstr_outputs, abstr_bias, batch_go, hist_vector=hist_vector)
                    kword_logit_list = [self.output_to_logit(o, proj_w, proj_b) for o in kword_output_list]
                    kword_target_list = [tf.argmax(o, output_type=tf.int32, axis=-1)
                                         for o in kword_logit_list]

                    kword_lossbias = [
                        tf.to_float(tf.not_equal(d, self.voc_kword.encode(constant.SYMBOL_PAD)))
                        for d in kword_ph]
                    kword_lossbias = tf.stack(kword_lossbias, axis=1)
                    if self.model_config.number_samples > 0:
                        loss_fn = tf.nn.sampled_softmax_loss
                    else:
                        loss_fn = None
                    loss = sequence_loss(logits=tf.stack(kword_logit_list, axis=1),
                                         targets=tf.stack(kword_ph, axis=1),
                                         weights=kword_lossbias,
                                         softmax_loss_function=loss_fn,
                                         w=proj_w,
                                         b=proj_b,
                                         decoder_outputs=tf.stack(kword_output_list, axis=1),
                                         number_samples=self.model_config.number_samples
                                         )
                    kword_target = tf.stack(kword_target_list, axis=1)
                    targets.append(kword_target)

                    if 'kp_attn' in self.model_config.cov_mode:
                        kword_embed = self.embedding_fn(kword_ph, emb_kword)
                        hist_vector += tf.expand_dims(tf.reduce_mean(
                            tf.stack(kword_embed, axis=1), axis=1), axis=1)

                    # Train for length control
                    pred_occupy = self.get_pred_occupy_logit(hist_vector, abstr_outputs)
                    occupy_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=pred_occupy, labels=kword_occupies_ph[kword_idx])
                    loss += tf.reduce_mean(occupy_loss)
                    pred_occupies.append(pred_occupy)

                    losses.append(loss)
                else:
                    loss, kword_target = self.transformer_beam_search(
                        abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, hist_vector=hist_vector)

                    targets.append(kword_target)
                    losses = loss

                    if 'kp_attn' in self.model_config.cov_mode:
                        kword_embed = self.embedding_fn(kword_target, emb_kword)
                        hist_vector += tf.expand_dims(tf.reduce_mean(kword_embed, axis=1), axis=1)

                    pred_occupy = tf.round(tf.sigmoid(self.get_pred_occupy_logit(hist_vector, abstr_outputs)))
                    pred_occupies.append(pred_occupy)

                tf.get_variable_scope().reuse_variables()
        if targets:
            obj['targets'] = tf.stack(targets, axis=1)
        obj['abstr_ph'] = abstr_ph
        obj['kwords_ph'] = kwords_ph
        if self.is_train:
            obj['kword_occupies_ph'] = kword_occupies_ph
        pred_occupies = tf.stack(pred_occupies, axis=1)
        obj['pred_occupies'] = pred_occupies

        if type(losses) is list:
            losses = tf.add_n(losses)
        return losses, obj
Beispiel #28
0
    def create_model(self):
        with tf.variable_scope('variables'):
            abstr_ph = []
            for _ in range(self.model_config.max_abstr_len):
                abstr_ph.append(
                    tf.zeros(self.model_config.batch_size,
                             tf.int32,
                             name='abstract_input'))

            kwords_ph = []
            for _ in range(self.model_config.max_cnt_kword):
                kword = []
                for _ in range(self.model_config.max_kword_len):
                    kword.append(
                        tf.zeros(self.model_config.batch_size,
                                 tf.int32,
                                 name='kword_input'))
                kwords_ph.append(kword)

            emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding()
            abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1)
            kwords = []
            for kword_idx in range(self.model_config.max_cnt_kword):
                kwords.append(
                    self.embedding_fn(kwords_ph[kword_idx], emb_kword))

        with tf.variable_scope('model_encoder'):
            if self.hparams.pos == 'timing':
                abstr = common_attention.add_timing_signal_1d(abstr)
            encoder_embed_inputs = tf.nn.dropout(
                abstr, 1.0 - self.hparams.layer_prepostprocess_dropout)
            abstr_bias = common_attention.attention_bias_ignore_padding(
                tf.to_float(
                    tf.equal(tf.stack(abstr_ph, axis=1),
                             self.voc_kword.encode(constant.SYMBOL_PAD))))
            abstr_outputs = transformer.transformer_encoder(
                encoder_embed_inputs, abstr_bias, self.hparams)

            if 'tuzhaopeng' in self.model_config.cov_mode:
                attn_stick = tf.ones([
                    self.model_config.batch_size, self.model_config.num_heads,
                    1,
                    self.model_config.dimension / self.model_config.num_heads
                ], tf.float32, 'attn_memory')

        losses = []
        targets = []
        obj = {}
        with tf.variable_scope('model_decoder'):
            for kword_idx in range(self.model_config.max_cnt_kword):
                if self.is_train:
                    kword = kwords[kword_idx][:-1]
                    kword_ph = kwords_ph[kword_idx]
                    kword_output_list, new_attn_stick = self.decode_step(
                        kword, abstr_outputs, abstr_bias, attn_stick)
                    kword_logit_list = [
                        self.output_to_logit(o, proj_w, proj_b)
                        for o in kword_output_list
                    ]
                    kword_target_list = [
                        tf.argmax(o, output_type=tf.int32, axis=-1)
                        for o in kword_logit_list
                    ]
                    attn_stick = new_attn_stick

                    if self.model_config.number_samples > 0:
                        loss_fn = tf.nn.sampled_softmax_loss
                    else:
                        loss_fn = None
                    kword_lossbias = [
                        tf.to_float(
                            tf.not_equal(
                                d, self.voc_kword.encode(constant.SYMBOL_PAD)))
                        for d in kword_ph
                    ]
                    kword_lossbias = tf.stack(kword_lossbias, axis=1)
                    loss = sequence_loss(
                        logits=tf.stack(kword_logit_list, axis=1),
                        targets=tf.stack(kword_ph, axis=1),
                        weights=kword_lossbias,
                        softmax_loss_function=loss_fn,
                        w=proj_w,
                        b=proj_b,
                        decoder_outputs=tf.stack(kword_output_list, axis=1),
                        number_samples=self.model_config.number_samples)
                    targets.append(tf.stack(kword_target_list, axis=1))

                    if 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode:
                        target_emb = tf.stack(self.embedding_fn(
                            kword_target_list, emb_kword),
                                              axis=1)
                        target_emb = common_attention.split_heads(
                            target_emb, self.model_config.num_heads)
                        target_emb = tf.reduce_mean(target_emb, axis=2)
                        target_emb_trans = tf.get_variable(
                            'dim_weight_trans',
                            shape=[
                                1,
                                target_emb.get_shape()[-1].value,
                                target_emb.get_shape()[-1].value
                            ],
                            dtype=tf.float32,
                            initializer=tf.contrib.layers.xavier_initializer())
                        target_emb = tf.nn.conv1d(target_emb, target_emb_trans,
                                                  1, 'SAME')
                        target_emb = tf.expand_dims(target_emb, axis=2)
                        attn_stick += target_emb
                    losses.append(loss)
                else:
                    if self.model_config.beam_search_size > 0:
                        loss, target, new_attn_stick = self.transformer_beam_search(
                            abstr_outputs,
                            abstr_bias,
                            emb_kword,
                            proj_w,
                            proj_b,
                            attn_stick=attn_stick)
                    else:
                        loss, target, new_attn_stick = self.greed_search(
                            kword_idx,
                            abstr_outputs,
                            abstr_bias,
                            emb_kword,
                            proj_w,
                            proj_b,
                            attn_stick=attn_stick)
                    targets.append(target)
                    losses = loss
                    attn_stick = new_attn_stick
                    if 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode:
                        target.set_shape([
                            self.model_config.batch_size,
                            self.model_config.max_kword_len
                        ])
                        target_list = tf.unstack(target, axis=1)
                        target_emb = tf.stack(self.embedding_fn(
                            target_list, emb_kword),
                                              axis=1)
                        target_emb = common_attention.split_heads(
                            target_emb, self.model_config.num_heads)
                        target_emb = tf.reduce_mean(target_emb, axis=2)
                        target_emb_trans = tf.get_variable(
                            'dim_weight_trans',
                            shape=[
                                1,
                                target_emb.get_shape()[-1].value,
                                target_emb.get_shape()[-1].value
                            ],
                            dtype=tf.float32,
                            initializer=tf.contrib.layers.xavier_initializer())
                        target_emb = tf.nn.conv1d(target_emb, target_emb_trans,
                                                  1, 'SAME')
                        target_emb = tf.expand_dims(target_emb, axis=2)
                        attn_stick += target_emb
                tf.get_variable_scope().reuse_variables()
        if targets:
            obj['targets'] = tf.stack(targets, axis=1)
        obj['abstr_ph'] = abstr_ph
        obj['kwords_ph'] = kwords_ph
        obj['attn_stick'] = attn_stick
        if type(losses) is list:
            losses = tf.add_n(losses)
        return losses, obj
    def transformer_fn(self,
                       sentence_complex_input_placeholder, emb_complex,
                       sentence_simple_input_placeholder, emb_simple,
                       w, b,
                       rule_id_input_placeholder, rule_target_input_placeholder,
                       mem_contexts, mem_outputs,
                       global_step, score, comp_features, obj):
        encoder_mask = tf.to_float(
            tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                     self.data.vocab_complex.encode(constant.SYMBOL_PAD)))
        encoder_attn_bias = common_attention.attention_bias_ignore_padding(encoder_mask)

        obj_tensors = {}

        train_mode = self.model_config.train_mode
        if self.model_config.bert_mode:
            # Leave space for decoder when static seq
            gpu_id = 0 if train_mode == 'static_seq' or train_mode == 'static_self-critical' or 'direct' in self.model_config.memory else 1
            with tf.device('/device:GPU:%s' % gpu_id):
                sentence_complex_input = tf.stack(sentence_complex_input_placeholder, axis=1)
                bert_model = BertModel(
                    BertConfig.from_json_file(self.model_config.bert_config),
                    self.is_train, sentence_complex_input,
                    input_mask=1.0-encoder_mask, token_type_ids=None, use_one_hot_embeddings=False)
                encoder_embed_inputs = bert_model.embedding_output
                encoder_outputs = bert_model.sequence_output
                emb_complex = bert_model.embedding_table # update emb complex
                if (self.model_config.tie_embedding == 'all' or
                        self.model_config.tie_embedding == 'enc_dec'):
                    emb_simple = bert_model.embedding_table
                if (self.model_config.tie_embedding == 'all' or
                        self.model_config.tie_embedding == 'dec_out'):
                    emb_w_proj = tf.get_variable(
                        'emb_w_proj', shape=[self.model_config.dimension, self.model_config.dimension],
                        initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32)
                    w = tf.matmul(bert_model.embedding_table, emb_w_proj)

                if 'direct' in self.model_config.memory:
                    with tf.device('/device:GPU:1'):
                        direct_mask = tf.to_float(
                            tf.equal(tf.stack(rule_target_input_placeholder, axis=1),
                                     self.data.vocab_complex.encode(constant.SYMBOL_PAD)))
                        direct_bert_model = BertModel(
                            BertConfig.from_json_file(self.model_config.bert_config),
                            self.is_train, tf.stack(rule_target_input_placeholder, axis=1),
                            input_mask=1.0 - direct_mask, token_type_ids=None, use_one_hot_embeddings=False,
                            embedding_table=emb_simple,
                            scope='direct')
                        direct_bert_output = direct_bert_model.sequence_output
                        obj_tensors['direct_bert_bias'] = common_attention.attention_bias_ignore_padding(direct_mask)
                        obj_tensors['direct_bert_output'] = direct_bert_output
        else:
            encoder_embed_inputs = tf.stack(
                self.embedding_fn(sentence_complex_input_placeholder, emb_complex), axis=1)
            if self.hparams.pos == 'timing':
                encoder_embed_inputs = common_attention.add_timing_signal_1d(encoder_embed_inputs)
                print('Use positional encoding in encoder text.')

            if self.model_config.subword_vocab_size and self.model_config.seg_mode:
                encoder_embed_inputs = common_attention.add_positional_embedding(
                    encoder_embed_inputs, 100, 'seg_embedding',
                    positions=obj['line_comp_segids'])
                print('Add segment embedding.')

            with tf.variable_scope('transformer_encoder'):
                encoder_embed_inputs = tf.nn.dropout(encoder_embed_inputs,
                                                     1.0 - self.hparams.layer_prepostprocess_dropout)

                if self.model_config.architecture == 'ut2t':
                    encoder_outputs, encoder_extra_output = universal_transformer_util.universal_transformer_encoder(
                        encoder_embed_inputs, encoder_attn_bias, self.hparams)
                    enc_ponder_times, enc_remainders = encoder_extra_output
                    extra_encoder_loss = (
                            self.hparams.act_loss_weight *
                            tf.reduce_mean(enc_ponder_times + enc_remainders))
                    if self.is_train:
                        obj_tensors['extra_encoder_loss'] = extra_encoder_loss
                else:
                    encoder_outputs = transformer.transformer_encoder(
                        encoder_embed_inputs, encoder_attn_bias, self.hparams)

                # Update score based on multiplier
                score, pred_score_tuple = self.update_score(
                    score, encoder_outputs=encoder_outputs, encoder_mask=tf.to_float(
                        tf.not_equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                                     self.data.vocab_complex.encode(constant.SYMBOL_PAD))),
                    comp_features=comp_features)

                encoder_outputs = self.update_encoder_embedding(encoder_outputs, score)

        encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1)

        with tf.variable_scope('transformer_decoder', reuse=tf.AUTO_REUSE):
            if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode:
                go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0]
            else:
                go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)
            batch_go = tf.tile(
                tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0),
                [self.model_config.batch_size, 1])

            # For static_seq train_mode
            if self.model_config.npad_mode == 'static_seq':
                with tf.variable_scope('npad'):
                    npad_w = tf.get_variable(
                        'npad_w', shape=[1, self.model_config.dimension, self.model_config.dimension],
                        initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32)
                    obj_tensors['npad_w'] = npad_w

            if self.is_train and (train_mode == 'teacher' or
                                  train_mode == 'teachercritical'or train_mode ==  'teachercriticalv2'):
                # General train
                print('Use Generally Process.')
                decoder_embed_inputs_list = self.embedding_fn(
                    sentence_simple_input_placeholder[:-1], emb_simple)
                final_output, decoder_output, cur_context = self.decode_step(
                    decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias,
                    rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go,
                    obj_tensors)

                decoder_logit = (
                        tf.nn.conv1d(final_output, tf.expand_dims(tf.transpose(w), axis=0), 1, 'SAME') +
                        tf.expand_dims(tf.expand_dims(b, axis=0), axis=0))
                decoder_target_list = []
                decoder_logit_list = tf.unstack(decoder_logit, axis=1)
                for logit in decoder_logit_list:
                    decoder_target_list.append(tf.argmax(logit, output_type=tf.int32, axis=-1))

                decoder_output_list = [
                    tf.squeeze(d, 1)
                    for d in tf.split(decoder_output, self.model_config.max_simple_sentence, axis=1)]
                final_output_list = [
                    tf.squeeze(d, 1)
                    for d in tf.split(final_output, self.model_config.max_simple_sentence, axis=1)]

                if self.model_config.pointer_mode:
                    segment_mask = None
                    if 'line_comp_segids' in obj:
                        segment_mask = obj['line_comp_segids']
                    decoder_logit_list = word_distribution(
                        decoder_logit_list, decoder_output_list, encoder_outputs, encoder_embed_inputs,
                        sentence_complex_input_placeholder, obj_tensors, self.model_config, self.data, segment_mask)
            elif self.is_train and (train_mode == 'static_seq' or train_mode == 'static_self-critical'):
                decoder_target_list = []
                decoder_logit_list = []
                decoder_embed_inputs_list = []
                # Will Override for following 3 lists
                final_output_list = []
                decoder_output_list = []
                contexts = []
                sample_target_list = []
                sample_logit_list = []

                gpu_assign_interval = int(self.model_config.max_simple_sentence / 3)
                for step in range(self.model_config.max_simple_sentence):
                    gpu_id = int(step / gpu_assign_interval)
                    if gpu_id > 3:
                        gpu_id = 3
                    gpu_id += 1
                    with tf.device('/device:GPU:%s' % gpu_id):
                        print('Step%s with GPU%s' % (step, gpu_id))
                        final_outputs, _, cur_context = self.decode_step(
                            decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias,
                            rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                            score, batch_go, obj_tensors)

                        final_output_list = [
                            tf.squeeze(d, 1)
                            for d in tf.split(final_outputs, step+1, axis=1)]
                        final_output = final_output_list[-1]

                        # if self.model_config.npad_mode == 'static_seq':
                        #     final_output = tf.matmul(final_output, npad_w)

                        last_logit_list = self.output_to_logit(final_output, w, b)
                        last_target_list = tf.argmax(last_logit_list, output_type=tf.int32, axis=-1)
                        decoder_logit_list.append(last_logit_list)
                        decoder_target_list.append(last_target_list)
                        decoder_embed_inputs_list.append(
                            tf.stop_gradient(self.embedding_fn(last_target_list, emb_simple)))
                        if train_mode == 'static_self-critical':
                            last_sample_list = tf.multinomial(last_logit_list, 1)
                            sample_target_list.append(last_sample_list)
                            indices = tf.stack(
                                [tf.range(0, self.model_config.batch_size, dtype=tf.int64),
                                 tf.squeeze(last_sample_list)],
                                axis=-1)
                            sample_logit_list.append(tf.gather_nd(tf.nn.softmax(last_logit_list), indices))
            else:
                # Beam Search
                print('Use Beam Search with Beam Search Size %d.' % self.model_config.beam_search_size)
                return self.transformer_beam_search(encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list,
                                                    sentence_complex_input_placeholder, emb_simple, w, b,
                                                    rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                                                    score, obj, obj_tensors)

        gt_target_list = sentence_simple_input_placeholder
        output = ModelOutput(
            contexts=cur_context if 'rule' in self.model_config.memory else None,
            encoder_outputs=encoder_outputs,
            decoder_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None,
            final_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None,
            decoder_logit_list=decoder_logit_list if train_mode != 'dynamic_self-critical' else None,
            gt_target_list=gt_target_list,
            encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1),
            decoder_target_list=decoder_target_list,
            sample_logit_list=sampled_logit_list if train_mode == 'dynamic_self-critical' else None,
            sample_target_list=sampled_target_list if train_mode == 'dynamic_self-critical' else None,
            pred_score_tuple=pred_score_tuple if 'pred' in self.model_config.tune_mode else None,
            obj_tensors=obj_tensors,
        )
        return output
Beispiel #30
0
    def transformer_fn(self, sentence_complex_input_placeholder, emb_complex,
                       sentence_simple_input_placeholder, emb_simple, w, b,
                       rule_id_input_placeholder, mem_contexts, mem_outputs,
                       global_step):
        encoder_embed_inputs = tf.stack(self.embedding_fn(
            sentence_complex_input_placeholder, emb_complex),
                                        axis=1)
        encoder_attn_bias = common_attention.attention_bias_ignore_padding(
            tf.to_float(
                tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                         self.data.vocab_complex.encode(constant.SYMBOL_PAD))))
        if self.hparams.pos == 'timing':
            encoder_embed_inputs = common_attention.add_timing_signal_1d(
                encoder_embed_inputs)
            print('Use positional encoding in encoder text.')

        with tf.variable_scope('transformer_encoder'):
            encoder_embed_inputs = tf.nn.dropout(
                encoder_embed_inputs,
                1.0 - self.hparams.layer_prepostprocess_dropout)
            encoder_outputs = transformer.transformer_encoder(
                encoder_embed_inputs, encoder_attn_bias, self.hparams)

        encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1)
        with tf.variable_scope('transformer_decoder'):
            train_mode = self.model_config.train_mode
            if self.is_train and (train_mode == 'teacher'
                                  or train_mode == 'teachercritical'
                                  or train_mode == 'teachercriticalv2'):
                # General train
                print('Use Generally Process.')
                decoder_embed_inputs_list = self.embedding_fn(
                    sentence_simple_input_placeholder[:-1], emb_simple)
                final_output_list, decoder_output_list, cur_context = self.decode_step(
                    decoder_embed_inputs_list, encoder_outputs,
                    encoder_attn_bias, rule_id_input_placeholder, mem_contexts,
                    mem_outputs, global_step)
                decoder_logit_list = [
                    self.output_to_logit(o, w, b) for o in final_output_list
                ]
                decoder_target_list = [
                    tf.argmax(o, output_type=tf.int32, axis=-1)
                    for o in decoder_logit_list
                ]
            elif self.is_train and train_mode == 'dynamic_self-critical':
                decoder_target_tensor = tf.TensorArray(
                    tf.int32,
                    size=0,
                    dynamic_size=True,
                    clear_after_read=False,
                    element_shape=[
                        self.model_config.batch_size,
                    ])
                sampled_target_tensor = tf.TensorArray(
                    tf.int32,
                    size=0,
                    dynamic_size=True,
                    clear_after_read=False,
                    element_shape=[
                        self.model_config.batch_size,
                    ])
                sampled_logit_tensor = tf.TensorArray(
                    tf.float32,
                    size=0,
                    dynamic_size=True,
                    clear_after_read=False,
                    element_shape=[
                        self.model_config.batch_size,
                    ])

                def _is_finished(step, decoder_target_tensor,
                                 sampled_target_tensor, sampled_logit_tensor):
                    return tf.less(step, self.model_config.max_simple_sentence)

                def _recursive(step, decoder_target_tensor,
                               sampled_target_tensor, sampled_logit_tensor):
                    decoder_target_stack = tf.transpose(
                        decoder_target_tensor.stack(), perm=[1, 0])

                    def get_empty_emb():
                        decoder_emb_inputs = tf.zeros([
                            self.model_config.batch_size, 1,
                            self.model_config.dimension
                        ])
                        return decoder_emb_inputs

                    def get_emb():
                        batch_go = tf.zeros([
                            self.model_config.batch_size, 1,
                            self.model_config.dimension
                        ])
                        decoder_emb_inputs = tf.concat([
                            batch_go,
                            tf.gather(emb_simple, decoder_target_stack)
                        ],
                                                       axis=1)
                        return decoder_emb_inputs

                    decoder_emb_inputs = tf.cond(tf.equal(step, 0),
                                                 lambda: get_empty_emb(),
                                                 lambda: get_emb())

                    final_outputs, _, _ = self.decode_inputs_to_outputs(
                        decoder_emb_inputs, encoder_outputs, encoder_attn_bias,
                        rule_id_input_placeholder, mem_contexts, mem_outputs,
                        global_step)
                    final_output = final_outputs[:, -1, :]
                    decoder_logit = tf.add(
                        tf.matmul(final_output, tf.transpose(w)), b)
                    decoder_target = tf.stop_gradient(
                        tf.argmax(decoder_logit, output_type=tf.int32,
                                  axis=-1))
                    sampled_target = tf.cast(
                        tf.squeeze(tf.multinomial(decoder_logit, 1), axis=1),
                        tf.int32)

                    indices = tf.stack([
                        tf.range(
                            0, self.model_config.batch_size, dtype=tf.int32),
                        tf.squeeze(sampled_target)
                    ],
                                       axis=-1)
                    logit_unit = tf.gather_nd(
                        tf.nn.softmax(decoder_logit, axis=1), indices)

                    decoder_target_tensor = decoder_target_tensor.write(
                        step, decoder_target)
                    sampled_target_tensor = sampled_target_tensor.write(
                        step, sampled_target)
                    sampled_logit_tensor = sampled_logit_tensor.write(
                        step, logit_unit)

                    return step + 1, decoder_target_tensor, sampled_target_tensor, sampled_logit_tensor

                step = tf.constant(0)
                (_, decoder_target_tensor, sampled_target_tensor,
                 sampled_logit_tensor) = tf.while_loop(
                     _is_finished,
                     _recursive, [
                         step, decoder_target_tensor, sampled_target_tensor,
                         sampled_logit_tensor
                     ],
                     back_prop=True,
                     parallel_iterations=1,
                     swap_memory=False)

                decoder_target_tensor = decoder_target_tensor.stack()
                decoder_target_tensor.set_shape([
                    self.model_config.max_simple_sentence,
                    self.model_config.batch_size
                ])
                decoder_target_tensor = tf.transpose(decoder_target_tensor,
                                                     perm=[1, 0])
                decoder_target_list = tf.unstack(decoder_target_tensor, axis=1)

                sampled_target_tensor = sampled_target_tensor.stack()
                sampled_target_tensor.set_shape([
                    self.model_config.max_simple_sentence,
                    self.model_config.batch_size
                ])
                sampled_target_tensor = tf.transpose(sampled_target_tensor,
                                                     perm=[1, 0])
                sampled_target_list = tf.unstack(sampled_target_tensor, axis=1)

                sampled_logit_tensor = sampled_logit_tensor.stack()
                sampled_logit_tensor.set_shape([
                    self.model_config.max_simple_sentence,
                    self.model_config.batch_size
                ])
                sampled_logit_tensor = tf.transpose(sampled_logit_tensor,
                                                    perm=[1, 0])
                sampled_logit_list = tf.unstack(sampled_logit_tensor, axis=1)

            else:
                # Beam Search
                print('Use Beam Search with Beam Search Size %d.' %
                      self.model_config.beam_search_size)
                return self.transformer_beam_search(
                    encoder_outputs, encoder_attn_bias,
                    encoder_embed_inputs_list,
                    sentence_complex_input_placeholder, emb_simple, w, b,
                    rule_id_input_placeholder, mem_contexts, mem_outputs,
                    global_step)

        gt_target_list = sentence_simple_input_placeholder
        output = ModelOutput(
            contexts=cur_context
            if 'rule' in self.model_config.memory else None,
            encoder_outputs=encoder_outputs,
            decoder_outputs_list=final_output_list
            if train_mode != 'dynamic_self-critical' else None,
            final_outputs_list=final_output_list
            if train_mode != 'dynamic_self-critical' else None,
            decoder_logit_list=decoder_logit_list
            if train_mode != 'dynamic_self-critical' else None,
            gt_target_list=gt_target_list,
            encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1),
            decoder_target_list=decoder_target_list,
            sample_logit_list=sampled_logit_list
            if train_mode == 'dynamic_self-critical' else None,
            sample_target_list=sampled_target_list
            if train_mode == 'dynamic_self-critical' else None)
        return output