コード例 #1
0
    def forward(
        self,
        lhs_pos: FloatTensorType,
        rhs_pos: FloatTensorType,
        lhs_neg: FloatTensorType,
        rhs_neg: FloatTensorType,
    ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]:
        num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1)
        match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim)
        match_shape(lhs_neg, num_chunks, -1, dim)
        match_shape(rhs_neg, num_chunks, -1, dim)

        # Smaller distances are higher scores, so take their negatives.
        pos_scores = (
            (lhs_pos.float() - rhs_pos.float())
            .pow_(2)
            .sum(dim=-1)
            .clamp_min_(1e-30)
            .sqrt_()
            .neg()
        )
        lhs_neg_scores = batched_all_pairs_l2_dist(rhs_pos, lhs_neg).neg()
        rhs_neg_scores = batched_all_pairs_l2_dist(lhs_pos, rhs_neg).neg()

        return pos_scores, lhs_neg_scores, rhs_neg_scores
コード例 #2
0
ファイル: model.py プロジェクト: pinellolab/simba_pbg
    def forward(
        self,
        lhs_pos: FloatTensorType,
        rhs_pos: FloatTensorType,
        lhs_neg: FloatTensorType,
        rhs_neg: FloatTensorType,
    ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]:
        num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1)
        match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim)
        match_shape(lhs_neg, num_chunks, -1, dim)
        match_shape(rhs_neg, num_chunks, -1, dim)

        # Equivalent to (but faster than) torch.einsum('cid,cid->ci', ...).
        pos_scores = (lhs_pos.float() * rhs_pos.float()).sum(-1)
        # Equivalent to (but faster than) torch.einsum('cid,cjd->cij', ...).
        lhs_neg_scores = torch.bmm(rhs_pos, lhs_neg.transpose(-1, -2))
        rhs_neg_scores = torch.bmm(lhs_pos, rhs_neg.transpose(-1, -2))

        return pos_scores, lhs_neg_scores, rhs_neg_scores