def _compute_projections(tensor_slices, factors, svd_fun, out=None): A, B, C = factors if out is None: out = [ T.zeros((tensor_slice.shape[0], C.shape[1]), **T.context(tensor_slice)) for tensor_slice in tensor_slices ] slice_idxes = range(T.shape(A)[0]) for projection, i, tensor_slice in zip(out, slice_idxes, tensor_slices): a_i = A[i] lhs = T.dot(B, T.transpose(a_i * C)) rhs = T.transpose(tensor_slice) U, S, Vh = svd_fun(T.dot(lhs, rhs), n_eigenvecs=A.shape[1]) out[i] = tl.index_update(projection, tl.index[:], T.transpose(T.dot(U, Vh))) return out
def _project_tensor_slices(tensor_slices, projections, out=None): if out is None: rank = projections[0].shape[1] num_slices = len(tensor_slices) num_cols = tensor_slices[0].shape[1] out = T.zeros((num_slices, rank, num_cols), **T.context(tensor_slices[0])) for i, (tensor_slice, projection) in enumerate(zip(tensor_slices, projections)): slice_ = T.dot(T.transpose(projection), tensor_slice) out = tl.index_update(out, tl.index[i, :], slice_) return out