def forward(self, x): # x: bxlxi x = self.init_map(x) # bxlxe ori = x p = self.position_encoding(eye) x = x + p values = self.value_mapping(x) #bxlxe keys = self.key_mapping(x) #bxlxe querys = self.key_mapping(x) #print('transformer', values.shape, keys.shape, querys.shape) attention = F.softmax(F.batched_matrix_mul(querys, keys.dimshuffle(0, 2, 1)), axis=1) #bxlxl #print(attention[0]) #print(attention[0].sum(axis=0)) #print('attention', attention.shape) out = F.batched_matrix_mul(values.dimshuffle(0, 2, 1), attention) out = out.dimshuffle(0, 2, 1) out = out + ori out = F.relu(out) #a,b,c = out.shape[0], out.shape[1], out.shape[2] #tmp = out.reshape(-1, self.key_embedding) #i = tmp.shape[0] #out = self.norm(tmp) #out = out.reshape(a,b,c) return out
def forward(self, x, position): # x: bxlxi values = x # bxlxe querys = self.query_mapping1(position) keys = self.key_mapping1(position) attention = F.softmax(F.batched_matrix_mul(querys, keys.dimshuffle(0, 2, 1)), axis=2) #bxlxl out = F.batched_matrix_mul(values.dimshuffle(0, 2, 1), attention.dimshuffle(0, 2, 1)) out = out.dimshuffle(0, 2, 1) return out
def forward(self, x): B, C, H, W = x.shape N = self.frames C = C // N A2 = F.dimshuffle(self.A2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W) B2 = F.dimshuffle(self.B2(x).reshape(B, N, C, H, W), (0, 1, 3, 4, 2)).reshape(B, N*H*W, C) A3 = self.A3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W) B3 = F.dimshuffle(self.B3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W), (0, 2, 1)) D2 = F.dimshuffle(self.D2(x).reshape(B, N, C, H, W), (0, 2, 1, 3, 4)).reshape(B, C, N*H*W) D3 = self.D3(x).reshape(B, N, C, H, W).reshape(B, N, C*H*W) attention2 = F.softmax(F.batched_matrix_mul(A2, B2), axis = -1) # [B, C, C] attention3 = F.softmax(F.batched_matrix_mul(A3, B3), axis = -1) # [B, N, N] E2 = F.dimshuffle(F.batched_matrix_mul(attention2, D2).reshape(B, C, N, H, W), (0, 2, 1, 3, 4)).reshape(B, N*C, H, W) E3 = F.batched_matrix_mul(attention3, D3).reshape(B, N*C, H, W) return x + E2 + E3
def matmul(a, b, transpose_b=None): dim = len(b.shape) if transpose_b: b = transpose(b, dim - 1, dim - 2) if dim > 3: a_shape = list(a.shape) b_shape = list(b.shape) reshape_batch_size = 1 for i in a_shape[0:dim - 2]: reshape_batch_size *= i a = a.reshape(*([reshape_batch_size] + a_shape[dim - 2:dim])) b = b.reshape(*([reshape_batch_size] + b_shape[dim - 2:dim])) c = F.batched_matrix_mul(a, b) c = c.reshape(*(a_shape[0:dim - 1] + b_shape[dim - 1:dim])) return c elif dim == 3: return F.batched_matrix_mul(a, b) else: return F.matrix_mul(a, b)