예제 #1
0
 def sigma_tf(self, t, X, Y):  # M x 1, M x D, M x 1
     M = self.M
     D = self.D
     return torch.diag_embed(torch.ones([M,
                                         D])).to(self.device)  # M x D x D
 def covariance_matrix(self):
     covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
                                       self._unbroadcasted_cov_factor.transpose(-1, -2))
                          + torch.diag_embed(self._unbroadcasted_cov_diag))
     return covariance_matrix.expand(self._batch_shape + self._event_shape +
                                     self._event_shape)
예제 #3
0
def test_diagflat(dtype, k):
    backend = pytorch_backend.PyTorchBackend()
    array = backend.randn((16, ), dtype=dtype, seed=10)
    actual = backend.diagflat(array, k=k)
    expected = torch.diag_embed(array, offset=k)
    np.testing.assert_allclose(expected, actual)
예제 #4
0
 def _eval_covar_matrix(self):
     cf = self.covar_factor
     return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var)
예제 #5
0
    def forward(self, input, derivative=0):
        """Compute the pairwise distance between the electrons
        or its derivative. \n

        When required, the derivative is computed wrt to the first electron i.e.

        .. math::
            \\frac{dr_{ij}}{dx_i}

        which is different from :

        .. math::
            \\frac{d r_{ij}}{dx_j}

        Args:
            input (torch.tesnor): position of the electron \n
                                  size : Nbatch x [Nelec x Ndim]
            derivative (int, optional): degre of the derivative. \n
                                        Defaults to 0.

        Returns:
            torch.tensor: distance (or derivative) matrix \n
                          Nbatch x Nelec x Nelec if derivative = 0 \n
                          Nbatch x Ndim x  Nelec x Nelec if derivative = 1,2

        """

        # get the distance matrices
        input_ = input.view(-1, self.nelec, self.ndim)
        dist = self._get_distance_quadratic(input_)

        # eosilon on the diag needed for back prop
        eps_ = self.eps * \
            torch.diag(dist.new_ones(dist.shape[-1])).expand_as(dist)

        # extact the diagonal as diag can be negative someties
        # due to numerical noise
        diag = torch.diag_embed(torch.diagonal(dist, dim1=-1, dim2=-2))

        # remove diagonal and add eps for backprop
        dist = torch.sqrt(dist - diag + eps_)

        if derivative == 0:
            return dist

        elif derivative == 1:

            eps_ = self.eps * \
                torch.diag(dist.new_ones(
                    dist.shape[-1])).expand_as(dist)

            invr = (1. / (dist + eps_)).unsqueeze(1)
            diff_axis = input_.transpose(1, 2).unsqueeze(3)
            diff_axis = diff_axis - diff_axis.transpose(2, 3)
            return diff_axis * invr

        elif derivative == 2:

            eps_ = self.eps * \
                torch.diag(dist.new_ones(
                    dist.shape[-1])).expand_as(dist)
            invr3 = (1. / (dist**3 + eps_)).unsqueeze(1)
            diff_axis = input_.transpose(1, 2).unsqueeze(3)
            diff_axis = (diff_axis - diff_axis.transpose(2, 3))**2

            diff_axis = diff_axis[:, [[1, 2], [2, 0], [0, 1]], ...].sum(2)
            return (diff_axis * invr3)
예제 #6
0
 def backward(ctx, grads):
     # (T, T, N, H) -> (N, H, T)
     grads = torch.diagonal(grads, dim1=0, dim2=1)
     # (N, H, T, T) -> (T, T, N, H)
     grads = torch.diag_embed(grads).permute(2, 3, 0, 1)
     return grads
예제 #7
0
def glasso_predict(data, COLLECT=True):
    print('Running ADMM: augmented lagrangian')
    # predict as a complete batch?
    criterion = nn.MSELoss()  # input, target
    m_sig = nn.Sigmoid()
    criterionBCE = nn.BCELoss()
    theta, S = data
    #theta, S = theta[0], S[0]
    # theta -> K_train x N x N (Matrix)
    # S -> K_train x N x N (observed vector)
    # train using ALISTA style training.
    lambda_f = args.lambda_init
    rho_l1 = args.rho
    alpha_lr = args.alpha_lr

    theta_true = convert_to_torch(theta, TESTING_FLAG=True)
    S = convert_to_torch(S, TESTING_FLAG=True)

    if args.INIT_DIAG == 1:
        print(' extract batchwise diagonals, add offset and take inverse')
        batch_diags = 1 / (torch.diagonal(S, offset=0, dim1=-2, dim2=-1) +
                           args.theta_init_offset)
        theta_init = torch.diag_embed(batch_diags)
    else:
        print('***************** (S+theta_offset*I)^-1 is used')
        theta_init = torch.inverse(S + args.theta_init_offset *
                                   torch.eye(args.N).expand_as(S).type_as(S))
    zero = torch.Tensor([0])  #.type(self.dtype)
    if USE_CUDA == True:
        zero = zero.cuda()

    # batch size is fixed for testing as 1
    num_batches = 1  #int(len(theta_true)/args.batch_size)
    #    print('num batches: ', num_batches)
    epoch_loss = []
    mse_binary_loss = []
    bce_loss = []
    frob_loss = []
    duality_gap = []
    ans = []
    #for batch_num in range(num_batches): # processing batchwise
    # Get a batch
    #ll = my_cholesky(theta_init[ridx][0])#(theta_pred) # lower triangular
    theta_pred = theta_init  #[batch_num*args.batch_size: (batch_num+1)*args.batch_size] #(theta_pred) # lower triangular

    Sb = S  #[batch_num*args.batch_size: (batch_num+1)*args.batch_size]#[0]
    U = torch.zeros(Sb.shape).type(Sb.type())
    #    print('sb type: ', Sb.dtype, Sb.type(), ' U:', U.dtype, U.type())

    identity_mat = torch.eye(Sb.shape[-1]).expand_as(Sb)
    #print('err: ', identity_mat, identity_mat.shape, Sb.shape, S.shape)
    if USE_CUDA == True:
        identity_mat = identity_mat.cuda()
#    print('ITR, conv.loss, obj_val_pred, obj_val_true', theta_pred)
    res_conv = []
    obj_true = get_obj_val(theta_true, S, rho_l1)
    for k in range(args.L):
        start = time.time()
        if COLLECT:
            theta_pred_diag = torch.diag_embed(
                torch.diagonal(theta_pred, offset=0, dim1=-2, dim2=-1))
            theta_true_diag = torch.diag_embed(
                torch.diagonal(theta_true, offset=0, dim1=-2, dim2=-1))
            cv_loss, cv_loss_off_diag, obj_pred = get_convergence_loss(
                theta_pred, theta_true), get_convergence_loss(
                    theta_pred - theta_pred_diag,
                    theta_true - theta_true_diag), get_obj_val(
                        theta_pred, S, rho_l1)
            #            print(k, '%.3f %.3f %.3f %.3f' %(cv_loss, obj_pred, obj_true_rho, obj_true_orig))
            res_conv.append([cv_loss, obj_pred, obj_true, cv_loss_off_diag])
        #print(k, '%.3f %.3f %.3f' %(get_convergence_loss(theta_pred, theta_true), get_obj_val(theta_pred, S, rho_l1), get_obj_val(theta_true, S, rho_l1)))


#       print('itr = ', itr, theta_pred)#, theta_true[ridx])
# step 1 : ADMM
        b = 1.0 / lambda_f * Sb - theta_pred + U
        b2_4ac = torch.matmul(b.transpose(-1, -2),
                              b) + 4.0 / lambda_f * identity_mat
        sqrt_term = torch_sqrtm(b2_4ac)  #batch_matrix_sqrt(b2_4ac)
        theta_k1 = 1.0 / 2 * (-1 * b + sqrt_term)
        # step 2 : ADMM
        #        theta_pred = model.eta_forward(theta_k1, k)
        theta_pred = torch.sign(theta_k1 + U) * torch.max(
            zero,
            torch.abs(theta_k1 + U) - rho_l1 / lambda_f)  #Z_k1
        # step 3: ADMM
        #lambda_f = lambda_f + alpha_lr * 0.5*get_frobenius_norm(theta_pred-theta_k1)
        U = U + theta_k1 - theta_pred

    print(
        'fdr, tpr, fpr, shd, nnz, nnz/theta_pred.size, nnz_true,nnz_true/theta_true.size, ps, np.linalg.cond(theta_pred), np.linalg.cond(theta_true)'
    )
    theta_pred = theta_pred.data.cpu().numpy()
    theta_true = theta_true.data.cpu().numpy()
    #    print('ADMM: ', theta_pred, ' condition number = ', np.linalg.cond(theta_pred))#, ' nnz = ', np.count_nonzero(theta_pred), ' nnz% = ', np.count_nonzero(theta_pred)/theta_pred.size)
    #    print('true: ', theta_true, ' condition_number = ', np.linalg.cond(theta_true))#, ' nnz = ', np.count_nonzero(theta_true), ' nnz% = ', np.count_nonzero(theta_true)/theta_true.size)
    fdr, tpr, fpr, shd, nnz, nnz_true, ps = metrics.report_metrics(
        theta_true, theta_pred)
    cond_theta_pred, cond_theta_true = np.linalg.cond(
        theta_pred), np.linalg.cond(theta_true)
    print(fdr, tpr, fpr, shd, nnz, nnz / theta_pred.size, nnz_true,
          nnz_true / theta_true.size, ps, cond_theta_pred, cond_theta_true)
    return [
        fdr, tpr, fpr, shd, nnz, nnz_true, ps, cond_theta_pred, cond_theta_true
    ], res_conv
예제 #8
0
def gista_glasso(data, c=args.c, eps=1e-12, COLLECT=True):
    c = torch.Tensor([c])
    #    print('Running Gista glasso')
    # predict a single matrix
    criterion = nn.MSELoss()  # input, target
    theta, S = data
    #theta = theta[0]
    #theta, S = theta[0], S[0]
    # theta -> K_train x N x N (Matrix)
    # S -> K_train x N x N (observed vector)
    theta_true = convert_to_torch(theta, TESTING_FLAG=True)
    S = convert_to_torch(S, TESTING_FLAG=True)
    # extract batchwise diagonals, add offset and take inverse
    #batch_diags = 1/(torch.diagonal(S, offset=0, dim1=-2, dim2=-1) + args.theta_init_offset)
    #theta_init = torch.diag_embed(batch_diags)
    if args.INIT_DIAG == 1:
        print(' extract batchwise diagonals, add offset and take inverse')
        batch_diags = 1 / (torch.diagonal(S, offset=0, dim1=-2, dim2=-1) +
                           args.theta_init_offset)
        theta_init = torch.diag_embed(batch_diags)
    else:
        print('***************** (S+theta_offset*I)^-1 is used')
        theta_init = torch.inverse(S + args.theta_init_offset *
                                   torch.eye(args.N).expand_as(S).type_as(S))
    #print('err: ', S, batch_diags, theta_init)
    zero = torch.Tensor([0])  #.type(self.dtype)
    if USE_CUDA == True:
        zero = zero.cuda()
        c = c.cuda()

    epoch_loss = []
    frob_loss = []
    duality_gap = []
    ans = []

    #ll = my_cholesky(theta_init[0])#(theta_pred) # lower triangular
    ll = torch.cholesky(theta_init)  #(theta_pred) # lower triangular
    #Sb = S[0]
    #***********
    theta_pred = torch.matmul(ll, ll.transpose(-1, -2))
    min_eig = torch.min(torch.eig(theta_pred)[0][:, 0])
    # All the inputs are ready for the G-ISTA
    delta = get_duality_gap(theta_pred, S)  # initial duality gap
    step_size = min_eig**2
    #print('err2: ', ll, theta_pred, logdet_eig(theta_init))
    #    print('Checking init delta, step_size, c: ', delta, step_size, c, theta_pred.shape, S.shape)#, theta, S, ll, torch.cholesky(theta_init[0]))
    #    print('ITR, duality_gap, conv.loss, obj_val_pred, obj_val_true')#, theta, S, ll, torch.cholesky(theta_init[0]))
    epoch = 0
    res_conv = []
    obj_true = get_obj_val(theta_true, S)
    while (delta > eps and epoch < args.MAX_EPOCH):
        #while(epoch < args.MAX_EPOCH):
        start = time.time()
        if COLLECT:
            theta_pred_diag = torch.diag_embed(
                torch.diagonal(theta_pred, offset=0, dim1=-2, dim2=-1))
            theta_true_diag = torch.diag_embed(
                torch.diagonal(theta_true, offset=0, dim1=-2, dim2=-1))
            cv_loss, cv_loss_off_diag, obj_pred = get_convergence_loss(
                theta_pred, theta_true), -1, get_obj_val(theta_pred, S)
            #            cv_loss, cv_loss_off_diag, obj_pred = get_convergence_loss(theta_pred, theta_true), get_convergence_loss(theta_pred-theta_pred_diag, theta_true-theta_true_diag), get_obj_val(theta_pred, S)
            #            print(k, '%.3f %.3f %.3f %.3f' %(cv_loss, obj_pred, obj_true_rho, obj_true_orig))
            res_conv.append([cv_loss, obj_pred, obj_true, cv_loss_off_diag])
        #print(epoch, '%.10f %0.3f %.3f %.3f %.3f' %(delta, step_size, get_convergence_loss(theta_pred, theta_true), get_obj_val(theta_pred, S), get_obj_val(theta_true, S)))
#        print(epoch, '%.10f %.3f %.3f %.3f' %(delta, get_convergence_loss(theta_pred, theta_true), get_obj_val(theta_pred, S), get_obj_val(theta_true, S)))
#        print('INEQ check: ', delta>eps)
        theta_pred = torch.matmul(ll, ll.transpose(-1, -2))
        theta_prev = theta_pred.clone()
        # Step 1 & 2: line search and update theta
        diff_term = S - torch.inverse(theta_pred)
        update_flag = 0
        #        print('EigVAL epoch = ', epoch, 'Step size: ', step_size, ' min eigvalue^2: ', min_eig**2)
        #for j in torch.arange(1, 0, -0.1):
        #for j in np.arange(1, 0, -0.1):
        for j in np.arange(1, 10):
            #cj = j
            cj = c**j
            #            print('j = ', j, ' cj = ', cj)
            next_theta = eta(theta_pred - (cj) * step_size * diff_term,
                             (cj) * step_size)
            if check_conditions(next_theta, theta_prev, S,
                                (cj) * step_size) == 1:
                # conditions satisfied
                theta_pred = next_theta
                update_flag = 1
                break
        if update_flag == 0:
            print('**********changing the step size to min eigval')
            min_eig = torch.min(torch.eig(theta_pred)[0][:, 0])
            step_size = min_eig**2
            #next_pred = eta(theta_pred-step_size*diff_term, step_size)
            #while check_conditions(next_theta, theta_prev, S, step_size) == 0:
            #    next_pred = eta(theta_pred-step_size*diff_term, step_size)
            #    print('reducing step_size by 1/2', step_size)
            #    step_size = 0.5*step_size
            #theta_pred = next_pred #eta(theta_pred-step_size*diff_term, step_size)
            theta_pred = eta(theta_pred - step_size * diff_term, step_size)
        # Step 3: set next step size
        #ll = my_cholesky(theta_pred)
        ll = torch.cholesky(theta_pred)
        step_size = get_step_size(theta_pred, theta_prev)
        # Step 4: Calc the duality gap
        delta = get_duality_gap(theta_pred, S)
        #        print('DELTA epoch = ', epoch, ' delta = ', delta, ' step = ', step_size)
        epoch += 1


#        print('Walltime ', epoch, time.time()-start)
    theta_pred = theta_pred.data.cpu().numpy()
    #    print('G-ISTA: condition number = ', np.linalg.cond(theta_pred), ' nnz = ', np.count_nonzero(theta_pred), ' nnz% = ', np.count_nonzero(theta_pred)/theta_pred.size)
    #print('Sample cov inv: ', S, torch.inverse(S))
    theta_true = theta_true.data.cpu().numpy()
    #    print('true: condition_number = ', np.linalg.cond(theta_true), ' nnz = ', np.count_nonzero(theta_true), ' nnz% = ', np.count_nonzero(theta_true)/theta_true.size)

    fdr, tpr, fpr, shd, nnz, nnz_true, ps = metrics.report_metrics(
        theta_true, theta_pred)
    cond_theta_pred, cond_theta_true = np.linalg.cond(
        theta_pred), np.linalg.cond(theta_true)
    print('Accuracy metrics: fdr ', fdr, ' tpr ', tpr, ' fpr ', fpr, ' shd ',
          shd, ' nnz ', nnz, ' nnz_true ', nnz_true, ' sign_match ', ps,
          ' pred_cond ', cond_theta_pred, ' true_cond ', cond_theta_true)
    #    print(fdr, tpr, fpr, shd, nnz, nnz_true)
    #    print('loss_summary:: ', sum(epoch_loss)/len(epoch_loss), ' Mean Frobenius loss: ',sum(frob_loss)/len(frob_loss), ' duality gap = ', sum(duality_gap)/len(duality_gap))
    #print('loss_summary:: ', sum(epoch_loss)/len(epoch_loss), ' Mean Frobenius loss: ',sum(frob_loss)/len(frob_loss))
    #10*np.log10( (np.sum(np.array(epoch_loss)))/(len(epoch_loss)*E_norm_xtrue)))
    return [
        fdr, tpr, fpr, shd, nnz, nnz_true, ps, cond_theta_pred, cond_theta_true
    ], res_conv
예제 #9
0
    def forward(self, xf, debug=False, manual_debug=False):  # gtvforward
        s = self.weight_sigma
        if self.opt.legacy:
            u = self.cnnu.forward(xf)
            u = u.unsqueeze(1).unsqueeze(1)
        else:
            u = self.uu.forward()
        u_max = self.opt.u_max
        u_min = self.opt.u_min
        if debug:
            self.u = u.clone()
        u = torch.clamp(u, u_min, u_max)

        z = self.opt.H.matmul(
            xf.view(xf.shape[0], xf.shape[1], self.opt.width**2, 1))

        ###################
        E = self.cnnf.forward(xf)
        Fs = (self.opt.H.matmul(
            E.view(E.shape[0], E.shape[1], self.opt.width**2, 1))**2)

        w = torch.exp(-(Fs.sum(axis=1)) / (s**2))
        if debug:
            s = f"Sample WEIGHT SUM: {w[0, :, :].sum().item():.4f} || Mean Processed u: {u.mean().item():.4f}"
            self.logger.info(s)

        w = w.unsqueeze(1).repeat(1, self.opt.channels, 1, 1)

        W = self.base_W.clone()
        Z = W.clone()
        W[:, :, self.opt.connectivity_idx[0],
          self.opt.connectivity_idx[1]] = w.view(xf.shape[0], 3, -1)
        W[:, :, self.opt.connectivity_idx[1],
          self.opt.connectivity_idx[0]] = w.view(xf.shape[0], 3, -1)
        Z[:, :, self.opt.connectivity_idx[0],
          self.opt.connectivity_idx[1]] = torch.abs(z.view(xf.shape[0], 3, -1))
        Z[:, :, self.opt.connectivity_idx[1],
          self.opt.connectivity_idx[0]] = torch.abs(z.view(xf.shape[0], 3, -1))
        Z = torch.max(Z, self.support_zmax)
        L = W / Z

        L1 = L @ self.support_L
        L = torch.diag_embed(L1.squeeze(-1)) - L

        ########################
        y = xf.view(xf.shape[0], self.opt.channels, -1, 1)
        ########################

        xhat = self.qpsolve(L, u, y, self.support_identity, self.opt.channels)

        # GLR 2
        def glr(y, w, u, debug=False, return_dict=None):
            W = self.base_W.clone()
            z = self.opt.H.matmul(y)
            Z = W.clone()
            W[:, :, self.opt.connectivity_idx[0],
              self.opt.connectivity_idx[1]] = w.view(xf.shape[0], 3, -1)
            W[:, :, self.opt.connectivity_idx[1],
              self.opt.connectivity_idx[0]] = w.view(xf.shape[0], 3, -1)
            Z[:, :, self.opt.connectivity_idx[0],
              self.opt.connectivity_idx[1]] = torch.abs(
                  z.view(xf.shape[0], 3, -1))
            Z[:, :, self.opt.connectivity_idx[1],
              self.opt.connectivity_idx[0]] = torch.abs(
                  z.view(xf.shape[0], 3, -1))
            Z = torch.max(Z, self.support_zmax)
            L = W / Z
            L1 = L @ self.support_L
            L = torch.diag_embed(L1.squeeze(-1)) - L

            xhat = self.qpsolve(L, u, y, self.support_identity,
                                self.opt.channels)
            return xhat

        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        xhat = glr(xhat, w, u)
        return xhat.view(xhat.shape[0], self.opt.channels, self.opt.width,
                         self.opt.width)
예제 #10
0
    def test_GPyTorchPosterior(self):
        for dtype in (torch.float, torch.double):
            n = 3
            mean = torch.rand(n, dtype=dtype, device=self.device)
            variance = 1 + torch.rand(n, dtype=dtype, device=self.device)
            covar = variance.diag()
            mvn = MultivariateNormal(mean, lazify(covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, self.device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([n, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            self.assertTrue(torch.equal(posterior.variance, variance.unsqueeze(-1)))
            # rsample
            samples = posterior.rsample()
            self.assertEqual(samples.shape, torch.Size([1, n, 1]))
            for sample_shape in ([4], [4, 2]):
                samples = posterior.rsample(sample_shape=torch.Size(sample_shape))
                self.assertEqual(samples.shape, torch.Size(sample_shape + [n, 1]))
            # check enabling of approximate root decomposition
            with ExitStack() as es:
                mock_func = es.enter_context(
                    mock.patch(
                        ROOT_DECOMP_PATH, return_value=torch.linalg.cholesky(covar)
                    )
                )
                es.enter_context(gpt_settings.max_cholesky_size(0))
                es.enter_context(
                    gpt_settings.fast_computations(covar_root_decomposition=True)
                )
                # need to clear cache, cannot re-use previous objects
                mvn = MultivariateNormal(mean, lazify(covar))
                posterior = GPyTorchPosterior(mvn=mvn)
                posterior.rsample(sample_shape=torch.Size([4]))
                mock_func.assert_called_once()

            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 1, device=self.device, dtype=dtype)
            # incompatible shapes
            with self.assertRaises(RuntimeError):
                posterior.rsample(
                    sample_shape=torch.Size([3]), base_samples=base_samples
                )
            # ensure consistent result
            for sample_shape in ([4], [4, 2]):
                base_samples = torch.randn(
                    *sample_shape, 3, 1, device=self.device, dtype=dtype
                )
                samples = [
                    posterior.rsample(
                        sample_shape=torch.Size(sample_shape), base_samples=base_samples
                    )
                    for _ in range(2)
                ]
                self.assertTrue(torch.allclose(*samples))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
            b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=self.device)
            b_covar = torch.diag_embed(b_variance)
            b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 1, 3, 1, device=self.device, dtype=dtype)
            b_samples = b_posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=b_base_samples
            )
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
예제 #11
0
 def lazy_covariance_matrix(self):
     """Get lazy covariance matrix."""
     return CholLazyTensor(torch.diag_embed(self.variance))
예제 #12
0
 def _make_gpytorch_posterior(self, shape, dtype):
     mean = torch.rand(*shape, dtype=dtype, device=self.device)
     variance = 1 + torch.rand(*shape, dtype=dtype, device=self.device)
     covar = torch.diag_embed(variance)
     mvn = MultivariateNormal(mean, lazify(covar))
     return GPyTorchPosterior(mvn=mvn)
예제 #13
0
    def predict(self):
        dataset = self.dataset["test"]
        dataloader = dataset.batch_delivery(self.cfg.batch_size)
        self.model.eval()

        # Slot-one-hot distribution.
        # [slot_num, vocab_len]
        slot_word_dist = F.log_softmax(
            torch.FloatTensor(
                self.model.get_unnormalized_phi()),
            dim=-1)
        assert torch.isnan(slot_word_dist).sum().item() == 0

        # Slot-ce distribution.
        # [slot_num, feature_dim]
        slot_mean_dist = torch.FloatTensor(
            self.model.get_beta_mean())
        # [slot_num, emb_dim]
        slot_stdvar_dist = torch.FloatTensor(
            self.model.get_beta_logvar()).exp().sqrt()
        if self.cfg.use_gpu:
            slot_word_dist = slot_word_dist.cuda()
            slot_mean_dist = slot_mean_dist.cuda()
            slot_stdvar_dist = slot_stdvar_dist.cuda()
        slot_emb_dist = [
            MultivariateNormal(
                loc=slot_mean_dist[k],
                covariance_matrix=torch.diag_embed(
                    slot_stdvar_dist[k])) for k in range(
                self.cfg.slot_num)]

        predictions = []

        with torch.no_grad():
            pbar = tqdm(
                dataloader,
                desc=f"Validating progress")
            for features, candidates in pbar:
                oh, ce, mask, lens, candis = dataset.add_padding(
                    features, candidates)
                if self.cfg.use_gpu:
                    oh = oh.cuda()
                    ce = ce.cuda()
                    mask = mask.cuda()
                domains, slots = self.model(
                    oh, ce, mask, compute_loss=False)
                # [batch_size, padded_candi_len, slot_num]
                padded_logps = torch.stack([slots[:, k].unsqueeze(-1) +
                                            slot_emb_dist[k].log_prob(ce) +
                                            slot_word_dist[k][candis]
                                            for k in range(self.cfg.slot_num)], dim=-1).cpu()
                assert torch.isnan(padded_logps).sum().item() == 0

                # iterating over batch_size
                for i in range(oh.shape[0]):
                    domain = domains[i]
                    true_len = lens[i]
                    # [candi_len]
                    true_candis = candis[i, :true_len]
                    # [candi_len, slot_num]
                    logps = padded_logps[i, :true_len]
                    probs = torch.softmax(logps, dim=-1)

                    prediction = [{"domain": domain.item(),
                                   "slot": torch.argmax(prob).item(),
                                   "prob": prob.data.numpy(),
                                   "word": self.vocab.itos[candi_index],
                                   } for candi_index,
                                  prob in zip(true_candis, probs)]

                    predictions.append(prediction)

                pbar.update(1)
        return predictions
 def subtract_diag(self, gram):
     diag_elements = torch.diag_embed(torch.diagonal(gram, 0))
     gram -= diag_elements
     return gram
예제 #15
0
def gmm_loss(log_market_shares,
             prod_char,
             market_masks,
             own_mat,
             model_ids_to_rows,
             beta_mat,
             xw_mat,
             z_mat,
             mdraws,
             mweights,
             theta,
             gmm_weights,
             delta_0=None,
             full_output=False,
             atol=1e-06,
             rtol=1e-06):

    delta, __ = batch_invert_shares(log_market_shares,
                                    prod_char,
                                    market_masks,
                                    mdraws,
                                    mweights,
                                    theta,
                                    delta_0=delta_0,
                                    atol=atol,
                                    rtol=rtol)

    s, cs, ws = batch_shares(delta,
                             prod_char,
                             market_masks,
                             mdraws,
                             mweights,
                             theta,
                             full_output=True)

    ww = ws * mdraws[:, 0:1, :] * theta[0]
    Jp = torch.diag_embed(ww.sum(dim=2)) - torch.bmm(ww,
                                                     torch.transpose(cs, 1, 2))
    OJp = Jp * own_mat + torch.diag_embed(1 - market_masks)
    eta, __ = torch.solve(-s[:, :, None], OJp)
    eta = eta.squeeze()
    costs = prod_char[:, :, 0] - eta

    mc = torch.masked_select(costs, market_masks.bool())
    mc[(mc < 0).detach()] = 0.001
    log_mc = torch.log(mc)

    y = torch.cat((torch.masked_select(delta, market_masks.bool()), log_mc), 0)
    beta = beta_mat @ y
    res = y - xw_mat @ beta

    g_hat = z_mat * res[:, None]
    g_hat_agg = model_ids_to_rows @ g_hat
    g_hat_mean = g_hat_agg.mean(dim=0)

    loss = g_hat_mean.t() @ gmm_weights @ g_hat_mean

    if full_output:
        with torch.no_grad():
            fact = 1.0 / (g_hat_agg.size(0) - 1)
            cov_mat = g_hat_agg - g_hat_mean
            cov_mat = fact * cov_mat.t().matmul(cov_mat)
            cov_mat = torch.inverse(cov_mat)
        return loss, delta, beta, cov_mat
    else:
        return loss, delta, beta
예제 #16
0
def fast_gaussian_multi_integral(
        mu0: torch.Tensor,
        mu1: torch.Tensor,
        var0: torch.Tensor,
        var1: torch.Tensor,
        forward: bool = True,
        need_zeta: bool = True,
        diag1: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    dim0 = mu0.size()[-1]
    dim1 = mu1.size()[-1]

    assert dim0 > dim1

    if forward:
        mu_pad = (0, dim1)
        var_pad = (0, dim1, 0, dim1)
    else:
        mu_pad = (dim1, 0)
        var_pad = (dim1, 0, dim1, 0)

    lambda0 = torch.inverse(var0)
    if diag1:
        lambda1 = torch.diag_embed(1.0 / var1).float()
    else:
        lambda1 = torch.inverse(var1)

    # use new variable avoid in-place operation
    mu0_new = mu0.unsqueeze(-1)
    mu1_new = mu1.unsqueeze(-1)
    eta0 = torch.matmul(lambda0, mu0_new).squeeze(-1)
    eta1 = torch.matmul(lambda1, mu1_new).squeeze(-1)
    # mu0_new = mu0_new.squeeze(-1)

    eta1_pad = F.pad(eta1, mu_pad, 'constant', 0)
    lambda1_pad = F.pad(lambda1, var_pad, 'constant', 0)

    lambda_new = lambda0 + lambda1_pad
    eta_new = eta0 + eta1_pad

    sigma_new = torch.inverse(lambda_new)
    mu_new = torch.matmul(sigma_new, eta_new.unsqueeze(-1)).squeeze(-1)

    if need_zeta:
        zeta0 = calculate_zeta(eta0, lambda0, mu=mu0)
        zeta1 = calculate_zeta(eta1, lambda1, mu=mu1)
        zeta_new = calculate_zeta(eta_new, lambda_new, sig=sigma_new)
        scale = zeta0 + zeta1 - zeta_new
    else:
        scale = None

    select = 1 if forward else 0

    res_mu = torch.split(mu_new, split_size_or_sections=[dim1, dim1],
                         dim=-1)[select]
    res_sigma = torch.split(torch.split(sigma_new,
                                        split_size_or_sections=[dim1, dim1],
                                        dim=-2)[select],
                            split_size_or_sections=[dim1, dim1],
                            dim=-1)[select]

    return scale, res_mu, res_sigma
    def test_symeig(self):
        dtypes = {"double": torch.double, "float": torch.float}
        for name, dtype in dtypes.items():
            tolerances = self.tolerances["symeig"][name]

            lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(
                True)
            lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(
                True)
            evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy)

            # Perform forward pass
            with linalg_dtypes(dtype):
                evals_unsorted, evecs_unsorted = lazy_tensor.symeig(
                    eigenvectors=True)
                evecs_unsorted = evecs_unsorted.evaluate()

            # since LazyTensor.symeig does not sort evals, we do this here for the check
            evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False)
            evecs = torch.gather(
                evecs_unsorted,
                dim=-1,
                index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape),
            )

            evals_actual, evecs_actual = torch.linalg.eigh(
                evaluated.type(dtype))
            evals_actual = evals_actual.to(dtype=evaluated.dtype)
            evecs_actual = evecs_actual.to(dtype=evaluated.dtype)

            # Check forward pass
            self.assertAllClose(evals, evals_actual, **tolerances)
            lt_from_eigendecomp = evecs @ torch.diag_embed(
                evals) @ evecs.transpose(-1, -2)
            self.assertAllClose(lt_from_eigendecomp, evaluated, **tolerances)

            # if there are repeated evals, we'll skip checking the eigenvectors for those
            any_evals_repeated = False
            evecs_abs, evecs_actual_abs = evecs.abs(), evecs_actual.abs()
            for idx in itertools.product(
                    *[range(b) for b in evals_actual.shape[:-1]]):
                eval_i = evals_actual[idx]
                if torch.unique(eval_i.detach()).shape[-1] == eval_i.shape[
                        -1]:  # detach to avoid pytorch/pytorch#41389
                    self.assertAllClose(evecs_abs[idx], evecs_actual_abs[idx],
                                        **tolerances)
                else:
                    any_evals_repeated = True

            # Perform backward pass
            symeig_grad = torch.randn_like(evals)
            ((evals * symeig_grad).sum()).backward()
            ((evals_actual * symeig_grad).sum()).backward()

            # Check grads if there were no repeated evals
            if not any_evals_repeated:
                for arg, arg_copy in zip(lazy_tensor.representation(),
                                         lazy_tensor_copy.representation()):
                    if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
                        self.assertAllClose(arg.grad, arg_copy.grad,
                                            **tolerances)

            # Test with eigenvectors=False
            _, evecs = lazy_tensor.symeig(eigenvectors=False)
            self.assertIsNone(evecs)
예제 #18
0
def normalize_adj(A):
    D_in = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=1)))
    D_out = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=2)))
    DA = stacked_spmm(D_in, A)  # swap D_in and D_out
    DAD = stacked_spmm(DA, D_out)
    return DAD
예제 #19
0
def axe_loss(logits: torch.FloatTensor,
             logit_lengths: torch.Tensor,
             targets: torch.LongTensor,
             target_lengths: torch.Tensor,
             blank_index: torch.LongTensor,
             delta: torch.FloatTensor,
             reduction: str = 'mean',
             label_smoothing: float = None,
             return_a: bool = False
            ) -> Union[torch.FloatTensor, List[torch.Tensor]]:
    """Aligned Cross Entropy
    Marjan Ghazvininejad, Vladimir Karpukhin, Luke Zettlemoyer, Omer Levy, in arXiv 2020
    https://arxiv.org/abs/2004.01655

    Computes the aligned cross entropy loss with parallel scheme.
    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    logit_lengths : ``torch.Tensor``, required.
        A ``torch.Tensor`` of size (batch_size,)
        which contains lengths of the logits
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    target_lengths : ``torch.Tensor``, required.
        A ``torch.Tensor`` of size (batch_size,)
        which contains lengths of the targets
    blank_index : ``torch.LongTensor``, required.
        A ``torch.LongTensor``, An index of special blank token.
    delta : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` for penalizing skip target operators.
    reduction : ``str``, optional.
        Specifies the reduction to apply to the output.
        Default "mean".
    label_smoothing : ``float``, optional
        Whether or not to apply label smoothing.
    return_a : ``bool``, optional.
        Whether to return the matrix of conditional axe values. Default is False.
    """
    
    assert targets.size(0) == logits.size(0), f'Inconsistency of batch size,  {targets.size(0)} of targets and {logits.size(0)} of logits.'

    batch_size, logits_sequence_length, num_class = logits.shape
    _, target_sequence_length = targets.shape

    device = logits.device
    
    # for torch.gather
    targets = targets.unsqueeze(-1) # batch_size, target_sequence_length, 1

    # (batch_size, target_sequence_length + 1, logits_sequence_length + 1)
    batch_A = torch.zeros(targets.size(0), targets.size(1) + 1, logits.size(1) + 1).to(device)
    batch_blank_index = torch.full((logits.size(0), 1), blank_index, dtype = torch.long).to(device)

    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    
    # A_{i,0} = A_{i−1,0} − delta * log P_1 (Y_i)
    for i in range(1, targets.size(1) + 1):
        # batch_A[:, i, 0] is calculated from targets[:, i-1, :], because batch_A added 0-th row
        batch_A[:, i, 0] = batch_A[:, i-1, 0] - delta * torch.gather(log_probs[:, 0, :], dim=1, index=targets[:, i-1, :]).squeeze(-1)

    # A_{0,j} = A_{0,j−1} − log P_j ("BLANK")
    for j in range(1, logits.size(1) + 1):
        # batch_A[:, 0, j] is calculated from log_probs[:, j-1, :], because batch_A added 0-th column
        batch_A[:, 0, j] = batch_A[:, 0, j-1] - delta * torch.gather(log_probs[:, j-1, :], dim=1, index=batch_blank_index).squeeze(-1)


    # flip logit dim to get anti-diagonal part by using use torch.diag
    batch_A_flip = batch_A.flip(-1) # (batch_size, target_sequence_length + 1, logits_sequence_length + 1)
    log_probs_flip = log_probs.flip(-2) # (batch_size, sequence_length, num_classes)

    # to extract indices for the regions corresponding diag part.
    map_logits = torch.arange(logits.size(1)) - torch.zeros(targets.size(1), 1)
    map_targets = torch.arange(targets.size(1)).unsqueeze(-1) - torch.zeros((1, logits.size(1)))
    # index must be int
    map_logits = map_logits.long().to(device)
    map_targets = map_targets.long().to(device)

    for i in range(logits.size(1) - 1, -targets.size(1), -1):
        
        
        # batch_A_flip_sets[:, :, :, 0] : batch_A[:, i  , j-1]
        # batch_A_flip_sets[:, :, :, 1] : batch_A[:, i-1, j  ]
        # batch_A_flip_sets[:, :, :, 2] : batch_A[:, i-1, j-1]
        batch_A_flip_sets = torch.cat((batch_A_flip.roll(shifts=-1, dims=-1).unsqueeze(-1),
                                       batch_A_flip.roll(shifts= 1, dims=-2).unsqueeze(-1),
                                       batch_A_flip.roll(shifts=(1, -1), dims=(-2, -1)).unsqueeze(-1)),
                                       dim = -1)
        
        # trimming
        # - the last column (A_{0,j} = A_{0,j−1} − log P_j ("BLANK")) 
        # - the first row (A_{i,0} = A_{i−1,0} − delta * log P_1 (Y_i))
        batch_A_flip_sets_trim = batch_A_flip_sets[:, 1:, :-1, :]

        # extracting anti-diagonal part
        # (batch, 3, num_diag)
        A_diag = batch_A_flip_sets_trim.diagonal(offset=i, dim1 = -3, dim2 = -2)
        
        # (batch, num_diag, 3)
        A_diag = A_diag.transpose(-1, -2)
        num_diag = A_diag.size(1)
        
        logit_indices = map_logits.diagonal(offset=i, dim1 = -2, dim2 = -1)
        # log_probs_diag : (batch, num_diag, num_class)
        log_probs_flip_diag = log_probs_flip[:, logit_indices[0]:logit_indices[-1]+1, :]

        target_indices = map_targets.diagonal(offset=i, dim1 = -2, dim2 = -1)
        # targets_diag : (batch, num_diag, num_class)
        targets_diag = targets[:, target_indices[0]:target_indices[-1]+1, :]

        # align, skip_prediction, skip_target
        batch_align = A_diag[:, :, 2] - torch.gather(log_probs_flip_diag, dim=2, index=targets_diag).squeeze(-1)
        batch_skip_prediction = A_diag[:, :, 0] - torch.gather(log_probs_flip_diag, dim=2, index=batch_blank_index.expand(-1, num_diag).unsqueeze(-1)).squeeze(-1)
        batch_skip_target = A_diag[:, :, 1] - delta * torch.gather(log_probs_flip_diag, dim=2, index=targets_diag).squeeze(-1)

        # (batch_size, num_diag, 3)
        operations = torch.cat((batch_align.unsqueeze(-1), batch_skip_prediction.unsqueeze(-1), batch_skip_target.unsqueeze(-1)), dim = -1)

        # (batch_size, num_diag)
        diag_axe = torch.min(operations, dim = -1).values
        
        assert logits.size(1) > targets.size(1), "assuming target length < logit length." 

        if i > (logits.size(1) - targets.size(1)):
            # (batch_size, logits_length, logits_length)
            # -> (batch_size, targets_length, logits_length)
            axe = torch.diag_embed(diag_axe, offset=i, dim1=-2, dim2=-1)
            batch_A_flip[:, 1:, :-1] += axe[:, :targets.size(1), :]
        elif i > 0:
            # (batch_size, logits_length, logits_length)
            # -> (batch_size, targets_length, logits_length)
            axe = torch.diag_embed(diag_axe, offset=0, dim1=-2, dim2=-1)
            batch_A_flip[:, 1:, i : i + targets.size(1)] += axe
        else:
            axe = torch.diag_embed(diag_axe, offset=i, dim1=-2, dim2=-1)
            batch_A_flip[:, 1:, :targets.size(1)] += axe

    # recover correct order in logit dim
    batch_A = batch_A_flip.flip(-1)
    
    # rm 0-th row and column
    _batch_A = batch_A[:, 1:, 1:]

    ## Gather A_nm, avoiding masks
    # index_m : (batch_size, target_sequence_length, 1)
    index_m = logit_lengths.unsqueeze(-1).expand(-1, _batch_A.size(1)).unsqueeze(-1).long()

    # gather m-th colmun
    # batch_A_nm : (batch_size, target_sequence_length, 1)
    # index_m occors out of bounds for index 
    batch_A_m = torch.gather(_batch_A, dim=2, index=(index_m - 1))
    batch_A_m = batch_A_m.squeeze(-1)

    # index_n : (batch_size, 1)
    index_n = target_lengths.unsqueeze(-1).long()
    
    # gather n-th row
    # batch_A_nm : (batch_size, 1, 1)
    batch_A_nm = torch.gather(batch_A_m, dim=1, index=(index_n - 1))

    # batch_A_nm : (batch_size)
    batch_A_nm = batch_A_nm.squeeze(-1)

    if reduction == "mean":
        axe_nm = batch_A_nm.mean()
    else:
        raise NotImplementedError
    
    # Refs fairseq nat_loss.
    # https://github.com/pytorch/fairseq/blob/6f6461b81ac457b381669ebc8ea2d80ea798e53a/fairseq/criterions/nat_loss.py#L70
    # actuary i'm not sure this is reasonable.
    if label_smoothing is not None and label_smoothing > 0.0:
        axe_nm = axe_nm * (1.0-label_smoothing) - log_probs.mean() * label_smoothing

    if return_a:
        return axe_nm, batch_A.detach()
    
    return axe_nm
예제 #20
0
    def dynam_prediction(self,
                         kp,
                         graph,
                         action=None,
                         eps=5e-2,
                         env=None,
                         node_params=None):
        # kp: B x n_his x n_kp x (mean + covariance)
        # action: B x n_his x n_kp x action_dim

        args = self.args
        nf = args.nf_hidden_dy * 4
        action_dim = args.action_dim
        node_attr_dim = args.node_attr_dim
        edge_attr_dim = args.edge_attr_dim
        edge_type_num = args.edge_type_num

        B, n_his, n_kp, _ = kp.size()

        # node_attr: B x n_kp x node_attr_dim
        # edge_attr: B x n_kp x n_kp x edge_attr_dim
        # edge_type: B x n_kp x n_kp x edge_type_num
        # edge_type_logits: B x n_kp x n_kp x edge_type_num
        node_attr, edge_attr, edge_type, edge_type_logits = graph

        # node_enc: B x n_his x n_kp x nf
        # edge_enc: B x n_his x (n_kp * n_kp) x nf

        node_enc = torch.cat([
            kp,
            node_attr.view(B, 1, n_kp, node_attr_dim).repeat(1, n_his, 1, 1),
            node_params.view(B, 1, n_kp, args.node_params).repeat(
                1, n_his, 1, 1)
        ], 3)
        edge_enc = torch.cat([
            torch.cat([
                kp[:, :, :, None, :].repeat(1, 1, 1, n_kp, 1),
                kp[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1)
            ], 4),
            edge_attr.view(B, 1, n_kp, n_kp, edge_attr_dim).repeat(
                1, n_his, 1, 1, 1)
        ], 4)

        node_enc, edge_enc = self.model_dynam_encode(
            node_enc.view(
                B * n_his, n_kp, node_attr_dim +
                (args.state_dim + args.state_dim**2) + args.node_params),
            edge_enc.view(
                B * n_his, n_kp, n_kp,
                edge_attr_dim + 2 * (args.state_dim + args.state_dim**2)),
            edge_type[:, None, :, :, :].repeat(1, n_his, 1, 1,
                                               1).view(B * n_his, n_kp, n_kp,
                                                       edge_type_num),
            start_idx=args.edge_st_idx)

        node_enc = node_enc.view(B, n_his, n_kp, nf)
        edge_enc = edge_enc.view(B, n_his, n_kp * n_kp, nf)

        # node_enc: B x n_kp x n_his x nf
        # edge_enc: B x (n_kp * n_kp) x n_his x nf
        node_enc = node_enc.transpose(1,
                                      2).contiguous().view(B, n_kp, n_his, nf)
        edge_enc = edge_enc.transpose(1, 2).contiguous().view(
            B, n_kp * n_kp, n_his, nf)

        # node_enc: B x n_kp x n_his x (nf + node_attr_dim + action_dim)
        # kp_node: B x n_kp x n_his x 6
        kp_node = kp.transpose(1, 2).contiguous().view(
            B, n_kp, n_his, (args.state_dim + args.state_dim**2))

        node_enc = torch.cat([
            kp_node, node_enc,
            node_attr.view(B, n_kp, 1, node_attr_dim).repeat(1, 1, n_his, 1)
        ], 3)

        # edge_enc: B x (n_kp * n_kp) x n_his x (nf + edge_attr_dim + action_dim)
        # kp_edge: B x (n_kp * n_kp) x n_his x (2 + 2)
        kp_edge = torch.cat([
            kp_node[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1),
            kp_node[:, None, :, :, :].repeat(1, n_kp, 1, 1, 1)
        ], 4)
        kp_edge = kp_edge.view(B, n_kp**2, n_his,
                               2 * (args.state_dim + args.state_dim**2))

        edge_enc = torch.cat([
            kp_edge, edge_enc,
            edge_attr.view(B, n_kp**2, 1, edge_attr_dim).repeat(
                1, 1, n_his, 1)
        ], 3)

        # append action
        if action is not None:

            action = action[:, :, :, None]
            action_t = action.transpose(1, 2).contiguous()
            action_t_r = action_t[:, :,
                                  None, :, :].repeat(1, 1, n_kp, 1, 1).view(
                                      B, n_kp**2, n_his, action_dim)
            action_t_s = action_t[:,
                                  None, :, :, :].repeat(1, n_kp, 1, 1, 1).view(
                                      B, n_kp**2, n_his, action_dim)
            # print('node_enc', node_enc.size(), 'edge_enc', edge_enc.size())
            # print('action_t', action_t.size(), 'action_t_r', action_t_r.size(), 'action_t_s', action_t_s.size())
            node_enc = torch.cat([node_enc, action_t], 3)
            edge_enc = torch.cat([edge_enc, action_t_r, action_t_s], 3)

        # node_enc: B x n_kp x nf
        # edge_enc: B x n_kp x n_kp x nf
        node_enc = self.model_dynam_node_forward(
            node_enc.view(B * n_kp, n_his, -1)).view(B, n_kp, nf)
        edge_enc = self.model_dynam_edge_forward(
            edge_enc.view(B * n_kp**2, n_his, -1)).view(B, n_kp, n_kp, nf)

        # kp_pred: B x n_kp x (2 + 3)
        node_enc = torch.cat([node_enc, node_attr, kp_node[:, :, -1]], 2)
        edge_enc = torch.cat([
            edge_enc, edge_attr, kp_edge[:, :, -1].view(
                B, n_kp, n_kp, 2 * (args.state_dim + args.state_dim**2))
        ], 3)

        if action is not None:
            # print('node_enc', node_enc.size(), 'edge_enc', edge_enc.size(), 'action', action.size())
            action_r = action[:, :, :, None, :].repeat(1, 1, 1, n_kp, 1)
            action_s = action[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1)
            node_enc = torch.cat([node_enc, action[:, -1]], 2)
            edge_enc = torch.cat([edge_enc, action_r[:, -1], action_s[:, -1]],
                                 3)

        kp_pred = self.model_dynam_decode(node_enc,
                                          edge_enc,
                                          edge_type,
                                          start_idx=args.edge_st_idx,
                                          ignore_edge=True)

        # kp_pred: B x n_kp x (mean + covariance)
        # Predicting change in state
        # kp_pred = torch.cat([
        #     kp[:, -1, :, :args.state_dim] + kp_pred[:, :, :args.state_dim],               # mean
        #     F.relu(kp_pred[:, :, 2:3]) + args.gauss_std,        # covar (0, 0), need to > 0
        #     torch.zeros(B, n_kp, 1).cuda(),                     # covar (0, 1)
        #     kp_pred[:, :, 3:4],                                 # covar (1, 0)
        #     F.relu(kp_pred[:, :, 4:5]) + args.gauss_std],       # covar (1, 1), need to > 0
        #     dim=2)
        kp_pred = torch.cat(
            [
                kp[:, -1, :, :args.state_dim] +
                kp_pred[:, :, :args.state_dim],  # mean
                torch.diag_embed(
                    F.relu(kp_pred[:, :, args.state_dim:]) +
                    args.gauss_std).view(B, n_kp,
                                         args.state_dim * args.state_dim)
            ],
            dim=2)

        return kp_pred
예제 #21
0
 def __init__(self, dim: int, num_operations: int):
     super().__init__(dim, num_operations)
     self.linear_transformations = nn.Parameter(
         torch.diag_embed(torch.ones(()).expand(num_operations, dim)))
     self.translations = nn.Parameter(torch.zeros((self.num_operations, self.dim)))
    def ops_2_to_2(self, inputs, dim, normalization='inf', normalization_val=1.0):  # N x D x m x m
#         print(f'input shape : {inputs.shape}')
        diag_part = torch.diagonal(inputs, dim1=-2, dim2=-1)   # N x D x m
#         print(f'diag_part shape : {diag_part.shape}')
        sum_diag_part = torch.sum(diag_part, dim=2, keepdim=True)  # N x D x 1
#         print(f'sum_diag_part shape : {sum_diag_part.shape}')
        sum_of_rows = torch.sum(inputs, dim=3)  # N x D x m
#         print(f'sum_of_rows shape : {sum_of_rows.shape}')
        sum_of_cols = torch.sum(inputs, dim=2)  # N x D x m
#         print(f'sum_of_cols shape : {sum_of_cols.shape}')
        sum_all = torch.sum(sum_of_rows, dim=2)  # N x D
#         print(f'sum_all shape : {sum_all.shape}')

        # op1 - (1234) - extract diag
        op1 = torch.diag_embed(diag_part)  # N x D x m x m

        # op2 - (1234) + (12)(34) - place sum of diag on diag
        op2 = torch.diag_embed(sum_diag_part.repeat(1, 1, dim))  # N x D x m x m

        # op3 - (1234) + (123)(4) - place sum of row i on diag ii
        op3 = torch.diag_embed(sum_of_rows)  # N x D x m x m

        # op4 - (1234) + (124)(3) - place sum of col i on diag ii
        op4 = torch.diag_embed(sum_of_cols)  # N x D x m x m

        # op5 - (1234) + (124)(3) + (123)(4) + (12)(34) + (12)(3)(4) - place sum of all entries on diag
        op5 = torch.diag_embed(torch.unsqueeze(sum_all, dim=2).repeat(1, 1, dim))  # N x D x m x m

        # op6 - (14)(23) + (13)(24) + (24)(1)(3) + (124)(3) + (1234) - place sum of col i on row i
        op6 = torch.unsqueeze(sum_of_cols, dim=3).repeat(1, 1, 1, dim)  # N x D x m x m

        # op7 - (14)(23) + (23)(1)(4) + (234)(1) + (123)(4) + (1234) - place sum of row i on row i
        op7 = torch.unsqueeze(sum_of_rows, dim=3).repeat(1, 1, 1, dim)  # N x D x m x m

        # op8 - (14)(2)(3) + (134)(2) + (14)(23) + (124)(3) + (1234) - place sum of col i on col i
        op8 = torch.unsqueeze(sum_of_cols, dim=2).repeat(1, 1, dim, 1)  # N x D x m x m

        # op9 - (13)(24) + (13)(2)(4) + (134)(2) + (123)(4) + (1234) - place sum of row i on col i
        op9 = torch.unsqueeze(sum_of_rows, dim=2).repeat(1, 1, dim, 1)  # N x D x m x m

        # op10 - (1234) + (14)(23) - identity
        op10 = inputs  # N x D x m x m

        # op11 - (1234) + (13)(24) - transpose
        op11 = inputs.permute(0, 1, 3, 2)  # N x D x m x m

        # op12 - (1234) + (234)(1) - place ii element in row i
        op12 = torch.unsqueeze(diag_part, dim=3).repeat(1, 1, 1, dim)  # N x D x m x m

        # op13 - (1234) + (134)(2) - place ii element in col i
        op13 = torch.unsqueeze(diag_part, dim=2).repeat(1, 1, dim, 1)  # N x D x m x m

        # op14 - (34)(1)(2) + (234)(1) + (134)(2) + (1234) + (12)(34) - place sum of diag in all entries
        op14 = torch.unsqueeze(sum_diag_part, dim=3).repeat(1, 1, dim, dim)   # N x D x m x m

        # op15 - sum of all ops - place sum of all entries in all entries
        op15 = torch.unsqueeze(torch.unsqueeze(sum_all, dim=2), dim=3).repeat(1, 1, dim, dim)  # N x D x m x m

        if normalization is not None:
            float_dim = dim.type(torch.FloatTensor)
            if normalization is 'inf':
                op2 = torch.div(op2, float_dim)
                op3 = torch.div(op3, float_dim)
                op4 = torch.div(op4, float_dim)
                op5 = torch.div(op5, float_dim**2)
                op6 = torch.div(op6, float_dim)
                op7 = torch.div(op7, float_dim)
                op8 = torch.div(op8, float_dim)
                op9 = torch.div(op9, float_dim)
                op14 = torch.div(op14, float_dim)
                op15 = torch.div(op15, float_dim**2)

        return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15]
예제 #23
0
def kabsch_transformation_estimation(x1,
                                     x2,
                                     weights=None,
                                     normalize_w=True,
                                     eps=1e-7,
                                     best_k=0,
                                     w_threshold=0):
    """
    Torch differentiable implementation of the weighted Kabsch algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm). Based on the correspondences and weights calculates
    the optimal rotation matrix in the sense of the Frobenius norm (RMSD), based on the estimate rotation matrix is then estimates the translation vector hence solving
    the Procrustes problem. This implementation supports batch inputs.

    Args:
        x1            (torch array): points of the first point cloud [b,n,3]
        x2            (torch array): correspondences for the PC1 established in the feature space [b,n,3]
        weights       (torch array): weights denoting if the coorespondence is an inlier (~1) or an outlier (~0) [b,n]
        normalize_w   (bool)       : flag for normalizing the weights to sum to 1
        best_k        (int)        : number of correspondences with highest weights to be used (if 0 all are used)
        w_threshold   (float)      : only use weights higher than this w_threshold (if 0 all are used)
    Returns:
        rot_matrices  (torch array): estimated rotation matrices [b,3,3]
        trans_vectors (torch array): estimated translation vectors [b,3,1]
        res           (torch array): pointwise residuals (Eucledean distance) [b,n]
        valid_gradient (bool): Flag denoting if the SVD computation converged (gradient is valid)

    """
    if weights is None:
        weights = torch.ones(x1.shape[0],
                             x1.shape[1]).type_as(x1).to(x1.device)

    if normalize_w:
        sum_weights = torch.sum(weights, dim=1, keepdim=True) + eps
        weights = (weights / sum_weights)

    weights = weights.unsqueeze(2)

    if best_k > 0:
        indices = np.argpartition(weights.cpu().numpy(), -best_k,
                                  axis=1)[0, -best_k:, 0]
        weights = weights[:, indices, :]
        x1 = x1[:, indices, :]
        x2 = x2[:, indices, :]

    if w_threshold > 0:
        weights[weights < w_threshold] = 0

    x1_mean = torch.matmul(weights.transpose(1, 2),
                           x1) / (torch.sum(weights, dim=1).unsqueeze(1) + eps)
    x2_mean = torch.matmul(weights.transpose(1, 2),
                           x2) / (torch.sum(weights, dim=1).unsqueeze(1) + eps)

    x1_centered = x1 - x1_mean
    x2_centered = x2 - x2_mean

    weight_matrix = torch.diag_embed(weights.squeeze(2))

    cov_mat = torch.matmul(x1_centered.transpose(1, 2),
                           torch.matmul(weight_matrix, x2_centered))

    try:
        u, s, v = torch.svd(cov_mat)
    except Exception as e:
        r = torch.eye(3, device=x1.device)
        r = r.repeat(x1_mean.shape[0], 1, 1)
        t = torch.zeros((x1_mean.shape[0], 3, 1), device=x1.device)

        res = transformation_residuals(x1, x2, r, t)

        return r, t, res, True

    tm_determinant = torch.det(
        torch.matmul(v.transpose(1, 2), u.transpose(1, 2)))

    determinant_matrix = torch.diag_embed(
        torch.cat((torch.ones((tm_determinant.shape[0], 2),
                              device=x1.device), tm_determinant.unsqueeze(1)),
                  1))

    rotation_matrix = torch.matmul(
        v, torch.matmul(determinant_matrix, u.transpose(1, 2)))

    # translation vector
    translation_matrix = x2_mean.transpose(1, 2) - torch.matmul(
        rotation_matrix, x1_mean.transpose(1, 2))

    # Residuals
    res = transformation_residuals(x1, x2, rotation_matrix, translation_matrix)

    return rotation_matrix, translation_matrix, res, False
예제 #24
0
 def _create_marginal_input(self, batch_shape=torch.Size()):
     mat = torch.randn(*batch_shape, 5, 5)
     eye = torch.diag_embed(torch.ones(*batch_shape, 5))
     return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye)
예제 #25
0
    def forward(self, images, imu_data, prev_lstm_states, prev_pose,
                prev_state, prev_covar, T_imu_cam):
        vis_meas_covar_scale = torch.ones(6, device=images.device)
        vis_meas_covar_scale[0:3] = vis_meas_covar_scale[0:3] * par.k4
        imu_noise_covar = self.get_imu_noise_covar()

        if prev_covar is None:
            prev_covar = torch.diag(self.init_covar_diag_sqrt *
                                    self.init_covar_diag_sqrt +
                                    par.init_covar_diag_eps).repeat(
                                        images.shape[0], 1, 1)

        encoded_images = self.vo_module.encode_image(images)

        num_timesteps = images.size(1) - 1  # equals to imu_data.size(1) - 1

        poses_over_timesteps = [prev_pose]
        states_over_timesteps = [prev_state]
        covars_over_timesteps = [prev_covar]
        vis_meas_over_timesteps = []
        vis_meas_covar_over_timesteps = []
        lstm_states = prev_lstm_states
        for k in range(0, num_timesteps):
            # ekf predict
            pred_states, pred_covars = self.ekf_module.predict(
                imu_data[:, k], imu_noise_covar, states_over_timesteps[-1],
                covars_over_timesteps[-1])

            if par.hybrid_recurrency and par.enable_ekf:
                # concatenate the predicted states and covar with the encoded images to feed into LSTM
                last_pred_state_so3 = IMUKalmanFilter.state_to_so3(
                    pred_states[-1])
                last_pred_covar_flattened = pred_covars[-1].view(
                    -1, IMUKalmanFilter.STATE_VECTOR_DIM**2)
                feature_vector = torch.cat([
                    last_pred_state_so3, last_pred_covar_flattened,
                    encoded_images[:, k]
                ], -1)
            else:
                feature_vector = encoded_images[:, k]

            # get vis measurement
            vis_meas_and_covar, lstm_states = self.vo_module.forward_one_ts(
                feature_vector, lstm_states)
            vis_meas = vis_meas_and_covar[:, 0:6]

            # process vis meas covar
            if par.vis_meas_covar_use_fixed:
                vis_meas_covar_diag = torch.tensor(par.vis_meas_fixed_covar,
                                                   dtype=torch.float32,
                                                   device=vis_meas.device)
                vis_meas_covar_diag = vis_meas_covar_diag * vis_meas_covar_scale
                vis_meas_covar_diag = vis_meas_covar_diag.repeat(
                    vis_meas.shape[0], vis_meas.shape[1], 1)
            else:
                vis_meas_covar_diag = par.vis_meas_covar_init_guess * \
                                      10 ** (par.vis_meas_covar_beta *
                                             torch.tanh(par.vis_meas_covar_gamma * vis_meas_and_covar[:, 6:12]))
            vis_meas_covar_scaled = torch.diag_embed(
                vis_meas_covar_diag / vis_meas_covar_scale.view(1, 6))
            vis_meas_covar = torch.diag_embed(vis_meas_covar_diag)

            # ekf correct
            est_state, est_covar = self.ekf_module.update(
                pred_states[-1], pred_covars[-1], vis_meas.unsqueeze(-1),
                vis_meas_covar_scaled, T_imu_cam)
            new_pose, new_state, new_covar = self.ekf_module.composition(
                poses_over_timesteps[-1], est_state, est_covar)

            poses_over_timesteps.append(new_pose)
            states_over_timesteps.append(new_state)
            covars_over_timesteps.append(new_covar)
            vis_meas_over_timesteps.append(vis_meas)
            vis_meas_covar_over_timesteps.append(vis_meas_covar)

        return torch.stack(vis_meas_over_timesteps, 1), \
               torch.stack(vis_meas_covar_over_timesteps, 1), \
               lstm_states, \
               torch.stack(poses_over_timesteps, 1), \
               torch.stack(states_over_timesteps, 1), \
               torch.stack(covars_over_timesteps, 1)
예제 #26
0
 def get_distribution(self, mean, std):
     if self.action_dim == 1:
         normal = Normal(mean, std)
     else:
         normal = MultivariateNormal(mean, torch.diag_embed(std))
     return normal
 def sigma_tf(self, t, X, Y):  # M x 1, M x D, M x 1
     return 0.4 * torch.diag_embed(X)  # M x D x D
예제 #28
0
def LBO(V, F):
    W, V_area = LBO_slim(V, F)
    area_matrix = torch.diag_embed(V_area)
    area_matrix_inv = torch.diag_embed(1 / V_area)
    L = torch.bmm(area_matrix_inv, W)  # VALIDATED
    return L, area_matrix, area_matrix_inv, W
예제 #29
0
def rho_precision(logs, rho, z_dim):
    precision_matrix = (torch.diag_embed(-rho.unsqueeze(1).expand(-1, z_dim - 1), offset=-1)
                        + torch.diag_embed(-rho.unsqueeze(1).expand(-1, z_dim - 1), offset=1)
                        + torch.diag_embed(F.pad((1 + rho * rho).unsqueeze(1).expand(-1, z_dim - 2), (1, 1), value=1.))) \
                       * (torch.exp(-logs) / (1 - rho * rho))[:, None, None]
    return precision_matrix
예제 #30
0
def vec_to_diag(vec):
    return _torch.diag_embed(vec, offset=0)