示例#1
0
def decode_transformer(encoder_output,
                       encoder_decoder_attention_bias,
                       targets,
                       hparams,
                       name,
                       task=None):
  """Original Transformer decoder."""
  with tf.variable_scope(name):
    if task is None:
      task = hparams.task
    if task == "translate":
      targets = common_layers.flatten4d3d(targets)

      decoder_input, decoder_self_bias = (
          transformer.transformer_prepare_decoder(targets, hparams))

      decoder_input = tf.nn.dropout(decoder_input,
                                    1.0 - hparams.layer_prepostprocess_dropout)

      decoder_output = transformer.transformer_decoder(
          decoder_input,
          encoder_output,
          decoder_self_bias,
          encoder_decoder_attention_bias,
          hparams)
      decoder_output = tf.expand_dims(decoder_output, axis=2)
    else:
      assert task == "image"
      inputs = None
      # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise
      # prepare_image will choke
      targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len,
                                     hparams.img_len,
                                     hparams.num_channels*hparams.hidden_size])

      # Prepare decoder inputs and bias.
      decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams)
      # Add class label to decoder input.
      if not hparams.drop_inputs:
        decoder_input += tf.reshape(
            inputs,
            [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size])
      decoder_output = cia.transformer_decoder_layers(
          decoder_input,
          None,
          bias,
          hparams.num_decoder_layers or hparams.num_hidden_layers,
          hparams,
          attention_type=hparams.dec_attention_type,
          name="decoder")
    decoder_output_shape = common_layers.shape_list(decoder_output)
    decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1,
                                                 hparams.hidden_size])
    # Expand since t2t expects 4d tensors.
    return decoder_output
示例#2
0
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets,
                       hparams, name):
  """Original Transformer decoder."""
  with tf.variable_scope(name):
    targets = common_layers.flatten4d3d(targets)

    decoder_input, decoder_self_bias = (
        transformer.transformer_prepare_decoder(targets, hparams))

    decoder_input = tf.nn.dropout(decoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    decoder_output = transformer.transformer_decoder(
        decoder_input, encoder_output, decoder_self_bias,
        encoder_decoder_attention_bias, hparams)
    decoder_output = tf.expand_dims(decoder_output, axis=2)
    decoder_output_shape = common_layers.shape_list(decoder_output)
    decoder_output = tf.reshape(
        decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size])
    # Expand since t2t expects 4d tensors.
    return decoder_output
def decode_transformer(encoder_output,
                       encoder_decoder_attention_bias,
                       targets,
                       hparams,
                       name,
                       task=None,
                       causal=True):
    """Original Transformer decoder."""
    orig_hparams = hparams
    with tf.variable_scope(name):
        if task is None:
            task = hparams.task
        if task == "translate":
            targets = common_layers.flatten4d3d(targets)

            decoder_input, decoder_self_bias = (
                transformer.transformer_prepare_decoder(targets, hparams))

            decoder_input = tf.nn.dropout(
                decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

            if not causal:
                decoder_self_bias *= 0.

            decoder_output = transformer.transformer_decoder(
                decoder_input, encoder_output, decoder_self_bias,
                encoder_decoder_attention_bias, hparams)
            decoder_output = tf.expand_dims(decoder_output, axis=2)
        else:
            assert task == "image"
            inputs = None
            # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise
            # prepare_image will choke
            targets = tf.reshape(targets, [
                tf.shape(targets)[0], hparams.img_len, hparams.img_len,
                hparams.num_channels * hparams.hidden_size
            ])

            # Prepare decoder inputs and bias.
            # TODO(nikip): Make prepare_decoder return bias
            decoder_input, _, _ = cia.prepare_decoder(targets, hparams)
            bias = None

            # Add class label to decoder input.
            if not hparams.drop_inputs:
                decoder_input += tf.reshape(inputs, [
                    common_layers.shape_list(targets)[0], 1, 1,
                    hparams.hidden_size
                ])
            decoder_output = cia.transformer_decoder_layers(
                decoder_input,
                encoder_output=None,
                num_layers=hparams.num_decoder_layers
                or hparams.num_hidden_layers,
                hparams=hparams,
                self_attention_bias=bias,
                attention_type=hparams.dec_attention_type,
                name="decoder")
        decoder_output_shape = common_layers.shape_list(decoder_output)
        decoder_output = tf.reshape(
            decoder_output,
            [decoder_output_shape[0], -1, 1, hparams.hidden_size])
        # Expand since t2t expects 4d tensors.
        hparams = orig_hparams
        return decoder_output
    def decode(self,
               decoder_input,
               encoder_output,
               encoder_decoder_attention_bias,
               decoder_self_attention_bias,
               hparams,
               cache=None,
               nonpadding=None):
        """Decode Transformer outputs from encoder representation.

    Args:
      decoder_input: inputs to bottom of the model.
          [batch_size, decoder_length, hidden_dim]
      encoder_output: Encoder representation.
          [batch_size, input_length, hidden_dim]
      encoder_decoder_attention_bias: Bias and mask weights for
          encoder-decoder attention. [batch_size, input_length]
      decoder_self_attention_bias: Bias and mask weights for decoder
          self-attention. [batch_size, decoder_length]
      hparams: hyperparmeters for model.
      cache: dict, containing tensors which are the results of previous
          attentions, used for fast decoding.
      nonpadding: optional Tensor with shape [batch_size, decoder_length]

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        decoder_input = tf.nn.dropout(
            decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

        decoder_output = transformer_decoder(
            decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_decoder_attention_bias,
            hparams,
            cache=cache,
            nonpadding=nonpadding,
            save_weights_to=self.attention_weights)

        if (common_layers.is_on_tpu()
                and hparams.mode == tf.estimator.ModeKeys.TRAIN):
            # TPU does not react kindly to extra dimensions.
            # TODO(noam): remove this once TPU is more forgiving of extra dims.
            return decoder_output
        else:
            # Expand since t2t expects 4d tensors.

            m = self.sentence_cache.Query(
                tf.reshape(decoder_output,
                           [hparams.batch_size, -1, hparams.hidden_size]))
            #m = tf.py_func(self.sentence_cache.QueryMultipleEntries, [decoder_output], tf.float32)

            lambd = self.calculate_mixing_weight(
                tf.reshape(decoder_output,
                           [hparams.batch_size, -1, hparams.hidden_size]), m)

            m = tf.reshape(m, tf.shape(decoder_output))

            lambd = tf.reshape(
                lambd, (tf.shape(decoder_output)[0], -1, hparams.hidden_size))

            if self.hparams.use_cache:
                return tf.expand_dims(lambd * decoder_output +
                                      (1.0 - lambd) * m,
                                      axis=2)
            else:
                return tf.expand_dims(decoder_output, axis=2)
示例#5
0
def encode_decode_task(features, hparams, train, attention_weights=None):
    """Model core graph for the one-shot action.

  Args:
    features: a dictionary contains "inputs" that is a tensor in shape of
        [batch_size, num_tokens], "verb_id_seq" that is in shape of
        [batch_size, num_actions], "object_spans" and "param_span" tensor
        in shape of [batch_size, num_actions, 2]. 0 is used as padding or
        non-existent values.
    hparams: the general hyperparameters for the model.
    train: the train mode.
    attention_weights: the dict to keep attention weights for analysis.
  Returns:
    loss_dict: the losses for training.
    prediction_dict: the predictions for action tuples.
    areas: the area encodings of the task.
    scope: the embedding scope.
  """
    del train
    input_embeddings, scope = common_embed.embed_tokens(
        features["task"], hparams.task_vocab_size, hparams.hidden_size,
        hparams)
    with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE):
        encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0)
        input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2),
                                       input_embeddings)
        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer.transformer_prepare_encoder(input_embeddings,
                                                    None,
                                                    hparams,
                                                    features=None))
        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_encoder == "transformer":
            encoder_output = transformer.transformer_encoder(
                encoder_input,
                self_attention_bias,
                hparams,
                save_weights_to=attention_weights,
                make_image_summary=not common_layers.is_xla_compiled())
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        span_rep = hparams.get("span_rep", "area")
        area_encodings, area_starts, area_ends = area_utils.compute_sum_image(
            encoder_output, max_area_width=hparams.max_span)
        current_shape = tf.shape(area_encodings)
        if span_rep == "area":
            area_encodings, _, _ = area_utils.compute_sum_image(
                encoder_output, max_area_width=hparams.max_span)
        elif span_rep == "basic":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=False)
        elif span_rep == "coref":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=True)
        else:
            raise ValueError("xyz")
        areas = {}
        areas["encodings"] = area_encodings
        areas["starts"] = area_starts
        areas["ends"] = area_ends
        with tf.control_dependencies([
                tf.print("encoder_output", tf.shape(encoder_output)),
                tf.assert_equal(current_shape,
                                tf.shape(area_encodings),
                                summarize=100)
        ]):
            paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32)
        padding_sum, _, _ = area_utils.compute_sum_image(
            tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2),
            max_area_width=hparams.max_span)
        num_areas = common_layers.shape_list(area_encodings)[1]
        area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0),
                                   [-1, num_areas])
        areas["bias"] = area_paddings
        decoder_nonpadding = tf.to_float(
            tf.greater(features["verb_refs"][:, :, 1],
                       features["verb_refs"][:, :, 0]))
        if hparams.instruction_encoder == "lstm":
            hparams_decoder = copy.copy(hparams)
            hparams_decoder.set_hparam("pos", "none")
        else:
            hparams_decoder = hparams
        decoder_input, decoder_self_attention_bias = _prepare_decoder_input(
            area_encodings,
            decoder_nonpadding,
            features,
            hparams_decoder,
            embed_scope=scope)
        decoder_input = tf.nn.dropout(decoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_decoder == "transformer":
            decoder_output = transformer.transformer_decoder(
                decoder_input=decoder_input,
                encoder_output=encoder_output,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                hparams=hparams_decoder)
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        return decoder_output, decoder_nonpadding, areas, scope
    def decode_srcs_to_trgs(self, trg_emb, trg_input_ids=None, outputs=None):
        trg_emb = common_attention.add_timing_signal_1d(trg_emb)

        trg_emb_fn = None
        control_flatten_outputs = None
        control_flatten_bias = None
        if 'control_vec' in self.shared_tensors and self.flags.control_mode:
            if "flatten" not in self.flags.control_mode:
                if 'bert' in self.flags.model_mode:
                    # In BERT, update  trg emb inside bert
                    trg_emb_fn = lambda trg_emb: self.update_embedding(
                        trg_emb, )
                else:
                    trg_emb = self.update_embedding(trg_emb)
            else:
                control_flatten_outputs = self.shared_tensors['control_vec']
                control_flatten_bias = tf.zeros([1, 1, 1, 1])

        control_outputs, control_bias = None, None
        if 'control_outputs' in self.shared_tensors:
            control_outputs = self.shared_tensors['control_outputs']
            control_bias = self.shared_tensors['control_bias']

        trg_length = tf.shape(trg_emb)[1]
        if 'gpt2' in self.flags.model_mode:
            trg_outputs = model.gpt2_decoder(
                self.hparams,
                trg_emb,
                encoder_outputs=self.shared_tensors['src_outputs'],
                encoder_bias=self.shared_tensors['src_bias'])
        elif 't2t' in self.flags.model_mode:
            trg_self_attention_bias = common_attention.attention_bias_lower_triangle(
                trg_length)
            trg_outputs = transformer.transformer_decoder(
                decoder_input=trg_emb,
                decoder_self_attention_bias=trg_self_attention_bias,
                encoder_output=self.shared_tensors['src_outputs'],
                encoder_decoder_attention_bias=self.shared_tensors['src_bias'],
                hparams=self.hparams,
                external_output=control_outputs,
                external_bias=control_bias,
                external_output2=control_flatten_outputs,
                external_bias2=control_flatten_bias,
                external_output3=self.shared_tensors['template_simp_outputs'],
                external_bias3=self.shared_tensors['template_simp_bias'],
                name='trg_decoder')
        elif 'bert' in self.flags.model_mode:
            trg_mask = common_attention.attention_bias_bert(trg_length, -1, 0)
            bert_model = BertModel(
                config=BertConfig.from_json_file(self.flags.bert_config_file),
                is_training=self.is_training,
                input_ids=trg_input_ids,
                input_mask=trg_mask,
                embeddings=self.shared_tensors['word_embedding_table'],
                encoder_ids=self.shared_tensors['src_ids'],
                encoder_outpus=self.shared_tensors['src_outputs'],
                encoder_mask=1.0 - self.shared_tensors['src_mask'],
                trg_emb_fn=trg_emb_fn)
            trg_outputs = bert_model.get_sequence_output()
        else:
            raise ValueError('model_mode not known')
        return trg_outputs
示例#7
0
def vae_transformer_internal(inputs, targets, target_space, hparams):
    """VAE Transformer, main step used for training."""
    with tf.variable_scope("vae_transformer"):
        is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
        # Prepare inputs, targets, and k.
        inputs = common_layers.flatten4d3d(inputs)
        targets = common_layers.flatten4d3d(targets)
        k = 2**hparams.num_compress_steps
        _, targets = common_layers.pad_to_same_length(
            inputs, targets, final_length_divisible_by=k)

        # Transformer preparations and encoder.
        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias
         ) = transformer.transformer_prepare_encoder(inputs, target_space,
                                                     hparams)
        residual_fn = transformer.get_residual_fn(hparams)
        encoder_input = tf.nn.dropout(encoder_input,
                                      1.0 - hparams.residual_dropout)
        encoder_output = transformer.transformer_encoder(
            encoder_input, residual_fn, encoder_self_attention_bias, hparams)

        def get_decoder_autoregressive():
            """Decoder input for autoregressive computation."""
            (a, b) = transformer.transformer_prepare_decoder(targets, hparams)
            return (a, b, tf.constant(0.0))

        # 10% of the time we compress all-zeros, as will be at decoding start.
        prob_targets = 0.9 if is_training else 1.0
        to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
                              lambda: targets, lambda: tf.zeros_like(targets))
        z, kl_loss = compress_vae(to_compress, hparams, "vae")
        # Decompress.
        for i in xrange(hparams.num_compress_steps):
            j = hparams.num_hidden_layers - i - 1
            z = decompress(z, hparams, "decompress_%d" % j)

        def get_decoder_from_vae():
            """Decoder input computed by VAE."""
            # Return decoder stuff.
            (a, b) = transformer.transformer_prepare_decoder(
                tf.squeeze(z, axis=2), hparams)
            return (a, b, kl_loss)

        # Randomize decoder inputs..
        prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7
        step = tf.to_float(tf.contrib.framework.get_global_step())
        if not is_training:
            prob_do_vae = tf.cond(tf.less(step,
                                          40000.0), lambda: tf.constant(0.0),
                                  lambda: tf.constant(1.0))
        (decoder_input, decoder_self_attention_bias,
         kl_loss2) = tf.cond(tf.less(tf.random_uniform([]), prob_do_vae),
                             get_decoder_from_vae, get_decoder_autoregressive)

        # Transformer decoder.
        decoder_output = transformer.transformer_decoder(
            decoder_input, encoder_output, residual_fn,
            decoder_self_attention_bias, encoder_decoder_attention_bias,
            hparams)
        decoder_output = tf.expand_dims(decoder_output, 2)

        cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0),
                            lambda: tf.constant(0.0))
        prob_self = 0.4 if is_training else cond_self
        (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]),
                                         prob_self), lambda: (z, kl_loss),
                                 lambda: (decoder_output, kl_loss2))

        kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0
        return ret, kl_loss