Esempio n. 1
0
    def model_fn_body(self, features):
        hparams = self._hparams
        targets = features["targets"]
        inputs = features["inputs"]
        target_space = features["target_space_id"]

        inputs = common_layers.flatten4d3d(inputs)
        targets = common_layers.flatten4d3d(targets)

        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = (
             transformer.transformer_prepare_encoder(inputs, target_space,
                                                     hparams))
        (decoder_input, decoder_self_attention_bias
         ) = transformer.transformer_prepare_decoder(targets, hparams)

        encoder_input = tf.nn.dropout(
            encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
        decoder_input = tf.nn.dropout(
            decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = transformer_revnet_encoder(
            encoder_input, encoder_self_attention_bias, hparams)

        decoder_output = transformer_revnet_decoder(
            decoder_input, encoder_output, decoder_self_attention_bias,
            encoder_decoder_attention_bias, hparams)
        decoder_output = tf.expand_dims(decoder_output, 2)

        return decoder_output
Esempio n. 2
0
def ae_transformer_internal(inputs, targets, target_space, hparams):
    """AE Transformer, main step used for training."""
    with tf.variable_scope("ae_transformer"):
        # Prepare inputs, targets, k.
        k = 2**hparams.num_compress_steps
        _, targets = common_layers.pad_to_same_length(
            targets, targets, final_length_divisible_by=k)
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")

        # Compress and ae.
        ae, hot, kl = ae_compress(targets, hparams.is_2d, hparams, "ae")
        tf.summary.histogram("hot", tf.reshape(tf.argmax(hot, axis=-1), [-1]))
        emb = ae_embed(hot, hparams, "ae", reuse=True)

        # Compress context and run autoregressive decoder on emb-hot.
        emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2)
        emb_flat = tf.stop_gradient(emb_flat)
        dec_c = decode(None, None, emb_flat, inputs, ed, hparams)
        dec_c = tf.reshape(dec_c, tf.shape(emb))
        c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context")
        reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(labels=hot,
                                                                   logits=c_z)
        # If not training, use the predicted z instead of the autoregressive one.
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size)

        # Decompress, pass for ae loss.
        z = ae_decompress(emb, ae, targets, hparams.is_2d, hparams, "ae")
        kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8),
                                              min_value=0.0001)
        reconstruct_loss *= common_layers.inverse_exp_decay(
            hparams.startup_steps)
        losses = {"kl": kl, "reconstruction": reconstruct_loss * 0.1}
        return z, losses
Esempio n. 3
0
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train):
    """LSTM seq2seq model with attention, main step used for training."""
    with tf.variable_scope("lstm_seq2seq_attention"):
        # Flatten inputs.
        inputs = common_layers.flatten4d3d(inputs)
        # LSTM encoder.
        encoder_outputs, final_encoder_state = lstm(
            tf.reverse(inputs, axis=[1]), hparams, train, "encoder")
        # LSTM decoder with attention
        shifted_targets = common_layers.shift_right(targets)
        decoder_outputs, _ = lstm_attention_decoder(
            common_layers.flatten4d3d(shifted_targets), hparams, train,
            "decoder", final_encoder_state, encoder_outputs)
        return tf.expand_dims(decoder_outputs, axis=2)
Esempio n. 4
0
def bytenet_internal(inputs, targets, hparams):
    """ByteNet, main step used for training."""
    with tf.variable_scope("bytenet"):
        # Flatten inputs and extend length by 50%.
        inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
        extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1]))
        inputs_shape = inputs.shape.as_list()
        inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]])
        inputs_shape[1] = None
        inputs.set_shape(
            inputs_shape)  # Don't lose the other shapes when padding.
        # Pad inputs and targets to be the same length, divisible by 50.
        inputs, targets = common_layers.pad_to_same_length(
            inputs, targets, final_length_divisible_by=50)
        final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat,
                                              "SAME", "encoder", hparams)

        shifted_targets = common_layers.shift_right(targets)
        kernel = (hparams.kernel_height, hparams.kernel_width)
        decoder_start = common_layers.conv_block(
            tf.concat([final_encoder, shifted_targets], axis=3),
            hparams.hidden_size, [((1, 1), kernel)],
            padding="LEFT")

        return residual_dilated_conv(decoder_start, hparams.num_block_repeat,
                                     "LEFT", "decoder", hparams)
Esempio n. 5
0
    def encode(self, inputs, target_space, hparams):
        """Encode transformer inputs.

    Args:
      inputs: Transformer inputs [batch_size, input_length, hidden_dim]
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encodre-decoder attention. [batch_size, input_length]
    """
        inputs = common_layers.flatten4d3d(inputs)

        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer_prepare_encoder(inputs, target_space, hparams))

        encoder_input = tf.nn.dropout(
            encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

        encoder_output = transformer_encoder(encoder_input,
                                             self_attention_bias, hparams)

        return encoder_output, encoder_decoder_attention_bias
Esempio n. 6
0
        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets
Esempio n. 7
0
    def model_fn_body(self, features):
        """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "tragets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        hparams = self._hparams

        inputs = features["inputs"]

        target_space = features["target_space_id"]
        encoder_output, encoder_decoder_attention_bias = self.encode(
            inputs, target_space, hparams)

        targets = features["targets"]
        targets = common_layers.flatten4d3d(targets)

        decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
            targets, hparams)

        return self.decode(decoder_input, encoder_output,
                           encoder_decoder_attention_bias,
                           decoder_self_attention_bias, hparams)
Esempio n. 8
0
def slicenet_internal(inputs, targets, target_space, problem_idx, hparams):
    """The slicenet model, main step used for training."""
    with tf.variable_scope("slicenet"):
        # Flatten inputs and encode.
        inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
        inputs_mask = 1.0 - embedding_to_padding(inputs)
        inputs = common_layers.add_timing_signal(inputs)  # Add position info.
        target_space_emb = embed_target_space(target_space,
                                              hparams.hidden_size)
        extra_layers = int(hparams.num_hidden_layers * 1.5)
        inputs_encoded = multi_conv_res(inputs,
                                        "SAME",
                                        "encoder",
                                        extra_layers,
                                        hparams,
                                        mask=inputs_mask)
        target_modality_name = hparams.problems[
            problem_idx].target_modality.name
        if "class_label_modality" in target_modality_name:
            # If we're just predicing a class, there is no use for a decoder.
            return inputs_encoded
        # Do the middle part.
        decoder_start, similarity_loss = slicenet_middle(
            inputs_encoded, targets, target_space_emb, inputs_mask, hparams)
        # Decode.
        decoder_final = multi_conv_res(decoder_start,
                                       "LEFT",
                                       "decoder",
                                       hparams.num_hidden_layers,
                                       hparams,
                                       mask=inputs_mask,
                                       source=inputs_encoded)
        return decoder_final, tf.reduce_mean(similarity_loss)
Esempio n. 9
0
def lstm_seq2seq_internal(inputs, targets, hparams, train):
    """The basic LSTM seq2seq model, main step used for training."""
    with tf.variable_scope("lstm_seq2seq"):
        # Flatten inputs.
        inputs = common_layers.flatten4d3d(inputs)
        # LSTM encoder.
        _, final_encoder_state = lstm(tf.reverse(inputs, axis=[1]), hparams,
                                      train, "encoder")
        # LSTM decoder.
        shifted_targets = common_layers.shift_right(targets)
        decoder_outputs, _ = lstm(common_layers.flatten4d3d(shifted_targets),
                                  hparams,
                                  train,
                                  "decoder",
                                  initial_state=final_encoder_state)
        return tf.expand_dims(decoder_outputs, axis=2)
Esempio n. 10
0
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams):
    """Middle part of slicenet, connecting encoder and decoder."""
    def norm_fn(x, name):
        with tf.variable_scope(name, default_name="norm"):
            return common_layers.apply_norm(x, hparams.norm_type,
                                            hparams.hidden_size,
                                            hparams.norm_epsilon)

    # Flatten targets and embed target_space_id.
    targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2)
    target_space_emb = tf.tile(target_space_emb,
                               [tf.shape(targets_flat)[0], 1, 1, 1])

    # Calculate similarity loss (but don't run if not needed).
    if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001:
        targets_timed = common_layers.add_timing_signal(targets_flat)
        extra_layers = int(hparams.num_hidden_layers * 1.5)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder",
                                             extra_layers, hparams)
        with tf.variable_scope("similarity_loss"):
            similarity_loss = similarity_cost(inputs_encoded, targets_encoded)
            similarity_loss *= hparams.sim_loss_mult
    else:
        similarity_loss = 0.0

    # Use attention from each target to look at input and retrieve.
    targets_shifted = common_layers.shift_right(targets_flat,
                                                pad_value=target_space_emb)
    if hparams.attention_type == "none":
        targets_with_attention = tf.zeros_like(targets_shifted)
    else:
        inputs_padding_bias = (1.0 -
                               mask) * -1e9  # Bias to not attend to padding.
        targets_with_attention = attention(targets_shifted,
                                           inputs_encoded,
                                           norm_fn,
                                           hparams,
                                           bias=inputs_padding_bias)

    # Positional targets: merge attention and raw.
    kernel = (hparams.kernel_height, hparams.kernel_width)
    targets_merged = common_layers.subseparable_conv_block(
        tf.concat([targets_with_attention, targets_shifted], axis=3),
        hparams.hidden_size, [((1, 1), kernel)],
        normalizer_fn=norm_fn,
        padding="LEFT",
        separability=4,
        name="targets_merge")

    return targets_merged, similarity_loss
Esempio n. 11
0
    def model_fn_body(self, features):
        hparams = self._hparams
        targets = features["targets"]
        inputs = features.get("inputs")
        target_space = features.get("target_space_id")

        inputs = common_layers.flatten4d3d(inputs)
        targets = common_layers.flatten4d3d(targets)

        (encoder_input, encoder_attention_bias,
         _) = (transformer.transformer_prepare_encoder(inputs, target_space,
                                                       hparams))
        (decoder_input,
         _) = (transformer.transformer_prepare_decoder(targets, hparams))

        encoder_mask = bias_to_mask(encoder_attention_bias)

        def residual_fn(x, y):
            return common_layers.layer_norm(
                x + tf.nn.dropout(y, 1.0 - hparams.residual_dropout))

        encoder_input = tf.nn.dropout(encoder_input,
                                      1.0 - hparams.residual_dropout)
        decoder_input = tf.nn.dropout(decoder_input,
                                      1.0 - hparams.residual_dropout)

        encoder_output = alt_transformer_encoder(encoder_input, residual_fn,
                                                 encoder_mask, hparams)

        decoder_output = alt_transformer_decoder(decoder_input, encoder_output,
                                                 residual_fn,
                                                 encoder_attention_bias,
                                                 hparams)

        decoder_output = tf.expand_dims(decoder_output, 2)

        return decoder_output
Esempio n. 12
0
 def targets_bottom(self, inputs):
     with tf.variable_scope(self.name):
         # Reshape inputs to 2-d tensor and embed the RGB pixel values.
         shape = tf.shape(inputs)
         inputs = common_layers.flatten4d3d(inputs)
         ret = common_layers.embedding(tf.to_int32(inputs),
                                       self.top_dimensionality,
                                       self._body_input_depth,
                                       name="input_rgb_embedding")
         if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
             ret *= self._body_input_depth**0.5
         ret = tf.reshape(
             ret,
             [shape[0], shape[1], shape[2], self._body_input_depth * 3])
         return tf.layers.dense(ret, self._body_input_depth)
Esempio n. 13
0
    def model_fn_body(self, features):
        hparams = self._hparams
        targets = features["targets"]

        targets = common_layers.flatten4d3d(targets)

        (decoder_input,
         decoder_self_attention_bias) = transformer_prepare_decoder(
             targets, hparams)

        decoder_input = tf.nn.dropout(
            decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

        decoder_output = transformer_decoder(decoder_input, None,
                                             decoder_self_attention_bias, None,
                                             hparams)
        decoder_output = tf.expand_dims(decoder_output, 2)

        return decoder_output
Esempio n. 14
0
  def model_fn_body(self, features):
    inputs = features["inputs"]
    inputs.get_shape().assert_has_rank(4)

    hp = self._hparams

    out = inputs
    out = common_layers.flatten4d3d(out)

    # Conv layers
    assert hp.num_conv_layers == len(hp.pooling_windows)
    for i in xrange(hp.num_conv_layers):
      out = conv_layer(
          out,
          hp.hidden_size,
          hp.kernel_width,
          hp.stride,
          hp.pooling_windows[i],
          hp.dropout,
          dilation_rate=1,
          name="conv_%d" % (i + 1))

    # Dense dilated conv layers
    for i in xrange(hp.num_dconv_layers):
      dilation_rate = 2**(i + 1)
      dconv_out = conv_layer(
          out,
          hp.hidden_size,
          hp.kernel_width,
          stride=1,
          pooling_window=0,
          dropout_rate=hp.dropout,
          dilation_rate=dilation_rate,
          name="dconv_%d" % (i + 1))
      out = tf.concat([out, dconv_out], axis=2)

    # Fully connected layer
    out = fc_layer(out, hp.hidden_size, hp.dropout, name="fc")

    out.get_shape().assert_has_rank(3)
    out = tf.expand_dims(out, 2)
    return out
Esempio n. 15
0
 def flatten(inputs):
     return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)