Ejemplo n.º 1
0
def main(args):
    model = LightningLinearVAE(args)
    if (args.eigvectors is not None and args.eigvalues is not None):
        eigvectors = np.loadtxt(args.eigvectors)
        eigvalues = np.loadtxt(args.eigvalues)
        model.set_eigs(eigvectors, eigvalues)
    trainer = Trainer(
        max_epochs=args.epochs,
        gpus=args.gpus,
        check_val_every_n_epoch=1,
        gradient_clip_val=args.grad_clip,
    )
    ckpt_path = os.path.join(
        args.output_directory,
        trainer.logger.name,
        f"linear_vae_version_{trainer.logger.version}",
        "checkpoints",
    )
    checkpoint_callback = ModelCheckpoint(filepath=ckpt_path,
                                          period=1,
                                          monitor='val_loss',
                                          mode='min',
                                          verbose=True)
    trainer.checkpoint_callback = checkpoint_callback

    trainer.fit(model)
    torch.save(model.state_dict(), args.output_directory + '/last_ckpt.pt')
Ejemplo n.º 2
0
def main(hparams):

    torch.manual_seed(hparams.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(hparams.seed)
    random.seed(hparams.seed)

    module = PointCloudDenoising(hparams)

    if hparams.debug:
        trainer = Trainer(gpus=hparams.n_gpu,
                          fast_dev_run=True,
                          logger=False,
                          checkpoint_callback=False,
                          distributed_backend='dp')
    else:
        trainer = Trainer(
            gpus=hparams.n_gpu,
            early_stop_callback=None,
            distributed_backend='dp',
        )
        os.makedirs('./lightning_logs', exist_ok=True)
        os.makedirs(trainer.logger.log_dir)
        trainer.checkpoint_callback = ModelCheckpoint(
            filepath=trainer.logger.log_dir, save_top_k=-1)

    trainer.fit(module)
Ejemplo n.º 3
0
def main(parser, fast_dev_run) -> None:
    args = parser.parse_args()
    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, args.name)
    os.makedirs(args.savedir, exist_ok=True)

    model = get_model(args)

    early_stop_callback = EarlyStopping(
        monitor=args.monitor,
        min_delta=0.0,
        patience=args.patience,
        verbose=True,
        mode=args.metric_mode,
    )

    trainer = Trainer(logger=setup_testube_logger(args),
                      checkpoint_callback=True,
                      early_stop_callback=early_stop_callback,
                      default_root_dir=args.savedir,
                      gpus=args.gpus,
                      distributed_backend=args.distributed_backend,
                      precision=args.precision,
                      amp_level=args.amp_level,
                      max_epochs=args.max_epochs,
                      min_epochs=args.min_epochs,
                      accumulate_grad_batches=args.accumulate_grad_batches,
                      val_percent_check=args.val_percent_check,
                      fast_dev_run=fast_dev_run,
                      num_sanity_val_steps=0)

    ckpt_path = os.path.join(
        trainer.default_root_dir,
        trainer.logger.name,
        f"version_{trainer.logger.version}",
        "checkpoints",
    )
    # initialize Model Checkpoint Saver
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=args.save_top_k,
        verbose=True,
        monitor=args.monitor,
        period=1,
        mode=args.metric_mode,
    )
    trainer.checkpoint_callback = checkpoint_callback

    trainer.fit(model)
Ejemplo n.º 4
0
def main(hparams) -> None:

    set_seed(hparams.seed)
    model = BERTClassifier(hparams)

    early_stop_callback = EarlyStopping(
        monitor=hparams.monitor,
        min_delta=0.0,
        patience=hparams.patience,
        verbose=True,
        mode=hparams.metric_mode,
    )
    save_dir = os.environ['HOME'] + "/data/lightning_experiments/"
    trainer = Trainer(
        logger=setup_testube_logger(save_dir),
        checkpoint_callback=True,
        early_stop_callback=early_stop_callback,
        default_save_path=save_dir,
        gpus=hparams.gpus,
        num_nodes=hparams.num_nodes,
        distributed_backend="ddp",
        use_amp=True,
        log_gpu_memory='all',
        max_epochs=hparams.max_epochs,
        min_epochs=hparams.min_epochs,
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        val_percent_check=hparams.val_percent_check,
    )

    ckpt_path = os.path.join(
        trainer.default_save_path,
        trainer.logger.name,
        f"version_{trainer.logger.version}",
        "checkpoints",
    )
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=hparams.save_top_k,
        verbose=True,
        monitor=hparams.monitor,
        period=1,
        mode=hparams.metric_mode,
    )
    trainer.checkpoint_callback = checkpoint_callback

    trainer.fit(model)
Ejemplo n.º 5
0
def main(args):
    print('args', args)
    if args.load_from_checkpoint is not None:
        model = LightningBatchLinearVAE(args)
        checkpoint = torch.load(
            args.load_from_checkpoint,
            map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model = LightningBatchLinearVAE(args)
    print(model)
    if (args.eigvectors is not None and
        args.eigvalues is not None):
        eigvectors = np.loadtxt(args.eigvectors)
        eigvalues = np.loadtxt(args.eigvalues)
        model.set_eigs(eigvectors, eigvalues)
    trainer = Trainer(
        max_epochs=args.epochs,
        gpus=args.gpus,
        check_val_every_n_epoch=1,
        gradient_clip_val=args.grad_clip,
        accumulate_grad_batches=args.grad_accum
    )
    ckpt_path = os.path.join(
        args.output_directory,
        trainer.logger.name,
        f"catvae_version_{trainer.logger.version}",
        "checkpoints",
    )
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        period=1,
        monitor='val_loss',
        mode='min',
        verbose=True
    )
    trainer.checkpoint_callback = checkpoint_callback

    trainer.fit(model)
    torch.save(model.state_dict(),
               args.output_directory + '/last_ckpt.pt')
Ejemplo n.º 6
0
def setup_training(hparams: HyperOptArgumentParser) -> tuple:
    """
    Setup for the training loop.
    :param hparams: HyperOptArgumentParser

    Returns:
        - pytorch_lightning Trainer
    """
    if hparams.verbose:
        log.info(hparams)

    if hparams.early_stopping:
        # Enable Early stopping
        early_stop_callback = EarlyStopping(
            monitor=hparams.monitor,
            min_delta=hparams.min_delta,
            patience=hparams.patience,
            verbose=hparams.verbose,
            mode=hparams.metric_mode,
        )
    else:
        early_stop_callback = None

    # configure trainer
    if hparams.epochs > 0.0:
        hparams.min_epochs = hparams.epochs
        hparams.max_epochs = hparams.epochs

    trainer = Trainer(
        logger=setup_testube_logger(),
        checkpoint_callback=True,
        early_stop_callback=early_stop_callback,
        default_save_path="experiments/",
        gradient_clip_val=hparams.gradient_clip_val,
        gpus=hparams.gpus,
        show_progress_bar=False,
        overfit_pct=hparams.overfit_pct,
        check_val_every_n_epoch=hparams.check_val_every_n_epoch,
        fast_dev_run=False,
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        max_epochs=hparams.max_epochs,
        min_epochs=hparams.min_epochs,
        train_percent_check=hparams.train_percent_check,
        val_percent_check=hparams.val_percent_check,
        val_check_interval=hparams.val_check_interval,
        log_save_interval=hparams.log_save_interval,
        row_log_interval=hparams.row_log_interval,
        distributed_backend=hparams.distributed_backend,
        precision=hparams.precision,
        weights_summary=hparams.weights_summary,
        resume_from_checkpoint=hparams.resume_from_checkpoint,
        profiler=hparams.profiler,
        log_gpu_memory="all",
    )

    ckpt_path = os.path.join(
        trainer.default_save_path,
        trainer.logger.name,
        f"version_{trainer.logger.version}",
        "checkpoints",
    )

    # initialize Model Checkpoint Saver
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=hparams.save_top_k,
        verbose=hparams.verbose,
        monitor=hparams.monitor,
        save_weights_only=hparams.save_weights_only,
        period=hparams.period,
        mode=hparams.metric_mode,
    )
    trainer.checkpoint_callback = checkpoint_callback
    return trainer
Ejemplo n.º 7
0
def main(hparams) -> None:
    """
    Main training routine specific for this project
    :param hparams:
    """
    set_seed(hparams.seed)
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = BERTClassifier(hparams)

    # ------------------------
    # 2 INIT EARLY STOPPING
    # ------------------------
    early_stop_callback = EarlyStopping(
        monitor=hparams.monitor,
        min_delta=0.0,
        patience=hparams.patience,
        verbose=True,
        mode=hparams.metric_mode,
    )
    # ------------------------
    # 3 INIT TRAINER
    # ------------------------
    trainer = Trainer(
        logger=setup_testube_logger(),
        checkpoint_callback=True,
        early_stop_callback=early_stop_callback,
        default_save_path="experiments/",
        gpus=hparams.gpus,
        distributed_backend=hparams.distributed_backend,
        use_amp=hparams.use_16bit,
        max_epochs=hparams.max_epochs,
        min_epochs=hparams.min_epochs,
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        log_gpu_memory=hparams.log_gpu_memory,
        val_percent_check=hparams.val_percent_check,
    )

    # --------------------------------
    # 4 INIT MODEL CHECKPOINT CALLBACK
    # -------------------------------
    ckpt_path = os.path.join(
        trainer.default_save_path,
        trainer.logger.name,
        f"version_{trainer.logger.version}",
        "checkpoints",
    )
    # initialize Model Checkpoint Saver
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=hparams.save_top_k,
        verbose=True,
        monitor=hparams.monitor,
        period=1,
        mode=hparams.metric_mode,
    )
    trainer.checkpoint_callback = checkpoint_callback

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model)
Ejemplo n.º 8
0
    def __init__(self,
                 pl_trainer: pl.Trainer,
                 model: pl.LightningModule,
                 population_tasks: mp.Queue,
                 tune_hparams: Dict,
                 process_position: int,
                 global_epoch: mp.Value,
                 max_epoch: int,
                 full_parallel: bool,
                 pbt_period: int = 4,
                 pbt_monitor: str = 'val_loss',
                 logger_info=None,
                 dataloaders: Optional[Dict] = None):
        """

        Args:
            pl_trainer:
            model:
            population_tasks:
            tune_hparams:
            process_position:
            global_epoch:
            max_epoch:
            full_parallel:
            pbt_period:
            **dataloaders:
        """
        super().__init__()
        # Set monitor and monitor_precision
        monitor_precision = 32
        # Set checkpoint dirpath
        #checkpoint_dirpath = pl_trainer.checkpoint_callback.dirpath
        #period = pl_trainer.checkpoint_callback.period
        # Formatting checkpoints
        checkpoint_format = '{task:03d}-{' + f'{pbt_monitor}:.{monitor_precision}f' + '}'
        checkpoint_filepath = os.path.join(pl_trainer.logger.log_dir,
                                           checkpoint_format)

        # For TaskSaving
        print(logger_info)

        checkpoint_dirpath = pl_trainer.logger.log_dir

        pl_trainer.checkpoint_callback = TaskSaving(
            filepath=checkpoint_filepath,
            monitor=pbt_monitor,
            population_tasks=population_tasks,
            period=1,
            full_parallel=full_parallel,
        )

        # For EarlyStopping
        pl_trainer.early_stop_callback = EarlyStopping(
            global_epoch=global_epoch, max_global_epoch=max_epoch)

        # For TaskLoading
        pl_trainer.callbacks = [
            TaskLoading(population_tasks=population_tasks,
                        global_epoch=global_epoch,
                        filepath=checkpoint_filepath,
                        monitor=pbt_monitor,
                        tune_hparams=tune_hparams,
                        pbt_period=pbt_period)
        ]

        # Alter logger to spec.
        #if isinstance(pl_trainer.logger, pl.loggers.TensorBoardLogger):
        pl_trainer.logger = loggers.TensorBoardLogger(
            save_dir=logger_info['save_dir'],
            name=logger_info['name'],
            version=logger_info['version'],
            task=process_position,
        )

        # Set process_position
        pl_trainer.process_position = process_position
        # pl_trainer.logger._version = f'worker_{process_position}'
        # Define and set = to
        self.trainer = pl_trainer
        self.model = model
        self.global_epoch = global_epoch
        self.population_tasks = population_tasks
        self.max_epoch = max_epoch
        self.dataloaders = dataloaders or {}
        print(dataloaders)