def init_from_tensor(self, tensor, **kwargs): with torch.no_grad(): # TODO: deal properly with wrong kwargs factors = tensor_train(tensor, self.rank) self.factors = FactorList([nn.Parameter(f) for f in factors]) return self
def from_tensor(cls, tensor, tensorized_row_shape, tensorized_column_shape, rank='same', **kwargs): full_shape = tensorized_row_shape + tensorized_column_shape n_matrices = _ensure_tuple(tensor.shape[:-len(full_shape)]) rank = tl.tt_tensor.validate_tt_rank(n_matrices + full_shape, rank) with torch.no_grad(): factors = tensor_train(tensor, rank, **kwargs) return cls([nn.Parameter(f) for f in factors], tensorized_row_shape, tensorized_column_shape, rank=rank, n_matrices=n_matrices)
def from_tensor(cls, tensor, rank='same', **kwargs): shape = tensor.shape rank = tl.tt_tensor.validate_tt_rank(shape, rank) with torch.no_grad(): # TODO: deal properly with wrong kwargs factors = tensor_train(tensor, rank) return cls([nn.Parameter(f) for f in factors])
def build_network(five_tuple, k, chi, omega): s, a, P, R, gamma = five_tuple H = [] H_core = [] data = softmax_by_state(torch.randn((s * a, 1)), s, a) data.requires_grad = True for i in range(k): H.append(initialize_H(five_tuple, i + 1)) if i >= 1: factors = tensor_train(H[i].get_tensor(), chi).factors combined_cores, _ = put_mps(factors) H_core.append(combined_cores) else: masked_H = H[i].get_tensor() * omega[i] H_core.append([tn.Node(masked_H, backend=backend)]) return H, H_core, data
def init_from_tensor(self, tensor, **kwargs): with torch.no_grad(): factors = tensor_train(tensor, self.rank, **kwargs) self.factors = FactorList([nn.Parameter(f) for f in factors]) return self