def interpolated_sum(self, cnns, coords, grids): X = coords[:, :, 0] Y = coords[:, :, 1] cnn_outs = [] for i in range(len(grids)): grid = grids[i] Xs = X * grid X0 = torch.floor(Xs) X1 = X0 + 1 Ys = Y * grid Y0 = torch.floor(Ys) Y1 = Y0 + 1 w_00 = (X1 - Xs) * (Y1 - Ys) w_01 = (X1 - Xs) * (Ys - Y0) w_10 = (Xs - X0) * (Y1 - Ys) w_11 = (Xs - X0) * (Ys - Y0) X0 = torch.clamp(X0, 0, grid - 1) X1 = torch.clamp(X1, 0, grid - 1) Y0 = torch.clamp(Y0, 0, grid - 1) Y1 = torch.clamp(Y1, 0, grid - 1) N1_id = X0 + Y0 * grid N2_id = X0 + Y1 * grid N3_id = X1 + Y0 * grid N4_id = X1 + Y1 * grid M_00 = utils.gather_feature(N1_id, cnns[i]) M_01 = utils.gather_feature(N2_id, cnns[i]) M_10 = utils.gather_feature(N3_id, cnns[i]) M_11 = utils.gather_feature(N4_id, cnns[i]) cnn_out = w_00.unsqueeze(2) * M_00 + \ w_01.unsqueeze(2) * M_01 + \ w_10.unsqueeze(2) * M_10 + \ w_11.unsqueeze(2) * M_11 cnn_outs.append(cnn_out) concat_features = torch.cat(cnn_outs, dim=2) return concat_features
def sampling(self, ids, features): cnn_out_feature = [] for i in range(ids.size()[1]): id = ids[:, i, :] cnn_out = utils.gather_feature(id, features[i]) cnn_out_feature.append(cnn_out) concat_features = torch.cat(cnn_out_feature, dim=2) return concat_features