def init_data(self, use_cuda):
            self.test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(4)
                turbo_transformers.set_num_threads(4)

            torch.set_grad_enabled(False)
            self.head_count = 16
            self.model_dim = 1024  #self.model_dim should % self.head_count = 0
            self.size_per_head = int(self.model_dim / self.head_count)

            self.query_seq_len_list = query_seq_len_list
            self.key_seq_len_list = key_seq_len_list
            # build the torch model
            self.model = MultiHeadedAttention(self.head_count, self.model_dim)
            self.model.eval()

            if use_cuda:
                self.model.to(self.test_device)

            # prepare torch input data
            self.Q_list = []
            for query_seq_len in query_seq_len_list:
                Q = torch.rand(
                    size=(
                        1,
                        query_seq_len,  #from_seq
                        self.model_dim),
                    dtype=torch.float32,
                    device=self.test_device)
                self.Q_list.append(Q)

            self.K_list = []
            self.V_list = []
            for key_seq_len in key_seq_len_list:
                K = torch.rand(
                    size=(
                        1,
                        key_seq_len,  #from_seq
                        self.model_dim),
                    dtype=torch.float32,
                    device=self.test_device)

                V = torch.rand(
                    size=(
                        1,
                        key_seq_len,  #to_seq
                        self.model_dim),
                    dtype=torch.float32,
                    device=self.test_device)
                self.K_list.append(K)
                self.V_list.append(V)

            # prepare turbo smart batch model
            self.turbo_smart_pad = turbo_transformers.MultiHeadedAttentionSmartBatch.from_onmt(
                self.model)
Example #2
0
class TransformerEncoderLayer(nn.Module):
    """
    A single layer of the transformer encoder.

    Args:
        d_model (int): the dimension of keys/values/queries in
                   MultiHeadedAttention, also the input size of
                   the first-layer of the PositionwiseFeedForward.
        heads (int): the number of head for MultiHeadedAttention.
        d_ff (int): the second-layer of the PositionwiseFeedForward.
        dropout (float): dropout probability(0-1.0).
    """
    def __init__(self,
                 d_model,
                 heads,
                 d_ff,
                 dropout,
                 max_relative_positions=0):
        super(TransformerEncoderLayer, self).__init__()

        self.self_attn = MultiHeadedAttention(
            heads,
            d_model,
            dropout=dropout,
            max_relative_positions=max_relative_positions)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs, mask):
        """
        Args:
            inputs (FloatTensor): ``(batch_size, src_len, model_dim)``
            mask (LongTensor): ``(batch_size, src_len, src_len)``

        Returns:
            (FloatTensor):

            * outputs ``(batch_size, src_len, model_dim)``
        """
        input_norm = self.layer_norm(inputs)
        context, _ = self.self_attn(input_norm,
                                    input_norm,
                                    input_norm,
                                    mask=mask,
                                    type="self")
        out = self.dropout(context) + inputs
        return self.feed_forward(out)

    def update_dropout(self, dropout):
        self.self_attn.update_dropout(dropout)
        self.feed_forward.update_dropout(dropout)
        self.dropout.p = dropout
Example #3
0
    def __init__(self,
                 d_model,
                 heads,
                 d_ff,
                 dropout,
                 max_relative_positions=0):
        super(TransformerEncoderLayer, self).__init__()

        self.self_attn = MultiHeadedAttention(
            heads,
            d_model,
            dropout=dropout,
            max_relative_positions=max_relative_positions)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
Example #4
0
    def from_onmt(multi_headed_attn: OnmtMultiHeadedAttention,
                  is_trans_weight: bool = False):
        attn_params = {k: v for k, v in multi_headed_attn.named_parameters()}
        if multi_headed_attn.max_relative_positions != 0:
            raise "multi_headed_attn's max_relative_positions should be 0!"

        with torch.no_grad():
            att = MultiHeadedAttention(
                *(MultiHeadedAttention.pack_parameter(attn_params,
                                                      is_trans_weight)),
                multi_headed_attn.head_count)
            return att
Example #5
0
 def from_onmt(multi_headed_attn: OnmtMultiHeadedAttention,
               layer_norm: TorchLayerNorm,
               is_trans_weight: bool = False):
     ln_params = {k: v for k, v in layer_norm.named_parameters()}
     attn_params = {k: v for k, v in multi_headed_attn.named_parameters()}
     with torch.no_grad():
         att = MultiHeadedAttention(
             *(MultiHeadedAttention.pack_parameter(multi_headed_attn,
                                                   is_trans_weight)),
             convert2tt_tensor(ln_params['weight']),
             convert2tt_tensor(ln_params['bias']),
             multi_headed_attn.head_count)
         return att
Example #6
0
    def pack_parameter(multi_headed_attn: OnmtMultiHeadedAttention,
                       is_trans_weight: Optional[bool] = False):
        # linear_keys.weight
        # linear_keys.bias
        # linear_values.weight
        # linear_values.bias
        # linear_query.weight
        # linear_query.bias
        # final_linear.weight
        # final_linear.bias
        attn_params = {k: v for k, v in multi_headed_attn.named_parameters()}
        if multi_headed_attn.max_relative_positions != 0:
            raise "multi_headed_attn's max_relative_positions should be 0!"

        # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight
        if is_trans_weight:
            qkv_weight = torch.cat((attn_params['linear_query.weight'],
                                    attn_params['linear_keys.weight'],
                                    attn_params['linear_values.weight']), 0)
            k_w = convert2tt_tensor(attn_params['linear_keys.weight'])
            v_w = convert2tt_tensor(attn_params['linear_values.weight'])
            q_w = convert2tt_tensor(attn_params['linear_query.weight'])
            f_w = convert2tt_tensor(attn_params['final_linear.weight'])
        else:
            qkv_weight = torch.clone(
                torch.t(
                    torch.cat((attn_params['linear_query.weight'],
                               attn_params['linear_keys.weight'],
                               attn_params['linear_values.weight']),
                              0).contiguous()).contiguous())
            k_w = convert2tt_tensor(
                torch.clone(
                    torch.t(attn_params['linear_keys.weight']).contiguous()))
            v_w = convert2tt_tensor(
                torch.clone(
                    torch.t(attn_params['linear_values.weight']).contiguous()))
            q_w = convert2tt_tensor(
                torch.clone(
                    torch.t(attn_params['linear_query.weight']).contiguous()))
            f_w = convert2tt_tensor(
                torch.clone(
                    torch.t(attn_params['final_linear.weight']).contiguous()))

        qkv_bias = torch.cat(
            (attn_params['linear_query.bias'], attn_params['linear_keys.bias'],
             attn_params['linear_values.bias']), 0)
        return (k_w, convert2tt_tensor(attn_params['linear_keys.bias']), v_w,
                convert2tt_tensor(attn_params['linear_values.bias']), q_w,
                convert2tt_tensor(attn_params['linear_query.bias']), f_w,
                convert2tt_tensor(attn_params['final_linear.bias']),
                convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias))
        def init_data(self, use_cuda):
            self.test_device = torch.device('cuda:0') if use_cuda else \
                   torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(4)
                turbo_transformers.set_num_threads(4)

            torch.set_grad_enabled(False)
            self.head_count = 16
            self.model_dim = 1024  #self.model_dim should % self.head_count = 0
            self.size_per_head = int(self.model_dim / self.head_count)

            onmt_multi_headed_attention = MultiHeadedAttention(
                self.head_count, self.model_dim)
            onmt_multi_headed_attention.eval()
            torch_layernorm = torch.nn.LayerNorm(self.model_dim, eps=1e-6)
            torch_layernorm.eval()

            if use_cuda:
                onmt_multi_headed_attention.to(self.test_device)
                torch_layernorm.to(self.test_device)

            K = torch.rand(
                size=(
                    batch_size,
                    key_seq_len,  #from_seq
                    self.model_dim),
                dtype=torch.float32,
                device=self.test_device)
            V = torch.rand(size=(batch_size, key_seq_len, self.model_dim),
                           dtype=torch.float32,
                           device=self.test_device)
            Q = torch.rand(
                size=(
                    batch_size,
                    query_seq_len,  #to_seq
                    self.model_dim),
                dtype=torch.float32,
                device=self.test_device)

            turbo_attn_trans = turbo_transformers.MultiHeadedAttention.from_onmt(
                onmt_multi_headed_attention,
                torch_layernorm,
                is_trans_weight=True)
            turbo_attn_notrans = turbo_transformers.MultiHeadedAttention.from_onmt(
                onmt_multi_headed_attention,
                torch_layernorm,
                is_trans_weight=False)

            if with_quantize_dynamic and not use_cuda:
                self.q_onmt_multi_headed_attention = torch.quantization.quantize_dynamic(
                    onmt_multi_headed_attention)
            return onmt_multi_headed_attention, torch_layernorm, turbo_attn_trans, turbo_attn_notrans, Q, K, V
    class TestMultiHeadedAttentionSmartBatch(unittest.TestCase):
        def init_data(self, use_cuda):
            self.test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(4)
                turbo_transformers.set_num_threads(4)

            torch.set_grad_enabled(False)
            self.head_count = 16
            self.model_dim = 1024  #self.model_dim should % self.head_count = 0
            self.size_per_head = int(self.model_dim / self.head_count)

            self.query_seq_len_list = query_seq_len_list
            self.key_seq_len_list = key_seq_len_list
            # build the torch model
            self.model = MultiHeadedAttention(self.head_count, self.model_dim)
            self.model.eval()

            if use_cuda:
                self.model.to(self.test_device)

            # prepare torch input data
            self.Q_list = []
            for query_seq_len in query_seq_len_list:
                Q = torch.rand(
                    size=(
                        1,
                        query_seq_len,  #from_seq
                        self.model_dim),
                    dtype=torch.float32,
                    device=self.test_device)
                self.Q_list.append(Q)

            self.K_list = []
            self.V_list = []
            for key_seq_len in key_seq_len_list:
                K = torch.rand(
                    size=(
                        1,
                        key_seq_len,  #from_seq
                        self.model_dim),
                    dtype=torch.float32,
                    device=self.test_device)

                V = torch.rand(
                    size=(
                        1,
                        key_seq_len,  #to_seq
                        self.model_dim),
                    dtype=torch.float32,
                    device=self.test_device)
                self.K_list.append(K)
                self.V_list.append(V)

            # prepare turbo smart batch model
            self.turbo_smart_pad = turbo_transformers.MultiHeadedAttentionSmartBatch.from_onmt(
                self.model)

        def check_torch_and_turbo(self, use_cuda, num_iter=1):
            self.init_data(use_cuda)

            device = "GPU" if use_cuda else "CPU"
            info = f"\"({device}, {set_layer_cache}, {pre_layernorm}, {post_add_input}, {attn_type})\""

            # TODO(jiaruifang) test scenario where mask is not None.
            attention_mask = None
            layer_cache_torch = None

            res_list = []
            for Q, K, V in zip(self.Q_list, self.K_list, self.V_list):
                res, _ = self.model(
                    Q if attn_type == "self" else K,  #K,
                    Q if attn_type == "self" else V,  #V,
                    Q,
                    mask=attention_mask,
                    layer_cache=None,  #layer_cache_torch
                    attn_type=attn_type)
                res_list.append(res)

            # concat res_list together
            for i in range(len(res_list)):
                if i == 0:
                    concat_res = res_list[i]
                else:
                    concat_res = torch.cat((concat_res, res_list[i]), 1)

            self.assertTrue(
                concat_res.size()[1] == sum(self.query_seq_len_list))

            # concat K, Q, V together
            for i in range(len(self.query_seq_len_list)):
                if i == 0:
                    concat_Q = self.Q_list[i]
                    concat_K = self.K_list[i]
                    concat_V = self.V_list[i]
                else:
                    concat_Q = torch.cat((concat_Q, self.Q_list[i]), 1)
                    concat_K = torch.cat((concat_K, self.K_list[i]), 1)
                    concat_V = torch.cat((concat_V, self.V_list[i]), 1)

            self.assertTrue(concat_Q.size()[1] == sum(self.query_seq_len_list))
            self.assertTrue(concat_K.size()[1] == sum(self.key_seq_len_list))
            self.assertTrue(concat_V.size()[1] == sum(self.key_seq_len_list))
            self.assertTrue(attn_type == "self" or attn_type == "context")

            pad_res, _ = self.turbo_smart_pad(concat_K,
                                              concat_V,
                                              concat_Q,
                                              self.query_seq_len_list,
                                              self.key_seq_len_list,
                                              mask=attention_mask,
                                              layer_cache=None,
                                              attn_type=attn_type)

            diff = pad_res - concat_res
            # print(diff)
            print(torch.max(diff))
            self.assertTrue(
                numpy.allclose(pad_res.cpu(),
                               concat_res.cpu(),
                               atol=1e-3,
                               rtol=1e-3))

        def test_multi_headed_attention(self):
            # self.check_torch_and_turbo(use_cuda=False)
            if torch.cuda.is_available() and \
                    turbo_transformers.config.is_compiled_with_cuda():
                self.check_torch_and_turbo(use_cuda=True)