def _tucker3(X, n_components, tol, max_iter, init_type, random_state=None): """ 3 dimensional Tucker decomposition. This code is meant to be a tutorial/testing example... in general _tuckerN should be more compact and equivalent mathematically. """ if len(X.shape) != 3: raise ValueError("Tucker3 decomposition only supports 3 dimensions!") if init_type == "random": A, B, C = _random_init(X, n_components, random_state) elif init_type == "hosvd": A, B, C = _hosvd_init(X, n_components) err = 1E10 X_sq = np.sum(X ** 2) for itr in range(max_iter): err_old = err U, S, V = linalg.svd(matricize(X, 0).dot(np.kron(C, B)), full_matrices=False) A = U[:, :n_components] U, S, V = linalg.svd(matricize(X, 1).dot(np.kron(C, A)), full_matrices=False) B = U[:, :n_components] U, S, V = linalg.svd(matricize(X, 2).dot(np.kron(B, A)), full_matrices=False) C = U[:, :n_components] G = tmult(tmult(tmult(X, A.T, 0), B.T, 1), C.T, 2) err = np.sum(G ** 2) - X_sq thresh = np.abs(err - err_old) / err_old if thresh < tol: break return G, A, B, C
def _tuckerN(X, n_components, tol, max_iter, init_type, random_state=None): """Generalized Tucker decomposition.""" if init_type == "random": components = _random_init(X, n_components, random_state) elif init_type == "hosvd": components = _hosvd_init(X, n_components) err = 1E10 X_sq = np.sum(X ** 2) def mod_tmult(arg0, arg1): return tmult(arg0, arg1[0], arg1[1]) for itr in range(max_iter): err_old = err for idx in range(len(components)): components_sublist = [components[n] for n in range(len(components)) if n != idx] p1 = reduce(np.kron, components_sublist[:-1][::-1], components_sublist[-1]) U, S, V = linalg.svd(matricize(X, idx).dot(p1), full_matrices=False) components[idx] = U[:, :n_components] mod_components = [(c.T, idx) for idx, c in enumerate(components)] G = reduce(mod_tmult, mod_components[1:], tmult(X, *mod_components[0])) err = np.sum(G ** 2) - X_sq thresh = np.abs(err - err_old) / err_old if thresh < tol: break ret = [G] ret.extend(components) return ret
def mod_tmult(arg0, arg1): return tmult(arg0, arg1[0], arg1[1])