def forward(self, mid, ref): B, C, H, W = mid.shape mid = F.normalize(mid, p=2, axis=1) ref = F.normalize(ref, p=2, axis=1) cost_volume, ref = compute_cost_volume( mid, ref, max_displacement=self.d) # [B, (2d+1)**2, H, W] cost_volume = F.dimshuffle(cost_volume, (0, 2, 3, 1)) cost_volume = cost_volume.reshape((-1, (2 * self.d + 1)**2)) # argmax indices = F.top_k(cost_volume, k=self.K, descending=True)[1] # [B*H*W, K] del cost_volume ref_list = [] # [B, C, H, W] origin_i_j = F.arange(0, H * W, 1) # float32 origin_i = F.floor(origin_i_j / W) # (H*W, ) origin_j = F.mod(origin_i_j, W) # (H*W, ) del origin_i_j # reshape ref ref = ref.reshape((B, C, (H + 2 * self.d) * (W + 2 * self.d))) for i in range(self.K): index = indices[:, i] # [B*H*W, ] index = index.reshape((-1, H * W)) index_i = F.floor(index / (2 * self.d + 1)) + origin_i # [B, H*W] index_j = F.mod(index, (2 * self.d + 1)) + origin_j # [B, H*W] # 根据每个pixel的i,j 算出index index = index_i * W + index_j # [B, H*W] index = index.astype('int32') # add axis index = F.add_axis(index, axis=1) # [B, 1, H*W] # broadcast index = F.broadcast_to(index, (B, C, H * W)) # gather output = F.gather(ref, axis=2, index=index) # [B, C, H*W] ref_list.append(output.reshape((B, C, H, W))) return self.conv(F.concat(ref_list, axis=1))
def inference_func(images): model.eval() # classic test-time mirror augment embedding_origin = model.forward_embedding_only(images) embedding_mirror = model.forward_embedding_only(images[:, :, :, ::-1]) embedding = embedding_origin + embedding_mirror embedding = F.normalize(embedding, axis=1) return embedding
def forward_embedding_only(self, images): """run forward pass without calculating loss, expected useful during evaluation. Args: images (Tensor): preprocessed images (shape: n * 3 * 112 * 112) Returns: embedding (Tensor): embedding """ if self.use_stn: images = self.stn(images) feature_map = self.backbone(images) embedding = self.head(feature_map) embedding = F.normalize(embedding, axis=1) return embedding
def forward(self, embedding): w = F.normalize(self.weight, axis=1) x = embedding # embedding has been normalized already logits = F.matmul(x, w.transpose(1, 0)) return logits