Esempio n. 1
0
    def init_from_tensor(self, tensor, **kwargs):
        with torch.no_grad():
            # TODO: deal properly with wrong kwargs
            factors = tensor_train(tensor, self.rank)

        self.factors = FactorList([nn.Parameter(f) for f in factors])
        return self
Esempio n. 2
0
    def from_tensor(cls, tensor, tensorized_row_shape, tensorized_column_shape, rank='same', **kwargs):
        full_shape = tensorized_row_shape + tensorized_column_shape
        n_matrices = _ensure_tuple(tensor.shape[:-len(full_shape)])
        rank = tl.tt_tensor.validate_tt_rank(n_matrices + full_shape, rank)

        with torch.no_grad():
            factors = tensor_train(tensor, rank, **kwargs)
        
        return cls([nn.Parameter(f) for f in factors], tensorized_row_shape, tensorized_column_shape, rank=rank, n_matrices=n_matrices)
Esempio n. 3
0
    def from_tensor(cls, tensor, rank='same', **kwargs):
        shape = tensor.shape
        rank = tl.tt_tensor.validate_tt_rank(shape, rank)

        with torch.no_grad():
            # TODO: deal properly with wrong kwargs
            factors = tensor_train(tensor, rank)

        return cls([nn.Parameter(f) for f in factors])
Esempio n. 4
0
def build_network(five_tuple, k, chi, omega):
    s, a, P, R, gamma = five_tuple
    H = []
    H_core = []
    data = softmax_by_state(torch.randn((s * a, 1)), s, a)
    data.requires_grad = True

    for i in range(k):
        H.append(initialize_H(five_tuple, i + 1))
        if i >= 1:
            factors = tensor_train(H[i].get_tensor(), chi).factors
            combined_cores, _ = put_mps(factors)
            H_core.append(combined_cores)
        else:
            masked_H = H[i].get_tensor() * omega[i]
            H_core.append([tn.Node(masked_H, backend=backend)])
    return H, H_core, data
Esempio n. 5
0
 def init_from_tensor(self, tensor, **kwargs):
     with torch.no_grad():
         factors = tensor_train(tensor, self.rank, **kwargs)
     
     self.factors = FactorList([nn.Parameter(f) for f in factors])
     return self