Ejemplo n.º 1
0
def hottaBasisVectors(X,sub_dim):
    d,n = X.shape
    if d < n:
        C = torch.mm(X, X.transpose(0,1))
        matrank = torch.matrix_rank(C)
        tmp_val, tmp_vec = torch.eig(C, eigenvectors = True)
        value, index = torch.sort(tmp_val,descending = True)
        eig_vec = tmp_vec[:,index[:matrank]]
        eig_val = value[:matrank]
        eig_vec  =  eig_vec[:,:sub_dim]
        eig_val  =  value[:sub_dim]
    else:
        C = torch.mm(X.transpose(0,1), X)
        matrank = torch.matrix_rank(C)
        tmp_val, tmp_vec = torch.eig(C, eigenvectors = True)
        
        # second column is zero if the eig vals are real
        tmp_val = tmp_val[:,0]

        value, index = torch.sort(tmp_val,descending = True)
#        tmp_vec = tmp_vec[:,index[:matrank]]
#        eig_vec = torch.mm(X, tmp_vec)
        eig_vec = torch.zeros((X.shape[0],matrank))
        for i in range(matrank):
            eig_vec[:,i] = (X.mv(tmp_vec[:,index[i]])).div((value[i]).sqrt())
        eig_vec  =  normalize_dataset(eig_vec[:,:sub_dim])
        eig_val  =  value[:sub_dim]

    return eig_vec, eig_val
Ejemplo n.º 2
0
 def rank(self, x, conv):
     if conv == "conv1":
         for f in range(x[0].shape[0]):
             self.rank1[f] += torch.matrix_rank(x[0][f])
     elif conv == "conv2":
         for f in range(x[0].shape[0]):
             self.rank2[f] += torch.matrix_rank(x[0][f])
Ejemplo n.º 3
0
def main():

    # Projection
    n_data = 100
    circle_input, _ = get_circle_data(n_data)
    plane_input, _ = get_plane_data(n_data)

    projection_input = torch.hstack([circle_input, plane_input])

    discr, gener = train_gan(projection_input)

    func = gener
    inputs = gener.generate_generator_input(1)
    jac = get_jacobian(func=func, inputs=inputs).squeeze()
    print("Generator output ID is ", torch.matrix_rank(jac))

    proj_func = ProjectionFunction()
    func = lambda x: proj_func(gener(x))
    jac = get_jacobian(func=func, inputs=inputs).squeeze()
    print("ANN output ID is ", torch.matrix_rank(jac))

    evaluate(
        estimate_id=twonn_dimension,
        get_data=lambda n: get_parabolic_data(n)[1],
    )

    plt.show()
Ejemplo n.º 4
0
def get_maximal_linearly_independent_system(x, max_tries=1000):
    assert x.shape[0] >= x.shape[1]
    rank = torch.matrix_rank(x)
    try_count = 0
    print("getting the maximal linearly independent system. col {}, rank {}".format(x.shape[1], rank))

    if rank == x.shape[1]:
        return list(range(x.shape[1]))

    for i in combinations(range(x.shape[1]),rank):
        if torch.matrix_rank(x[:,i]) == rank:
            return i
        try_count += 1
        if try_count >= max_tries:
            return -1
    raise Exception("maximal linearly independent system not found")
Ejemplo n.º 5
0
 def __init__fixme(self, C_t):
     '''
     for numerical stability, use singular values to determine the column rank.
     discard left singular vectors with
              singular value < ratio_cutoff * max singular value
     '''
     # center and orthogonalize
     C0 = C_t - C_t.mean(0)
     mat_rank = torch.matrix_rank(C0)
     print(mat_rank)
     ss, vv, dd = torch.svd(C0)
     # self.Q_t, R_ = torch.qr(C_t - C_t.mean(0))
     # here we take care of C_t with non full column rank
     # we follow the same behavior of R::lm where the correlated covariates
     # will be discards (won't contribute to computation and dof)
     # here we check and take care of this situation
     # specifically, we set Q columns with no contribution to zeros
     # and they won't be considered in dof calculation
     # get a binary vector indicating if Q_t columns are used.
     # Q_t_usage = (R_ != 0).sum(axis=0) != 0
     self.Q_t = ss
     # Q_t_usage =  (vv.abs() / vv.max()) >= ratio_cutoff
     self.Q_t[:, mat_rank:] = 0
     self.Q_t = self.Q_t.to(torch.float32)
     # self.Q_t[:, torch.logical_not(Q_t_usage)] = 0
     self.dof = C_t.shape[0] - 2 - mat_rank
Ejemplo n.º 6
0
 def get_feature_hook(self, module, input, output):
     a = output.size(0)
     b = output.size(1)
     # w = output.size(2)
     # h = output.size(3)
     # u, s, v = torch.svd(output.view(-1, w, h), compute_uv=False)  # s: [batch*channel, singular_values]
     # print(s[0])
     # s = torch.abs(a)
     # if self.s_threshold:
     #     s[torch.abs(s) < self.s_threshold] = 0
     # else:
     #     self.s_threshold = (10**-7) * max(w, h) * s[:,0]
     #     for i in range(int(s.size(0))):
     #         s[i][s[i] < self.s_threshold[i]] = 0
     # s = (torch.abs(s) > 0).view(a, b, s.size(-1))   # [batch, channel, singular_values]
     # s = s.sum(1).squeeze().float() / a   # [channel]
     c = torch.tensor([
         torch.matrix_rank(output[i, j, :, :]) for i in range(a)
         for j in range(b)
     ]).to(self.device)
     c = c.view(a, -1).float()
     c = c.sum(0)
     self.feature_result = self.feature_result * self.total + c
     self.total = self.total + a
     self.feature_result = self.feature_result / self.total
Ejemplo n.º 7
0
def matrix_rank(x):
    if _TORCH_LESS_THAN_ONE:
        import numpy as np

        return torch.from_numpy(
            np.asarray(np.linalg.matrix_rank(x.detach().cpu().numpy()))
        ).type_as(x)
    with torch.no_grad():
        batches = x.shape[:-2]
        if batches:
            out = x.new(*batches)
            for idx in itertools.product(*map(range, batches)):
                out[idx] = torch.matrix_rank(x[idx])
            return out
        else:
            return torch.matrix_rank(x)
Ejemplo n.º 8
0
    def lstsq(self, A, Y, lamb=0.0):
        """
        Differentiable least square
        :param A: m x n
        :param Y: n x 1
        """
        cols = A.shape[1]
        if np.isinf(A.data.cpu().numpy()).any():
            import ipdb;
            ipdb.set_trace()

        # Assuming A to be full column rank
        if cols == torch.matrix_rank(A):
            # Full column rank
            q, r = torch.qr(A)
            x = torch.inverse(r) @ q.transpose(1, 0) @ Y
        else:
            # rank(A) < n, do regularized least square.
            AtA = A.transpose(1, 0) @ A

            # get the smallest lambda that suits our purpose, so that error in
            # results minimized.
            with torch.no_grad():
                lamb = best_lambda(AtA)
            A_dash = AtA + lamb * torch.eye(cols, device=A.get_device())
            Y_dash = A.transpose(1, 0) @ Y

            # if it still doesn't work, just set the lamb to be very high value.
            x = self.lstsq(A_dash, Y_dash, 1)
        return x
Ejemplo n.º 9
0
    def _prox(self, metric_mat, _x_train, eta):
        # TODO: move this function to utilities
        eig_val, eig_vec = torch.symeig(metric_mat, True)
        med = eig_val.median() * 0.8
        if med < eta:
            eig_val = torch.relu(eig_val - med)
            eta *= med
        else:
            eig_val = torch.relu(eig_val - eta)

        space_dim = (eig_val > 1e-8).sum()
        s_dim = eig_val.shape[0] - space_dim
        new_metric = (
            eig_vec[:, s_dim:] @ eig_val[s_dim:].diag() @ eig_vec[:, s_dim:].T)
        space_dim = torch.matrix_rank(new_metric)

        if space_dim != new_metric.shape[0]:
            _x_data = lp_normalize(self.reddim_mat @ torch.cat(_x_train, 1),
                                   self.p_norm)
            autocorr_mat = new_metric @ _x_data @ _x_data.T
            d, sing_vec = autocorr_mat.cpu().eig(True)
            d = d.to(autocorr_mat)
            sing_vec = sing_vec.to(autocorr_mat)
            sing_vec = sing_vec * d[:, 0]
            sing_vec, _, _ = sing_vec.svd()
            proj_basis = sing_vec[:, :space_dim].T

            for _ in range(10):
                _tmp_new_metric = proj_basis @ new_metric @ proj_basis.T
                _rank = torch.matrix_rank(_tmp_new_metric)
                if _rank < space_dim:
                    space_dim = _rank
                    proj_basis = sing_vec[:, :space_dim].T
                    _tmp_new_metric = proj_basis @ new_metric @ proj_basis.T
                else:
                    break
            new_metric = _tmp_new_metric
            _reddim_mat = proj_basis @ self.reddim_mat

            _red_subs = proj_basis @ self.sub_basis.permute((2, 0, 1))
            sub_basis, _ = torch.qr(_red_subs)
            sub_basis = sub_basis.to(_red_subs.device).permute((1, 2, 0))

        else:
            _reddim_mat = self.reddim_mat
            sub_basis = self.sub_basis
        return new_metric, _reddim_mat, sub_basis, eta
Ejemplo n.º 10
0
def lstq(A, Y, lamb=0.01):
    """
        Differentiable least square
        :param A: m x n
        :param Y: n x 1
        """
    # Assuming A to be full column rank
    cols = A.shape[1]
    print(torch.matrix_rank(A))
    if cols == torch.matrix_rank(A):
        q, r = torch.qr(A)
        x = torch.inverse(r) @ q.T @ Y
    else:
        A_dash = A.permute(1, 0) @ A + lamb * torch.eye(cols)
        Y_dash = A.permute(1, 0) @ Y
        x = lstq(A_dash, Y_dash)
    return x
Ejemplo n.º 11
0
def get_feature_hook_googlenet(self, input, output):
    global feature_result
    global total
    a = output.shape[0]
    b = output.shape[1]
    c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b-12,b)])

    c = c.view(a, -1).float()
    c = c.sum(0)
    feature_result = feature_result * total + c
    total = total + a
    feature_result = feature_result / total
Ejemplo n.º 12
0
def similarity_estimate(src, dst, estimate_scale=True):
    if not isinstance(src, torch.Tensor):
        src = torch.from_numpy(src).float()
    if not isinstance(dst, torch.Tensor):
        dst = torch.from_numpy(dst).float()

    src = src.cpu()
    dst = dst.cpu()

    num, dim = src.shape

    # compute mean of src and dst
    src_mean = torch.mean(src, dim=0)
    dst_mean = torch.mean(dst, dim=0)

    # Subtract mean from src and dst
    src_demean = src - src_mean
    dst_demean = dst - dst_mean

    A = torch.matmul(dst_demean.t(), src_demean) / num

    d = torch.ones(dim).float()
    if torch.linalg.det(A) < 0:
        d[dim - 1] = -1

    T = torch.eye(dim+1).float()

    U, S, V = torch.svd(A)

    rank = torch.matrix_rank(A)
    if rank == 0:
        return None
    elif rank == dim - 1:
        if torch.linalg.det(U) * torch.linalg.det(V) > 0.:
            T[:dim, :dim] = torch.matmul(U, V)
        else:
            s = d[dim - 1]
            d[dim - 1] = -1
            T[:dim, :dim] = torch.matmul(torch.matmul(U, torch.diag(d).float()), V)
            d[dim - 1] = s
    else:
        T[:dim, :dim] = torch.matmul(torch.matmul(U, torch.diag(d).float()), V)

    if estimate_scale:
        scale = 1.0 / torch.var(src_demean, dim=0, unbiased=False).sum() * torch.matmul(S, d)
    else:
        scale = 1.0

    T[:dim, dim] = dst_mean - scale * (torch.matmul(T[:dim, :dim], src_mean.t()))
    T[:dim, :dim] *= scale

    return T
Ejemplo n.º 13
0
    def __init__(self, mean, cov):
        self.mean = torch.Tensor(mean).cuda()

        cov = torch.Tensor(cov).cuda()

        # For estimating degrees of freedom
        self.rank = torch.matrix_rank(cov)

        self.cov = cov + torch.eye(cov.shape[0]).cuda() * epsilon

        self.inv = torch.inverse(self.cov)

        self.distribution = MultivariateNormal(self.mean, self.cov)
Ejemplo n.º 14
0
def matrix_rank(x: torch.Tensor):  # pragma: no cover
    # inspired by
    # https://discuss.pytorch.org/t/multidimensional-svd/4366/2
    # prolonged here:
    if x.dim() == 2:
        result = torch.matrix_rank(x)
    else:
        batches = x.shape[:-2]
        other = x.shape[-2:]
        flat = x.view((-1, ) + other)
        slices = flat.unbind(0)
        ranks = []
        # I wish I had a parallel_for
        for i in range(flat.shape[0]):
            r = torch.matrix_rank(slices[i])
            # interesting,
            # ranks.append(r)
            # does not work on pytorch 1.0.0
            # but the below code does
            ranks += [r]
        result = torch.stack(ranks).view(batches)
    return result
 def fit(self, Z: torch.tensor, Y: torch.tensor, eps: float = 0.01) -> None:
     m, p = Z.shape
     k = len(Y.unique())
     self.m = m
     self.k = k
     self.subspaces = []
     for j in range(k):
         Zj = Z[Y == j]
         A = Zj.T.matmul(Zj)
         rank = torch.matrix_rank(A, symmetric=True)
         eigval, eigvec = torch.symeig(A, eigenvectors=True)
         subspace = eigvec[:,
                           -rank:]  # sorted in ascending order of eigenvalues
         self.subspaces.append(subspace)
Ejemplo n.º 16
0
def test_LSE_when_not_full_rank():
    m, n = 100, 30
    A = torch.randn(m, n)
    B = torch.randn(m, 1)

    A[:, 3] = A[:, 4] + A[:, 5]
    A[:, 0] = A[:, 7] + A[:, 10] + A[:, 21]
    print(torch.matrix_rank(A))
    w, mlis = get_w_by_LSE(A, B)
    print(w)
    print(mlis)
    diff = B - torch.mm(A[:, mlis], w)
    print(diff**2)
    print(torch.max(diff**2))
Ejemplo n.º 17
0
 def lstsq(a, b):
     if linalg_lstsq_avail:
         x, residuals, _, _ = torch.linalg.lstsq(a,
                                                 b,
                                                 rcond=None,
                                                 driver='gelsd')
         return x, residuals
     else:
         n = a.shape[1]
         sol = torch.lstsq(b, a)[0]
         x = sol[:n]
         residuals = torch.norm(sol[n:], dim=0)**2
         return x, residuals if torch.matrix_rank(a) == n else torch.tensor(
             [], device=x.device)
    def _fit(self, x_train, y_train):
        cls_data = [
            torch.cat(
                [x_train[j]
                 for j in range(len(y_train)) if y_train[j] == i], 1)
            for i in set(y_train)
        ]
        sub_basis = (pca_for_sets(cls_data, self.n_sdim,
                                  self.p_norm).contiguous().permute((2, 0, 1)))

        gram_mat = (sub_basis @ sub_basis.permute((0, 2, 1))).sum(0)
        full_dim = torch.matrix_rank(gram_mat)
        _, eig_vec = torch.symeig(gram_mat, eigenvectors=True)
        eig_vec = eig_vec.flip(1)[:, self.n_reducedim:full_dim]
        self.metric_mat = eig_vec @ eig_vec.T
        self.dic = ortha_subs(self.sub_basis, self.metric_mat)
Ejemplo n.º 19
0
    def generate_matrix(self, dim_embedding=None):
        if dim_embedding is None:
            dim_embedding = np.random.randint(self.min_dim_embedding,
                                              self.max_dim_embedding + 1)

        while True:
            # matrix = torch.rand(dim_embedding, self.dim_feature) - 0.5
            matrix = torch.normal(0, 1, size=(dim_embedding, self.dim_feature))

            matrix2 = torch.mm(torch.t(matrix), matrix)

            if torch.matrix_rank(matrix2) == torch.tensor(dim_embedding):
                # return matrix / matrix2.det()
                return matrix
            else:
                pass
Ejemplo n.º 20
0
def best_lambda(A):
    """
    Takes an under determined system and small lambda value,
    and comes up with lambda that makes the matrix A + lambda I
    invertible. Assuming A to be square matrix.
    """
    lamb = 1e-6
    cols = A.shape[0]

    for i in range(7):
        A_dash = A + lamb * torch.eye(cols, device=A.get_device())
        if cols == torch.matrix_rank(A_dash):
            # we achieved the required rank
            break
        else:
            # factor by which to increase the lambda. Choosing 10 for performance.
            lamb *= 10
    return lamb
Ejemplo n.º 21
0
    def get_orthogonal_out(self, x):
        m_sqrt = np.sqrt(x.shape[0])
        y = self.forward(x)

        rk = torch.matrix_rank(y).item()
        if rk < self.db['k']:
            import pdb
            pdb.set_trace()
        #	y[:, (self.db['k'] - rk):] = y[:, (self.db['k'] - rk):] + np.ran
        #print(rk)
        #import pdb; pdb.set_trace()

        YY = torch.mm(torch.t(y), y)
        L = torch.cholesky(YY)

        Li = m_sqrt * torch.t(torch.inverse(L))
        Yout = torch.mm(y, Li)
        return Yout
Ejemplo n.º 22
0
def lstq(Y, A, lamb=0.0):
    """
    Differentiable least square
    :param A: m x n
    :param Y: n x 1
    """
    # Assuming A to be full column rank
    cols = A.shape[1]
    if cols == torch.matrix_rank(A):
        q, r = torch.qr(A)
        x = torch.inverse(r) @ q.T @ Y
    else:
        A_dash = A.permute(1, 0) @ A + lamb * torch.eye(cols)
        Y_dash = A.permute(1, 0) @ Y
        #if Y_dash.dim() == 1:
        #  Y_dash = Y_dash.view(-1, 1)
        x = lstq(Y_dash, A_dash)
    return x
Ejemplo n.º 23
0
def nys_svrg(function,
             w,
             iteration=100,
             convergence=0.0001,
             learning_rate=torch.tensor(0.001, device='cuda'),
             lam=torch.tensor(0.01, device='cuda'),
             rho=torch.tensor(1, device='cuda')):
    initial = w.clone()
    hist = [None] * iteration
    C = torch.zeros(k, d).cuda()
    for i in range(iteration):
        previous_data = initial
        value = function(initial, lam)
        hist[i] = value
        idx = torch.randint(0, d, (k, )).cuda()
        grad = torch.autograd.grad(value, initial, create_graph=True)[0]
        for j in range(k):
            C[j] = torch.autograd.grad(grad[idx[j]],
                                       initial,
                                       create_graph=True)[0].t()
        W = C[:, idx]
        u, s, v = torch.svd(W, some=True)
        #u = u.cuda()
        s = torch.diag(s)
        #print(s)
        r = torch.matrix_rank(s)
        s = s[:r, :r]
        u = u[:, :r]
        s = torch.sqrt(torch.inverse(s))
        Z = torch.matmul(C.t(), torch.matmul(u, s))
        #print(Z.device)
        Q = 1 / (rho * rho) * torch.matmul(
            Z,
            torch.inverse(
                torch.eye(r, device='cuda') + torch.matmul(Z.t(), Z) / rho))
        #print(Q.device)
        initial -= learning_rate * (grad / rho -
                                    torch.matmul(Q, torch.matmul(Z.t(), grad)))
        print("epoch {}, obtain {}".format(i, value))
        if value < torch.tensor(convergence):
            print("break")
    return hist
Ejemplo n.º 24
0
def get_w_by_LSE(x, y):
    print("getting w by LSE")
    mlis = get_maximal_linearly_independent_system(x)
    if mlis == -1:
        return torch.ones(x.shape[1],1).cuda(), list(range(x.shape[1]))
    x = x[:, mlis]
    assert torch.matrix_rank(x) == x.shape[1]
    '''
    raw implementation
    '''
    # a = torch.matmul(torch.transpose(x,0,1),x)
    # a_inv = torch.inverse(a)
    # w = torch.chain_matmul(a_inv, torch.transpose(x, 0,1), y)

    '''
    now use torch.lstsq() to substitue the implemetation.
    because according to the test, the result are almost same, only negligible diff.
    '''
    w = torch.lstsq(y,x)[0][:x.shape[1]]
    return w, mlis
Ejemplo n.º 25
0
def gen_with_svd(mat_size, dtype):
    x = torch.randint(-5, 5, (mat_size, mat_size)).to(torch.cdouble)
    if dtype.is_complex:
        x += 1j * torch.randint(-5, 5, (mat_size, mat_size)).to(torch.cdouble)

    # Need at least one linearly dependent pair of rows
    x[1].copy_(x[0])

    k = torch.matrix_rank(x)
    assert k < mat_size

    u, _, v = x.svd()

    s = torch.randint(1, 10, (mat_size,)).to(torch.cdouble)
    if dtype.is_complex:
        s += 1j * torch.randint(1, 10, (mat_size,)).to(torch.cdouble)
    s[-k:] = 0

    matrix = (u * s.unsqueeze(-2)) @ v.transpose(-1, -2).conj()

    return matrix
Ejemplo n.º 26
0
def get_svd_ranks_acts(h0, delta_h_1, h1, x1, L, tol=None):
    columns = ['layer', 'h0', 'delta_h_1', 'h1', 'x1', 'max']
    df = pd.DataFrame(columns=columns, index=range(1, L + 2))
    df.index.name = 'layer'

    with torch.no_grad():
        for l in df.index:
            if tol is None:
                df.loc[l, columns] = [l,
                                      torch.matrix_rank(h0[l]).item(),
                                      torch.matrix_rank(delta_h_1[l]).item(),
                                      torch.matrix_rank(h1[l]).item(),
                                      torch.matrix_rank(x1[l]).item(),
                                      min(h0[l].shape[0], h0[l].shape[1])]
            else:
                df.loc[l, columns] = [l,
                                      torch.matrix_rank(h0[l], tol=tol).item(),
                                      torch.matrix_rank(delta_h_1[l], tol=tol).item(),
                                      torch.matrix_rank(h1[l], tol=tol).item(),
                                      torch.matrix_rank(x1[l], tol=tol).item(),
                                      min(h0[l].shape[0], h0[l].shape[1])]

    return df
Ejemplo n.º 27
0
    def command(self, state, choose_best=False):
        if not torch.is_tensor(state):
            state = torch.tensor(state)
        state = state.to(dtype=self.dtype, device=self.d)

        self.reset()

        for m in range(self.M):
            top_samples = self._sample_top_trajectories(state, self.num_elite)
            # fit the gaussian to those samples
            self.mean = torch.mean(top_samples, dim=0)
            self.cov = pytorch_cov(top_samples, rowvar=False)
            if torch.matrix_rank(self.cov) < self.cov.shape[0]:
                self.cov += self.cov_reg

        if choose_best and self.choose_best:
            top_sample = self._sample_top_trajectories(state, 1)
        else:
            top_sample = self.action_distribution.sample((1, ))

        # only apply the first action from this trajectory
        u = top_sample[0, self._slice_control(0)]

        return u
Ejemplo n.º 28
0
def get_svd_ranks_weights(W0, Delta_W_1, Delta_W_2, L, tol=None):
    columns = ['layer', 'W0', 'Delta_W_1', 'Delta_W_2', 'max']
    df = pd.DataFrame(columns=columns, index=range(1, L + 2))
    df.index.name = 'layer'

    with torch.no_grad():
        for l in df.index:
            if tol is None:
                df.loc[l, columns] = [l,
                                      torch.matrix_rank(W0[l]).item(),
                                      torch.matrix_rank(Delta_W_1[l]).item(),
                                      torch.matrix_rank(Delta_W_2[l]).item(),
                                      min(W0[l].shape[0], W0[l].shape[1])]
            else:
                df.loc[l, columns] = [l,
                                      torch.matrix_rank(W0[l], tol=tol).item(),
                                      torch.matrix_rank(Delta_W_1[l], tol=tol).item(),
                                      torch.matrix_rank(Delta_W_2[l], tol=tol).item(),
                                      min(W0[l].shape[0], W0[l].shape[1])]

    return df
Ejemplo n.º 29
0
def rank_torch(A, args=None):
    A_tensor = torch.from_numpy(A).cuda()
    return torch.matrix_rank(A_tensor).cpu().item()
Ejemplo n.º 30
0
SEED = np.random.randint(1000)
print(f"Running with seed {SEED}")
np.random.seed(SEED)
torch.manual_seed(SEED)

from IPython import display
get_ipython().run_line_magic('pdb', 'off')
get_ipython().run_line_magic('matplotlib', 'inline')


## Define functions


get_jacobian = torch.autograd.functional.jacobian

get_rank = lambda x: int(torch.matrix_rank(x.squeeze()))

get_twonn = lambda x: twonn_dimension(x.detach().numpy().squeeze())

## Train GAN

n_data = 1000
circle_input, _ = get_circle_data(n_data)
plane_input, _ = get_plane_data(n_data)

projection_input = torch.hstack([circle_input, plane_input])

Activation = torch.nn.ReLU
Optimizer = torch.optim.Adam
generator_kwargs = dict(activation_function=Activation)
discr, gener = train_gan(