Exemple #1
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        x = dp(tf.squeeze, sharded_features["inputs"], 2)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        ffn_hidden_sizes = [
            int(s) for s in hparams.ffn_hidden_sizes.split(",")
        ]
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.mask_right:

            def _bias(x):
                return common_attention.attention_bias_lower_triangle(
                    tf.shape(x)[1])

            bias = dp(_bias, x)
        else:
            bias = tf.zeros([1, 1, 1, 1])
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        batch_coordinate = dp(get_batch_coordinate, x)

        layers = hparams.layers.strip(",").split(",")
        for layer_num, layer_type in enumerate(layers):
            with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
                if _should_preprocess(layer_type):
                    x = preprocess(x)
                if layer_type == "timing":
                    y = dp(common_attention.add_timing_signal_nd, x)
                elif layer_type == "pos_emb":
                    y = dp(common_attention.add_positional_embedding_nd,
                           x,
                           hparams.max_length,
                           name="pos_emb")
                elif layer_type == "att":
                    y = dp(
                        common_attention.multihead_attention,
                        x,
                        None,
                        bias,  # bias
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                elif layer_type == "att_grouped":
                    multiplicative_overhead = (
                        hparams.multiplicative_overhead
                        if hparams.mode == ModeKeys.TRAIN else
                        hparams.multiplicative_overhead_eval)
                    y, loss = dp(
                        common_attention.grouped_attention_multihead,
                        x,
                        x,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        num_groups=hparams.attention_num_groups,
                        memory_target_density=hparams.memory_target_density,
                        multiplicative_overhead=multiplicative_overhead,
                        make_image_summary=hparams.attention_image_summary,
                        mask_right=hparams.mask_right,
                    )
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "att_memory_efficient":
                    assert hparams.layer_preprocess_sequence == "n"
                    y = dp(
                        common_attention.
                        multihead_self_attention_memory_efficient, x, bias,
                        hparams.num_heads)
                elif layer_type == "att_local":
                    y = dp(
                        common_attention.multihead_attention,
                        x,
                        None,
                        None,  # bias
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=("local_mask_right"
                                        if hparams.mask_right else
                                        "local_unmasked"),
                        block_length=hparams.local_attention_window,
                        block_width=hparams.local_attention_window)
                elif layer_type == "att_pseudolocal":
                    # This is an inefficient implementation of local attention, for the
                    # purpose of testing model quality.
                    def _pseudolocal_bias(x):
                        return common_attention.attention_bias_local(
                            tf.shape(x)[1], hparams.local_attention_window,
                            0 if hparams.mask_right else
                            hparams.local_attention_window)

                    pseudolocal_bias = dp(_pseudolocal_bias, x)
                    y = dp(
                        common_attention.multihead_attention, x, None,
                        pseudolocal_bias, hparams.attention_key_channels
                        or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size, hparams.hidden_size,
                        hparams.num_heads, hparams.attention_dropout)
                elif layer_type == "att_local_expert":
                    y, loss = dp(
                        common_attention.local_expert_attention,
                        x,
                        k=hparams.attention_moe_k,
                        loss_coef=hparams.attention_load_balance,
                        attention_num_experts=hparams.attention_num_experts,
                        train=hparams.mode == ModeKeys.TRAIN,
                        batch_coordinate=batch_coordinate,
                        mask_right=hparams.mask_right,
                        split_batch=bool(hparams.attention_split_batch),
                        attention_kq_size=hparams.attention_kq_size,
                        attention_v_size=hparams.attention_v_size)
                    # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "att_lsh":
                    if hparams.lsh_truncated:
                        attention_fn = common_attention.multihead_attention_sparse_truncated
                    else:
                        attention_fn = common_attention.multihead_attention_sparse_dot_prod
                    y, loss = dp(
                        attention_fn,
                        x,
                        None,
                        None,  # Bias is computed inside
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,

                        # Additional parameters
                        bi=[
                            common_attention.BatchInfo(
                                coordinates=batch_coordinate[i],
                                order=None,  # No future mask
                            ) for i in range(dp.n)
                        ],
                        use_map_fn=False,
                        experts_params=dict(nb_hyperplanes=4, ))
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "moe":
                    y, loss = expert_utils.distributed_moe(
                        dp,
                        self._ps_devices,
                        x,
                        hparams.mode == ModeKeys.TRAIN,
                        input_size=hparams.hidden_size,
                        expert_fn=expert_fn,
                        num_experts=hparams.moe_num_experts,
                        k=hparams.moe_k,
                        loss_coef=hparams.moe_loss_coef)
                    extra_loss += loss
                elif layer_type == "ffn":
                    y = dp(
                        expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   ffn_hidden_sizes,
                                                   hparams.hidden_size),
                        dp(expert_utils.flatten_all_but_last, x))
                    y = dp(expert_utils.reshape_like, y, x)
                elif layer_type == "conv":
                    y = dp(
                        common_layers.conv1d,
                        x,
                        hparams.hidden_size,
                        hparams.kernel_height,
                        activation=tf.nn.relu,
                        padding="SAME",
                    )
                else:
                    assert False, "unknown sublayer %s" % layer_type
                if _should_postprocess(layer_type):
                    x = postprocess(x, y)
                else:
                    x = y
        x = preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
    def body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        if hparams.use_inputs:
            decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
            decoder_self_attention_bias = None
        else:
            targets = sharded_features["targets"]
            targets = dp(tf.squeeze, targets, 2)
            (decoder_input, decoder_self_attention_bias,
             pad_remover) = dp(attention_lm_moe_prepare_decoder, targets,
                               hparams)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0

        if not hparams.use_inputs:
            # As preprocess and postprocess are called with batch of size one (all
            # batches concatenated), we just make sure that batch_norm is not use (
            # should not either way)
            assert hparams.norm_type != "batch"

            tf.logging.info(
                "Applying Padding Remover for the attention experts")

            dp_remove_pad = functools.partial(dp,
                                              remove_pad,
                                              pad_remover=pad_remover,
                                              mode=hparams.mode)
            dp_restore_pad = functools.partial(dp,
                                               restore_pad,
                                               ref_x=x,
                                               pad_remover=pad_remover,
                                               mode=hparams.mode)
        else:
            # Using identity function: No effect
            dp_remove_pad = lambda x: x
            dp_restore_pad = lambda x: x

        if hparams.attention_exp_factor != 0:
            tf.logging.info(
                "Expand/compress tokens before sending them to experts")
            dp_expand_bc = lambda x: dp(  # pylint: disable=g-long-lambda
                expand_batch_coordinates, x, hparams.attention_exp_factor)
            dp_expand_x = lambda x: dp(  # pylint: disable=g-long-lambda
                common_attention.deconv_elems_1d, x, hparams.
                attention_exp_factor, hparams.attention_exp_inputdim)
            dp_compress_x = lambda x, l: dp(  # pylint: disable=g-long-lambda
                common_attention.conv_elems_1d, x, hparams.
                attention_exp_factor, l)
        else:
            dp_expand_bc = lambda x: x
            dp_expand_x = lambda x: x
            dp_compress_x = lambda x, l: x

        def print_shape(x, suffix, debug=False):
            # To help debugging, print the input/output shapes at inference and eval
            # Inference for long sequences can take a long time, so that's help to
            # see the progression of the generation
            if not debug and hparams.mode == ModeKeys.TRAIN:
                return x
            return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

        with tf.name_scope("batch_coordinate_preprocess"):
            batch_coordinate = dp(get_batch_coordinate, x)
            batch_coordinate = dp_remove_pad(batch_coordinate)
            batch_coordinate = dp_expand_bc(batch_coordinate)
            batch_order = dp(get_batch_coordinate, x, axis=-1)
            batch_order = dp_remove_pad(batch_order)
            batch_order = dp_expand_bc(batch_order)

        x = dp(print_shape, x, "in")

        assert hparams.batch_size >= hparams.max_length

        num_hidden_layers = (len(hparams.attention_layers)
                             or hparams.num_hidden_layers)
        for layer in range(num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                # Use the layer type defined in attention_layers
                if hparams.attention_layers:
                    attention_type = LAYER_SYMBOLS[
                        hparams.attention_layers[layer]]
                else:
                    attention_type = hparams.attention_type

                with tf.variable_scope("attention_{}".format(attention_type)):
                    if attention_type in [
                            AttentionType.MULTIHEAD,
                            AttentionType.MULTIHEAD_FULL
                    ]:
                        attention_dot_type = ("local_mask_right"
                                              if hparams.attention_local else
                                              "dot_product")
                        if attention_type == AttentionType.MULTIHEAD_FULL:
                            attention_dot_type = "dot_product"
                        y = dp(common_attention.multihead_attention,
                               preprocess(x),
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               attention_type=attention_dot_type,
                               block_length=hparams.attention_block_length,
                               name="decoder_self_attention")
                    elif attention_type == AttentionType.SPARSE_MULTIHEAD:
                        x_in = preprocess(x)
                        x_in = dp_remove_pad(x_in)
                        y, loss_experts = dp(
                            common_attention.
                            multihead_attention_sparse_dot_prod,
                            x_in,
                            None,
                            None,  # Bias is computed inside
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,

                            # Additional parameters
                            bi=[
                                common_attention.BatchInfo(
                                    coordinates=batch_coordinate[i],
                                    order=batch_order[i],  # No future mask
                                ) for i in range(dp.n)
                            ],
                            use_map_fn=hparams.lsh_use_map_fn,
                            experts_params=dict(
                                nb_hyperplanes=hparams.lsh_num_hyperplanes, ),
                        )
                        y = dp_restore_pad(y)

                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss_experts) / dp.n
                    elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
                        x_in = preprocess(x)
                        y, loss_experts = dp(
                            common_attention.
                            multihead_attention_sparse_truncated,
                            x_in,
                            None,
                            None,  # Bias is computed inside
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,

                            # Additional parameters
                            bi=[
                                common_attention.BatchInfo(
                                    coordinates=batch_coordinate[i],
                                    order=batch_order[i],  # No future mask
                                ) for i in range(dp.n)
                            ],
                            mask_right=True,
                            experts_params=dict(
                                nb_hyperplanes=hparams.lsh_num_hyperplanes, ),
                        )

                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss_experts) / dp.n
                    elif attention_type == AttentionType.MEMORY_EFFICIENT:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_attention.
                               multihead_self_attention_memory_efficient,
                               x,
                               decoder_self_attention_bias,
                               hparams.num_heads,
                               name="decoder_self_attention")
                    elif attention_type == AttentionType.MULTIHEAD_REDUCED:
                        y = dp(
                            common_attention.multihead_self_attention_reduced,
                            preprocess(x),
                            factor=hparams.attention_red_factor,
                            reduction_type=hparams.attention_reduction_type,
                            nonlinearity=hparams.attention_nonlinearity,
                            multihead_params=dict(
                                total_key_depth=hparams.attention_key_channels
                                or hparams.hidden_size,
                                total_value_depth=hparams.
                                attention_value_channels
                                or hparams.hidden_size,
                                num_heads=hparams.num_heads,
                                dropout_rate=hparams.attention_dropout,
                            ))
                    elif attention_type == AttentionType.LOCAL_EXPERTS:
                        x_in = preprocess(x)
                        x_in = dp_remove_pad(x_in)
                        x_in = dp_expand_x(x_in)
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            x_in,
                            k=hparams.attention_moe_k,
                            loss_coef=hparams.attention_load_balance,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode == ModeKeys.TRAIN,
                            batch_coordinate=batch_coordinate,
                            mask_right=not hparams.use_inputs,
                            split_batch=bool(hparams.attention_split_batch),
                            attention_num_head=hparams.attention_num_head,
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
                        y = dp_restore_pad(y)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if hparams.memory_efficient_ffn:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_layers.conv_hidden_relu_memory_efficient,
                               x, hparams.filter_size)
                    else:
                        additional_conv_params = {}
                        if hparams.use_sepconv:
                            additional_conv_params = dict(
                                padding="LEFT",
                                # Parameters copied from the transformer model
                                kernel_size=(3, 1),
                                second_kernel_size=(31, 1),
                            )
                        y = dp(common_layers.conv_hidden_relu,
                               preprocess(x),
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout,
                               **additional_conv_params)
                    x = postprocess(x, y)
        x = preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss