def mvpca(Xs, cross_covariance=True): if cross_covariance and len(set(X.size(1) for X in Xs)) > 1: raise ValueError( f"MvPCA with cross covariance only works with multiview data with " f"same dimensions, got dimensions of {tuple(X.size(1) for X in Xs)}" ) if _has_ops(): return torchsl.ops.mvpca(Xs, cross_covariance) options = dict(dtype=Xs[0].dtype, device=Xs[0].device) num_components = Xs[0].size(0) dims = dimensions(Xs) Xs_centered = [X - X.mean(0) for X in Xs] if cross_covariance: cov = multiview_covariance_matrix( dims, lambda j, r: Xs_centered[j].t().mm(Xs_centered[r]), options) else: cov = multiview_covariance_matrix( dims, lambda j, r: Xs_centered[j].t().mm(Xs_centered[r]) if j == r else 0, options) cov /= num_components - 1 SI = torch.eye(sum(dims), **options) return SI, cov
def pca(X): if _has_ops(): return torchsl.ops.pca(X) options = dict(dtype=X.dtype, device=X.device) num_samples = X.size(0) dim = X.size(1) X_centered = X - X.mean(0) cov = X_centered.t().mm(X_centered).div(num_samples - 1) SI = torch.eye(dim, **options) return SI, cov
def mvpls(Xs): if _has_ops(): return torchsl.ops.mvpls(Xs) options = dict(dtype=Xs[0].dtype, device=Xs[0].device) num_views = len(Xs) num_components = Xs[0].size(0) num_samples = num_views * num_components dims = dimensions(Xs) I = torch.eye(num_components, **options) B = torch.ones(num_components, num_components, **options) / num_samples SI = torch.eye(dims.sum(), **options) Sb = multiview_covariance_matrix( dims, lambda j, r: Xs[j].t().mm(I - B).mm(Xs[r]) if j != r else 0, options) return SI, Sb
def pclda(X, y, y_unique=None, beta=1, q=1): if y_unique is None: y_unique = torch.unique(y) if _has_ops(): return torchsl.ops.pclda(X, y, y_unique, beta, q) options = dict(dtype=X.dtype, device=X.device) num_samples = y.size(0) num_classes = y_unique.size(0) ecs = class_vectors(y, y_unique).to(dtype=options['dtype']) ucs = class_means(X, ecs) y_unique_counts = ecs.sum(1) out_dimension = X.size(1) pairs = torch.combinations(torch.arange(num_classes, dtype=torch.long), r=2) class_W = torch.empty(num_classes, num_samples, num_samples, **options) class_I = torch.empty(num_classes, num_samples, num_samples, **options) for ci in range(num_classes): class_W[ci] = ecs[ci].unsqueeze(0).t().mm(ecs[ci].unsqueeze(0)).div_(y_unique_counts[ci]) class_I[ci] = torch.eye(num_samples, **options) * ecs[ci] W = class_W.sum(dim=0) I = torch.eye(num_samples, **options) class_Sw = torch.empty(num_classes, out_dimension, out_dimension, **options) for ci in range(num_classes): class_Sw[ci] = X.t().mm(class_I[ci] - class_W[ci]).mm(X) Sw = X.t().mm(I - W).mm(X) out = 0 for ca, cb in pairs: Sw_ab = beta * (y_unique_counts[ca] * class_Sw[ca] + y_unique_counts[cb] * class_Sw[cb]) Sw_ab.div_(y_unique_counts[ca] + y_unique_counts[cb]).add_((1 - beta) * Sw) du_ab = ucs[ca].sub(ucs[cb]).unsqueeze_(0) # Sb_ab = du_ab.t().mm(du_ab) # out += y_unique_counts[ca] * y_unique_counts[cb] * (torch.trace(Sb_ab) / torch.trace(Sw_ab)).pow_(-q) out += y_unique_counts[ca] * y_unique_counts[cb] * (du_ab.mm(Sw_ab.inverse()).mm(du_ab.t())).pow_(-q) out /= num_samples * num_samples return out
def mvda(Xs, y, y_unique=None, alpha_vc=0, reg_vc=1e-4): if y_unique is None: y_unique = torch.unique(y) if _has_ops(): return torchsl.ops.mvda(Xs, y, y_unique, alpha_vc, reg_vc) options = dict(dtype=Xs[0].dtype, device=Xs[0].device) num_views = len(Xs) num_components = y.size(0) num_samples = num_views * num_components num_classes = y_unique.size(0) ecs = class_vectors(y, y_unique).to(dtype=options['dtype']) y_unique_counts = ecs.sum(1) dims = dimensions(Xs) W = torch.zeros(num_components, num_components, **options) for ci in range(num_classes): W += ecs[ci].unsqueeze(1).mm(ecs[ci].unsqueeze(0)).div( num_views * y_unique_counts[ci]) I = torch.eye(num_components, **options) B = torch.ones(num_components, num_components, **options) / num_samples Sw = multiview_covariance_matrix( dims, lambda j, r: Xs[j].t().mm(I - W).mm(Xs[r]) if j == r else Xs[j].t().mm(-W).mm(Xs[r]), options) Sb = multiview_covariance_matrix( dims, lambda j, r: Xs[j].t().mm(W - B).mm(Xs[r]), options) if alpha_vc != 0: Ps = [Xs[vi].mm(Xs[vi].t()) for vi in range(num_views)] Ps = [(Ps[vi].add_( torch.zeros(num_components, num_components, **options).fill_diagonal_( reg_vc * Ps[vi].trace()))).inverse().mm(Xs[vi]) for vi in range(num_views)] Sw += multiview_covariance_matrix( dims, lambda j, r: 2 * (num_views - 1) * Ps[j].t().mm(Ps[r]) if j == r else -2 * Ps[j].t().mm(Ps[r]), options) * alpha_vc return Sw, Sb
def mvmda(Xs, y, y_unique=None): if y_unique is None: y_unique = torch.unique(y) if _has_ops(): return torchsl.ops.mvmda(Xs, y, y_unique) options = dict(dtype=Xs[0].dtype, device=Xs[0].device) num_views = len(Xs) num_components = y.size(0) num_classes = y_unique.size(0) ecs = class_vectors(y, y_unique).to(dtype=options['dtype']) y_unique_counts = ecs.sum(1) dims = dimensions(Xs) W = torch.zeros(num_components, num_components, **options) J = torch.zeros_like(W) B = torch.zeros_like(W) for ci in range(num_classes): tmp = ecs[ci].unsqueeze(1).mm(ecs[ci].unsqueeze(0)).div( y_unique_counts[ci]) W += tmp.div(y_unique_counts[ci]) J += tmp del tmp W *= num_classes / (num_views * num_views) J /= num_views for ca in range(num_classes): for cb in range(num_classes): B += ecs[ca].unsqueeze(1).mm(ecs[cb].unsqueeze(0)).div( y_unique_counts[ca] * y_unique_counts[cb]) B /= num_views * num_views I = torch.eye(num_components, **options) Sw = multiview_covariance_matrix( dims, lambda j, r: Xs[j].t().mm(I - J).mm(Xs[r]) if j == r else Xs[j].t().mm(-J).mm(Xs[r]), options) Sb = multiview_covariance_matrix( dims, lambda j, r: Xs[j].t().mm(W - B).mm(Xs[r]), options) return Sw, Sb
def lda(X, y, y_unique=None): if y_unique is None: y_unique = torch.unique(y) if _has_ops(): return torchsl.ops.lda(X, y, y_unique) options = dict(dtype=X.dtype, device=X.device) num_samples = X.size(0) num_classes = y_unique.size(0) ecs = class_vectors(y, y_unique).to(dtype=options['dtype']) y_unique_counts = ecs.sum(1) W = torch.zeros(num_samples, num_samples, **options) for ci in range(num_classes): W += ecs[ci].unsqueeze(1).mm(ecs[ci].unsqueeze(0)).div_( y_unique_counts[ci]) I = torch.eye(num_samples, **options) B = torch.ones(num_samples, num_samples, **options) / num_samples Sw = X.t().mm(I - W).mm(X) Sb = X.t().mm(W - B).mm(X) return Sw, Sb
def pcmvda_loss(Ys, y, y_unique=None, beta=1, q=1): if len(set(X.size(1) for X in Ys)) > 1: raise ValueError( f"pc-MvDA only works on projected data with same dimensions, " f"got dimensions of {tuple(X.size(1) for X in Ys)}.") if y_unique is None: y_unique = torch.unique(y) if _has_ops(): return torch.ops.torchsl.pcmvda(Ys, y, y_unique, beta, q) Xs_cat = torch.cat(Ys, 0) options = dict(dtype=Xs_cat.dtype, device=Xs_cat.device) num_views = len(Ys) num_components = y.size(0) num_classes = y_unique.size(0) ecs = class_vectors(y, y_unique).to(dtype=options['dtype']) ucs = class_means(Ys, ecs) y_unique_counts = ecs.sum(1) out_dimension = Xs_cat.size(1) covariance_dims = torch.tensor(num_components, dtype=torch.long).repeat(num_views) pairs = torch.combinations(torch.arange(num_classes, dtype=torch.long), r=2) class_W = torch.empty(num_classes, num_components, num_components, **options) class_I = torch.empty(num_classes, num_components, num_components, **options) for ci in range(num_classes): class_W[ci] = ecs[ci].unsqueeze(0).t().mm(ecs[ci].unsqueeze(0)).div_( num_views * y_unique_counts[ci]) class_I[ci] = torch.eye(num_components, **options) * ecs[ci] W = class_W.sum(dim=0) I = torch.eye(num_components, **options) class_Sw = torch.empty(num_classes, out_dimension, out_dimension, **options) for ci in range(num_classes): class_Sw[ci] = Xs_cat.t().mm( multiview_covariance_matrix( covariance_dims, lambda j, r: class_I[ci] - class_W[ci] if j == r else -class_W[ci], options)).mm(Xs_cat) Sw = Xs_cat.t().mm( multiview_covariance_matrix(covariance_dims, lambda j, r: I - W if j == r else -W, options)).mm(Xs_cat) out = 0 for ca, cb in pairs: Sw_ab = beta * (y_unique_counts[ca] * class_Sw[ca] + y_unique_counts[cb] * class_Sw[cb]) Sw_ab.div_(y_unique_counts[ca] + y_unique_counts[cb]).add_( (1 - beta) * Sw) du_ab = sum(uc[ca] for uc in ucs).sub(sum( uc[cb] for uc in ucs)).div_(num_views).unsqueeze_(0) Sb_ab = du_ab.t().mm(du_ab) out += y_unique_counts[ca] * y_unique_counts[cb] * ( torch.trace(Sb_ab) / torch.trace(Sw_ab)).pow_(-q) # out += y_unique_counts[ca] * y_unique_counts[cb] * (du_ab.mm(Sw_ab.inverse()).mm(du_ab.t())).pow_(-q) out /= num_components * num_components return out