Пример #1
0
class NASGrammarModel():
    def __init__(self, grammar, device, hparams=stgs.VAE_HPARAMS):
        """
        Load trained encoder/decoder and grammar model
        :param grammar: A nas_grammar.Grammar object
        :param hparams: dict, hyperparameters for the VAE and the grammar model
        """
        self._grammar = grammar
        self.device = device
        self.hp = hparams
        self.max_len = self.hp['max_len']
        self._productions = self._grammar.GCFG.productions()
        self._prod_map = make_prod_map(grammar.GCFG)
        self._parser = nltk.ChartParser(grammar.GCFG)
        self._tokenize = make_tokenizer(grammar.GCFG)
        self._n_chars = len(self._productions)
        self._lhs_map = grammar.lhs_map
        self.vae = NA_VAE(self.hp)
        self.vae.eval()

    @staticmethod
    def _pop_or_nothing(stack):
        # Tries to pop item at top of stack S, unless there is nothing.
        try:
            return stack.pop()
        except:
            return 'Nothing'

    @staticmethod
    def _prods_to_sent(prods):
        # converts a list of productions into a sentence
        seq = [prods[0].lhs()]
        for prod in prods:
            if str(prod.lhs()) == 'Nothing':
                break
            for i, s in enumerate(seq):
                if s == prod.lhs():
                    seq = seq[:i] + list(prod.rhs()) + seq[i + 1:]
                    break
        try:
            return ''.join(seq)
        except:
            return ''

    def encode(self, sents):
        """
        Returns the mean of the distribution, which is the predicted latent vector, for a one-hot vector of production
        rules.
        """
        one_hot = make_one_hot(self._grammar.GCFG, self._tokenize,
                               self._prod_map,
                               sents, self.max_len, self._n_chars).transpose(
                                   2, 1)  # (1, batch, max_len, n_chars)
        one_hot = one_hot.to(self.device)
        self.vae.eval()
        with torch.no_grad():
            mu_1, logvar_1, mu_2, logvar_2, z2 = self.vae.encode(one_hot)
            z1 = self.vae.reparameterize(mu_1, logvar_1)
            debed_z1 = self.vae.debed_1(z1)
            debed_z2 = self.vae.debed_2(z2)
            z = torch.cat([debed_z1, debed_z2], dim=1)

        return z, one_hot  # (batch, latent_sz)

    def _sample_using_masks(self, unmasked, logs=True):
        """
        Samples a one-hot vector from unmasked selection, masking at each timestep.
        /!\ This is probably where we will diverge from the Grammar-VAE paper because we need to introduce
        conditions on the nodes that can be selected as input for each layer, and the Agg types also
        depend on the node values selected (Agg == '-' iff ND == '-').
        :param unmasked: The output of the VAE's decoder, so a collection of logit vectors (i.e. before
        softmax); size (batch, timesteps, max_length)
        """
        x_hat = np.zeros_like(unmasked)
        # Create a stack (data structure) for each input in the batch, i.e. each sentence
        S = np.empty((unmasked.shape[0], ),
                     dtype=object)  # dimension 0 == number of sentences
        for i in range(S.shape[0]):
            S[i] = [str(self._grammar.start_index)
                    ]  # initialise each stack with the start symbol 'S'
        # Loop over time axis, sampling values and updating masks at every step
        for t in range(unmasked.shape[2]):
            next_nonterminal = [
                self._lhs_map[self._pop_or_nothing(a)] for a in S
            ]
            mask = self._grammar.masks[
                next_nonterminal]  # get indices of valid productions for next symbol
            if logs:
                masked_output = np.multiply(np.exp(unmasked[..., t]),
                                            mask) + 1e-100
            else:
                masked_output = np.multiply(unmasked[..., t], mask) + 1e-100
                # This comes from Kusner et al. 2016 - GANs for Sequences of Discrete Elements with the
                # Gumbel-Softmax Distribution, using work done in Jang et al. 2017 - Categorical
                # Reparameteterization with Gumbel-Softmax, which itself uses the Gumber-Max trick presented
                # in Maddison et al. 2014 - A* Sampling. y ~ Softmax(h) is equivalent to setting
                # y=one_hot(argmax((h_i + g_i))) where g_i are independently sampled from Gumbel(0, 1)
            sampled_output = np.argmax(np.add(
                np.random.gumbel(size=masked_output.shape),
                np.log(masked_output)),
                                       axis=-1)
            # Fill the highest-probability production rule with 1., all others are 0.
            x_hat[np.arange(unmasked.shape[0]), sampled_output, t] = 1.
            # Collect non-terminals in RHS of the selected production and push them onto stack in reverse order
            rhs = [
                filter(
                    lambda a: (isinstance(a, nltk.grammar.Nonterminal) and
                               (str(a) != 'None')), self._productions[i].rhs())
                for i in sampled_output
            ]  # single output per sentence
            for i in range(S.shape[0]):
                S[i].extend(list(map(str, rhs[i]))[::-1])
            if not S.any(): break  # stop when stack is empty
        return x_hat

    def decode(self, z=None, one_hot=None):
        """
        Sample from the grammar decoder using the CFG-based mask, and return a sequence of production rules.
        :param z: latent vector representing the sentence, of dimensions (batch, latent_sz). If None, must provide
        argument one_hot (for testing purposes, mainly).
        :param one_hot: If provided, decode the one_hot matrix directly instead of the decoded latent vector.
        """
        if z is None:  # testing purposes
            unmasked = one_hot
            logs = False
        else:  # normal regime
            logs = True
            #assert z.ndim == 2  # (batch, latent_sz)
            self.vae.eval()
            with torch.no_grad():
                unmasked = self.vae.decode(
                    z).detach().cpu().numpy()  # (batch, max_len, n_char)
        assert unmasked.shape[1:] == (self._n_chars, self.max_len), \
            print(f'umasked_shape[1:] == {unmasked.shape[1:]}, expected {(self._n_chars, self.max_len)}.')
        x_hat = self._sample_using_masks(unmasked, logs)
        # Convert from one-hot to sequence of production rules
        prod_seq = [[
            self._productions[x_hat[index, :, t].argmax()]
            for t in range(x_hat.shape[2])
        ] for index in range(x_hat.shape[0])]
        return [self._prods_to_sent(prods) for prods in prod_seq], x_hat
Пример #2
0
class IntegratedPredictor(pl.LightningModule):
    def __init__(
        self,
        grammar_mdl: NASGrammarModel,
        pretrained_vae: bool,
        freeze_vae: bool,
        vae_hparams: dict = stgs.VAE_HPARAMS,
        pred_hparams: dict = stgs.PRED_HPARAMS,
    ):
        super().__init__()
        # make separate Namespaces for convenience
        self.vae_hparams, self.pred_hparams = vae_hparams, pred_hparams

        # make Namespace of combined hyperparameters for compatibility with PL:
        self.hparams = {}
        for k, v in vae_hparams.items():
            self.hparams['_'.join(['vae', k])] = v
        for k, v in pred_hparams.items():
            self.hparams['_'.join(['pred', k])] = v
        self.hparams = Namespace(**self.hparams)
        self.vae = NA_VAE(self.vae_hparams)
        if pretrained_vae:
            self.vae.load_state_dict(torch.load(self.hparams.vae_weights_path))
        if freeze_vae:
            self.vae.freeze()
            print('VAE encoder frozen.')
        self.predictor = PerfPredictor(grammar_mdl, self.pred_hparams)

    def forward(self, batch) -> torch.Tensor:
        one_hot, n_layers = batch
        mu_1, logvar_1, mu_2, logvar_2, z2 = self.vae.encode(
            one_hot.squeeze(1))
        z1 = self.vae.reparameterize(mu_1, logvar_1)
        debed_z1 = self.vae.debed_1(z1)
        debed_z2 = self.vae.debed_2(z2)
        z = torch.cat([debed_z1, debed_z2], dim=1)
        return self.predictor(z)

    def mixup_inputs(self, batch, alpha=1.0):
        """
        Returns mixed pairs of inputs and targets within a batch, and a lambda value sampled from a beta
        distribution.
        """
        one_hot, n_layers, y_true = batch
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1.

        idx = torch.randperm(y_true.size(0))
        mixed_onehot = lam * one_hot + (1 - lam) * one_hot[idx, ...]
        mixed_n_layers = lam * n_layers + (1 - lam) * n_layers[idx]
        return mixed_onehot, mixed_n_layers, y_true, y_true[idx], lam

    def mixup_criterion(self, criterion, y_pred, y_true_a, y_true_b, lam):
        return lam * criterion(y_pred, y_true_a) + (1 - lam) * criterion(
            y_pred, y_true_b)

    def loss_function(self, y_pred: torch.Tensor,
                      y_true: torch.Tensor) -> torch.Tensor:
        return self.predictor.loss_function(y_pred, y_true)

    def configure_optimizers(self):
        return self.predictor.configure_optimizers()

    def training_step(self, batch, batch_idx):
        one_hot, n_layers, y_true = batch
        if one_hot.dim() < 3: one_hot.unsqueeze_(0)
        if n_layers.dim() < 2: n_layers.unsqueeze_(0)
        if y_true.dim() < 2: y_true.unsqueeze_(0)

        if self.hparams.pred_mixup:
            one_hot, n_layers, y_true_a, y_true_b, lam = self.mixup_inputs(
                batch, self.hparams.pred_mixup_alpha)

        y_pred = self.forward((one_hot, n_layers))

        if self.hparams.pred_mixup:
            loss_val = self.mixup_criterion(self.loss_function, y_pred,
                                            y_true_a, y_true_b, lam)
        else:
            loss_val = self.loss_function(y_pred, y_true)

        lr = torch.tensor(
            self.predictor.optim.param_groups[0]["lr"]).type_as(loss_val)
        step = torch.tensor(self.global_step).type_as(lr)
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val.unsqueeze_(0)
            lr.unsqueeze_(0)
            step.unsqueeze_(0)
        logs = {"loss": loss_val.sqrt(), "lr": lr, "step": step}
        p_bar = {"global_step": step}
        return {
            "loss": loss_val.sqrt(),
            "lr": lr,
            "log": logs,
            "global_step": step,
            "progress_bar": p_bar,
        }

    def test_step(self, batch, batch_idx):
        one_hot, n_layers, y_true = batch
        y_pred = self.forward((one_hot, n_layers))
        loss_val = self.loss_function(y_pred, y_true)
        return {'test_loss': loss_val.sqrt()}

    def test_epoch_end(self, outputs):
        test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
        return {'test_loss': test_loss_mean}

    def prepare_data(self):
        try:
            tr_set = SentenceRetrieverWithVaeTraining(
                stgs.PRED_BATCH_PATH / "train.csv",
                grammar_mdl=self.predictor.grammar_mdl)
            self.tst_set = SentenceRetrieverWithVaeTraining(
                stgs.PRED_BATCH_PATH / "test.csv",
                grammar_mdl=self.predictor.grammar_mdl)
        except:
            tr_set = SentenceRetrieverWithVaeTraining(
                stgs.PRED_BATCH_PATH / "fitnessbatch.csv",
                grammar_mdl=self.predictor.grammar_mdl,
            )
        if self.hparams.pred_val_set_pct > 0.0:
            val_len = int(self.hp.val_set_pct * len(tr_set))
            self.tr_set, self.val_set = random_split(
                tr_set, [len(tr_set) - val_len, val_len])
        else:
            self.tr_set, self.val_set = tr_set, None

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.tr_set,
            batch_size=self.hparams.pred_batch_sz,
            shuffle=True,
            drop_last=False,
            num_workers=self.hparams.pred_num_workers,
        )

    def test_dataloader(self):
        return DataLoader(self.tst_set,
                          batch_size=1,
                          shuffle=False,
                          drop_last=False,
                          num_workers=self.hparams.pred_num_workers)