예제 #1
0
 def forward(self, mid, ref):
     B, C, H, W = mid.shape
     mid = F.normalize(mid, p=2, axis=1)
     ref = F.normalize(ref, p=2, axis=1)
     cost_volume, ref = compute_cost_volume(
         mid, ref, max_displacement=self.d)  # [B, (2d+1)**2, H, W]
     cost_volume = F.dimshuffle(cost_volume, (0, 2, 3, 1))
     cost_volume = cost_volume.reshape((-1, (2 * self.d + 1)**2))
     # argmax
     indices = F.top_k(cost_volume, k=self.K,
                       descending=True)[1]  # [B*H*W, K]
     del cost_volume
     ref_list = []  # [B, C, H, W]
     origin_i_j = F.arange(0, H * W, 1)  # float32
     origin_i = F.floor(origin_i_j / W)  # (H*W, )
     origin_j = F.mod(origin_i_j, W)  # (H*W, )
     del origin_i_j
     # reshape ref
     ref = ref.reshape((B, C, (H + 2 * self.d) * (W + 2 * self.d)))
     for i in range(self.K):
         index = indices[:, i]  # [B*H*W, ]
         index = index.reshape((-1, H * W))
         index_i = F.floor(index / (2 * self.d + 1)) + origin_i  # [B, H*W]
         index_j = F.mod(index, (2 * self.d + 1)) + origin_j  # [B, H*W]
         # 根据每个pixel的i,j 算出index
         index = index_i * W + index_j  # [B, H*W]
         index = index.astype('int32')
         # add axis
         index = F.add_axis(index, axis=1)  # [B, 1, H*W]
         # broadcast
         index = F.broadcast_to(index, (B, C, H * W))
         # gather
         output = F.gather(ref, axis=2, index=index)  # [B, C, H*W]
         ref_list.append(output.reshape((B, C, H, W)))
     return self.conv(F.concat(ref_list, axis=1))
예제 #2
0
파일: evaluate.py 프로젝트: KingsYR123/YR
 def inference_func(images):
     model.eval()
     # classic test-time mirror augment
     embedding_origin = model.forward_embedding_only(images)
     embedding_mirror = model.forward_embedding_only(images[:, :, :, ::-1])
     embedding = embedding_origin + embedding_mirror
     embedding = F.normalize(embedding, axis=1)
     return embedding
예제 #3
0
파일: model.py 프로젝트: KingsYR123/YR
    def forward_embedding_only(self, images):
        """run forward pass without calculating loss, expected useful during evaluation.

        Args:
            images (Tensor): preprocessed images (shape: n * 3 * 112 * 112)

        Returns:
            embedding (Tensor): embedding
        """
        if self.use_stn:
            images = self.stn(images)
        feature_map = self.backbone(images)
        embedding = self.head(feature_map)
        embedding = F.normalize(embedding, axis=1)
        return embedding
예제 #4
0
 def forward(self, embedding):
     w = F.normalize(self.weight, axis=1)
     x = embedding  # embedding has been normalized already
     logits = F.matmul(x, w.transpose(1, 0))
     return logits