trainer = pl.Trainer(gpus=-1, val_check_interval=9999, early_stop_callback=None, distributed_backend=None, logger=logger, max_steps=max_steps, max_epochs=max_steps, checkpoint_callback=checkpoint, weights_save_path=checkpoint_path) trainer.fit(vae) torch.save(vae.state_dict(), f'{checkpoint_path}/weights_256.pt') vae.get_data_generator(min_depth, max_depth, seed=int(time())) # # checkpoint = torch.load(cwd+'/base_ckpt/full_grammar.hdf5', map_location = lambda storage, loc: storage) # # vae.load_state_dict(checkpoint['state_dict']) vae.load_state_dict(torch.load(f'{checkpoint_path}/weights_256.pt')) vae = vae.cuda() vae.eval() grammar_mdl = NASGrammarModel(grammar, device='cuda') grammar_mdl.vae = vae orig_sents = [] torch.cuda.empty_cache() gen = SentenceGenerator(grammar.GCFG, min_depth, max_depth, batch_size=200) dataloader = DataLoader(gen, batch_size=1) batch = next(iter(dataloader)) orig_one_hots, lens = batch print(orig_one_hots.size()) orig_one_hots, lens = orig_one_hots.to('cuda'), lens.to('cuda') orig_sents.extend(gen.sents)
class IntegratedPredictor(pl.LightningModule): def __init__( self, grammar_mdl: NASGrammarModel, pretrained_vae: bool, freeze_vae: bool, vae_hparams: dict = stgs.VAE_HPARAMS, pred_hparams: dict = stgs.PRED_HPARAMS, ): super().__init__() # make separate Namespaces for convenience self.vae_hparams, self.pred_hparams = vae_hparams, pred_hparams # make Namespace of combined hyperparameters for compatibility with PL: self.hparams = {} for k, v in vae_hparams.items(): self.hparams['_'.join(['vae', k])] = v for k, v in pred_hparams.items(): self.hparams['_'.join(['pred', k])] = v self.hparams = Namespace(**self.hparams) self.vae = NA_VAE(self.vae_hparams) if pretrained_vae: self.vae.load_state_dict(torch.load(self.hparams.vae_weights_path)) if freeze_vae: self.vae.freeze() print('VAE encoder frozen.') self.predictor = PerfPredictor(grammar_mdl, self.pred_hparams) def forward(self, batch) -> torch.Tensor: one_hot, n_layers = batch mu_1, logvar_1, mu_2, logvar_2, z2 = self.vae.encode( one_hot.squeeze(1)) z1 = self.vae.reparameterize(mu_1, logvar_1) debed_z1 = self.vae.debed_1(z1) debed_z2 = self.vae.debed_2(z2) z = torch.cat([debed_z1, debed_z2], dim=1) return self.predictor(z) def mixup_inputs(self, batch, alpha=1.0): """ Returns mixed pairs of inputs and targets within a batch, and a lambda value sampled from a beta distribution. """ one_hot, n_layers, y_true = batch if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 1. idx = torch.randperm(y_true.size(0)) mixed_onehot = lam * one_hot + (1 - lam) * one_hot[idx, ...] mixed_n_layers = lam * n_layers + (1 - lam) * n_layers[idx] return mixed_onehot, mixed_n_layers, y_true, y_true[idx], lam def mixup_criterion(self, criterion, y_pred, y_true_a, y_true_b, lam): return lam * criterion(y_pred, y_true_a) + (1 - lam) * criterion( y_pred, y_true_b) def loss_function(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: return self.predictor.loss_function(y_pred, y_true) def configure_optimizers(self): return self.predictor.configure_optimizers() def training_step(self, batch, batch_idx): one_hot, n_layers, y_true = batch if one_hot.dim() < 3: one_hot.unsqueeze_(0) if n_layers.dim() < 2: n_layers.unsqueeze_(0) if y_true.dim() < 2: y_true.unsqueeze_(0) if self.hparams.pred_mixup: one_hot, n_layers, y_true_a, y_true_b, lam = self.mixup_inputs( batch, self.hparams.pred_mixup_alpha) y_pred = self.forward((one_hot, n_layers)) if self.hparams.pred_mixup: loss_val = self.mixup_criterion(self.loss_function, y_pred, y_true_a, y_true_b, lam) else: loss_val = self.loss_function(y_pred, y_true) lr = torch.tensor( self.predictor.optim.param_groups[0]["lr"]).type_as(loss_val) step = torch.tensor(self.global_step).type_as(lr) if self.trainer.use_dp or self.trainer.use_ddp2: loss_val.unsqueeze_(0) lr.unsqueeze_(0) step.unsqueeze_(0) logs = {"loss": loss_val.sqrt(), "lr": lr, "step": step} p_bar = {"global_step": step} return { "loss": loss_val.sqrt(), "lr": lr, "log": logs, "global_step": step, "progress_bar": p_bar, } def test_step(self, batch, batch_idx): one_hot, n_layers, y_true = batch y_pred = self.forward((one_hot, n_layers)) loss_val = self.loss_function(y_pred, y_true) return {'test_loss': loss_val.sqrt()} def test_epoch_end(self, outputs): test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() return {'test_loss': test_loss_mean} def prepare_data(self): try: tr_set = SentenceRetrieverWithVaeTraining( stgs.PRED_BATCH_PATH / "train.csv", grammar_mdl=self.predictor.grammar_mdl) self.tst_set = SentenceRetrieverWithVaeTraining( stgs.PRED_BATCH_PATH / "test.csv", grammar_mdl=self.predictor.grammar_mdl) except: tr_set = SentenceRetrieverWithVaeTraining( stgs.PRED_BATCH_PATH / "fitnessbatch.csv", grammar_mdl=self.predictor.grammar_mdl, ) if self.hparams.pred_val_set_pct > 0.0: val_len = int(self.hp.val_set_pct * len(tr_set)) self.tr_set, self.val_set = random_split( tr_set, [len(tr_set) - val_len, val_len]) else: self.tr_set, self.val_set = tr_set, None def train_dataloader(self) -> DataLoader: return DataLoader( self.tr_set, batch_size=self.hparams.pred_batch_sz, shuffle=True, drop_last=False, num_workers=self.hparams.pred_num_workers, ) def test_dataloader(self): return DataLoader(self.tst_set, batch_size=1, shuffle=False, drop_last=False, num_workers=self.hparams.pred_num_workers)