Ejemplo n.º 1
0
    def interpolated_sum(self, cnns, coords, grids):

        X = coords[:, :, 0]
        Y = coords[:, :, 1]

        cnn_outs = []
        for i in range(len(grids)):
            grid = grids[i]

            Xs = X * grid
            X0 = torch.floor(Xs)
            X1 = X0 + 1

            Ys = Y * grid
            Y0 = torch.floor(Ys)
            Y1 = Y0 + 1

            w_00 = (X1 - Xs) * (Y1 - Ys)
            w_01 = (X1 - Xs) * (Ys - Y0)
            w_10 = (Xs - X0) * (Y1 - Ys)
            w_11 = (Xs - X0) * (Ys - Y0)

            X0 = torch.clamp(X0, 0, grid - 1)
            X1 = torch.clamp(X1, 0, grid - 1)
            Y0 = torch.clamp(Y0, 0, grid - 1)
            Y1 = torch.clamp(Y1, 0, grid - 1)

            N1_id = X0 + Y0 * grid
            N2_id = X0 + Y1 * grid
            N3_id = X1 + Y0 * grid
            N4_id = X1 + Y1 * grid

            M_00 = utils.gather_feature(N1_id, cnns[i])
            M_01 = utils.gather_feature(N2_id, cnns[i])
            M_10 = utils.gather_feature(N3_id, cnns[i])
            M_11 = utils.gather_feature(N4_id, cnns[i])
            cnn_out = w_00.unsqueeze(2) * M_00 + \
                      w_01.unsqueeze(2) * M_01 + \
                      w_10.unsqueeze(2) * M_10 + \
                      w_11.unsqueeze(2) * M_11

            cnn_outs.append(cnn_out)
        concat_features = torch.cat(cnn_outs, dim=2)
        return concat_features
    def sampling(self, ids, features):
        cnn_out_feature = []
        for i in range(ids.size()[1]):
            id = ids[:, i, :]
            cnn_out = utils.gather_feature(id, features[i])
            cnn_out_feature.append(cnn_out)

        concat_features = torch.cat(cnn_out_feature, dim=2)

        return concat_features