Esempio n. 1
0
    def forward(self, x, reverse, logdet):
        B, C, m = layers.assert_real(x, param['eig'])
        d = param['eig']
        x1 = x[:, :C // 2, ...]  # [B, C//2, m, m, d]
        x2 = x[:, C // 2:, ...]  # [B, C//2, m, m, d]
        x_ = x1.permute(0, 1, 4, 2, 3)  # [B, C//2, d, m, m]
        CC = int(C * d / 2)  # CC = C * d / 2
        x_ = x_.reshape(B, CC, m, m)  # [B, C*d/2, m, m]
        assert CC == self.inchannel
        x_ = self.m_function(x_)  # [B, C*d/2, m, m, m] -> [B, C*d, m, m, m]
        logs = x_[:, :CC, ...]  # [B, C*d/2, m, m, m]
        t = x_[:, CC:, ...]  # [B, C*d/2, m, m, m]
        logs = logs.reshape(B, C // 2, d, m, m)
        logs = logs.permute(0, 1, 3, 4, 2)  # [B, C//2, m, m, m, d]
        logs = torch.tanh(logs)
        t = t.reshape(B, C // 2, d, m, m)
        t = t.permute(0, 1, 3, 4, 2)  # [B, C//2, m, m, m, d]
        t = torch.tanh(t)

        scale = torch.exp(logs)

        if not reverse:
            x2 = x2 + t
            x2 = x2 * scale
            x = torch.cat([x1, x2], dim=1)  # -> [B, C, m, m, d]
            dlogdet = logs.sum(dim=(1, 2, 3, 4))  # [B,]
            return x, logdet + param['reg'] * dlogdet
        else:
            x2 = x2 / scale
            x2 = x2 - t
            x = torch.cat([x1, x2], dim=1)  # -> [B, C, m, m, m, d]
            return x
Esempio n. 2
0
    def prior(self, dti, eig, odf=None, sample=False):
        """
        only define distribution on y, shape [B, C, m, m, 2],
        use convolution to learn the distribution and conditioned on x's distribution
        """
        B, C, m = layers.assert_real(dti, param['dti'])
        gdti = layers.gaussian_real(mean=torch.zeros_like(dti).to(dti.device),
                                    logsd=torch.zeros_like(dti).to(dti.device),
                                    dim=param['dti'])

        geig = layers.gaussian_real(mean=torch.zeros_like(eig).to(eig.device),
                                    logsd=torch.zeros_like(eig).to(eig.device),
                                    dim=param['eig'])

        y = torch.cat([dti, eig], dim=-1)  # [B, C, m, m, dti+eig]
        y = y.reshape(B * C, m, m, param['dti'] + param['eig'])
        y = y.permute(0, 3, 1, 2)  # [B * C, dti+eig, m, m]
        y = self.dti2odf(y)  # [B * C, 2*odf, m, m]
        y = y.reshape(B, C, 2 * param['odf'], m, m)
        y = y.permute(0, 1, 3, 4, 2)  # [B, C, m, m, 2 * odf]
        xmean = y[..., :param['odf']]  # [B, C, m, m, odf]
        xlogstd = y[..., param['odf']:]
        # xlogstd = torch.tanh(xlogstd)
        godf = layers.gaussian_real(mean=xmean,
                                    logsd=xlogstd,
                                    dim=param['odf'])
        if not sample:
            logpdti = gdti.logp(dti)
            logpeig = geig.logp(eig)
            logpodf = godf.logp(odf)
            return param['reg'] * logpdti + param['reg'] * logpeig + logpodf
        else:
            sample = godf.sample
            return sample
Esempio n. 3
0
 def expmap(self, v):
     B, C, m = layers.assert_real(v, param['odf'])
     d = param['odf'] + 1
     north_pole = torch.zeros(1, 1, 1, 1, d).to(v.device)
     north_pole[..., 0] = 1
     v_add_dim = torch.zeros(B, C, m, m, d).to(v.device)
     v_add_dim[..., 1:] = v
     odf = torch.zeros(B, C, m, m, d).to(v.device)
     v_norm = torch.norm(v_add_dim, dim=-1).unsqueeze(-1)  # [B, C, m, m, m]
     cos_norm = torch.cos(v_norm)
     sin_norm = torch.sin(v_norm)
     odf = cos_norm * north_pole + sin_norm * v_add_dim / (v_norm + epsilon)
     odf[..., 0] = torch.abs(odf[..., 0])
     return odf
Esempio n. 4
0
 def logmap(self, odf):  # NOTE
     """
     :param odf: [B, C, m, m, m, d]
     """
     d = param['odf'] + 1
     B, C, m = layers.assert_real(odf, d)
     north_pole = torch.zeros(1, 1, 1, 1, d).to(odf.device)
     north_pole[..., 0] = 1
     theta = torch.acos(odf[..., 0])  # [B, C, m, m, m]
     cos_theta = odf[..., 0]  # [B, C, m, m, m]
     sin_theta = torch.sin(theta)  # [B, C, m, m, m]
     odf = odf - north_pole * cos_theta.unsqueeze(-1)
     odf = odf * theta.unsqueeze(-1) / (sin_theta.unsqueeze(-1) + epsilon)
     odf = odf[..., 1:]  # [B, C, m, m, m, d-1]
     return odf