class NASGrammarModel(): def __init__(self, grammar, device, hparams=stgs.VAE_HPARAMS): """ Load trained encoder/decoder and grammar model :param grammar: A nas_grammar.Grammar object :param hparams: dict, hyperparameters for the VAE and the grammar model """ self._grammar = grammar self.device = device self.hp = hparams self.max_len = self.hp['max_len'] self._productions = self._grammar.GCFG.productions() self._prod_map = make_prod_map(grammar.GCFG) self._parser = nltk.ChartParser(grammar.GCFG) self._tokenize = make_tokenizer(grammar.GCFG) self._n_chars = len(self._productions) self._lhs_map = grammar.lhs_map self.vae = NA_VAE(self.hp) self.vae.eval() @staticmethod def _pop_or_nothing(stack): # Tries to pop item at top of stack S, unless there is nothing. try: return stack.pop() except: return 'Nothing' @staticmethod def _prods_to_sent(prods): # converts a list of productions into a sentence seq = [prods[0].lhs()] for prod in prods: if str(prod.lhs()) == 'Nothing': break for i, s in enumerate(seq): if s == prod.lhs(): seq = seq[:i] + list(prod.rhs()) + seq[i + 1:] break try: return ''.join(seq) except: return '' def encode(self, sents): """ Returns the mean of the distribution, which is the predicted latent vector, for a one-hot vector of production rules. """ one_hot = make_one_hot(self._grammar.GCFG, self._tokenize, self._prod_map, sents, self.max_len, self._n_chars).transpose( 2, 1) # (1, batch, max_len, n_chars) one_hot = one_hot.to(self.device) self.vae.eval() with torch.no_grad(): mu_1, logvar_1, mu_2, logvar_2, z2 = self.vae.encode(one_hot) z1 = self.vae.reparameterize(mu_1, logvar_1) debed_z1 = self.vae.debed_1(z1) debed_z2 = self.vae.debed_2(z2) z = torch.cat([debed_z1, debed_z2], dim=1) return z, one_hot # (batch, latent_sz) def _sample_using_masks(self, unmasked, logs=True): """ Samples a one-hot vector from unmasked selection, masking at each timestep. /!\ This is probably where we will diverge from the Grammar-VAE paper because we need to introduce conditions on the nodes that can be selected as input for each layer, and the Agg types also depend on the node values selected (Agg == '-' iff ND == '-'). :param unmasked: The output of the VAE's decoder, so a collection of logit vectors (i.e. before softmax); size (batch, timesteps, max_length) """ x_hat = np.zeros_like(unmasked) # Create a stack (data structure) for each input in the batch, i.e. each sentence S = np.empty((unmasked.shape[0], ), dtype=object) # dimension 0 == number of sentences for i in range(S.shape[0]): S[i] = [str(self._grammar.start_index) ] # initialise each stack with the start symbol 'S' # Loop over time axis, sampling values and updating masks at every step for t in range(unmasked.shape[2]): next_nonterminal = [ self._lhs_map[self._pop_or_nothing(a)] for a in S ] mask = self._grammar.masks[ next_nonterminal] # get indices of valid productions for next symbol if logs: masked_output = np.multiply(np.exp(unmasked[..., t]), mask) + 1e-100 else: masked_output = np.multiply(unmasked[..., t], mask) + 1e-100 # This comes from Kusner et al. 2016 - GANs for Sequences of Discrete Elements with the # Gumbel-Softmax Distribution, using work done in Jang et al. 2017 - Categorical # Reparameteterization with Gumbel-Softmax, which itself uses the Gumber-Max trick presented # in Maddison et al. 2014 - A* Sampling. y ~ Softmax(h) is equivalent to setting # y=one_hot(argmax((h_i + g_i))) where g_i are independently sampled from Gumbel(0, 1) sampled_output = np.argmax(np.add( np.random.gumbel(size=masked_output.shape), np.log(masked_output)), axis=-1) # Fill the highest-probability production rule with 1., all others are 0. x_hat[np.arange(unmasked.shape[0]), sampled_output, t] = 1. # Collect non-terminals in RHS of the selected production and push them onto stack in reverse order rhs = [ filter( lambda a: (isinstance(a, nltk.grammar.Nonterminal) and (str(a) != 'None')), self._productions[i].rhs()) for i in sampled_output ] # single output per sentence for i in range(S.shape[0]): S[i].extend(list(map(str, rhs[i]))[::-1]) if not S.any(): break # stop when stack is empty return x_hat def decode(self, z=None, one_hot=None): """ Sample from the grammar decoder using the CFG-based mask, and return a sequence of production rules. :param z: latent vector representing the sentence, of dimensions (batch, latent_sz). If None, must provide argument one_hot (for testing purposes, mainly). :param one_hot: If provided, decode the one_hot matrix directly instead of the decoded latent vector. """ if z is None: # testing purposes unmasked = one_hot logs = False else: # normal regime logs = True #assert z.ndim == 2 # (batch, latent_sz) self.vae.eval() with torch.no_grad(): unmasked = self.vae.decode( z).detach().cpu().numpy() # (batch, max_len, n_char) assert unmasked.shape[1:] == (self._n_chars, self.max_len), \ print(f'umasked_shape[1:] == {unmasked.shape[1:]}, expected {(self._n_chars, self.max_len)}.') x_hat = self._sample_using_masks(unmasked, logs) # Convert from one-hot to sequence of production rules prod_seq = [[ self._productions[x_hat[index, :, t].argmax()] for t in range(x_hat.shape[2]) ] for index in range(x_hat.shape[0])] return [self._prods_to_sent(prods) for prods in prod_seq], x_hat
class IntegratedPredictor(pl.LightningModule): def __init__( self, grammar_mdl: NASGrammarModel, pretrained_vae: bool, freeze_vae: bool, vae_hparams: dict = stgs.VAE_HPARAMS, pred_hparams: dict = stgs.PRED_HPARAMS, ): super().__init__() # make separate Namespaces for convenience self.vae_hparams, self.pred_hparams = vae_hparams, pred_hparams # make Namespace of combined hyperparameters for compatibility with PL: self.hparams = {} for k, v in vae_hparams.items(): self.hparams['_'.join(['vae', k])] = v for k, v in pred_hparams.items(): self.hparams['_'.join(['pred', k])] = v self.hparams = Namespace(**self.hparams) self.vae = NA_VAE(self.vae_hparams) if pretrained_vae: self.vae.load_state_dict(torch.load(self.hparams.vae_weights_path)) if freeze_vae: self.vae.freeze() print('VAE encoder frozen.') self.predictor = PerfPredictor(grammar_mdl, self.pred_hparams) def forward(self, batch) -> torch.Tensor: one_hot, n_layers = batch mu_1, logvar_1, mu_2, logvar_2, z2 = self.vae.encode( one_hot.squeeze(1)) z1 = self.vae.reparameterize(mu_1, logvar_1) debed_z1 = self.vae.debed_1(z1) debed_z2 = self.vae.debed_2(z2) z = torch.cat([debed_z1, debed_z2], dim=1) return self.predictor(z) def mixup_inputs(self, batch, alpha=1.0): """ Returns mixed pairs of inputs and targets within a batch, and a lambda value sampled from a beta distribution. """ one_hot, n_layers, y_true = batch if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 1. idx = torch.randperm(y_true.size(0)) mixed_onehot = lam * one_hot + (1 - lam) * one_hot[idx, ...] mixed_n_layers = lam * n_layers + (1 - lam) * n_layers[idx] return mixed_onehot, mixed_n_layers, y_true, y_true[idx], lam def mixup_criterion(self, criterion, y_pred, y_true_a, y_true_b, lam): return lam * criterion(y_pred, y_true_a) + (1 - lam) * criterion( y_pred, y_true_b) def loss_function(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: return self.predictor.loss_function(y_pred, y_true) def configure_optimizers(self): return self.predictor.configure_optimizers() def training_step(self, batch, batch_idx): one_hot, n_layers, y_true = batch if one_hot.dim() < 3: one_hot.unsqueeze_(0) if n_layers.dim() < 2: n_layers.unsqueeze_(0) if y_true.dim() < 2: y_true.unsqueeze_(0) if self.hparams.pred_mixup: one_hot, n_layers, y_true_a, y_true_b, lam = self.mixup_inputs( batch, self.hparams.pred_mixup_alpha) y_pred = self.forward((one_hot, n_layers)) if self.hparams.pred_mixup: loss_val = self.mixup_criterion(self.loss_function, y_pred, y_true_a, y_true_b, lam) else: loss_val = self.loss_function(y_pred, y_true) lr = torch.tensor( self.predictor.optim.param_groups[0]["lr"]).type_as(loss_val) step = torch.tensor(self.global_step).type_as(lr) if self.trainer.use_dp or self.trainer.use_ddp2: loss_val.unsqueeze_(0) lr.unsqueeze_(0) step.unsqueeze_(0) logs = {"loss": loss_val.sqrt(), "lr": lr, "step": step} p_bar = {"global_step": step} return { "loss": loss_val.sqrt(), "lr": lr, "log": logs, "global_step": step, "progress_bar": p_bar, } def test_step(self, batch, batch_idx): one_hot, n_layers, y_true = batch y_pred = self.forward((one_hot, n_layers)) loss_val = self.loss_function(y_pred, y_true) return {'test_loss': loss_val.sqrt()} def test_epoch_end(self, outputs): test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() return {'test_loss': test_loss_mean} def prepare_data(self): try: tr_set = SentenceRetrieverWithVaeTraining( stgs.PRED_BATCH_PATH / "train.csv", grammar_mdl=self.predictor.grammar_mdl) self.tst_set = SentenceRetrieverWithVaeTraining( stgs.PRED_BATCH_PATH / "test.csv", grammar_mdl=self.predictor.grammar_mdl) except: tr_set = SentenceRetrieverWithVaeTraining( stgs.PRED_BATCH_PATH / "fitnessbatch.csv", grammar_mdl=self.predictor.grammar_mdl, ) if self.hparams.pred_val_set_pct > 0.0: val_len = int(self.hp.val_set_pct * len(tr_set)) self.tr_set, self.val_set = random_split( tr_set, [len(tr_set) - val_len, val_len]) else: self.tr_set, self.val_set = tr_set, None def train_dataloader(self) -> DataLoader: return DataLoader( self.tr_set, batch_size=self.hparams.pred_batch_sz, shuffle=True, drop_last=False, num_workers=self.hparams.pred_num_workers, ) def test_dataloader(self): return DataLoader(self.tst_set, batch_size=1, shuffle=False, drop_last=False, num_workers=self.hparams.pred_num_workers)