Esempio n. 1
0
    def __init__(self, model):

        # vocabulary
        self.code_vocab = utils.load_vocab_pk(config.code_vocab_path)
        self.code_vocab_size = len(self.code_vocab)
        self.ast_vocab = utils.load_vocab_pk(config.ast_vocab_path)
        self.ast_vocab_size = len(self.ast_vocab)
        self.nl_vocab = utils.load_vocab_pk(config.nl_vocab_path)
        self.nl_vocab_size = len(self.nl_vocab)

        # dataset
        self.dataset = data.CodePtrDataset(code_path=config.test_code_path,
                                           ast_path=config.test_sbt_path,
                                           nl_path=config.test_nl_path)
        self.dataset_size = len(self.dataset)
        self.dataloader = DataLoader(
            dataset=self.dataset,
            batch_size=config.test_batch_size,
            collate_fn=lambda *args: utils.unsort_collate_fn(
                args,
                code_vocab=self.code_vocab,
                ast_vocab=self.ast_vocab,
                nl_vocab=self.nl_vocab,
                raw_nl=True))

        # model
        if isinstance(model, str):
            self.model = models.Model(code_vocab_size=self.code_vocab_size,
                                      ast_vocab_size=self.ast_vocab_size,
                                      nl_vocab_size=self.nl_vocab_size,
                                      model_file_path=os.path.join(
                                          config.model_dir, model),
                                      is_eval=True)
        elif isinstance(model, dict):
            self.model = models.Model(code_vocab_size=self.code_vocab_size,
                                      ast_vocab_size=self.ast_vocab_size,
                                      nl_vocab_size=self.nl_vocab_size,
                                      model_state_dict=model,
                                      is_eval=True)
        else:
            raise Exception(
                'Parameter \'model\' for class \'Test\' must be file name or state_dict of the model.'
            )
Esempio n. 2
0
    def __init__(self, model, vocab):

        assert isinstance(model, dict) or isinstance(model, str)
        assert isinstance(vocab, tuple) or isinstance(vocab, str)

        # dataset
        logger.info('-' * 100)
        logger.info('Loading training and validation dataset')
        self.dataset = data.CodePtrDataset(mode='test')
        self.dataset_size = len(self.dataset)
        logger.info('Size of training dataset: {}'.format(self.dataset_size))

        logger.info('The dataset are successfully loaded')

        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_size=config.test_batch_size,
                                     collate_fn=lambda *args: utils.collate_fn(args,
                                                                               source_vocab=self.source_vocab,
                                                                               code_vocab=self.code_vocab,
                                                                               ast_vocab=self.ast_vocab,
                                                                               nl_vocab=self.nl_vocab,
                                                                               raw_nl=True))

        # vocab
        logger.info('-' * 100)
        if isinstance(vocab, tuple):
            logger.info('Vocabularies are passed from parameters')
            assert len(vocab) == 4
            self.source_vocab, self.code_vocab, self.ast_vocab, self.nl_vocab = vocab
        else:
            logger.info('Vocabularies are read from dir: {}'.format(vocab))
            self.source_vocab = utils.load_vocab(vocab, 'source')
            self.code_vocab = utils.load_vocab(vocab, 'code')
            self.ast_vocab = utils.load_vocab(vocab, 'ast')
            self.nl_vocab = utils.load_vocab(vocab, 'nl')

        # vocabulary
        self.source_vocab_size = len(self.source_vocab)
        self.code_vocab_size = len(self.code_vocab)
        self.ast_vocab_size = len(self.ast_vocab)
        self.nl_vocab_size = len(self.nl_vocab)

        logger.info('Size of source vocabulary: {} -> {}'.format(self.source_vocab.origin_size, self.source_vocab_size))
        logger.info('Size of code vocabulary: {} -> {}'.format(self.code_vocab.origin_size, self.code_vocab_size))
        logger.info('Size of ast vocabulary: {}'.format(self.ast_vocab_size))
        logger.info('Size of nl vocabulary: {} -> {}'.format(self.nl_vocab.origin_size, self.nl_vocab_size))

        logger.info('Vocabularies are successfully built')

        # model
        logger.info('-' * 100)
        logger.info('Building model')
        self.model = models.Model(source_vocab_size=self.source_vocab_size,
                                  code_vocab_size=self.code_vocab_size,
                                  ast_vocab_size=self.ast_vocab_size,
                                  nl_vocab_size=self.nl_vocab_size,
                                  is_eval=True,
                                  model=model)
        # model device
        logger.info('Model device: {}'.format(next(self.model.parameters()).device))
        # log model statistic
        logger.info('Trainable parameters: {}'.format(utils.human_format(utils.count_params(self.model))))
Esempio n. 3
0
    def __init__(self, vocab_file_path=None, model_file_path=None):
        """

        :param vocab_file_path: tuple of code vocab, ast vocab, nl vocab, if given, build vocab by given path
        :param model_file_path:
        """

        # dataset
        self.train_dataset = data.CodePtrDataset(
            code_path=config.train_code_path,
            ast_path=config.train_sbt_path,
            nl_path=config.train_nl_path)
        self.train_dataset_size = len(self.train_dataset)
        self.train_dataloader = DataLoader(
            dataset=self.train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            collate_fn=lambda *args: utils.unsort_collate_fn(
                args,
                code_vocab=self.code_vocab,
                ast_vocab=self.ast_vocab,
                nl_vocab=self.nl_vocab))

        # vocab
        self.code_vocab: utils.Vocab
        self.ast_vocab: utils.Vocab
        self.nl_vocab: utils.Vocab
        # load vocab from given path
        if vocab_file_path:
            code_vocab_path, ast_vocab_path, nl_vocab_path = vocab_file_path
            self.code_vocab = utils.load_vocab_pk(code_vocab_path)
            self.ast_vocab = utils.load_vocab_pk(ast_vocab_path)
            self.nl_vocab = utils.load_vocab_pk(nl_vocab_path)
        # new vocab
        else:
            self.code_vocab = utils.Vocab('code_vocab')
            self.ast_vocab = utils.Vocab('ast_vocab')
            self.nl_vocab = utils.Vocab('nl_vocab')
            codes, asts, nls = self.train_dataset.get_dataset()
            for code, ast, nl in zip(codes, asts, nls):
                self.code_vocab.add_sentence(code)
                self.ast_vocab.add_sentence(ast)
                self.nl_vocab.add_sentence(nl)

            self.origin_code_vocab_size = len(self.code_vocab)
            self.origin_nl_vocab_size = len(self.nl_vocab)

            # trim vocabulary
            self.code_vocab.trim(config.code_vocab_size)
            self.nl_vocab.trim(config.nl_vocab_size)
            # save vocabulary
            self.code_vocab.save(config.code_vocab_path)
            self.ast_vocab.save(config.ast_vocab_path)
            self.nl_vocab.save(config.nl_vocab_path)
            self.code_vocab.save_txt(config.code_vocab_txt_path)
            self.ast_vocab.save_txt(config.ast_vocab_txt_path)
            self.nl_vocab.save_txt(config.nl_vocab_txt_path)

        self.code_vocab_size = len(self.code_vocab)
        self.ast_vocab_size = len(self.ast_vocab)
        self.nl_vocab_size = len(self.nl_vocab)

        # model
        self.model = models.Model(code_vocab_size=self.code_vocab_size,
                                  ast_vocab_size=self.ast_vocab_size,
                                  nl_vocab_size=self.nl_vocab_size,
                                  model_file_path=model_file_path)
        self.params = list(self.model.code_encoder.parameters()) + \
            list(self.model.ast_encoder.parameters()) + \
            list(self.model.reduce_hidden.parameters()) + \
            list(self.model.decoder.parameters())

        # optimizer
        self.optimizer = Adam([
            {
                'params': self.model.code_encoder.parameters(),
                'lr': config.code_encoder_lr
            },
            {
                'params': self.model.ast_encoder.parameters(),
                'lr': config.ast_encoder_lr
            },
            {
                'params': self.model.reduce_hidden.parameters(),
                'lr': config.reduce_hidden_lr
            },
            {
                'params': self.model.decoder.parameters(),
                'lr': config.decoder_lr
            },
        ],
                              betas=(0.9, 0.999),
                              eps=1e-08,
                              weight_decay=0,
                              amsgrad=False)

        if config.use_lr_decay:
            self.lr_scheduler = lr_scheduler.StepLR(
                self.optimizer,
                step_size=config.lr_decay_every,
                gamma=config.lr_decay_rate)

        # best score and model(state dict)
        self.min_loss: float = 1000
        self.best_model: dict = {}
        self.best_epoch_batch: (int, int) = (None, None)

        # eval instance
        self.eval_instance = eval.Eval(self.get_cur_state_dict())

        # early stopping
        self.early_stopping = None
        if config.use_early_stopping:
            self.early_stopping = utils.EarlyStopping()

        config.model_dir = os.path.join(config.model_dir,
                                        utils.get_timestamp())
        if not os.path.exists(config.model_dir):
            os.makedirs(config.model_dir)
Esempio n. 4
0
    def __init__(self):

        # dataset
        logger.info('-' * 100)
        logger.info('Loading training and validation dataset')
        self.dataset = data.CodePtrDataset(mode='train')
        self.dataset_size = len(self.dataset)
        logger.info('Size of training dataset: {}'.format(self.dataset_size))
        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     collate_fn=lambda *args: utils.collate_fn(
                                         args,
                                         source_vocab=self.source_vocab,
                                         code_vocab=self.code_vocab,
                                         ast_vocab=self.ast_vocab,
                                         nl_vocab=self.nl_vocab))

        # valid dataset
        self.valid_dataset = data.CodePtrDataset(mode='valid')
        self.valid_dataset_size = len(self.valid_dataset)
        self.valid_dataloader = DataLoader(
            dataset=self.valid_dataset,
            batch_size=config.valid_batch_size,
            collate_fn=lambda *args: utils.collate_fn(
                args,
                source_vocab=self.source_vocab,
                code_vocab=self.code_vocab,
                ast_vocab=self.ast_vocab,
                nl_vocab=self.nl_vocab))
        logger.info('Size of validation dataset: {}'.format(
            self.valid_dataset_size))
        logger.info('The dataset are successfully loaded')

        # vocab
        logger.info('-' * 100)
        logger.info('Building vocabularies')

        sources, codes, asts, nls = self.dataset.get_dataset()

        self.source_vocab = utils.build_word_vocab(
            dataset=sources,
            vocab_name='source',
            ignore_case=True,
            max_vocab_size=config.source_vocab_size,
            save_dir=config.vocab_root)
        self.source_vocab_size = len(self.source_vocab)
        logger.info('Size of source vocab: {} -> {}'.format(
            self.source_vocab.origin_size, self.source_vocab_size))

        self.code_vocab = utils.build_word_vocab(
            dataset=codes,
            vocab_name='code',
            ignore_case=True,
            max_vocab_size=config.code_vocab_size,
            save_dir=config.vocab_root)
        self.code_vocab_size = len(self.code_vocab)
        logger.info('Size of code vocab: {} -> {}'.format(
            self.code_vocab.origin_size, self.code_vocab_size))

        self.ast_vocab = utils.build_word_vocab(dataset=asts,
                                                vocab_name='ast',
                                                ignore_case=True,
                                                save_dir=config.vocab_root)
        self.ast_vocab_size = len(self.ast_vocab)
        logger.info('Size of ast vocab: {}'.format(self.ast_vocab_size))

        self.nl_vocab = utils.build_word_vocab(
            dataset=nls,
            vocab_name='nl',
            ignore_case=True,
            max_vocab_size=config.nl_vocab_size,
            save_dir=config.vocab_root)
        self.nl_vocab_size = len(self.nl_vocab)
        logger.info('Size of nl vocab: {} -> {}'.format(
            self.nl_vocab.origin_size, self.nl_vocab_size))

        logger.info('Vocabularies are successfully built')

        # model
        logger.info('-' * 100)
        logger.info('Building the model')
        self.model = models.Model(source_vocab_size=self.source_vocab_size,
                                  code_vocab_size=self.code_vocab_size,
                                  ast_vocab_size=self.ast_vocab_size,
                                  nl_vocab_size=self.nl_vocab_size)
        # model device
        logger.info('Model device: {}'.format(
            next(self.model.parameters()).device))
        # log model statistic
        logger.info('Trainable parameters: {}'.format(
            utils.human_format(utils.count_params(self.model))))

        # optimizer
        self.optimizer = Adam([
            {
                'params': self.model.parameters(),
                'lr': config.learning_rate
            },
        ])

        self.criterion = nn.CrossEntropyLoss(
            ignore_index=self.nl_vocab.get_pad_index())

        if config.use_lr_decay:
            self.lr_scheduler = lr_scheduler.StepLR(self.optimizer,
                                                    step_size=1,
                                                    gamma=config.lr_decay_rate)

        # early stopping
        self.early_stopping = None
        if config.use_early_stopping:
            self.early_stopping = utils.EarlyStopping(
                patience=config.early_stopping_patience, high_record=False)