Ejemplo n.º 1
0
def train(config: DictConfig):
    filter_warnings()
    print_config(config)
    seed_everything(config.seed)

    known_models = {"code2seq": get_code2seq, "code2class": get_code2class, "typed-code2seq": get_typed_code2seq}
    if config.name not in known_models:
        print(f"Unknown model: {config.name}, try on of {known_models.keys()}")

    vocabulary = Vocabulary.load_vocabulary(join(config.data_folder, config.dataset.name, config.vocabulary_name))
    model, data_module = known_models[config.name](config, vocabulary)

    # define logger
    wandb_logger = WandbLogger(
        project=f"{config.name}-{config.dataset.name}", log_model=True, offline=config.log_offline
    )
    wandb_logger.watch(model)
    # define model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=wandb_logger.experiment.dir,
        filename="{epoch:02d}-{val_loss:.4f}",
        period=config.save_every_epoch,
        save_top_k=-1,
    )
    upload_checkpoint_callback = UploadCheckpointCallback(wandb_logger.experiment.dir)
    # define early stopping callback
    early_stopping_callback = EarlyStopping(
        patience=config.hyper_parameters.patience, monitor="val_loss", verbose=True, mode="min"
    )
    # define callback for printing intermediate result
    print_epoch_result_callback = PrintEpochResultCallback("train", "val")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateMonitor("step")
    trainer = Trainer(
        max_epochs=config.hyper_parameters.n_epochs,
        gradient_clip_val=config.hyper_parameters.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.val_every_epoch,
        log_every_n_steps=config.log_every_epoch,
        logger=wandb_logger,
        gpus=gpu,
        progress_bar_refresh_rate=config.progress_bar_refresh_rate,
        callbacks=[
            lr_logger,
            early_stopping_callback,
            checkpoint_callback,
            upload_checkpoint_callback,
            print_epoch_result_callback,
        ],
        resume_from_checkpoint=config.resume_from_checkpoint,
    )

    trainer.fit(model=model, datamodule=data_module)
    trainer.test()
Ejemplo n.º 2
0
    def load_config(self, args):
        '''
    Load the config specified by args.config_path. The config will be copied into args.checkpoint_path if there is no config there. We can overwrite a few hyperparameters in the config and being listed up on the bottom of this file by specifying as arguments when runnning this code.
     e.g.
        ./run.sh checkpoints/tmp test --batch_size=30
    '''
        self.model_path = args.checkpoint_path
        self.summaries_path = self.model_path + '/summaries'
        self.checkpoints_path = self.model_path + '/checkpoints'
        self.tests_path = self.model_path + '/tests'
        self.config_path = args.config_path if args.config_path else self.model_path + '/config'

        # Read and restore config if there is no existing config in the checkpoint.
        sys.stderr.write('Reading a config from %s ...\n' % (self.config_path))
        config = pyhocon.ConfigFactory.parse_file(self.config_path)
        config_restored_path = os.path.join(self.model_path, 'config')
        if not os.path.exists(self.summaries_path):
            os.makedirs(self.summaries_path)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)
        if not os.path.exists(self.tests_path):
            os.makedirs(self.tests_path)

        # Overwrite configs by temporary args. They have higher priorities than those in the config file.
        if 'dataset_type' in args and args.dataset_type:
            config['dataset_type'] = args.dataset_type
        if 'train_data_path' in args and args.train_data_path:
            config['dataset_path']['train'] = args.train_data_path
        if 'test_data_path' in args and args.test_data_path:
            config['dataset_path']['test'] = args.test_data_path
        if 'num_train_data' in args and args.num_train_data:
            config['num_train_data'] = args.num_train_data
        if 'vocab_size' in args and args.vocab_size:
            config['vocab_size'] = args.vocab_size
        if 'batch_size' in args and args.batch_size:
            config['batch_size'] = args.batch_size
        if 'target_attribute' in args and args.target_attribute:
            config['target_attribute'] = args.target_attribute

        # The restored confing in the checkpoint will be overwritten with the argument --cleanup=True.
        if args.cleanup or not os.path.exists(config_restored_path):
            sys.stderr.write('Restore the config to %s ...\n' %
                             (config_restored_path))

            with open(config_restored_path, 'w') as f:
                sys.stdout = f
                common.print_config(config)
                sys.stdout = sys.__stdout__
        config = common.recDotDict(config)  # Allows dot-access.

        # The default config for old models which don't have some recently added hyperparameters will be overwritten if a model has the corresponding hyperparameters.
        default_config.update(config)
        config = default_config

        print(config)
        return config
Ejemplo n.º 3
0
  def load_config(self, args):
    self.model_path = args.checkpoint_path
    self.summaries_path = self.model_path + '/summaries'
    self.checkpoint_path = self.model_path + '/checkpoints'
    self.tests_path = self.model_path + '/tests'
    self.config_path = args.config_path if args.config_path else self.model_path + '/config'

    # Read and restore config
    sys.stderr.write('Reading a config from %s ...\n' % (self.config_path))
    config = pyhocon.ConfigFactory.parse_file(self.config_path)
    config_restored_path = os.path.join(self.model_path, 'config')
    if not os.path.exists(self.summaries_path):
      os.makedirs(self.summaries_path)
    if not os.path.exists(self.checkpoint_path):
      os.makedirs(self.checkpoint_path)
    if not os.path.exists(self.tests_path):
      os.makedirs(self.tests_path)

    # Overwrite configs by temporary args. They have higher priorities than those in the config of models.
    if 'dataset_type' in args and args.dataset_type:
      config['dataset_type'] = args.dataset_type
    if 'train_data_size' in args and args.train_data_size:
      config['dataset_info']['train']['max_lines'] = args.train_data_size
    if 'train_data_path' in args and args.train_data_path:
      config['dataset_info']['train']['path'] = args.train_data_path
    if 'test_data_path' in args and args.test_data_path:
      config['dataset_info']['test']['path'] = args.test_data_path
    if 'batch_size' in args and args.batch_size:
      config['batch_size'] = args.batch_size
    if 'w_vocab_size' in args and args.w_vocab_size:
      config['w_vocab_size'] = args.w_vocab_size
    if 'target_attribute' in args and args.target_attribute:
      config['target_attribute'] = args.target_attribute

    if args.cleanup or not os.path.exists(config_restored_path):
      sys.stderr.write('Restore the config to %s ...\n' % (config_restored_path))

      with open(config_restored_path, 'w') as f:
        sys.stdout = f
        common.print_config(config)
        sys.stdout = sys.__stdout__
    config = common.recDotDict(config)
    default_config.update(config)
    config = default_config
    return config
Ejemplo n.º 4
0
    def get_config(self, args):
        self.model_path = args.checkpoint_path
        self.summaries_path = self.model_path + '/summaries'
        self.checkpoints_path = self.model_path + '/checkpoints'
        self.tests_path = self.model_path + '/tests'
        self.config_path = args.config_path if args.config_path else self.model_path + '/config'

        # Read and restore config
        sys.stderr.write('Reading a config from %s ...\n' % (self.config_path))
        config = pyhocon.ConfigFactory.parse_file(self.config_path)
        config_restored_path = os.path.join(self.model_path, 'config')
        if not os.path.exists(self.summaries_path):
            os.makedirs(self.summaries_path)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)
        if not os.path.exists(self.tests_path):
            os.makedirs(self.tests_path)

        if args.cleanup or not os.path.exists(config_restored_path):
            sys.stderr.write('Restore the config to %s ...\n' %
                             (config_restored_path))

            with open(config_restored_path, 'w') as f:
                sys.stdout = f
                common.print_config(config)
                sys.stdout = sys.__stdout__
        config = common.recDotDict(config)

        default_config.update(config)
        config = default_config

        # Override configs by temporary args.
        if args.test_data_path:
            config.dataset_path.test = args.test_data_path
        if args.batch_size:
            config.batch_size = args.batch_size
        config.debug = args.debug
        return config
Ejemplo n.º 5
0
def train(config: DictConfig, resume_from_checkpoint: str = None):
    filter_warnings()
    print_config(config)
    seed_everything(config.seed)

    known_models = {
        "token": get_token_based,
        "vuldeepecker": get_VDP,
        "vgdetector": get_VGD,
        "sysevr": get_SYS,
        "mulvuldeepecker": get_MULVDP,
        "code2seq": get_C2S,
        "code2vec": get_C2V
    }

    vocab = {
        "token": Vocabulary_token,
        "vuldeepecker": Vocabulary_token,
        "vgdetector": Vocabulary_token,
        "sysevr": Vocabulary_token,
        "mulvuldeepecker": Vocabulary_token,
        "code2seq": Vocabulary_c2s,
        "code2vec": Vocabulary_c2s
    }
    if config.name not in known_models:
        print(f"Unknown model: {config.name}, try on of {known_models.keys()}")
        return
    if os.path.exists(
            join(config.data_folder, config.name, config.dataset.name,
                 "vocab.pkl")):
        vocabulary = vocab[config.name].load_vocabulary(
            join(config.data_folder, config.name, config.dataset.name,
                 "vocab.pkl"))
    else:
        vocabulary = None
    model, data_module = known_models[config.name](config, vocabulary)
    # define logger
    # wandb logger
    # wandb_logger = WandbLogger(project=f"{config.name}-{config.dataset.name}",
    #                            log_model=True,
    #                            offline=config.log_offline)
    # wandb_logger.watch(model)
    # checkpoint_callback = ModelCheckpoint(
    #     dirpath=wandb_logger.experiment.dir,
    #     filename="{epoch:02d}-{val_loss:.4f}",
    #     period=config.save_every_epoch,
    #     save_top_k=-1,
    # )
    # upload_checkpoint_callback = UploadCheckpointCallback(
    #     wandb_logger.experiment.dir)

    # tensorboard logger
    tensorlogger = TensorBoardLogger(join("ts_logger", config.name),
                                     config.dataset.name)
    # define model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=join(tensorlogger.log_dir, "checkpoints"),
        monitor="val_loss",
        filename="{epoch:02d}-{val_loss:.4f}",
        period=config.save_every_epoch,
        save_top_k=3,
    )
    upload_checkpoint_callback = UploadCheckpointCallback(
        join(tensorlogger.log_dir, "checkpoints"))

    # define early stopping callback
    early_stopping_callback = EarlyStopping(
        patience=config.hyper_parameters.patience,
        monitor="val_loss",
        verbose=True,
        mode="min")
    # define callback for printing intermediate result
    print_epoch_result_callback = PrintEpochResultCallback("train", "val")
    collect_test_res_callback = CollectTestResCallback(config)
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateMonitor("step")
    trainer = Trainer(
        max_epochs=config.hyper_parameters.n_epochs,
        gradient_clip_val=config.hyper_parameters.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.val_every_epoch,
        log_every_n_steps=config.log_every_epoch,
        logger=[tensorlogger],
        reload_dataloaders_every_epoch=config.hyper_parameters.
        reload_dataloader,
        gpus=gpu,
        progress_bar_refresh_rate=config.progress_bar_refresh_rate,
        callbacks=[
            lr_logger, early_stopping_callback, checkpoint_callback,
            print_epoch_result_callback, upload_checkpoint_callback,
            collect_test_res_callback
        ],
        resume_from_checkpoint=resume_from_checkpoint,
    )

    trainer.fit(model=model, datamodule=data_module)
    trainer.test()
Ejemplo n.º 6
0
def train_treelstm(config: DictConfig):
    filter_warnings()
    seed_everything(config.seed)
    dgl.seed(config.seed)

    print_config(config, ["hydra", "log_offline"])

    data_module = JsonlDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    model: LightningModule
    if "max_types" in config and "max_type_parts" in config:
        model = TypedTreeLSTM2Seq(config, data_module.vocabulary)
    else:
        model = TreeLSTM2Seq(config, data_module.vocabulary)

    # define logger
    wandb_logger = WandbLogger(project=f"tree-lstm-{config.dataset}",
                               log_model=False,
                               offline=config.log_offline)
    wandb_logger.watch(model)
    # define model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=wandb_logger.experiment.dir,
        filename="{epoch:02d}-{val_loss:.4f}",
        period=config.save_every_epoch,
        save_top_k=-1,
    )
    upload_checkpoint_callback = UploadCheckpointCallback(
        wandb_logger.experiment.dir)
    # define early stopping callback
    early_stopping_callback = EarlyStopping(patience=config.patience,
                                            monitor="val_loss",
                                            verbose=True,
                                            mode="min")
    # define callback for printing intermediate result
    print_epoch_result_callback = PrintEpochResultCallback("train", "val")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateMonitor("step")
    trainer = Trainer(
        max_epochs=config.n_epochs,
        gradient_clip_val=config.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.val_every_epoch,
        log_every_n_steps=config.log_every_step,
        logger=wandb_logger,
        gpus=gpu,
        progress_bar_refresh_rate=config.progress_bar_refresh_rate,
        callbacks=[
            lr_logger,
            early_stopping_callback,
            checkpoint_callback,
            upload_checkpoint_callback,
            print_epoch_result_callback,
        ],
        resume_from_checkpoint=config.resume_checkpoint,
    )

    trainer.fit(model=model, datamodule=data_module)
    trainer.test()