def __init__(self,
                 batch_size,
                 from_seq_length,
                 to_seq_length,
                 num_attention_heads=1,
                 size_per_head=512,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 do_return_2d_tensor=False,
                 use_relative_positions=False,
                 dtype=mstype.float32,
                 compute_type=mstype.float32):

        super(BertAttentionRelativePositionValues, self).__init__()
        self.batch_size = batch_size
        self.from_seq_length = from_seq_length
        self.to_seq_length = to_seq_length
        self.use_relative_positions = use_relative_positions
        self.size_per_head = size_per_head
        self.num_attention_heads = num_attention_heads
        self.trans_shape_position = (1, 2, 0, 3)
        self.trans_shape_relative = (2, 0, 1, 3)

        self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))],
                                 dtype=dtype)
        self.trans_shape = (0, 2, 1, 3)

        self.reshape = P.Reshape()
        self.multiply = P.Mul()
        self.transpose = P.Transpose()
        self.batch_num = batch_size * num_attention_heads
        self.matmul = P.BatchMatMul()
        self.do_return_2d_tensor = do_return_2d_tensor
        if self.do_return_2d_tensor:
            self.shp_return = (batch_size * from_seq_length,
                               num_attention_heads * size_per_head)
        else:
            self.shp_return = (batch_size, from_seq_length,
                               num_attention_heads * size_per_head)

        self.cast_compute_type = SaturateCast(dst_type=compute_type)
        self._generate_relative_positions_embeddings = \
            RelaPosEmbeddingsGenerator(length=self.to_seq_length,
                                       depth=self.size_per_head,
                                       max_relative_position=16,
                                       initializer_range=initializer_range,
                                       use_one_hot_embeddings=use_one_hot_embeddings)
        self.fill = P.Fill()
        self.multiply = P.Mul()
        self.type = P.DType()
        self.cast = P.Cast()
Ejemplo n.º 2
0
    def __init__(self,
                 batch_size,
                 from_seq_length,
                 to_seq_length,
                 num_attention_heads=1,
                 size_per_head=512,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 use_relative_positions=False,
                 dtype=mstype.float32,
                 compute_type=mstype.float32):
        super(BertAttentionRelativePositionKeys, self).__init__()
        self.batch_size = batch_size
        self.from_seq_length = from_seq_length
        self.to_seq_length = to_seq_length
        self.use_relative_positions = use_relative_positions
        self.size_per_head = size_per_head
        self.num_attention_heads = num_attention_heads
        self.trans_shape_position = (1, 2, 0, 3)
        self.trans_shape_relative = (2, 0, 1, 3)

        self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))

        self.reshape = P.Reshape()
        self.multiply = P.Mul()
        self.transpose = P.Transpose()
        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
        self.batch_num = batch_size * num_attention_heads
        self.cast = P.Cast()

        self.cast_compute_type = SaturateCast(dst_type=compute_type)
        self._generate_relative_positions_embeddings = \
            RelaPosEmbeddingsGenerator(length=self.to_seq_length,
                                       depth=self.size_per_head,
                                       max_relative_position=16,
                                       initializer_range=initializer_range,
                                       use_one_hot_embeddings=use_one_hot_embeddings)
 }, {
     'id':
     'RelaPosMatrixGenerator',
     'group':
     'RelaPosMatrixGenerator',
     'block':
     RelaPosMatrixGenerator(length=128, max_relative_position=16)
 }, {
     'id':
     'RelaPosEmbeddingsGenerator',
     'group':
     'RelaPosEmbeddingsGenerator',
     'block':
     RelaPosEmbeddingsGenerator(length=128,
                                depth=64,
                                max_relative_position=16,
                                initializer_range=0.02,
                                use_one_hot_embeddings=False)
 }, {
     'id':
     'BertAttention',
     'group':
     'BertAttention',
     'block':
     BertAttention(batch_size=1,
                   from_tensor_width=1024,
                   to_tensor_width=1024,
                   from_seq_length=128,
                   to_seq_length=128,
                   num_attention_heads=16,
                   size_per_head=64,