コード例 #1
0
 def __getitem__(self,
                 idx: int) -> Tuple[torch.Tensor, torch.Tensor, float]:
     sent, fitness = [self.data.iloc[idx, 0]], self.data.iloc[idx, 1]
     one_hot = make_one_hot(self.gr_mdl._grammar.GCFG,
                            self.gr_mdl._tokenize, self.gr_mdl._prod_map,
                            sent, self.gr_mdl.max_len,
                            self.gr_mdl._n_chars).transpose(2, 1)
     n_layers = torch.tensor(
         (sent.count('/') + 1) * 1.0 / stgs.PRED_HPARAMS['max_depth'],
         dtype=torch.float32,
         requires_grad=False)
     return one_hot, n_layers, fitness
コード例 #2
0
 def generate_one_hots(self):
     generated = [self.generate_sentence() for _ in range(self.bsz)]
     self.sents = [g[0] for g in generated]
     lengths = [g[1] for g in generated]
     # self.sents = self.generate_sentence()
     out = make_one_hot(self.cfg,
                        self.tokenizer,
                        self.prod_map,
                        self.sents,
                        max_len=self.max_len,
                        n_chars=self.n_chars)
     return out.transpose(-2, -1), torch.tensor(lengths,
                                                dtype=torch.float32)
コード例 #3
0
ファイル: NASGrammarModel.py プロジェクト: zhenglidreams/VAES
 def encode(self, sents):
     """
     Returns the mean of the distribution, which is the predicted latent vector, for a one-hot vector of production
     rules.
     """
     one_hot = make_one_hot(self._grammar.GCFG, self._tokenize, self._prod_map, sents, self.max_len,
                            self._n_chars).transpose(2, 1)  # (1, batch, max_len, n_chars)
     one_hot = one_hot.to(self.device)
     self.vae.eval()
     with torch.no_grad():
         z = self.vae.encode(one_hot)[0]
         z = z.repeat(self.num_samples, 1, 1).permute(1, 0, 2) 
     return z, one_hot  # (batch, latent_sz)
コード例 #4
0
 def encode(self, sents):
     """
     Returns the mean of the distribution, which is the predicted latent vector, for a one-hot vector of production
     rules.
     """
     one_hot = make_one_hot(self._grammar.GCFG, self._tokenize, self._prod_map, sents, self.max_len,
                            self._n_chars).transpose(2, 1)  # (1, batch, max_len, n_chars)
     one_hot = one_hot.to(self.device)
     self.vae.eval()
     with torch.no_grad():
       mu, logvar,q = self.vae.encode(one_hot)
       z = self.vae.reparameterize(mu, logvar,q)
     return z, one_hot  # (batch, latent_sz)