def __init__(self, model): # vocabulary self.code_vocab = utils.load_vocab_pk(config.code_vocab_path) self.code_vocab_size = len(self.code_vocab) self.ast_vocab = utils.load_vocab_pk(config.ast_vocab_path) self.ast_vocab_size = len(self.ast_vocab) self.nl_vocab = utils.load_vocab_pk(config.nl_vocab_path) self.nl_vocab_size = len(self.nl_vocab) # dataset self.dataset = data.CodePtrDataset(code_path=config.test_code_path, ast_path=config.test_sbt_path, nl_path=config.test_nl_path) self.dataset_size = len(self.dataset) self.dataloader = DataLoader( dataset=self.dataset, batch_size=config.test_batch_size, collate_fn=lambda *args: utils.unsort_collate_fn( args, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab, raw_nl=True)) # model if isinstance(model, str): self.model = models.Model(code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size, model_file_path=os.path.join( config.model_dir, model), is_eval=True) elif isinstance(model, dict): self.model = models.Model(code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size, model_state_dict=model, is_eval=True) else: raise Exception( 'Parameter \'model\' for class \'Test\' must be file name or state_dict of the model.' )
def __init__(self, model, vocab): assert isinstance(model, dict) or isinstance(model, str) assert isinstance(vocab, tuple) or isinstance(vocab, str) # dataset logger.info('-' * 100) logger.info('Loading training and validation dataset') self.dataset = data.CodePtrDataset(mode='test') self.dataset_size = len(self.dataset) logger.info('Size of training dataset: {}'.format(self.dataset_size)) logger.info('The dataset are successfully loaded') self.dataloader = DataLoader(dataset=self.dataset, batch_size=config.test_batch_size, collate_fn=lambda *args: utils.collate_fn(args, source_vocab=self.source_vocab, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab, raw_nl=True)) # vocab logger.info('-' * 100) if isinstance(vocab, tuple): logger.info('Vocabularies are passed from parameters') assert len(vocab) == 4 self.source_vocab, self.code_vocab, self.ast_vocab, self.nl_vocab = vocab else: logger.info('Vocabularies are read from dir: {}'.format(vocab)) self.source_vocab = utils.load_vocab(vocab, 'source') self.code_vocab = utils.load_vocab(vocab, 'code') self.ast_vocab = utils.load_vocab(vocab, 'ast') self.nl_vocab = utils.load_vocab(vocab, 'nl') # vocabulary self.source_vocab_size = len(self.source_vocab) self.code_vocab_size = len(self.code_vocab) self.ast_vocab_size = len(self.ast_vocab) self.nl_vocab_size = len(self.nl_vocab) logger.info('Size of source vocabulary: {} -> {}'.format(self.source_vocab.origin_size, self.source_vocab_size)) logger.info('Size of code vocabulary: {} -> {}'.format(self.code_vocab.origin_size, self.code_vocab_size)) logger.info('Size of ast vocabulary: {}'.format(self.ast_vocab_size)) logger.info('Size of nl vocabulary: {} -> {}'.format(self.nl_vocab.origin_size, self.nl_vocab_size)) logger.info('Vocabularies are successfully built') # model logger.info('-' * 100) logger.info('Building model') self.model = models.Model(source_vocab_size=self.source_vocab_size, code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size, is_eval=True, model=model) # model device logger.info('Model device: {}'.format(next(self.model.parameters()).device)) # log model statistic logger.info('Trainable parameters: {}'.format(utils.human_format(utils.count_params(self.model))))
def __init__(self, vocab_file_path=None, model_file_path=None): """ :param vocab_file_path: tuple of code vocab, ast vocab, nl vocab, if given, build vocab by given path :param model_file_path: """ # dataset self.train_dataset = data.CodePtrDataset( code_path=config.train_code_path, ast_path=config.train_sbt_path, nl_path=config.train_nl_path) self.train_dataset_size = len(self.train_dataset) self.train_dataloader = DataLoader( dataset=self.train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda *args: utils.unsort_collate_fn( args, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab)) # vocab self.code_vocab: utils.Vocab self.ast_vocab: utils.Vocab self.nl_vocab: utils.Vocab # load vocab from given path if vocab_file_path: code_vocab_path, ast_vocab_path, nl_vocab_path = vocab_file_path self.code_vocab = utils.load_vocab_pk(code_vocab_path) self.ast_vocab = utils.load_vocab_pk(ast_vocab_path) self.nl_vocab = utils.load_vocab_pk(nl_vocab_path) # new vocab else: self.code_vocab = utils.Vocab('code_vocab') self.ast_vocab = utils.Vocab('ast_vocab') self.nl_vocab = utils.Vocab('nl_vocab') codes, asts, nls = self.train_dataset.get_dataset() for code, ast, nl in zip(codes, asts, nls): self.code_vocab.add_sentence(code) self.ast_vocab.add_sentence(ast) self.nl_vocab.add_sentence(nl) self.origin_code_vocab_size = len(self.code_vocab) self.origin_nl_vocab_size = len(self.nl_vocab) # trim vocabulary self.code_vocab.trim(config.code_vocab_size) self.nl_vocab.trim(config.nl_vocab_size) # save vocabulary self.code_vocab.save(config.code_vocab_path) self.ast_vocab.save(config.ast_vocab_path) self.nl_vocab.save(config.nl_vocab_path) self.code_vocab.save_txt(config.code_vocab_txt_path) self.ast_vocab.save_txt(config.ast_vocab_txt_path) self.nl_vocab.save_txt(config.nl_vocab_txt_path) self.code_vocab_size = len(self.code_vocab) self.ast_vocab_size = len(self.ast_vocab) self.nl_vocab_size = len(self.nl_vocab) # model self.model = models.Model(code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size, model_file_path=model_file_path) self.params = list(self.model.code_encoder.parameters()) + \ list(self.model.ast_encoder.parameters()) + \ list(self.model.reduce_hidden.parameters()) + \ list(self.model.decoder.parameters()) # optimizer self.optimizer = Adam([ { 'params': self.model.code_encoder.parameters(), 'lr': config.code_encoder_lr }, { 'params': self.model.ast_encoder.parameters(), 'lr': config.ast_encoder_lr }, { 'params': self.model.reduce_hidden.parameters(), 'lr': config.reduce_hidden_lr }, { 'params': self.model.decoder.parameters(), 'lr': config.decoder_lr }, ], betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) if config.use_lr_decay: self.lr_scheduler = lr_scheduler.StepLR( self.optimizer, step_size=config.lr_decay_every, gamma=config.lr_decay_rate) # best score and model(state dict) self.min_loss: float = 1000 self.best_model: dict = {} self.best_epoch_batch: (int, int) = (None, None) # eval instance self.eval_instance = eval.Eval(self.get_cur_state_dict()) # early stopping self.early_stopping = None if config.use_early_stopping: self.early_stopping = utils.EarlyStopping() config.model_dir = os.path.join(config.model_dir, utils.get_timestamp()) if not os.path.exists(config.model_dir): os.makedirs(config.model_dir)
def __init__(self): # dataset logger.info('-' * 100) logger.info('Loading training and validation dataset') self.dataset = data.CodePtrDataset(mode='train') self.dataset_size = len(self.dataset) logger.info('Size of training dataset: {}'.format(self.dataset_size)) self.dataloader = DataLoader(dataset=self.dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda *args: utils.collate_fn( args, source_vocab=self.source_vocab, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab)) # valid dataset self.valid_dataset = data.CodePtrDataset(mode='valid') self.valid_dataset_size = len(self.valid_dataset) self.valid_dataloader = DataLoader( dataset=self.valid_dataset, batch_size=config.valid_batch_size, collate_fn=lambda *args: utils.collate_fn( args, source_vocab=self.source_vocab, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab)) logger.info('Size of validation dataset: {}'.format( self.valid_dataset_size)) logger.info('The dataset are successfully loaded') # vocab logger.info('-' * 100) logger.info('Building vocabularies') sources, codes, asts, nls = self.dataset.get_dataset() self.source_vocab = utils.build_word_vocab( dataset=sources, vocab_name='source', ignore_case=True, max_vocab_size=config.source_vocab_size, save_dir=config.vocab_root) self.source_vocab_size = len(self.source_vocab) logger.info('Size of source vocab: {} -> {}'.format( self.source_vocab.origin_size, self.source_vocab_size)) self.code_vocab = utils.build_word_vocab( dataset=codes, vocab_name='code', ignore_case=True, max_vocab_size=config.code_vocab_size, save_dir=config.vocab_root) self.code_vocab_size = len(self.code_vocab) logger.info('Size of code vocab: {} -> {}'.format( self.code_vocab.origin_size, self.code_vocab_size)) self.ast_vocab = utils.build_word_vocab(dataset=asts, vocab_name='ast', ignore_case=True, save_dir=config.vocab_root) self.ast_vocab_size = len(self.ast_vocab) logger.info('Size of ast vocab: {}'.format(self.ast_vocab_size)) self.nl_vocab = utils.build_word_vocab( dataset=nls, vocab_name='nl', ignore_case=True, max_vocab_size=config.nl_vocab_size, save_dir=config.vocab_root) self.nl_vocab_size = len(self.nl_vocab) logger.info('Size of nl vocab: {} -> {}'.format( self.nl_vocab.origin_size, self.nl_vocab_size)) logger.info('Vocabularies are successfully built') # model logger.info('-' * 100) logger.info('Building the model') self.model = models.Model(source_vocab_size=self.source_vocab_size, code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size) # model device logger.info('Model device: {}'.format( next(self.model.parameters()).device)) # log model statistic logger.info('Trainable parameters: {}'.format( utils.human_format(utils.count_params(self.model)))) # optimizer self.optimizer = Adam([ { 'params': self.model.parameters(), 'lr': config.learning_rate }, ]) self.criterion = nn.CrossEntropyLoss( ignore_index=self.nl_vocab.get_pad_index()) if config.use_lr_decay: self.lr_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=config.lr_decay_rate) # early stopping self.early_stopping = None if config.use_early_stopping: self.early_stopping = utils.EarlyStopping( patience=config.early_stopping_patience, high_record=False)