コード例 #1
0
    def initialize(self,
                   input_dim: int,
                   hidden_dim: int,
                   init_scale: float = 0.001,
                   basis: coo_matrix = None,
                   encoder_depth: int = 1,
                   imputer: Callable[[torch.Tensor], torch.Tensor] = None,
                   batch_size: int = 10,
                   bias=True):
        self.hidden_dim = hidden_dim
        self.bias = bias
        # Psi must be dimension D - 1 x D
        if basis is None:
            tree = random_linkage(input_dim)
            basis = sparse_balance_basis(tree)[0].copy()
        indices = np.vstack((basis.row, basis.col))
        Psi = torch.sparse_coo_tensor(indices.copy(),
                                      basis.data.astype(np.float32).copy(),
                                      requires_grad=False)

        # Psi.requires_grad = False
        self.input_dim = Psi.shape[0]
        if imputer is None:
            self.imputer = lambda x: x + 1

        if encoder_depth > 1:
            self.first_encoder = nn.Linear(self.input_dim,
                                           hidden_dim,
                                           bias=self.bias)
            num_encoder_layers = encoder_depth
            layers = []
            layers.append(self.first_encoder)
            for layer_i in range(num_encoder_layers - 1):
                layers.append(nn.Softplus())
                layers.append(nn.Linear(hidden_dim, hidden_dim,
                                        bias=self.bias))
            self.encoder = nn.Sequential(*layers)

            # initialize
            for encoder_layer in self.encoder:
                if isinstance(encoder_layer, nn.Linear):
                    encoder_layer.weight.data.normal_(0.0, init_scale)

        else:
            self.encoder = nn.Linear(self.input_dim,
                                     hidden_dim,
                                     bias=self.bias)
            self.encoder.weight.data.normal_(0.0, init_scale)

        self.decoder = nn.Linear(hidden_dim, self.input_dim, bias=False)
        self.variational_logvars = nn.Parameter(torch.zeros(hidden_dim))
        self.log_sigma_sq = nn.Parameter(torch.tensor(0.01))
        self.eta = nn.Parameter(torch.zeros(batch_size, self.input_dim))
        self.eta.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)
        zI = torch.ones(self.hidden_dim).to(self.eta.device)
        zm = torch.zeros(self.hidden_dim).to(self.eta.device)
        self.register_buffer('Psi', Psi)
        self.register_buffer('zI', zI)
        self.register_buffer('zm', zm)
コード例 #2
0
ファイル: util.py プロジェクト: flatironinstitute/catvae
def extract_observation_embeddings(vae_model, tree, return_type='dataframe'):
    """ Extracts observation embeddings from model (i.e. OTUs).

    The observation embeddings are all represented in CLR coordinates.

    Parameters
    ----------
    vae_model : MultVAE
        Pretrained Multinomial VAE
    tree : skbio.TreeNode
        The tree used to train the VAE
    return_type : str
        Options include 'tensor', 'array', 'dataframe' (default='dataframe')
    """
    # ILR representation of the VAE decoder loadings
    W = vae_model.vae.decoder.weight
    Psi, _ = sparse_balance_basis(tree)
    if return_type == 'torch':
        indices = np.vstack((Psi.row, Psi.col))
        Psi = torch.sparse_coo_tensor(indices.copy(),
                                      Psi.data.astype(np.float32).copy(),
                                      requires_grad=False).coalesce()
        return Psi.T @ W
    if return_type == 'array':
        return Psi.T @ W.detach().numpy()
    if return_type == 'dataframe':
        names = [n.name for n in tree.tips()]
        return pd.DataFrame(Psi.T @ W.detach().numpy(), index=names)
    else:
        ValueError(f'return type {return_type} is not supported.')
コード例 #3
0
def get_basis(input_dim, basis=None):
    if basis is None:
        tree = random_linkage(input_dim)
        basis = sparse_balance_basis(tree)[0].copy()
    indices = np.vstack((basis.row, basis.col))
    Psi = torch.sparse_coo_tensor(
        indices.copy(), basis.data.astype(np.float32).copy(),
        requires_grad=False).coalesce()
    return Psi
コード例 #4
0
    def test_sparse_balance_basis_base_case(self):
        tree = u"(a,b);"
        t = TreeNode.read([tree])

        exp_basis = coo_matrix(np.array([[-np.sqrt(1. / 2), np.sqrt(1. / 2)]]))
        exp_keys = [t.name]
        res_basis, res_keys = sparse_balance_basis(t)

        assert_coo_allclose(exp_basis, res_basis)
        self.assertListEqual(exp_keys, res_keys)
コード例 #5
0
    def test_sparse_balance_basis_unbalanced(self):
        tree = u"((a,b)c, d);"
        t = TreeNode.read([tree])
        exp_basis = coo_matrix(
            np.array([[-np.sqrt(1. / 6), -np.sqrt(1. / 6),
                       np.sqrt(2. / 3)],
                      [-np.sqrt(1. / 2), np.sqrt(1. / 2), 0]]))
        exp_keys = [t.name, t[0].name]
        res_basis, res_keys = sparse_balance_basis(t)

        assert_coo_allclose(exp_basis, res_basis)
        self.assertListEqual(exp_keys, res_keys)
コード例 #6
0
ファイル: test_balances.py プロジェクト: biocore/gneiss
    def test_sparse_balance_basis_unbalanced(self):
        tree = u"((a,b)c, d);"
        t = TreeNode.read([tree])
        exp_basis = coo_matrix(np.array(
            [[-np.sqrt(1. / 6), -np.sqrt(1. / 6), np.sqrt(2. / 3)],
             [-np.sqrt(1. / 2), np.sqrt(1. / 2), 0]]
        ))
        exp_keys = [t.name, t[0].name]
        res_basis, res_keys = sparse_balance_basis(t)

        assert_coo_allclose(exp_basis, res_basis)
        self.assertListEqual(exp_keys, res_keys)
コード例 #7
0
ファイル: test_balances.py プロジェクト: biocore/gneiss
    def test_sparse_balance_basis_base_case(self):
        tree = u"(a,b);"
        t = TreeNode.read([tree])

        exp_basis = coo_matrix(
            np.array([[-np.sqrt(1. / 2),
                       np.sqrt(1. / 2)]]))
        exp_keys = [t.name]
        res_basis, res_keys = sparse_balance_basis(t)

        assert_coo_allclose(exp_basis, res_basis)
        self.assertListEqual(exp_keys, res_keys)
コード例 #8
0
ファイル: linear_vae.py プロジェクト: mortonjt/metvae
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 init_scale=0.001,
                 use_analytic_elbo=True,
                 encoder_depth=1,
                 likelihood='gaussian',
                 basis=None,
                 bias=False):
        super(LinearVAE, self).__init__()
        self.bias = bias
        self.hidden_dim = hidden_dim
        self.likelihood = likelihood
        self.use_analytic_elbo = use_analytic_elbo

        if basis is None:
            tree = random_linkage(input_dim)
            basis = sparse_balance_basis(tree)[0].copy()
        indices = np.vstack((basis.row, basis.col))
        Psi = torch.sparse_coo_tensor(indices.copy(),
                                      basis.data.astype(np.float32).copy(),
                                      requires_grad=False)
        self.input_dim = Psi.shape[0]
        self.register_buffer('Psi', Psi)

        if encoder_depth > 1:
            self.first_encoder = nn.Linear(self.input_dim,
                                           hidden_dim,
                                           bias=self.bias)
            num_encoder_layers = encoder_depth
            layers = []
            layers.append(self.first_encoder)
            for layer_i in range(num_encoder_layers - 1):
                layers.append(nn.Softplus())
                layers.append(nn.Linear(hidden_dim, hidden_dim,
                                        bias=self.bias))
            self.encoder = nn.Sequential(*layers)

            # initialize
            for encoder_layer in self.encoder:
                if isinstance(encoder_layer, nn.Linear):
                    encoder_layer.weight.data.normal_(0.0, init_scale)

        else:
            self.encoder = nn.Linear(self.input_dim,
                                     hidden_dim,
                                     bias=self.bias)
            self.encoder.weight.data.normal_(0.0, init_scale)

        self.decoder = nn.Linear(hidden_dim, self.input_dim, bias=self.bias)
        self.imputer = lambda x: x + 1
        self.variational_logvars = nn.Parameter(torch.zeros(hidden_dim))
        self.log_sigma_sq = nn.Parameter(torch.tensor(0.0))
コード例 #9
0
    def test_sparse_balance_basis_unbalanced2(self):
        tree = u"(d, (a,b)c);"

        t = TreeNode.read([tree])

        exp_basis = coo_matrix(
            np.array([[-np.sqrt(2. / 3),
                       np.sqrt(1. / 6),
                       np.sqrt(1. / 6)],
                      [0, -np.sqrt(1. / 2),
                       np.sqrt(1. / 2)]]))

        exp_keys = [t.name, t[1].name]
        res_basis, res_keys = sparse_balance_basis(t)
        assert_coo_allclose(exp_basis, res_basis, atol=1e-7, rtol=1e-7)
        self.assertListEqual(exp_keys, res_keys)
コード例 #10
0
ファイル: test_balances.py プロジェクト: biocore/gneiss
    def test_sparse_balance_basis_unbalanced2(self):
        tree = u"(d, (a,b)c);"

        t = TreeNode.read([tree])

        exp_basis = coo_matrix(np.array(
            [
                [-np.sqrt(2. / 3), np.sqrt(1. / 6), np.sqrt(1. / 6)],
                [0, -np.sqrt(1. / 2), np.sqrt(1. / 2)]
            ]
        ))

        exp_keys = [t.name, t[1].name]
        res_basis, res_keys = sparse_balance_basis(t)
        assert_coo_allclose(exp_basis, res_basis, atol=1e-7, rtol=1e-7)
        self.assertListEqual(exp_keys, res_keys)
コード例 #11
0
 def test_sparse_balance_basis_invalid(self):
     with self.assertRaises(ValueError):
         tree = u"(a,b,c);"
         t = TreeNode.read([tree])
         sparse_balance_basis(t)
コード例 #12
0
ファイル: test_balances.py プロジェクト: biocore/gneiss
 def test_sparse_balance_basis_invalid(self):
     with self.assertRaises(ValueError):
         tree = u"(a,b,c);"
         t = TreeNode.read([tree])
         sparse_balance_basis(t)
コード例 #13
0
def ilr_basis(nwk):
    tree = TreeNode.read(nwk)
    t = tree.copy()
    t.bifurcate()
    basis = sparse_balance_basis(tree)[0]
    return basis