def __init__(self, grammar, device, hparams=stgs.VAE_HPARAMS): """ Load trained encoder/decoder and grammar model :param grammar: A nas_grammar.Grammar object :param hparams: dict, hyperparameters for the VAE and the grammar model """ self._grammar = grammar self.device = device self.hp = hparams self.max_len = self.hp['max_len'] self._productions = self._grammar.GCFG.productions() self._prod_map = make_prod_map(grammar.GCFG) self._parser = nltk.ChartParser(grammar.GCFG) self._tokenize = make_tokenizer(grammar.GCFG) self._n_chars = len(self._productions) self._lhs_map = grammar.lhs_map self.vae = NA_VAE(self.hp) self.vae.eval()
def __init__(self, cfg, min_sample_depth, max_sample_depth, batch_size=256, seed=0): """ :param cfg: An nltk.CFG object :param min_sample_depth: :param max_sample_depth: :param batch_size: """ super().__init__() random.seed(seed) self.cfg = cfg self.bsz = batch_size self.tokenizer = make_tokenizer(cfg) self.prod_map = make_prod_map(cfg) self.weighted_sampling = stgs.VAE_HPARAMS['weighted_sampling'] self.temp = stgs.VAE_HPARAMS['temperature'] self.min_sample_depth, self.max_sample_depth = min_sample_depth, max_sample_depth # min and max lengths of the # sequences to sample if self.weighted_sampling: len_range = max_sample_depth - min_sample_depth self.probs = np.array([ np.exp(self.temp * n / len_range) for n in range(min_sample_depth, max_sample_depth + 1) ]) self.probs /= self.probs.sum() self.max_len = stgs.VAE_HPARAMS[ 'max_len'] # max possible length of a sequence self.lay_symb = stgs.VAE_HPARAMS[ 'layer_symbol'] # symbol representing a new layer self.n_chars = len(cfg.productions()) print(f'Grammar with {self.n_chars} productions.') self.sents = []