def body_sharded(self, sharded_features):
        # ========= Prepare the input and target =========

        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        inputs = sharded_features["inputs"]
        target_space = sharded_features["target_space_id"]

        inputs = dp(common_layers.flatten4d3d, inputs)
        targets = dp(common_layers.flatten4d3d, targets)

        def dp_preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def dp_postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = dp(
             transformer.transformer_prepare_encoder, inputs, target_space,
             hparams)
        (decoder_input, decoder_self_attention_bias) = dp(
            transformer.transformer_prepare_decoder, targets, hparams)
        encoder_input = dp(tf.nn.dropout, encoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        decoder_input = dp(tf.nn.dropout, decoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)

        cache = dict(extra_loss=0.0)

        def prepostprocess(fct):
            """Apply processing and capture the extra loss."""
            @expert_utils.add_var_scope()
            def decorated(x, *args, **kwargs):
                x = dp_preprocess(x)
                y, loss = fct(x, *args, **kwargs)
                cache["extra_loss"] += loss
                return dp_postprocess(x, y)

            return decorated

        # ========= Compute the transformer architecture =========

        def extract_layer_types(layer_types):
            """Parse the layer string.

      Args:
        layer_types (str): String containing the network architecture. See
          top file comment for examples of format.

      Returns:
        list[tuple[str, str]]: Encoder layers: list of (attention, feed-forward)
        list[tuple[str, str, str]]: Decoder layers: list of (self-attention,
          enc-dec attention, feed-forward)
      """
            # If the architecture has not explicitly been set, we just construct a
            # standard transformer with the fallback values
            if not layer_types:
                layer_types = SEP_LAYER.join([hparams.default_att] *
                                             hparams.num_hidden_layers)

            # If encoder not explicitly defined, the encoder will have the same
            # structure as the decoder
            layer_types = layer_types.split(SEP_ENCODEC)
            if len(layer_types) == 1:
                layer_types *= 2

            # Some models don't need the encoder (ex: language modeling)
            # TODO(epot): What are the other conditions (has_input ?)
            if hparams.prepend_mode != "none":
                layer_types[0] = ""

            # Extend the blocks and fill them with the default values if not specified
            final_layers = ([], [])
            for i, blocks_str in enumerate(layer_types):
                for blocks_str in blocks_str.split(SEP_LAYER):
                    if not blocks_str:
                        continue
                    blocks_list = blocks_str.split(SEP_FF)
                    # Eventually use the fallback values for the layer_types. If the
                    # encoder is empty, do not use the enco-deco attention.
                    self_att = blocks_list[0] or hparams.default_att
                    ende_att = hparams.default_att if layer_types[0] else "_"
                    ff = hparams.default_ff
                    if len(blocks_list) > 1:
                        ff = blocks_list[-1]
                    if len(blocks_list) == 3:
                        ende_att = blocks_list[1]
                    if i == 0:  # Encoder
                        blocks_tuple = (self_att, ff)
                    elif i == 1:  # Decoder
                        blocks_tuple = (self_att, ende_att, ff)
                    final_layers[i].append(blocks_tuple)

            return final_layers

        # ========= Construct the transformer encoder and decoder =========

        encoder_layers, decoder_layers = extract_layer_types(
            hparams.layer_types)

        layers = common_attention.get_standardized_layers(
            hparams=hparams,
            dp=dp,
            ps_devices=self._ps_devices,
        )

        if hparams.mode == tf.estimator.ModeKeys.TRAIN:

            # Display the encoder-decoder architecture
            def print_layer(name, layers):
                tf.logging.info("{} architecture:".format(name))
                for i, l in enumerate(layers):
                    tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l)))

            print_layer("Encoder", encoder_layers)
            print_layer("Decoder", decoder_layers)

        encoder_outputs = []

        x = encoder_input
        with tf.variable_scope("encoder"):
            for layer_num, block_types in enumerate(encoder_layers):
                # Each encoder layers is composed of two blocks:
                # * self-attention block
                # * feed-forward block
                att_type, ff_type = block_types
                with tf.variable_scope("layer_{}".format(layer_num)):
                    x = prepostprocess(layers[att_type])(
                        x,
                        bias=encoder_self_attention_bias,
                        name="att_{}".format(att_type),
                    )
                    x = prepostprocess(layers[ff_type])(
                        x, name="ff_{}".format(ff_type))
                encoder_outputs.append(x)
            if encoder_outputs:
                encoder_outputs[-1] = dp_preprocess(x)

        x = decoder_input
        with tf.variable_scope("decoder"):
            for layer_num, block_types in enumerate(decoder_layers):
                # Each decoder layers is composed of three blocks:
                # * self-attention block
                # * enco-deco attention block (optional)
                # * feed-forward block
                self_att_type, att_ende_type, ff_type = block_types
                with tf.variable_scope("layer_{}".format(layer_num)):
                    x = prepostprocess(layers[self_att_type])(
                        x,
                        bias=decoder_self_attention_bias,
                        name="self_att_{}".format(self_att_type),
                    )
                    # Only add the enco-deco attention layer if there is an encoder
                    if encoder_outputs:
                        x = prepostprocess(layers[att_ende_type])(
                            x,
                            memory_antecedent=encoder_outputs[-1],
                            bias=encoder_decoder_attention_bias,
                            name="att_ende_{}".format(att_ende_type),
                        )
                    x = prepostprocess(layers[ff_type])(
                        x, name="ff_{}".format(ff_type))
            # If normalization is done in layer_preprocess, then it should also be
            # done on the output, since the output can grow very large, being the sum
            # of a whole stack of unnormalized layer outputs.
            x = dp_preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, cache["extra_loss"]
Example #2
0
  def body_sharded(self, sharded_features):
    # ========= Prepare the input and target =========

    hparams = self._hparams
    dp = self._data_parallelism

    # Process input
    inputs = sharded_features["inputs"]
    target_space = sharded_features["target_space_id"]
    (
        encoder_input,
        encoder_self_attention_bias,
        encoder_decoder_attention_bias,
    ) = dp(self._prepare_encoder, inputs, target_space)

    # Process output
    targets = sharded_features["targets"]
    decoder_input, decoder_self_attention_bias = dp(
        self._prepare_decoder, targets
    )

    def dp_preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def dp_postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    cache = dict(extra_loss=0.0)

    def prepostprocess(fct):
      """Apply processing and capture the extra loss."""
      @expert_utils.add_var_scope()
      def decorated(x, *args, **kwargs):
        x = dp_preprocess(x)
        y, loss = fct(x, *args, **kwargs)
        cache["extra_loss"] += loss
        return dp_postprocess(x, y)
      return decorated

    # ========= Compute the transformer architecture =========

    encoder_layers, decoder_layers = self._extract_layer_types()

    layers = common_attention.get_standardized_layers(
        hparams=hparams,
        dp=dp,
        ps_devices=self._ps_devices,
    )

    if hparams.mode == tf.estimator.ModeKeys.TRAIN:

      # Display the encoder-decoder architecture
      def print_layer(name, layers):
        tf.logging.info("{} architecture:".format(name))
        for i, l in enumerate(layers):
          tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l)))
      print_layer("Encoder", encoder_layers)
      print_layer("Decoder", decoder_layers)

    # ========= Construct the transformer encoder and decoder =========

    encoder_outputs = []

    x = encoder_input
    with tf.variable_scope("encoder"):
      for layer_num, block_types in enumerate(encoder_layers):
        # Each encoder layers is composed of two blocks:
        # * self-attention block
        # * feed-forward block
        att_type, ff_type = block_types
        with tf.variable_scope("layer_{}".format(layer_num)):
          x = prepostprocess(layers[att_type])(
              x,
              bias=encoder_self_attention_bias,
              name="att_{}".format(att_type),
          )
          x = prepostprocess(layers[ff_type])(
              x,
              name="ff_{}".format(ff_type)
          )
        encoder_outputs.append(x)
      if encoder_outputs:
        encoder_outputs[-1] = dp_preprocess(x)

    x = decoder_input
    with tf.variable_scope("decoder"):
      for layer_num, block_types in enumerate(decoder_layers):
        # Each decoder layers is composed of three blocks:
        # * self-attention block
        # * enco-deco attention block (optional)
        # * feed-forward block
        self_att_type, att_ende_type, ff_type = block_types
        with tf.variable_scope("layer_{}".format(layer_num)):
          x = prepostprocess(layers[self_att_type])(
              x,
              bias=decoder_self_attention_bias,
              name="self_att_{}".format(self_att_type),
          )
          # Only add the enco-deco attention layer if there is an encoder
          if encoder_outputs:
            x = prepostprocess(layers[att_ende_type])(
                x,
                memory_antecedent=encoder_outputs[-1],
                bias=encoder_decoder_attention_bias,
                name="att_ende_{}".format(att_ende_type),
            )
          x = prepostprocess(layers[ff_type])(
              x,
              name="ff_{}".format(ff_type)
          )
      # If normalization is done in layer_preprocess, then it should also be
      # done on the output, since the output can grow very large, being the sum
      # of a whole stack of unnormalized layer outputs.
      x = dp_preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, cache["extra_loss"]
Example #3
0
  def body_sharded(self, sharded_features):
    # ========= Prepare the input and target =========

    hparams = self._hparams
    dp = self._data_parallelism

    # Process input
    inputs = sharded_features["inputs"]
    target_space = sharded_features["target_space_id"]
    (
        encoder_input,
        encoder_self_attention_bias,
        encoder_decoder_attention_bias,
    ) = dp(self._prepare_encoder, inputs, target_space)

    # Process output
    targets = sharded_features["targets"]
    decoder_input, decoder_self_attention_bias = dp(
        self._prepare_decoder, targets
    )

    def dp_preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def dp_postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    cache = dict(extra_loss=0.0)

    def prepostprocess(fct):
      """Apply processing and capture the extra loss."""
      @expert_utils.add_var_scope()
      def decorated(x, *args, **kwargs):
        x_preprocessed = dp_preprocess(x)
        y, loss = fct(x_preprocessed, *args, **kwargs)
        cache["extra_loss"] += loss
        return dp_postprocess(x, y)
      return decorated

    # ========= Compute the transformer architecture =========

    encoder_layers, decoder_layers = self._extract_layer_types()

    layers = common_attention.get_standardized_layers(
        hparams=hparams,
        dp=dp,
    )

    if hparams.mode == tf.estimator.ModeKeys.TRAIN:

      # Display the encoder-decoder architecture
      def print_layer(name, layers):
        tf.logging.info("{} architecture:".format(name))
        for i, l in enumerate(layers):
          tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l)))
      print_layer("Encoder", encoder_layers)
      print_layer("Decoder", decoder_layers)

    # ========= Construct the transformer encoder and decoder =========

    encoder_outputs = []

    x = encoder_input
    with tf.variable_scope("encoder"):
      for layer_num, block_types in enumerate(encoder_layers):
        # Each encoder layers is composed of two blocks:
        # * self-attention block
        # * feed-forward block
        att_type, ff_type = block_types
        with tf.variable_scope("layer_{}".format(layer_num)):
          x = prepostprocess(layers[att_type])(
              x,
              bias=encoder_self_attention_bias,
              name="att_{}".format(att_type),
          )
          x = prepostprocess(layers[ff_type])(
              x,
              name="ff_{}".format(ff_type)
          )
        encoder_outputs.append(x)
      if encoder_outputs:
        encoder_outputs[-1] = dp_preprocess(x)

    x = decoder_input
    with tf.variable_scope("decoder"):
      for layer_num, block_types in enumerate(decoder_layers):
        # Each decoder layers is composed of three blocks:
        # * self-attention block
        # * enco-deco attention block (optional)
        # * feed-forward block
        self_att_type, att_ende_type, ff_type = block_types
        with tf.variable_scope("layer_{}".format(layer_num)):
          x = prepostprocess(layers[self_att_type])(
              x,
              bias=decoder_self_attention_bias,
              name="self_att_{}".format(self_att_type),
          )
          # Only add the enco-deco attention layer if there is an encoder
          if encoder_outputs:
            x = prepostprocess(layers[att_ende_type])(
                x,
                memory_antecedent=encoder_outputs[-1],
                bias=encoder_decoder_attention_bias,
                name="att_ende_{}".format(att_ende_type),
            )
          x = prepostprocess(layers[ff_type])(
              x,
              name="ff_{}".format(ff_type)
          )
      # If normalization is done in layer_preprocess, then it should also be
      # done on the output, since the output can grow very large, being the sum
      # of a whole stack of unnormalized layer outputs.
      x = dp_preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, cache["extra_loss"]
  def body_sharded(self, sharded_features):
    # ========= Prepare the input and target =========

    hparams = self._hparams
    dp = self._data_parallelism
    targets = sharded_features["targets"]
    inputs = sharded_features["inputs"]
    target_space = sharded_features["target_space_id"]

    inputs = dp(common_layers.flatten4d3d, inputs)
    targets = dp(common_layers.flatten4d3d, targets)

    def dp_preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def dp_postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    (encoder_input, encoder_self_attention_bias,
     encoder_decoder_attention_bias) = dp(
         transformer.transformer_prepare_encoder,
         inputs, target_space, hparams)
    (decoder_input, decoder_self_attention_bias) = dp(
        transformer.transformer_prepare_decoder, targets, hparams)
    encoder_input = dp(tf.nn.dropout, encoder_input,
                       1.0 - hparams.layer_prepostprocess_dropout)
    decoder_input = dp(tf.nn.dropout, decoder_input,
                       1.0 - hparams.layer_prepostprocess_dropout)

    cache = dict(extra_loss=0.0)

    def prepostprocess(fct):
      """Apply processing and capture the extra loss."""
      @expert_utils.add_var_scope()
      def decorated(x, *args, **kwargs):
        x = dp_preprocess(x)
        y, loss = fct(x, *args, **kwargs)
        cache["extra_loss"] += loss
        return dp_postprocess(x, y)
      return decorated

    # ========= Compute the transformer architecture =========

    def extract_layer_types(layer_types):
      """Parse the layer string.

      Args:
        layer_types (str): String containing the network architecture. See
          top file comment for examples of format.

      Returns:
        list[tuple[str, str]]: Encoder layers: list of (attention, feed-forward)
        list[tuple[str, str, str]]: Decoder layers: list of (self-attention,
          enc-dec attention, feed-forward)
      """
      # If the architecture has not explicitly been set, we just construct a
      # standard transformer with the fallback values
      if not layer_types:
        layer_types = SEP_LAYER.join(
            [hparams.default_att] * hparams.num_hidden_layers)

      # If encoder not explicitly defined, the encoder will have the same
      # structure as the decoder
      layer_types = layer_types.split(SEP_ENCODEC)
      if len(layer_types) == 1:
        layer_types *= 2

      # Some models don't need the encoder (ex: language modeling)
      # TODO(epot): What are the other conditions (has_input ?)
      if hparams.prepend_mode != "none":
        layer_types[0] = ""

      # Extend the blocks and fill them with the default values if not specified
      final_layers = ([], [])
      for i, blocks_str in enumerate(layer_types):
        for blocks_str in blocks_str.split(SEP_LAYER):
          if not blocks_str:
            continue
          blocks_list = blocks_str.split(SEP_FF)
          # Eventually use the fallback values for the layer_types. If the
          # encoder is empty, do not use the enco-deco attention.
          self_att = blocks_list[0] or hparams.default_att
          ende_att = hparams.default_att if layer_types[0] else "_"
          ff = hparams.default_ff
          if len(blocks_list) > 1:
            ff = blocks_list[-1]
          if len(blocks_list) == 3:
            ende_att = blocks_list[1]
          if i == 0:  # Encoder
            blocks_tuple = (self_att, ff)
          elif i == 1:  # Decoder
            blocks_tuple = (self_att, ende_att, ff)
          final_layers[i].append(blocks_tuple)

      return final_layers

    # ========= Construct the transformer encoder and decoder =========

    encoder_layers, decoder_layers = extract_layer_types(hparams.layer_types)

    layers = common_attention.get_standardized_layers(
        hparams=hparams,
        dp=dp,
        ps_devices=self._ps_devices,
    )

    if hparams.mode == tf.estimator.ModeKeys.TRAIN:

      # Display the encoder-decoder architecture
      def print_layer(name, layers):
        tf.logging.info("{} architecture:".format(name))
        for i, l in enumerate(layers):
          tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l)))
      print_layer("Encoder", encoder_layers)
      print_layer("Decoder", decoder_layers)

    encoder_outputs = []

    x = encoder_input
    with tf.variable_scope("encoder"):
      for layer_num, block_types in enumerate(encoder_layers):
        # Each encoder layers is composed of two blocks:
        # * self-attention block
        # * feed-forward block
        att_type, ff_type = block_types
        with tf.variable_scope("layer_{}".format(layer_num)):
          x = prepostprocess(layers[att_type])(
              x,
              bias=encoder_self_attention_bias,
              name="att_{}".format(att_type),
          )
          x = prepostprocess(layers[ff_type])(
              x,
              name="ff_{}".format(ff_type)
          )
        encoder_outputs.append(x)
      if encoder_outputs:
        encoder_outputs[-1] = dp_preprocess(x)

    x = decoder_input
    with tf.variable_scope("decoder"):
      for layer_num, block_types in enumerate(decoder_layers):
        # Each decoder layers is composed of three blocks:
        # * self-attention block
        # * enco-deco attention block (optional)
        # * feed-forward block
        self_att_type, att_ende_type, ff_type = block_types
        with tf.variable_scope("layer_{}".format(layer_num)):
          x = prepostprocess(layers[self_att_type])(
              x,
              bias=decoder_self_attention_bias,
              name="self_att_{}".format(self_att_type),
          )
          # Only add the enco-deco attention layer if there is an encoder
          if encoder_outputs:
            x = prepostprocess(layers[att_ende_type])(
                x,
                memory_antecedent=encoder_outputs[-1],
                bias=encoder_decoder_attention_bias,
                name="att_ende_{}".format(att_ende_type),
            )
          x = prepostprocess(layers[ff_type])(
              x,
              name="ff_{}".format(ff_type)
          )
      # If normalization is done in layer_preprocess, then it should also be
      # done on the output, since the output can grow very large, being the sum
      # of a whole stack of unnormalized layer outputs.
      x = dp_preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, cache["extra_loss"]