def to_tensor(self): start = ord('d') in1_eq = [] in2_eq = [] out_eq = [] for i, s in enumerate(self.tensorized_shape): in1_eq.append(start + i) if isinstance(s, int): in2_eq.append(start + i) out_eq.append(start + i) else: in2_eq.append(start + self.order + i) out_eq.append(start + i) out_eq.append(start + self.order + i) in1_eq = ''.join(chr(i) for i in in1_eq) in2_eq = ''.join(chr(i) for i in in2_eq) out_eq = ''.join(chr(i) for i in out_eq) equation = f'a{in1_eq}b,b{in2_eq}c->a{out_eq}c' for i, factor in enumerate(self.factors): if not i: res = factor else: out_shape = list(res.shape) for i, s in enumerate(self.tensorized_shape): if not isinstance(s, int): out_shape[i + 1] *= factor.shape[i + 1] out_shape[-1] = factor.shape[-1] res = tl.reshape(tl.einsum(equation, res, factor), out_shape) return tl.reshape(res.squeeze(0).squeeze(-1), self.tensor_shape)
def linear_blocktt(tensor, tt_matrix, transpose=True): if transpose: contraction_axis = 1 else: contraction_axis = 0 ndim = len(tt_matrix.tensorized_shape[contraction_axis]) tensor = tensor.reshape(-1, *tt_matrix.tensorized_shape[contraction_axis]) bs = 'a' start = ord(bs) + 1 in_idx = bs + ''.join(chr(i) for i in [start + i for i in range(ndim)]) factors_idx = [] for i in range(ndim): if transpose: idx = [ start + ndim * 2 + i, start + ndim + i, start + i, start + ndim * 2 + i + 1 ] else: idx = [ start + ndim * 2 + i, start + i, start + ndim + i, start + ndim * 2 + i + 1 ] factors_idx.append(''.join(chr(j) for j in idx)) out_idx = bs + ''.join( chr(i) for i in [start + ndim + i for i in range(ndim)]) eq = in_idx + ',' + ','.join(i for i in factors_idx) + '->' + out_idx res = tl.einsum(eq, tensor, *tt_matrix.factors) return tl.reshape(res, (tl.shape(res)[0], -1))
def higher_order_moment(tensor, order): """Computes the Higher-Order Momemt Parameters ---------- tensor : 2D-tensor -- or ND-tensor matrix of size (n_samples, n_features) or tensor of size(n_samples, D1, ..., DN) order : int order of the higher-order moment to compute Returns ------- tensor : moment if tensor is a matrix of size (n_samples, n_features), tensor of size (n_features, )*order """ batch = ord('a') start = batch + 1 tensor_sym = [f'{chr(start+i)}' for i, _ in enumerate(tensor.shape)] out_sym = tensor_sym[1:] * order eq = ''.join(tensor_sym) + '->' + ''.join(out_sym) return tl.einsum(eq, tensor)
def tt_matrix_to_tensor(tt_matrix): """Returns the full tensor whose TT-Matrix decomposition is given by 'factors' Re-assembles 'factors', which represent a tensor in TT-Matrix format into the corresponding full tensor Parameters ---------- factors: list of 4D-arrays TT-Matrix factors (known as core) of shape (rank_k, left_dim_k, right_dim_k, rank_{k+1}) Returns ------- output_tensor: ndarray tensor whose TT-Matrix decomposition was given by 'factors' """ ndim = len(tt_matrix) order = list(range(0, ndim * 2, 2)) + list(range(1, ndim * 2, 2)) start_in = ord('a') start_out = start_in + ndim start_rank = start_out + ndim factors_idx = [] for i in range(ndim): idx = [start_rank + i, start_in + i, start_out + i, start_rank + i + 1] factors_idx.append(''.join(chr(j) for j in idx)) out_idx = ''.join( chr(start_in + i) + chr(start_out + i) for i in range(ndim)) eq = ','.join(idx for idx in factors_idx) + '->' + out_idx res = tl.einsum(eq, *tt_matrix) return tl.tranpose(res, order)
def tt_factorized_linear(tt_vec, ttm_weights): """Contracts a TT tensor with a TT matrix and returns a TT tensor. Parameters ---------- tt_vec : tensor train tensor ttm_weights : tensor train matrix Returns ------- The tensor train tensor obtained for contracting the TT tensor and the TT matrix. """ ncores = len(tt_vec) contr_layer = [] for i in range(ncores): dimW, dimX = ttm_weights[i].shape, tt_vec[i].shape contr = tl.einsum('abc,debf->adecf', tt_vec[i], ttm_weights[i]) contr_layer.append( tl.reshape(contr, (dimW[0] * dimX[0], dimW[1], dimW[3] * dimX[2]))) return TTTensor(contr_layer)
def tensor_dot_tucker(tensor, tucker, modes): modes_tensor, modes_tucker = _validate_contraction_modes( tl.shape(tensor), tucker.tensor_shape, modes) input_order = tensor.ndim weight_order = tucker.order sorted_modes_tucker = sorted(modes_tucker, reverse=True) sorted_modes_tensor = sorted(modes_tensor, reverse=True) # Symbol for dimensionality of the core rank_sym = [einsum_symbols[i] for i in range(weight_order)] # Symbols for tucker weight size tucker_sym = [ einsum_symbols[i + weight_order] for i in range(weight_order) ] # Symbolds for input tensor tensor_sym = [ einsum_symbols[i + 2 * weight_order] for i in range(tensor.ndim) ] # Output: input + weights symbols after removing contraction symbols output_sym = tensor_sym + tucker_sym for m in sorted_modes_tucker: output_sym.pop(m + input_order) for m in sorted_modes_tensor: output_sym.pop(m) for i, e in enumerate(modes_tensor): tensor_sym[e] = tucker_sym[modes_tucker[i]] # Form the actual equation: tensor, core, factors -> output eq = ''.join(tensor_sym) eq += ',' + ''.join(rank_sym) eq += ',' + ','.join(f'{s}{r}' for s, r in zip(tucker_sym, rank_sym)) eq += '->' + ''.join(output_sym) return tl.einsum(eq, tensor, tucker.core, *tucker.factors)
def tensordot(tensor1, tensor2, modes, batched_modes=()): """Batched tensor contraction between two tensors on specified modes Parameters ---------- tensor1 : tl.tensor tensor2 : tl.tensor modes : int list or int modes on which to contract tensor1 and tensor2 batched_modes : int or tuple[int] Returns ------- contraction : tensor1 contracted with tensor2 on the specified modes """ modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes) batch_modes1, batch_modes2 = _validate_contraction_modes( tensor1.shape, tensor2.shape, batched_modes, batched_modes=True) start = ord('a') order_t1 = tl.ndim(tensor1) all_modes1 = [chr(start + i) for i in range(order_t1)] all_modes2 = [chr(start + i + order_t1) for i in range(tl.ndim(tensor2))] for m1, m2 in zip(modes1 + batch_modes1, modes2 + batch_modes2): all_modes2[m2] = all_modes1[m1] remaining_modes1 = [j for i, j in enumerate(all_modes1) if i not in modes1] remaining_modes2 = [ j for i, j in enumerate(all_modes2) if i not in modes2 + batch_modes2 ] remaining_modes = remaining_modes1 + remaining_modes2 to_str = lambda x: ''.join(x) equation = f'{to_str(all_modes1)},{to_str(all_modes2)}->{to_str(remaining_modes)}' return tl.einsum(equation, tensor1, tensor2)
def tensor_dot_cp(tensor, cp, modes): """Contracts a to CP tensors in factorized form Returns ------- tensor = tensor x cp_matrix.to_matrix().T """ try: cp_shape = cp.tensor_shape except AttributeError: cp_shape = cp.shape modes_tensor, modes_cp = _validate_contraction_modes( tl.shape(tensor), cp_shape, modes) tensor_order = tl.ndim(tensor) # CP rank = 'a', start at b start = ord('b') eq_in = ''.join(f'{chr(start+index)}' for index in range(tensor_order)) eq_factors = [] eq_res = ''.join(eq_in[i] if i not in modes_tensor else '' for i in range(tensor_order)) counter_joint = 0 # contraction modes, shared indices between tensor and CP counter_free = 0 # new uncontracted modes from the CP for i in range(len(cp.factors)): if i in modes_cp: eq_factors.append(f'{eq_in[modes_tensor[counter_joint]]}a') counter_joint += 1 else: eq_factors.append(f'{chr(start+tensor_order+counter_free)}a') eq_res += f'{chr(start+tensor_order+counter_free)}' counter_free += 1 eq_factors = ','.join(f for f in eq_factors) eq = eq_in + ',a,' + eq_factors + '->' + eq_res res = tl.einsum(eq, tensor, cp.weights, *cp.factors) return res
def __getitem__(self, indices): factors = self.factors if not isinstance(indices, Iterable): indices = [indices] if len(indices) < self.ndim: indices = list(indices) indices.extend([slice(None)] * (self.ndim - len(indices))) elif len(indices) > self.ndim: indices = [indices] # We're only indexing the first dimension output_shape = [] indexed_factors = [] ndim = len(self.factors) indexed_ndim = len(indices) contract_factors = False # If True, the result is dense, we need to form the full result contraction_op = [] # Whether the operation is batched or not eq_in1 = 'a' # Previously contracted factors (rank_0, dim_0, ..., dim_N, rank_k) eq_in2 = 'b' # Current factor (rank_k, dim_0', ..., dim_N', rank_{k+1}) eq_out = 'a' # Output contracted factor (rank_0, dim_0", ..., dim_N", rank_{k_1}) # where either: # i. dim_k" = dim_k' = dim_k (contraction_op='b' for batched) # or ii. dim_k" = dim_k' x dim_k (contraction_op='m' for multiply) idx = ord('d') # Current character we can use for contraction pad = ( slice(None), ) # index previous dimensions with [:], to avoid using .take(dim=k) add_pad = False # whether to increment the padding post indexing for (index, shape) in zip(indices, self.tensorized_shape): if isinstance(shape, int): # We are indexing a "batched" mode, not a tensorized one if not isinstance(index, (np.integer, int)): if isinstance(index, slice): index = list(range(*index.indices(shape))) output_shape.append(len(index)) add_pad = True contraction_op += 'b' # batched eq_in1 += chr(idx) eq_in2 += chr(idx) eq_out += chr(idx) idx += 1 # else: we've essentially removed a mode of each factor index = [index] * ndim else: # We are indexing a tensorized mode if index == slice(None) or index == (): # Keeping all indices (:) output_shape.append(shape) eq_in1 += chr(idx) eq_in2 += chr(idx + 1) eq_out += chr(idx) + chr(idx + 1) idx += 2 add_pad = True index = [index] * ndim contraction_op += 'm' # multiply else: contract_factors = True if isinstance(index, slice): # Since we've already filtered out :, this is a partial slice # Convert into list max_index = math.prod(shape) index = list(range(*index.indices(max_index))) if isinstance(index, Iterable): output_shape.append(len(index)) contraction_op += 'b' # multiply eq_in1 += chr(idx) eq_in2 += chr(idx) eq_out += chr(idx) idx += 1 add_pad = True index = np.unravel_index(index, shape) # Index the whole tensorized shape, resulting in a single factor factors = [ff[pad + (idx, )] for (ff, idx) in zip(factors, index) ] # + factors[indexed_ndim:] if add_pad: pad += (slice(None), ) add_pad = False # output_shape.extend(self.tensorized_shape[indexed_ndim:]) if contract_factors: eq_in2 += 'c' eq_in1 += 'b' eq_out += 'c' eq = eq_in1 + ',' + eq_in2 + '->' + eq_out for i, factor in enumerate(factors): if not i: res = factor else: out_shape = list(res.shape) for j, s in enumerate(factor.shape[1:-1]): if contraction_op[j] == 'm': out_shape[j + 1] *= s out_shape[-1] = factor.shape[-1] # Last rank res = tl.reshape(tl.einsum(eq, res, factor), out_shape) return res.squeeze() else: return self.__class__(factors, output_shape, self.rank)