def initialize(self, input_dim: int, hidden_dim: int, init_scale: float = 0.001, basis: coo_matrix = None, encoder_depth: int = 1, imputer: Callable[[torch.Tensor], torch.Tensor] = None, batch_size: int = 10, bias=True): self.hidden_dim = hidden_dim self.bias = bias # Psi must be dimension D - 1 x D if basis is None: tree = random_linkage(input_dim) basis = sparse_balance_basis(tree)[0].copy() indices = np.vstack((basis.row, basis.col)) Psi = torch.sparse_coo_tensor(indices.copy(), basis.data.astype(np.float32).copy(), requires_grad=False) # Psi.requires_grad = False self.input_dim = Psi.shape[0] if imputer is None: self.imputer = lambda x: x + 1 if encoder_depth > 1: self.first_encoder = nn.Linear(self.input_dim, hidden_dim, bias=self.bias) num_encoder_layers = encoder_depth layers = [] layers.append(self.first_encoder) for layer_i in range(num_encoder_layers - 1): layers.append(nn.Softplus()) layers.append(nn.Linear(hidden_dim, hidden_dim, bias=self.bias)) self.encoder = nn.Sequential(*layers) # initialize for encoder_layer in self.encoder: if isinstance(encoder_layer, nn.Linear): encoder_layer.weight.data.normal_(0.0, init_scale) else: self.encoder = nn.Linear(self.input_dim, hidden_dim, bias=self.bias) self.encoder.weight.data.normal_(0.0, init_scale) self.decoder = nn.Linear(hidden_dim, self.input_dim, bias=False) self.variational_logvars = nn.Parameter(torch.zeros(hidden_dim)) self.log_sigma_sq = nn.Parameter(torch.tensor(0.01)) self.eta = nn.Parameter(torch.zeros(batch_size, self.input_dim)) self.eta.data.normal_(0.0, init_scale) self.decoder.weight.data.normal_(0.0, init_scale) zI = torch.ones(self.hidden_dim).to(self.eta.device) zm = torch.zeros(self.hidden_dim).to(self.eta.device) self.register_buffer('Psi', Psi) self.register_buffer('zI', zI) self.register_buffer('zm', zm)
def extract_observation_embeddings(vae_model, tree, return_type='dataframe'): """ Extracts observation embeddings from model (i.e. OTUs). The observation embeddings are all represented in CLR coordinates. Parameters ---------- vae_model : MultVAE Pretrained Multinomial VAE tree : skbio.TreeNode The tree used to train the VAE return_type : str Options include 'tensor', 'array', 'dataframe' (default='dataframe') """ # ILR representation of the VAE decoder loadings W = vae_model.vae.decoder.weight Psi, _ = sparse_balance_basis(tree) if return_type == 'torch': indices = np.vstack((Psi.row, Psi.col)) Psi = torch.sparse_coo_tensor(indices.copy(), Psi.data.astype(np.float32).copy(), requires_grad=False).coalesce() return Psi.T @ W if return_type == 'array': return Psi.T @ W.detach().numpy() if return_type == 'dataframe': names = [n.name for n in tree.tips()] return pd.DataFrame(Psi.T @ W.detach().numpy(), index=names) else: ValueError(f'return type {return_type} is not supported.')
def get_basis(input_dim, basis=None): if basis is None: tree = random_linkage(input_dim) basis = sparse_balance_basis(tree)[0].copy() indices = np.vstack((basis.row, basis.col)) Psi = torch.sparse_coo_tensor( indices.copy(), basis.data.astype(np.float32).copy(), requires_grad=False).coalesce() return Psi
def test_sparse_balance_basis_base_case(self): tree = u"(a,b);" t = TreeNode.read([tree]) exp_basis = coo_matrix(np.array([[-np.sqrt(1. / 2), np.sqrt(1. / 2)]])) exp_keys = [t.name] res_basis, res_keys = sparse_balance_basis(t) assert_coo_allclose(exp_basis, res_basis) self.assertListEqual(exp_keys, res_keys)
def test_sparse_balance_basis_unbalanced(self): tree = u"((a,b)c, d);" t = TreeNode.read([tree]) exp_basis = coo_matrix( np.array([[-np.sqrt(1. / 6), -np.sqrt(1. / 6), np.sqrt(2. / 3)], [-np.sqrt(1. / 2), np.sqrt(1. / 2), 0]])) exp_keys = [t.name, t[0].name] res_basis, res_keys = sparse_balance_basis(t) assert_coo_allclose(exp_basis, res_basis) self.assertListEqual(exp_keys, res_keys)
def test_sparse_balance_basis_unbalanced(self): tree = u"((a,b)c, d);" t = TreeNode.read([tree]) exp_basis = coo_matrix(np.array( [[-np.sqrt(1. / 6), -np.sqrt(1. / 6), np.sqrt(2. / 3)], [-np.sqrt(1. / 2), np.sqrt(1. / 2), 0]] )) exp_keys = [t.name, t[0].name] res_basis, res_keys = sparse_balance_basis(t) assert_coo_allclose(exp_basis, res_basis) self.assertListEqual(exp_keys, res_keys)
def test_sparse_balance_basis_base_case(self): tree = u"(a,b);" t = TreeNode.read([tree]) exp_basis = coo_matrix( np.array([[-np.sqrt(1. / 2), np.sqrt(1. / 2)]])) exp_keys = [t.name] res_basis, res_keys = sparse_balance_basis(t) assert_coo_allclose(exp_basis, res_basis) self.assertListEqual(exp_keys, res_keys)
def __init__(self, input_dim, hidden_dim, init_scale=0.001, use_analytic_elbo=True, encoder_depth=1, likelihood='gaussian', basis=None, bias=False): super(LinearVAE, self).__init__() self.bias = bias self.hidden_dim = hidden_dim self.likelihood = likelihood self.use_analytic_elbo = use_analytic_elbo if basis is None: tree = random_linkage(input_dim) basis = sparse_balance_basis(tree)[0].copy() indices = np.vstack((basis.row, basis.col)) Psi = torch.sparse_coo_tensor(indices.copy(), basis.data.astype(np.float32).copy(), requires_grad=False) self.input_dim = Psi.shape[0] self.register_buffer('Psi', Psi) if encoder_depth > 1: self.first_encoder = nn.Linear(self.input_dim, hidden_dim, bias=self.bias) num_encoder_layers = encoder_depth layers = [] layers.append(self.first_encoder) for layer_i in range(num_encoder_layers - 1): layers.append(nn.Softplus()) layers.append(nn.Linear(hidden_dim, hidden_dim, bias=self.bias)) self.encoder = nn.Sequential(*layers) # initialize for encoder_layer in self.encoder: if isinstance(encoder_layer, nn.Linear): encoder_layer.weight.data.normal_(0.0, init_scale) else: self.encoder = nn.Linear(self.input_dim, hidden_dim, bias=self.bias) self.encoder.weight.data.normal_(0.0, init_scale) self.decoder = nn.Linear(hidden_dim, self.input_dim, bias=self.bias) self.imputer = lambda x: x + 1 self.variational_logvars = nn.Parameter(torch.zeros(hidden_dim)) self.log_sigma_sq = nn.Parameter(torch.tensor(0.0))
def test_sparse_balance_basis_unbalanced2(self): tree = u"(d, (a,b)c);" t = TreeNode.read([tree]) exp_basis = coo_matrix( np.array([[-np.sqrt(2. / 3), np.sqrt(1. / 6), np.sqrt(1. / 6)], [0, -np.sqrt(1. / 2), np.sqrt(1. / 2)]])) exp_keys = [t.name, t[1].name] res_basis, res_keys = sparse_balance_basis(t) assert_coo_allclose(exp_basis, res_basis, atol=1e-7, rtol=1e-7) self.assertListEqual(exp_keys, res_keys)
def test_sparse_balance_basis_unbalanced2(self): tree = u"(d, (a,b)c);" t = TreeNode.read([tree]) exp_basis = coo_matrix(np.array( [ [-np.sqrt(2. / 3), np.sqrt(1. / 6), np.sqrt(1. / 6)], [0, -np.sqrt(1. / 2), np.sqrt(1. / 2)] ] )) exp_keys = [t.name, t[1].name] res_basis, res_keys = sparse_balance_basis(t) assert_coo_allclose(exp_basis, res_basis, atol=1e-7, rtol=1e-7) self.assertListEqual(exp_keys, res_keys)
def test_sparse_balance_basis_invalid(self): with self.assertRaises(ValueError): tree = u"(a,b,c);" t = TreeNode.read([tree]) sparse_balance_basis(t)
def ilr_basis(nwk): tree = TreeNode.read(nwk) t = tree.copy() t.bifurcate() basis = sparse_balance_basis(tree)[0] return basis