Пример #1
0
class MultinomialModel(BaseFeatureModel):
    def __init__(self, parameters, K, features, repeat_list):
        self.params = parameters
        self.repeat_list = repeat_list
        self.initialize(K, features)

    def initialize(self, K, features):
        init_method = self.params.get("init_method", "dirichlet")
        if init_method == "uniform":
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)
        elif init_method == "dirichlet":
            init_distr_list = TensorList()
            relative_distr = features.sum(dim=1).sum_list()
            relative_distr = relative_distr / relative_distr.sum()
            for r in relative_distr:
                dir = torch.distributions.dirichlet.Dirichlet(r)
                init_distr_list.append(dir.sample((K, )))
        else:
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)

        self.distr = TensorListList(init_distr_list, repeat=self.repeat_list)

    def to(self, dev):
        self.distr = self.distr.to(dev)

    def posteriors(self, y):
        p = y.permute(1, 0) @ self.distr.permute(1, 0)  # marginalization
        return p

    def maximize(self, ap, ow, y, den):
        tmp = y / den.permute(1, 0)
        as_sum = (tmp @ (ow * ap)).sum_list()
        self.distr = self.distr * TensorListList(as_sum.permute(1, 0),
                                                 repeat=self.repeat_list)
        self.distr = self.distr / self.distr.sum(dim=1, keepdims=True)
Пример #2
0
class VonMisesModelList(BaseFeatureModel):
    def __init__(self, parameters, K, features, s, mu, repeat_list):
        self.params = parameters
        self.mu = self.initialize_mu(K, features)
        if len(mu) > 0:
            self.mu = TensorListList(mu, repeat=repeat_list)

        self.K = K
        self.repeat_list = repeat_list
        self.s2 = s * s
        self.local_posterior = 1

    def to(self, dev):
        self.mu = self.mu.to(dev)

    def initialize_mu(self, K, features):
        X = TensorList()
        for TV in features:
            Xi = np.random.randn(TV[0].shape[0], K).astype(np.float32)
            Xi = torch.from_numpy(Xi).to(TV[0].device)
            Xi = Xi / torch.norm(Xi, dim=0, keepdim=True)
            X.append(Xi.permute(1, 0))

        return X

    def posteriors(self, y):
        log_p = y.permute(1, 0) @ TensorListList(self.mu.permute(1, 0),
                                                 repeat=self.repeat_list)
        p = log_p / self.s2
        return p.exp()

    def maximize(self, a, y, den):
        self.mu = ((y @ a).sum_list()).permute(1, 0)
        self.mu = self.mu / self.mu.norm(dim=-1, keepdim=True)
        return

    def detach(self):
        self.mu = self.mu.detach()
Пример #3
0
    def register_point_sets(self, x):
        Vs = x["coordinates_ds"]
        features = x["features"]
        features_w = x["att"]
        repeat_list = [len(V) for V in Vs]
        init_R, init_t = get_init_transformation_list(
            Vs, self.params.get("mean_init", True))
        TVs = init_R @ Vs + init_t

        X = TensorList()
        Q = TensorList()
        mu = TensorList()
        for TV, Fs in zip(TVs, features):
            if self.params.cluster_init == "box":
                Xi = get_randn_box_cluster_means_list(TV, self.params.K)
            else:
                Xi = get_randn_sphere_cluster_means_list(
                    TV, self.params.K,
                    self.params.get("cluster_mean_scale", 1.0))
            Q.append(
                get_scaled_cluster_precisions_list(
                    TV, Xi, self.params.cluster_precision_scale))
            X.append(Xi.T)

        feature_distr = feature_models.MultinomialModel(
            self.params.feature_distr_parameters,
            self.params.K,
            features,
            repeat_list=repeat_list)

        feature_distr.to(self.params.device)

        X = TensorListList(X, repeat=repeat_list)
        self.betas = get_default_beta(Q, self.params.gamma)

        Vs = Vs
        TVs = TVs

        # Compute the observation weights
        if self.params.use_dare_weighting:
            observation_weights = empirical_estimate(Vs, self.params.ow_args)
            ow_reg_factor = 8.0
            ow_mean = observation_weights.mean(dim=0, keepdim=True)
            for idx in range(len(observation_weights)):
                for idxx in range(len(observation_weights[idx])):
                    observation_weights[idx][idxx][observation_weights[idx][idxx] > ow_reg_factor * ow_mean[idx][idxx]] \
                        = ow_reg_factor * ow_mean[idx][idxx]

        else:
            observation_weights = 1.0

        ds = TVs.permute(1, 0).sqe(X).permute(1, 0)

        if self.params.debug:
            self.visdom.register(
                dict(pcds=Vs[0].cpu(), X=X[0][0].cpu(), c=None),
                'point_clouds', 2, 'init')
            time.sleep(1)

        Rs = init_R.to(self.params.device)
        ts = init_t.to(self.params.device)

        self.betas = TensorListList(self.betas, repeat=repeat_list)
        QL = TensorListList(Q, repeat=repeat_list)
        Riter = TensorListList()
        titer = TensorListList()
        TVs_iter = TensorListList()
        for i in range(self.params.num_iters):
            Qt = QL.permute(1, 0)

            ap = (-0.5 * ds * QL).exp() * QL.pow(1.5)

            if i < 1000:
                pyz_feature = feature_distr.posteriors(features)
            else:
                pyz_feature = 1.0

            a = ap * pyz_feature

            ac_den = a.sum(dim=-1, keepdim=True) + self.betas
            a = a / ac_den  # normalize row-wise
            a = a * observation_weights

            L = a.sum(dim=-2, keepdim=True).permute(1, 0)
            W = (Vs @ a) * QL

            b = L * Qt  # weights, b
            mW = W.sum(dim=-1, keepdim=True)
            mX = (b.permute(1, 0) @ X).permute(1, 0)
            z = L.permute(1, 0) @ Qt
            P = (W @ X).permute(1, 0) - mX @ mW.permute(1, 0) / z

            # Compute R and t
            svd_list_list = P.cpu().svd()
            Rs = TensorListList()
            for svd_list in svd_list_list:
                Rs_list = TensorList()
                for svd in svd_list:
                    uu, vv = svd.U, svd.V
                    vvt = vv.permute(1, 0)
                    detuvt = uu @ vvt
                    detuvt = detuvt.det()
                    S = torch.ones(1, 3)
                    S[:, -1] = detuvt
                    Rs_list.append((uu * S) @ vvt)

                Rs.append(Rs_list)

            Rs = Rs.to(self.params.device)
            Riter.append(Rs)
            ts = (mX - Rs @ mW) / z
            titer.append(ts)
            TVs = Rs @ Vs + ts

            TVs_iter.append(TVs.clone())
            if self.params.debug:
                self.visdom.register(
                    dict(pcds=TVs[0].cpu(), X=X[0][0].cpu(), c=None),
                    'point_clouds', 2, 'registration-iter')
                time.sleep(0.2)
            # Update X
            den = L.sum_list()

            if self.params.fix_cluster_pos_iter < i:
                X = (TVs @ a).permute(1, 0)
                X = TensorListList(X.sum_list() / den, repeat_list)

            # Update Q
            ds = TVs.permute(1, 0).sqe(X).permute(1, 0)

            wn = (a * ds).sum(dim=-2, keepdim=True).sum_list()
            Q = (3 * den /
                 (wn.permute(1, 0) + 3 * den * self.params.epsilon)).permute(
                     1, 0)
            QL = TensorListList(Q, repeat=repeat_list)

            feature_distr.maximize(ap=ap,
                                   ow=observation_weights,
                                   y=features,
                                   den=ac_den)

        out = dict(Rs=Rs,
                   ts=ts,
                   X=X,
                   Riter=Riter[:-1],
                   titer=titer[:-1],
                   Vs=TVs,
                   Vs_iter=TVs_iter[:-1],
                   ow=observation_weights)
        return out