def generic_train( model: BaseTransformer, args: argparse.Namespace, early_stopping_callback=None, logger=True, # can pass WandbLogger() here extra_callbacks=[], checkpoint_callback=None, logging_callback=None, **extra_train_kwargs ): pl.seed_everything(args.seed) # init model odir = Path(model.hparams.output_dir) odir.mkdir(exist_ok=True) # add custom checkpoints if checkpoint_callback is None: checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 ) if early_stopping_callback: extra_callbacks.append(early_stopping_callback) if logging_callback is None: logging_callback = LoggingCallback() train_params = {} # TODO: remove with PyTorch 1.6 since pl uses native amp if args.fp16: train_params["precision"] = 16 train_params["amp_level"] = args.fp16_opt_level if args.gpus > 1: train_params["accelerator"] = "ddp" train_params["accumulate_grad_batches"] = args.accumulate_grad_batches # train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) trainer = pl.Trainer.from_argparse_args( args, weights_summary=None, callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback], logger=logger, plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version val_check_interval=1, num_sanity_val_steps=2, **train_params, ) if args.do_train: trainer.fit(model) # else: # print("RAG modeling tests with new set functions successfuly executed!") return trainer
def test_v1_6_0_ddp_sync_batchnorm(): with pytest.deprecated_call( match= "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"): DDPPlugin(sync_batchnorm=False)
def test_v1_6_0_ddp_num_nodes(): with pytest.deprecated_call( match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): DDPPlugin(num_nodes=1)
def test_v1_6_0_ddp_plugin_task_idx(): plugin = DDPPlugin() with pytest.deprecated_call(match='Use `DDPPlugin.local_rank` instead'): _ = plugin.task_idx
def train_default_zoobot_from_scratch( # absolutely crucial arguments save_dir, # save model here schema, # answer these questions # input data - specify *either* catalog (to be split) or the splits themselves catalog=None, train_catalog=None, val_catalog=None, test_catalog=None, # model training parameters model_architecture='efficientnet', batch_size=256, epochs=1000, patience=8, # data and augmentation parameters # datamodule_class=GalaxyDataModule, # generic catalog of galaxies, will not download itself. Can replace with any datamodules from pytorch_galaxy_datasets color=False, resize_size=224, crop_scale_bounds=(0.7, 0.8), crop_ratio_bounds=(0.9, 1.1), # hardware parameters accelerator='auto', nodes=1, gpus=2, num_workers=4, prefetch_factor=4, mixed_precision=False, # replication parameters random_state=42, wandb_logger=None): slurm_debugging_logs() pl.seed_everything(random_state) assert save_dir is not None if not os.path.isdir(save_dir): os.mkdir(save_dir) if color: logging.warning( 'Training on color images, not converting to greyscale') channels = 3 else: logging.info('Converting images to greyscale before training') channels = 1 strategy = None if (gpus is not None) and (gpus > 1): # only works as plugins, not strategy # strategy = 'ddp' strategy = DDPPlugin(find_unused_parameters=False) logging.info('Using multi-gpu training') if nodes > 1: assert gpus == 2 logging.info('Using multi-node training') # this hangs silently on Manchester's slurm cluster - perhaps you will have more success? precision = 32 if mixed_precision: logging.info( 'Training with automatic mixed precision. Will reduce memory footprint but may cause training instability for e.g. resnet' ) precision = 16 assert num_workers > 0 if (gpus is not None) and (num_workers * gpus > os.cpu_count()): logging.warning("""num_workers * gpu > num cpu. You may be spawning more dataloader workers than you have cpus, causing bottlenecks. Suggest reducing num_workers.""") if num_workers > os.cpu_count(): logging.warning("""num_workers > num cpu. You may be spawning more dataloader workers than you have cpus, causing bottlenecks. Suggest reducing num_workers.""") if catalog is not None: assert train_catalog is None assert val_catalog is None assert test_catalog is None catalogs_to_use = {'catalog': catalog} else: assert catalog is None catalogs_to_use = { 'train_catalog': train_catalog, 'val_catalog': val_catalog, 'test_catalog': test_catalog } datamodule = GalaxyDataModule( label_cols=schema.label_cols, # can take either a catalog (and split it), or a pre-split catalog **catalogs_to_use, # augmentations parameters album=False, greyscale=not color, resize_size=resize_size, crop_scale_bounds=crop_scale_bounds, crop_ratio_bounds=crop_ratio_bounds, # hardware parameters batch_size= batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch) num_workers=num_workers, prefetch_factor=prefetch_factor) datamodule.setup() get_architecture, representation_dim = select_base_architecture_func_from_name( model_architecture) model = define_model.get_plain_pytorch_zoobot_model( output_dim=len(schema.answers), include_top=True, channels=channels, get_architecture=get_architecture, representation_dim=representation_dim) # This just adds schema.question_index_groups as an arg to the usual (labels, preds) loss arg format # Would use lambda but multi-gpu doesn't support as lambda can't be pickled def loss_func(preds, labels): # pytorch convention is preds, labels return losses.calculate_multiquestion_loss( labels, preds, schema.question_index_groups ) # my and sklearn convention is labels, preds lightning_model = define_model.GenericLightningModule(model, loss_func) callbacks = [ ModelCheckpoint(dirpath=os.path.join(save_dir, 'checkpoints'), monitor="val_loss", save_weights_only=True, mode='min', save_top_k=3), EarlyStopping(monitor='val_loss', patience=patience, check_finite=True) ] trainer = pl.Trainer( log_every_n_steps=3, accelerator=accelerator, gpus=gpus, # per node num_nodes=nodes, strategy=strategy, precision=precision, logger=wandb_logger, callbacks=callbacks, max_epochs=epochs, default_root_dir=save_dir) logging.info((trainer.training_type_plugin, trainer.world_size, trainer.local_rank, trainer.global_rank, trainer.node_rank)) trainer.fit(lightning_model, datamodule) trainer.test( model=lightning_model, datamodule=datamodule, ckpt_path= 'best' # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" )