def main(args): model = LightningLinearVAE(args) if (args.eigvectors is not None and args.eigvalues is not None): eigvectors = np.loadtxt(args.eigvectors) eigvalues = np.loadtxt(args.eigvalues) model.set_eigs(eigvectors, eigvalues) trainer = Trainer( max_epochs=args.epochs, gpus=args.gpus, check_val_every_n_epoch=1, gradient_clip_val=args.grad_clip, ) ckpt_path = os.path.join( args.output_directory, trainer.logger.name, f"linear_vae_version_{trainer.logger.version}", "checkpoints", ) checkpoint_callback = ModelCheckpoint(filepath=ckpt_path, period=1, monitor='val_loss', mode='min', verbose=True) trainer.checkpoint_callback = checkpoint_callback trainer.fit(model) torch.save(model.state_dict(), args.output_directory + '/last_ckpt.pt')
def main(hparams): torch.manual_seed(hparams.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(hparams.seed) random.seed(hparams.seed) module = PointCloudDenoising(hparams) if hparams.debug: trainer = Trainer(gpus=hparams.n_gpu, fast_dev_run=True, logger=False, checkpoint_callback=False, distributed_backend='dp') else: trainer = Trainer( gpus=hparams.n_gpu, early_stop_callback=None, distributed_backend='dp', ) os.makedirs('./lightning_logs', exist_ok=True) os.makedirs(trainer.logger.log_dir) trainer.checkpoint_callback = ModelCheckpoint( filepath=trainer.logger.log_dir, save_top_k=-1) trainer.fit(module)
def main(parser, fast_dev_run) -> None: args = parser.parse_args() set_seed(args.seed) args.savedir = os.path.join(args.savedir, args.name) os.makedirs(args.savedir, exist_ok=True) model = get_model(args) early_stop_callback = EarlyStopping( monitor=args.monitor, min_delta=0.0, patience=args.patience, verbose=True, mode=args.metric_mode, ) trainer = Trainer(logger=setup_testube_logger(args), checkpoint_callback=True, early_stop_callback=early_stop_callback, default_root_dir=args.savedir, gpus=args.gpus, distributed_backend=args.distributed_backend, precision=args.precision, amp_level=args.amp_level, max_epochs=args.max_epochs, min_epochs=args.min_epochs, accumulate_grad_batches=args.accumulate_grad_batches, val_percent_check=args.val_percent_check, fast_dev_run=fast_dev_run, num_sanity_val_steps=0) ckpt_path = os.path.join( trainer.default_root_dir, trainer.logger.name, f"version_{trainer.logger.version}", "checkpoints", ) # initialize Model Checkpoint Saver checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, save_top_k=args.save_top_k, verbose=True, monitor=args.monitor, period=1, mode=args.metric_mode, ) trainer.checkpoint_callback = checkpoint_callback trainer.fit(model)
def main(hparams) -> None: set_seed(hparams.seed) model = BERTClassifier(hparams) early_stop_callback = EarlyStopping( monitor=hparams.monitor, min_delta=0.0, patience=hparams.patience, verbose=True, mode=hparams.metric_mode, ) save_dir = os.environ['HOME'] + "/data/lightning_experiments/" trainer = Trainer( logger=setup_testube_logger(save_dir), checkpoint_callback=True, early_stop_callback=early_stop_callback, default_save_path=save_dir, gpus=hparams.gpus, num_nodes=hparams.num_nodes, distributed_backend="ddp", use_amp=True, log_gpu_memory='all', max_epochs=hparams.max_epochs, min_epochs=hparams.min_epochs, accumulate_grad_batches=hparams.accumulate_grad_batches, val_percent_check=hparams.val_percent_check, ) ckpt_path = os.path.join( trainer.default_save_path, trainer.logger.name, f"version_{trainer.logger.version}", "checkpoints", ) checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, save_top_k=hparams.save_top_k, verbose=True, monitor=hparams.monitor, period=1, mode=hparams.metric_mode, ) trainer.checkpoint_callback = checkpoint_callback trainer.fit(model)
def main(args): print('args', args) if args.load_from_checkpoint is not None: model = LightningBatchLinearVAE(args) checkpoint = torch.load( args.load_from_checkpoint, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) else: model = LightningBatchLinearVAE(args) print(model) if (args.eigvectors is not None and args.eigvalues is not None): eigvectors = np.loadtxt(args.eigvectors) eigvalues = np.loadtxt(args.eigvalues) model.set_eigs(eigvectors, eigvalues) trainer = Trainer( max_epochs=args.epochs, gpus=args.gpus, check_val_every_n_epoch=1, gradient_clip_val=args.grad_clip, accumulate_grad_batches=args.grad_accum ) ckpt_path = os.path.join( args.output_directory, trainer.logger.name, f"catvae_version_{trainer.logger.version}", "checkpoints", ) checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, period=1, monitor='val_loss', mode='min', verbose=True ) trainer.checkpoint_callback = checkpoint_callback trainer.fit(model) torch.save(model.state_dict(), args.output_directory + '/last_ckpt.pt')
def setup_training(hparams: HyperOptArgumentParser) -> tuple: """ Setup for the training loop. :param hparams: HyperOptArgumentParser Returns: - pytorch_lightning Trainer """ if hparams.verbose: log.info(hparams) if hparams.early_stopping: # Enable Early stopping early_stop_callback = EarlyStopping( monitor=hparams.monitor, min_delta=hparams.min_delta, patience=hparams.patience, verbose=hparams.verbose, mode=hparams.metric_mode, ) else: early_stop_callback = None # configure trainer if hparams.epochs > 0.0: hparams.min_epochs = hparams.epochs hparams.max_epochs = hparams.epochs trainer = Trainer( logger=setup_testube_logger(), checkpoint_callback=True, early_stop_callback=early_stop_callback, default_save_path="experiments/", gradient_clip_val=hparams.gradient_clip_val, gpus=hparams.gpus, show_progress_bar=False, overfit_pct=hparams.overfit_pct, check_val_every_n_epoch=hparams.check_val_every_n_epoch, fast_dev_run=False, accumulate_grad_batches=hparams.accumulate_grad_batches, max_epochs=hparams.max_epochs, min_epochs=hparams.min_epochs, train_percent_check=hparams.train_percent_check, val_percent_check=hparams.val_percent_check, val_check_interval=hparams.val_check_interval, log_save_interval=hparams.log_save_interval, row_log_interval=hparams.row_log_interval, distributed_backend=hparams.distributed_backend, precision=hparams.precision, weights_summary=hparams.weights_summary, resume_from_checkpoint=hparams.resume_from_checkpoint, profiler=hparams.profiler, log_gpu_memory="all", ) ckpt_path = os.path.join( trainer.default_save_path, trainer.logger.name, f"version_{trainer.logger.version}", "checkpoints", ) # initialize Model Checkpoint Saver checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, save_top_k=hparams.save_top_k, verbose=hparams.verbose, monitor=hparams.monitor, save_weights_only=hparams.save_weights_only, period=hparams.period, mode=hparams.metric_mode, ) trainer.checkpoint_callback = checkpoint_callback return trainer
def main(hparams) -> None: """ Main training routine specific for this project :param hparams: """ set_seed(hparams.seed) # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = BERTClassifier(hparams) # ------------------------ # 2 INIT EARLY STOPPING # ------------------------ early_stop_callback = EarlyStopping( monitor=hparams.monitor, min_delta=0.0, patience=hparams.patience, verbose=True, mode=hparams.metric_mode, ) # ------------------------ # 3 INIT TRAINER # ------------------------ trainer = Trainer( logger=setup_testube_logger(), checkpoint_callback=True, early_stop_callback=early_stop_callback, default_save_path="experiments/", gpus=hparams.gpus, distributed_backend=hparams.distributed_backend, use_amp=hparams.use_16bit, max_epochs=hparams.max_epochs, min_epochs=hparams.min_epochs, accumulate_grad_batches=hparams.accumulate_grad_batches, log_gpu_memory=hparams.log_gpu_memory, val_percent_check=hparams.val_percent_check, ) # -------------------------------- # 4 INIT MODEL CHECKPOINT CALLBACK # ------------------------------- ckpt_path = os.path.join( trainer.default_save_path, trainer.logger.name, f"version_{trainer.logger.version}", "checkpoints", ) # initialize Model Checkpoint Saver checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, save_top_k=hparams.save_top_k, verbose=True, monitor=hparams.monitor, period=1, mode=hparams.metric_mode, ) trainer.checkpoint_callback = checkpoint_callback # ------------------------ # 5 START TRAINING # ------------------------ trainer.fit(model)
def __init__(self, pl_trainer: pl.Trainer, model: pl.LightningModule, population_tasks: mp.Queue, tune_hparams: Dict, process_position: int, global_epoch: mp.Value, max_epoch: int, full_parallel: bool, pbt_period: int = 4, pbt_monitor: str = 'val_loss', logger_info=None, dataloaders: Optional[Dict] = None): """ Args: pl_trainer: model: population_tasks: tune_hparams: process_position: global_epoch: max_epoch: full_parallel: pbt_period: **dataloaders: """ super().__init__() # Set monitor and monitor_precision monitor_precision = 32 # Set checkpoint dirpath #checkpoint_dirpath = pl_trainer.checkpoint_callback.dirpath #period = pl_trainer.checkpoint_callback.period # Formatting checkpoints checkpoint_format = '{task:03d}-{' + f'{pbt_monitor}:.{monitor_precision}f' + '}' checkpoint_filepath = os.path.join(pl_trainer.logger.log_dir, checkpoint_format) # For TaskSaving print(logger_info) checkpoint_dirpath = pl_trainer.logger.log_dir pl_trainer.checkpoint_callback = TaskSaving( filepath=checkpoint_filepath, monitor=pbt_monitor, population_tasks=population_tasks, period=1, full_parallel=full_parallel, ) # For EarlyStopping pl_trainer.early_stop_callback = EarlyStopping( global_epoch=global_epoch, max_global_epoch=max_epoch) # For TaskLoading pl_trainer.callbacks = [ TaskLoading(population_tasks=population_tasks, global_epoch=global_epoch, filepath=checkpoint_filepath, monitor=pbt_monitor, tune_hparams=tune_hparams, pbt_period=pbt_period) ] # Alter logger to spec. #if isinstance(pl_trainer.logger, pl.loggers.TensorBoardLogger): pl_trainer.logger = loggers.TensorBoardLogger( save_dir=logger_info['save_dir'], name=logger_info['name'], version=logger_info['version'], task=process_position, ) # Set process_position pl_trainer.process_position = process_position # pl_trainer.logger._version = f'worker_{process_position}' # Define and set = to self.trainer = pl_trainer self.model = model self.global_epoch = global_epoch self.population_tasks = population_tasks self.max_epoch = max_epoch self.dataloaders = dataloaders or {} print(dataloaders)