def __init__( self, grammar_mdl: NASGrammarModel, pretrained_vae: bool, freeze_vae: bool, vae_hparams: dict = stgs.VAE_HPARAMS, pred_hparams: dict = stgs.PRED_HPARAMS, ): super().__init__() # make separate Namespaces for convenience self.vae_hparams, self.pred_hparams = vae_hparams, pred_hparams # make Namespace of combined hyperparameters for compatibility with PL: self.hparams = {} for k, v in vae_hparams.items(): self.hparams['_'.join(['vae', k])] = v for k, v in pred_hparams.items(): self.hparams['_'.join(['pred', k])] = v self.hparams = Namespace(**self.hparams) self.vae = NA_VAE(self.vae_hparams) if pretrained_vae: self.vae.load_state_dict(torch.load(self.hparams.vae_weights_path)) if freeze_vae: self.vae.freeze() print('VAE encoder frozen.') self.predictor = PerfPredictor(grammar_mdl, self.pred_hparams)
def __init__(self, grammar, device, hparams=stgs.VAE_HPARAMS): """ Load trained encoder/decoder and grammar model :param grammar: A nas_grammar.Grammar object :param hparams: dict, hyperparameters for the VAE and the grammar model """ self._grammar = grammar self.device = device self.hp = hparams self.max_len = self.hp['max_len'] self._productions = self._grammar.GCFG.productions() self._prod_map = make_prod_map(grammar.GCFG) self._parser = nltk.ChartParser(grammar.GCFG) self._tokenize = make_tokenizer(grammar.GCFG) self._n_chars = len(self._productions) self._lhs_map = grammar.lhs_map self.vae = NA_VAE(self.hp) self.vae.eval()
t_var = globals()[t] if isinstance(t_var, (torch.Tensor, np.ndarray)): print(t, ':', t_var.shape) elif isinstance(t_var, (list, dict, tuple, str)): print(t, ':', len(t_var)) if value: print(t_var) else: pass min_depth = 3 max_depth = stgs.VAE_HPARAMS['max_depth'] # maximum network depth print(f'Using maximum sequence length of {stgs.VAE_HPARAMS["max_len"]}.') torch.cuda.empty_cache() vae = NA_VAE(stgs.VAE_HPARAMS) vae = vae.float() vae = vae.cuda() # vae.load_state_dict(torch.load(f'{checkpoint_path}/weights.pt')) torch.cuda.empty_cache() version = datetime.strftime(datetime.fromtimestamp(seed), '%Y-%m-%d..%H.%M.%S') logger = TensorBoardLogger(checkpoint_path, version=version) checkpoint = ModelCheckpoint(filepath=checkpoint_path, save_top_k=1, verbose=True, monitor='loss', mode='min') early_stop = EarlyStopping( monitor='loss',