Exemple #1
0
    def forward(self, x):
        # x: bxlxi
        x = self.init_map(x)  # bxlxe
        ori = x
        p = self.position_encoding(eye)
        x = x + p

        values = self.value_mapping(x)  #bxlxe
        keys = self.key_mapping(x)  #bxlxe
        querys = self.key_mapping(x)

        #print('transformer', values.shape, keys.shape, querys.shape)

        attention = F.softmax(F.batched_matrix_mul(querys,
                                                   keys.dimshuffle(0, 2, 1)),
                              axis=1)  #bxlxl
        #print(attention[0])
        #print(attention[0].sum(axis=0))
        #print('attention', attention.shape)
        out = F.batched_matrix_mul(values.dimshuffle(0, 2, 1), attention)

        out = out.dimshuffle(0, 2, 1)
        out = out + ori
        out = F.relu(out)
        #a,b,c = out.shape[0], out.shape[1], out.shape[2]
        #tmp = out.reshape(-1, self.key_embedding)
        #i = tmp.shape[0]
        #out = self.norm(tmp)
        #out = out.reshape(a,b,c)

        return out
    def forward(self, x, position):
        # x: bxlxi

        values = x  # bxlxe

        querys = self.query_mapping1(position)
        keys = self.key_mapping1(position)

        attention = F.softmax(F.batched_matrix_mul(querys,
                                                   keys.dimshuffle(0, 2, 1)),
                              axis=2)  #bxlxl
        out = F.batched_matrix_mul(values.dimshuffle(0, 2, 1),
                                   attention.dimshuffle(0, 2, 1))

        out = out.dimshuffle(0, 2, 1)

        return out
Exemple #3
0
    def forward(self, x):
        B, C, H, W = x.shape
        N = self.frames
        C = C // N
        A2 = F.dimshuffle(self.A2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W)
        B2 = F.dimshuffle(self.B2(x).reshape(B, N, C, H, W), (0, 1, 3, 4, 2)).reshape(B, N*H*W, C)
        A3 = self.A3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W)
        B3 = F.dimshuffle(self.B3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W), (0, 2, 1))

        D2 = F.dimshuffle(self.D2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W)
        D3 = self.D3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W)

        attention2 = F.softmax(F.batched_matrix_mul(A2, B2), axis = -1)  # [B, C, C]
        attention3 = F.softmax(F.batched_matrix_mul(A3, B3), axis = -1)  # [B, N, N]

        E2 = F.dimshuffle(F.batched_matrix_mul(attention2, D2).reshape(B, C, N, H, W), (0, 2, 1, 3, 4)).reshape(B, N*C, H, W)
        E3 = F.batched_matrix_mul(attention3, D3).reshape(B, N*C, H, W)
        return x + E2 + E3
Exemple #4
0
def matmul(a, b, transpose_b=None):
    dim = len(b.shape)

    if transpose_b:
        b = transpose(b, dim - 1, dim - 2)

    if dim > 3:
        a_shape = list(a.shape)
        b_shape = list(b.shape)
        reshape_batch_size = 1
        for i in a_shape[0:dim - 2]:
            reshape_batch_size *= i
        a = a.reshape(*([reshape_batch_size] + a_shape[dim - 2:dim]))
        b = b.reshape(*([reshape_batch_size] + b_shape[dim - 2:dim]))
        c = F.batched_matrix_mul(a, b)
        c = c.reshape(*(a_shape[0:dim - 1] + b_shape[dim - 1:dim]))
        return c
    elif dim == 3:
        return F.batched_matrix_mul(a, b)
    else:
        return F.matrix_mul(a, b)